577 lines
23 KiB
Python
577 lines
23 KiB
Python
"""
|
|
oracle/prompt_orchestrator.py
|
|
Accepts a user prompt, assembles context, calls the Nemoclaw model runtime
|
|
(or uses a deterministic fallback), validates the generated plan via policy,
|
|
triggers the data access gateway, and produces a PromptExecution.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
import uuid
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from typing import Any
|
|
|
|
from .policy_service import PolicyContext, PolicyService
|
|
from .canvas_service import canvas_service
|
|
from .data_access_gateway import data_access_gateway
|
|
|
|
try:
|
|
import asyncpg # type: ignore
|
|
except Exception: # pragma: no cover
|
|
asyncpg = None # type: ignore
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_NEMOCLAW_URL = os.getenv("NEMOCLAW_API_URL", "")
|
|
_NEMOCLAW_API_KEY = os.getenv("NEMOCLAW_API_KEY", "")
|
|
_DB_URL = os.getenv("DATABASE_URL", "")
|
|
|
|
policy_svc = PolicyService()
|
|
|
|
|
|
def _now() -> str:
|
|
return datetime.now(timezone.utc).isoformat()
|
|
|
|
|
|
# ── Execution store ───────────────────────────────────────────────────────────
|
|
|
|
_DEMO_EXECUTIONS: dict[str, dict[str, Any]] = {}
|
|
|
|
|
|
def _db_ready() -> bool:
|
|
return bool(_DB_URL and not _DB_URL.startswith("PLACEHOLDER") and asyncpg is not None)
|
|
|
|
|
|
# ── Semantic intent detection (simplified) ────────────────────────────────────
|
|
|
|
_INTENT_KEYWORDS: dict[str, list[str]] = {
|
|
"pipeline_board": ["pipeline", "stage", "kanban", "deal", "funnel"],
|
|
"bar_chart": ["bar", "compare", "source", "channel", "distribution", "ranked", "lead", "whale"],
|
|
"geo_map": ["map", "geographic", "location", "district", "region", "area", "dubai"],
|
|
"table": ["table", "list", "broker", "performance", "leaderboard", "rank", "top"],
|
|
"line_chart": ["trend", "time", "monthly", "weekly", "absorption", "forecast"],
|
|
"kpi_tile": ["kpi", "total", "summary", "attainment", "quota", "how many"],
|
|
"activity_stream": ["timeline", "activity", "history", "follow-up", "queue", "contact"],
|
|
}
|
|
|
|
|
|
def _detect_component_types(prompt: str) -> list[str]:
|
|
lower = prompt.lower()
|
|
types: list[str] = []
|
|
for comp_type, keywords in _INTENT_KEYWORDS.items():
|
|
if any(k in lower for k in keywords):
|
|
types.append(comp_type)
|
|
return types or ["bar_chart"]
|
|
|
|
|
|
def _build_demo_retrieval_plan(
|
|
prompt: str,
|
|
tenant_id: str,
|
|
actor_role: str,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Deterministic plan builder for demo mode.
|
|
Produces a valid retrieval plan that passes policy validation.
|
|
"""
|
|
component_types = _detect_component_types(prompt)
|
|
row_limit = 50 if actor_role in ("senior_broker", "junior_broker") else 200
|
|
|
|
return {
|
|
"planId": str(uuid.uuid4()),
|
|
"components": [
|
|
{
|
|
"suggestedType": ct,
|
|
"dataset": _DATASET_MAP.get(ct, "aggregated_results"),
|
|
"privacyTier": "standard",
|
|
"rowLimit": row_limit,
|
|
"joins": [],
|
|
"queryTemplate": f"SELECT * FROM {_DATASET_MAP.get(ct, 'aggregated_results')} WHERE tenant_id = :tenant_id LIMIT :limit",
|
|
"queryParameters": {"tenant_id": tenant_id, "limit": row_limit},
|
|
}
|
|
for ct in component_types
|
|
],
|
|
"semanticModelVersion": "oracle_semantic_v2026_04_08_01",
|
|
"intentClass": "analytical",
|
|
}
|
|
|
|
|
|
_DATASET_MAP: dict[str, str] = {
|
|
"pipeline_board": "deals",
|
|
"bar_chart": "lead_daily_snapshot",
|
|
"geo_map": "lead_geo_interest_rollup",
|
|
"table": "broker_performance",
|
|
"line_chart": "inventory_absorption",
|
|
"kpi_tile": "oracle_aggregated_metric",
|
|
"activity_stream": "lead_activity_log",
|
|
}
|
|
|
|
|
|
class PromptOrchestrator:
|
|
"""
|
|
Orchestrates the full prompt-to-canvas pipeline:
|
|
1. Intent classification
|
|
2. Retrieval plan construction (Nemoclaw or fallback)
|
|
3. Policy validation
|
|
4. Component plan construction
|
|
5. Execution record persistence
|
|
"""
|
|
|
|
async def execute(
|
|
self,
|
|
*,
|
|
tenant_id: str,
|
|
page_id: str,
|
|
branch_id: str,
|
|
actor_id: str,
|
|
actor_role: str,
|
|
prompt: str,
|
|
conversation_context: list[dict[str, str]] | None = None,
|
|
client_request_id: str,
|
|
placement_mode: str = "append_after_last_visible_component",
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Full orchestration flow. Returns a PromptExecution dict.
|
|
"""
|
|
execution_id = str(uuid.uuid4())
|
|
now = _now()
|
|
warnings: list[str] = []
|
|
|
|
ctx = PolicyContext(
|
|
tenant_id=tenant_id,
|
|
actor_id=actor_id,
|
|
actor_role=actor_role,
|
|
)
|
|
|
|
execution: dict[str, Any] = {
|
|
"executionId": execution_id,
|
|
"tenantId": tenant_id,
|
|
"pageId": page_id,
|
|
"branchId": branch_id,
|
|
"actorId": actor_id,
|
|
"prompt": prompt,
|
|
"intentClass": "analytical",
|
|
"status": "planning",
|
|
"modelRuntime": "nemoclaw_hosted" if _NEMOCLAW_URL else "deterministic_fallback",
|
|
"semanticModelVersion": "oracle_semantic_v2026_04_08_01",
|
|
"warnings": warnings,
|
|
"componentsCreated": [],
|
|
"clientRequestId": client_request_id,
|
|
"createdAt": now,
|
|
}
|
|
_DEMO_EXECUTIONS[execution_id] = execution
|
|
await self._persist_execution(execution)
|
|
|
|
# ── Step 1: Build retrieval plan ──────────────────────────────────────
|
|
if _NEMOCLAW_URL and _NEMOCLAW_API_KEY:
|
|
try:
|
|
retrieval_plan = await self._call_nemoclaw(prompt, conversation_context or [], ctx)
|
|
execution["status"] = "validated"
|
|
except Exception as exc:
|
|
logger.warning("ORCH Nemoclaw call failed, using fallback: %s", exc)
|
|
warnings.append(f"Model runtime unavailable ({exc}); using deterministic fallback.")
|
|
retrieval_plan = _build_demo_retrieval_plan(prompt, tenant_id, actor_role)
|
|
else:
|
|
retrieval_plan = _build_demo_retrieval_plan(prompt, tenant_id, actor_role)
|
|
|
|
execution["retrievalPlan"] = retrieval_plan
|
|
|
|
# ── Step 2: Policy validation ─────────────────────────────────────────
|
|
policy_errors = []
|
|
for component_plan in retrieval_plan.get("components", []):
|
|
result = policy_svc.validate_retrieval_plan(component_plan, ctx)
|
|
if not result.passed:
|
|
policy_errors.extend(result.errors)
|
|
if result.warnings:
|
|
warnings.extend(result.warnings)
|
|
|
|
if policy_errors:
|
|
execution["status"] = "failed"
|
|
execution["warnings"] = warnings + policy_errors
|
|
execution["completedAt"] = _now()
|
|
logger.warning(
|
|
"ORCH policy_denial execution_id=%s actor=%s errors=%s",
|
|
execution_id, actor_id, policy_errors,
|
|
)
|
|
return execution
|
|
|
|
execution["status"] = "executing"
|
|
await self._persist_execution(execution)
|
|
|
|
# ── Step 3: Build visualization plan (component descriptors) ──────────
|
|
viz_plan = await self._build_visualization_plan(
|
|
retrieval_plan=retrieval_plan,
|
|
prompt=prompt,
|
|
execution_id=execution_id,
|
|
actor_id=actor_id,
|
|
tenant_id=tenant_id,
|
|
branch_id=branch_id,
|
|
placement_mode=placement_mode,
|
|
ctx=ctx,
|
|
)
|
|
execution["visualizationPlan"] = viz_plan
|
|
|
|
# ── Step 4: Commit revision ───────────────────────────────────────────
|
|
component_ids = [c["componentId"] for c in viz_plan.get("components", [])]
|
|
execution["componentsCreated"] = component_ids
|
|
|
|
# Commit a revision bump with the new components
|
|
try:
|
|
page = await canvas_service.get_page(page_id, tenant_id)
|
|
if page:
|
|
existing_comps = page.get("components", [])
|
|
new_comps = existing_comps + viz_plan.get("components", [])
|
|
revision = await canvas_service.commit_revision(
|
|
page_id=page_id,
|
|
tenant_id=tenant_id,
|
|
actor_id=actor_id,
|
|
commit_kind="prompt",
|
|
commit_summary=f"Oracle: {prompt[:80]}",
|
|
components=new_comps,
|
|
execution_id=execution_id,
|
|
idempotency_key=client_request_id,
|
|
)
|
|
execution["headRevision"] = revision["revisionNumber"]
|
|
except Exception as exc:
|
|
logger.warning("ORCH revision_commit failed (non-fatal): %s", exc)
|
|
warnings.append("Revision commit deferred — will retry on next sync.")
|
|
|
|
execution["status"] = "completed"
|
|
execution["summary"] = self._generate_summary(prompt, viz_plan)
|
|
execution["completedAt"] = _now()
|
|
execution["warnings"] = warnings
|
|
await self._persist_execution(execution)
|
|
return execution
|
|
|
|
async def _build_visualization_plan(
|
|
self,
|
|
*,
|
|
retrieval_plan: dict[str, Any],
|
|
prompt: str,
|
|
execution_id: str,
|
|
actor_id: str,
|
|
tenant_id: str,
|
|
branch_id: str,
|
|
placement_mode: str,
|
|
ctx: PolicyContext,
|
|
) -> dict[str, Any]:
|
|
"""Converts a retrieval plan into a list of CanvasComponent descriptors."""
|
|
components = []
|
|
base_order = 900 # Append after existing components
|
|
|
|
component_plans = retrieval_plan.get("components", [])
|
|
for i, plan in enumerate(component_plans):
|
|
ctype = plan["suggestedType"]
|
|
dataset = plan["dataset"]
|
|
component_id = str(uuid.uuid4())
|
|
query_result = await data_access_gateway.execute_component_plan(plan, ctx, prompt)
|
|
component_warnings = query_result.warnings
|
|
mapped_type = self._map_type(ctype)
|
|
data_rows = query_result.rows
|
|
|
|
comp: dict[str, Any] = {
|
|
"componentId": component_id,
|
|
"type": mapped_type,
|
|
"title": self._generate_title(prompt, ctype),
|
|
"description": f"Generated from: \"{prompt[:80]}\"",
|
|
"dataSourceDescriptor": {
|
|
"descriptorId": str(uuid.uuid4()),
|
|
"sourceType": "postgres",
|
|
"connectorId": "velocity-core-postgres",
|
|
"dataset": dataset,
|
|
"authContextRef": f"authctx_{actor_id}_scope",
|
|
"queryTemplate": plan.get("queryTemplate", f"SELECT * FROM {dataset} WHERE tenant_id = :tenant_id"),
|
|
"queryParameters": plan.get("queryParameters", {"tenant_id": tenant_id}),
|
|
"rowLimit": plan.get("rowLimit", 50),
|
|
"privacyTier": plan.get("privacyTier", "standard"),
|
|
"cachePolicy": {"mode": "ttl", "ttlSeconds": 120},
|
|
},
|
|
"visualizationParameters": self._default_viz_params(ctype, data_rows),
|
|
"dataBindings": self._default_bindings(ctype),
|
|
"version": 1,
|
|
"lifecycleState": "active",
|
|
"provenance": {
|
|
"originType": "prompt_generated",
|
|
"promptExecutionId": execution_id,
|
|
"sourceBranchId": branch_id,
|
|
"createdBy": actor_id,
|
|
"createdAt": _now(),
|
|
},
|
|
"renderingHints": self._rendering_hints(ctype),
|
|
"layout": {
|
|
"orderIndex": base_order + (i + 1) * 100,
|
|
"sectionId": "sec_prompt_generated",
|
|
"widthMode": "full" if ctype in ("pipeline_board", "table", "geo_map") else "half",
|
|
"minHeightPx": 300,
|
|
"stickyHeader": False,
|
|
},
|
|
"accessControls": {
|
|
"visibilityScope": "private",
|
|
"allowedRoles": ["senior_broker", "sales_director", "marketing_operator", "data_steward", "compliance_reviewer", "platform_admin"],
|
|
"redactionPolicy": "none",
|
|
},
|
|
"styleSignature": {
|
|
"theme": "velocity_glass",
|
|
"paletteToken": "ocean_signal",
|
|
"motionProfile": "calm_reveal",
|
|
"density": "comfortable",
|
|
"radiusScale": "lg",
|
|
"typographyScale": "balanced",
|
|
},
|
|
"validationState": {
|
|
"schema": "pass",
|
|
"policy": "pass",
|
|
"a11y": "pass",
|
|
"performance": "pass",
|
|
"status": "validated",
|
|
},
|
|
"auditLog": [f"aud_{execution_id}_create"],
|
|
"dataRows": data_rows,
|
|
}
|
|
if component_warnings and not data_rows:
|
|
comp = self._error_component(
|
|
component_id=component_id,
|
|
execution_id=execution_id,
|
|
actor_id=actor_id,
|
|
branch_id=branch_id,
|
|
dataset=dataset,
|
|
warnings=component_warnings,
|
|
order_index=base_order + (i + 1) * 100,
|
|
)
|
|
components.append(comp)
|
|
|
|
return {"components": components}
|
|
|
|
@staticmethod
|
|
def _map_type(plan_type: str) -> str:
|
|
mapping = {
|
|
"pipeline_board": "pipelineBoard",
|
|
"bar_chart": "barChart",
|
|
"geo_map": "geoMap",
|
|
"table": "table",
|
|
"line_chart": "lineChart",
|
|
"kpi_tile": "kpiTile",
|
|
"activity_stream": "activityStream",
|
|
}
|
|
return mapping.get(plan_type, "barChart")
|
|
|
|
@staticmethod
|
|
def _generate_title(prompt: str, comp_type: str) -> str:
|
|
labels = {
|
|
"pipeline_board": "Pipeline View",
|
|
"bar_chart": "Comparative Analysis",
|
|
"geo_map": "Geographic Distribution",
|
|
"table": "Performance Table",
|
|
"line_chart": "Trend Analysis",
|
|
"kpi_tile": "Key Metric",
|
|
"activity_stream": "Activity Stream",
|
|
}
|
|
return labels.get(comp_type, "Oracle Canvas Component")
|
|
|
|
@staticmethod
|
|
def _default_viz_params(comp_type: str, rows: list[dict[str, Any]]) -> dict[str, Any]:
|
|
defaults: dict[str, dict[str, Any]] = {
|
|
"bar_chart": {"xAxis": "category", "yAxis": "value", "sort": "desc", "showLabels": True, "legend": False},
|
|
"line_chart": {"showPoints": True, "smooth": True},
|
|
"kpi_tile": {
|
|
"label": rows[0].get("metric_label", "Result") if rows else "Result",
|
|
"trend": str(rows[0].get("trend_value", "")) if rows else "",
|
|
"comparisonLabel": rows[0].get("comparison_label", "") if rows else "",
|
|
},
|
|
"geo_map": {"mapStyle": "dubai_district_heat", "intensityField": "lead_count", "interactive": True, "tooltipFields": ["district", "lead_count", "avg_qd_score"]},
|
|
"table": {"rankBy": "revenue_generated", "showTopBadge": True, "columns": ["name", "deals_closed", "revenue_generated"]},
|
|
"pipeline_board": {"showValue": True, "colorByStage": True},
|
|
"activity_stream": {"showUrgencyIndicator": True},
|
|
}
|
|
return defaults.get(comp_type, {})
|
|
|
|
@staticmethod
|
|
def _default_bindings(comp_type: str) -> dict[str, Any]:
|
|
return {"dimensions": [], "measures": [], "series": [], "filters": []}
|
|
|
|
@staticmethod
|
|
def _rendering_hints(comp_type: str) -> dict[str, Any]:
|
|
priority_map = {
|
|
"pipeline_board": ("pipeline", 9), "bar_chart": ("chart", 8),
|
|
"geo_map": ("map", 9), "table": ("table", 7),
|
|
"line_chart": ("chart", 8), "kpi_tile": ("kpi", 6),
|
|
"activity_stream": ("table", 8),
|
|
}
|
|
skeleton, priority = priority_map.get(comp_type, ("chart", 7))
|
|
height_map = {
|
|
"pipeline_board": 400, "bar_chart": 320, "geo_map": 420,
|
|
"table": 300, "line_chart": 320, "kpi_tile": 140, "activity_stream": 360,
|
|
}
|
|
return {
|
|
"estimatedHeightPx": height_map.get(comp_type, 300),
|
|
"skeletonVariant": skeleton,
|
|
"virtualizationPriority": priority,
|
|
}
|
|
|
|
@staticmethod
|
|
def _generate_summary(prompt: str, viz_plan: dict[str, Any]) -> str:
|
|
count = len(viz_plan.get("components", []))
|
|
short_prompt = prompt[:60] + ("…" if len(prompt) > 60 else "")
|
|
return f'Generated {count} component{"s" if count != 1 else ""} for: "{short_prompt}"'
|
|
|
|
@staticmethod
|
|
def _error_component(
|
|
*,
|
|
component_id: str,
|
|
execution_id: str,
|
|
actor_id: str,
|
|
branch_id: str,
|
|
dataset: str,
|
|
warnings: list[str],
|
|
order_index: int,
|
|
) -> dict[str, Any]:
|
|
return {
|
|
"componentId": component_id,
|
|
"type": "errorNotice",
|
|
"title": f"{dataset} unavailable",
|
|
"description": "Oracle could not render live data for this component.",
|
|
"dataSourceDescriptor": {
|
|
"descriptorId": str(uuid.uuid4()),
|
|
"sourceType": "postgres",
|
|
"connectorId": "velocity-core-postgres",
|
|
"dataset": dataset,
|
|
"authContextRef": f"authctx_{actor_id}_scope",
|
|
"queryTemplate": "",
|
|
"queryParameters": {},
|
|
"rowLimit": 0,
|
|
"privacyTier": "standard",
|
|
},
|
|
"visualizationParameters": {
|
|
"errorCode": "oracle_live_query_failed",
|
|
"message": " | ".join(warnings[:2]),
|
|
"severity": "warning",
|
|
"retryable": True,
|
|
},
|
|
"dataBindings": {"dimensions": [], "measures": [], "series": [], "filters": []},
|
|
"version": 1,
|
|
"lifecycleState": "active",
|
|
"provenance": {
|
|
"originType": "prompt_generated",
|
|
"promptExecutionId": execution_id,
|
|
"sourceBranchId": branch_id,
|
|
"createdBy": actor_id,
|
|
"createdAt": _now(),
|
|
},
|
|
"renderingHints": {"estimatedHeightPx": 140, "skeletonVariant": "generic", "virtualizationPriority": 5},
|
|
"layout": {
|
|
"orderIndex": order_index,
|
|
"sectionId": "sec_prompt_generated",
|
|
"widthMode": "full",
|
|
"minHeightPx": 140,
|
|
"stickyHeader": False,
|
|
},
|
|
"accessControls": {
|
|
"visibilityScope": "private",
|
|
"allowedRoles": ["senior_broker", "sales_director", "marketing_operator", "data_steward", "compliance_reviewer", "platform_admin"],
|
|
"redactionPolicy": "none",
|
|
},
|
|
"styleSignature": {
|
|
"theme": "velocity_glass",
|
|
"paletteToken": "ocean_signal",
|
|
"motionProfile": "calm_reveal",
|
|
"density": "comfortable",
|
|
"radiusScale": "lg",
|
|
"typographyScale": "balanced",
|
|
},
|
|
"validationState": {
|
|
"schema": "pass",
|
|
"policy": "pass",
|
|
"a11y": "pass",
|
|
"performance": "pass",
|
|
"status": "validated",
|
|
},
|
|
"auditLog": [f"aud_{execution_id}_error"],
|
|
"dataRows": [],
|
|
}
|
|
|
|
async def _call_nemoclaw(
|
|
self,
|
|
prompt: str,
|
|
context: list[dict[str, str]],
|
|
ctx: PolicyContext,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Calls the Nemoclaw hosted model endpoint.
|
|
Raises on failure so the orchestrator can fall back to demo.
|
|
"""
|
|
import httpx # type: ignore
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
resp = await client.post(
|
|
f"{_NEMOCLAW_URL}/v1/oracle/plan",
|
|
headers={"Authorization": f"Bearer {_NEMOCLAW_API_KEY}"},
|
|
json={
|
|
"prompt": prompt,
|
|
"conversationContext": context,
|
|
"tenantId": ctx.tenant_id,
|
|
"actorRole": ctx.actor_role,
|
|
"semanticModelVersion": "oracle_semantic_v2026_04_08_01",
|
|
},
|
|
)
|
|
resp.raise_for_status()
|
|
return resp.json() # type: ignore[no-any-return]
|
|
|
|
async def get_execution(self, execution_id: str) -> dict[str, Any] | None:
|
|
return _DEMO_EXECUTIONS.get(execution_id)
|
|
|
|
async def _persist_execution(self, execution: dict[str, Any]) -> None:
|
|
_DEMO_EXECUTIONS[execution["executionId"]] = execution
|
|
if not _db_ready():
|
|
return
|
|
assert asyncpg is not None
|
|
conn = await asyncpg.connect(_DB_URL)
|
|
try:
|
|
await conn.execute(
|
|
"""
|
|
INSERT INTO oracle_prompt_executions (
|
|
execution_id, tenant_id, page_id, branch_id, actor_id, prompt, intent_class,
|
|
status, model_runtime, semantic_model_version, retrieval_plan, visualization_plan,
|
|
warnings, summary, components_created, client_request_id, created_at, completed_at
|
|
)
|
|
VALUES (
|
|
$1::uuid, $2, $3::uuid, $4, $5, $6, $7,
|
|
$8, $9, $10, $11::jsonb, $12::jsonb,
|
|
$13::text[], $14, $15::text[], $16, $17::timestamptz, $18::timestamptz
|
|
)
|
|
ON CONFLICT (execution_id)
|
|
DO UPDATE SET
|
|
status = EXCLUDED.status,
|
|
retrieval_plan = EXCLUDED.retrieval_plan,
|
|
visualization_plan = EXCLUDED.visualization_plan,
|
|
warnings = EXCLUDED.warnings,
|
|
summary = EXCLUDED.summary,
|
|
components_created = EXCLUDED.components_created,
|
|
completed_at = EXCLUDED.completed_at
|
|
""",
|
|
execution["executionId"],
|
|
execution["tenantId"],
|
|
execution["pageId"],
|
|
execution["branchId"],
|
|
execution["actorId"],
|
|
execution["prompt"],
|
|
execution["intentClass"],
|
|
execution["status"],
|
|
execution["modelRuntime"],
|
|
execution["semanticModelVersion"],
|
|
json.dumps(execution.get("retrievalPlan") or {}),
|
|
json.dumps(execution.get("visualizationPlan") or {}),
|
|
execution.get("warnings", []),
|
|
execution.get("summary"),
|
|
execution.get("componentsCreated", []),
|
|
execution.get("clientRequestId"),
|
|
execution["createdAt"],
|
|
execution.get("completedAt"),
|
|
)
|
|
finally:
|
|
await conn.close()
|
|
|
|
|
|
# ── Singleton ─────────────────────────────────────────────────────────────────
|
|
|
|
prompt_orchestrator = PromptOrchestrator()
|