""" oracle/prompt_orchestrator.py Accepts a user prompt, assembles context, calls the Nemoclaw model runtime (or uses a deterministic fallback), validates the generated plan via policy, triggers the data access gateway, and produces a PromptExecution. """ from __future__ import annotations import logging import os import uuid import json from datetime import datetime, timezone from typing import Any from .policy_service import PolicyContext, PolicyService from .canvas_service import canvas_service from .data_access_gateway import data_access_gateway from .persona_service import persona_service from backend.services.nemoclaw_runtime import nemoclaw_runtime try: import asyncpg # type: ignore except Exception: # pragma: no cover asyncpg = None # type: ignore logger = logging.getLogger(__name__) _NEMOCLAW_URL = os.getenv("NEMOCLAW_API_URL", "") _NEMOCLAW_API_KEY = os.getenv("NEMOCLAW_API_KEY", "") _DB_URL = os.getenv("DATABASE_URL", "") policy_svc = PolicyService() def _now() -> str: return datetime.now(timezone.utc).isoformat() # ── Execution store ─────────────────────────────────────────────────────────── _DEMO_EXECUTIONS: dict[str, dict[str, Any]] = {} def _db_ready() -> bool: return bool(_DB_URL and not _DB_URL.startswith("PLACEHOLDER") and asyncpg is not None) # ── Semantic intent detection (simplified) ──────────────────────────────────── _INTENT_KEYWORDS: dict[str, list[str]] = { "pipeline_board": ["pipeline", "stage", "kanban", "deal", "funnel"], "bar_chart": ["bar", "compare", "source", "channel", "distribution", "ranked", "lead", "whale"], "geo_map": ["map", "geographic", "location", "district", "region", "area", "dubai"], "table": ["table", "list", "broker", "performance", "leaderboard", "rank", "top"], "line_chart": ["trend", "time", "monthly", "weekly", "absorption", "forecast"], "kpi_tile": ["kpi", "total", "summary", "attainment", "quota", "how many"], "activity_stream": ["timeline", "activity", "history", "follow-up", "queue", "contact"], } def _detect_component_types(prompt: str) -> list[str]: lower = prompt.lower() types: list[str] = [] for comp_type, keywords in _INTENT_KEYWORDS.items(): if any(k in lower for k in keywords): types.append(comp_type) return types or ["bar_chart"] def _build_demo_retrieval_plan( prompt: str, tenant_id: str, actor_role: str, ) -> dict[str, Any]: """ Deterministic plan builder for demo mode. Produces a valid retrieval plan that passes policy validation. """ component_types = _detect_component_types(prompt) row_limit = 50 if actor_role in ("senior_broker", "junior_broker") else 200 return { "planId": str(uuid.uuid4()), "components": [ { "suggestedType": ct, "dataset": _DATASET_MAP.get(ct, "aggregated_results"), "privacyTier": "standard", "rowLimit": row_limit, "joins": [], "queryTemplate": f"SELECT * FROM {_DATASET_MAP.get(ct, 'aggregated_results')} WHERE tenant_id = :tenant_id LIMIT :limit", "queryParameters": {"tenant_id": tenant_id, "limit": row_limit}, } for ct in component_types ], "semanticModelVersion": "oracle_semantic_v2026_04_08_01", "intentClass": "analytical", } _DATASET_MAP: dict[str, str] = { "pipeline_board": "deals", "bar_chart": "lead_daily_snapshot", "geo_map": "lead_geo_interest_rollup", "table": "broker_performance", "line_chart": "inventory_absorption", "kpi_tile": "oracle_aggregated_metric", "activity_stream": "lead_activity_log", } class PromptOrchestrator: """ Orchestrates the full prompt-to-canvas pipeline: 1. Intent classification 2. Retrieval plan construction (Nemoclaw or fallback) 3. Policy validation 4. Component plan construction 5. Execution record persistence """ async def execute( self, *, tenant_id: str, page_id: str, branch_id: str, actor_id: str, actor_role: str, prompt: str, conversation_context: list[dict[str, str]] | None = None, client_request_id: str, placement_mode: str = "append_after_last_visible_component", ) -> dict[str, Any]: """ Full orchestration flow. Returns a PromptExecution dict. """ execution_id = str(uuid.uuid4()) now = _now() warnings: list[str] = [] ctx = PolicyContext( tenant_id=tenant_id, actor_id=actor_id, actor_role=actor_role, ) execution: dict[str, Any] = { "executionId": execution_id, "tenantId": tenant_id, "pageId": page_id, "branchId": branch_id, "actorId": actor_id, "prompt": prompt, "intentClass": "analytical", "status": "planning", "modelRuntime": "nemoclaw_hosted" if _NEMOCLAW_URL else "deterministic_fallback", "semanticModelVersion": "oracle_semantic_v2026_04_08_01", "warnings": warnings, "componentsCreated": [], "clientRequestId": client_request_id, "createdAt": now, } _DEMO_EXECUTIONS[execution_id] = execution await self._persist_execution(execution) # ── Step 1: Build retrieval plan ────────────────────────────────────── if _NEMOCLAW_URL and _NEMOCLAW_API_KEY: try: retrieval_plan = await self._call_nemoclaw(prompt, conversation_context or [], ctx) execution["status"] = "validated" except Exception as exc: logger.warning("ORCH Nemoclaw call failed, using fallback: %s", exc) warnings.append(f"Model runtime unavailable ({exc}); using deterministic fallback.") retrieval_plan = _build_demo_retrieval_plan(prompt, tenant_id, actor_role) else: retrieval_plan = _build_demo_retrieval_plan(prompt, tenant_id, actor_role) execution["retrievalPlan"] = retrieval_plan persona_plan = await persona_service.plan_for_prompt( prompt=prompt, tenant_id=tenant_id, actor_role=actor_role, ) execution["personaPlan"] = persona_plan execution["workflowDispatch"] = nemoclaw_runtime.build_workflow_dispatch( prompt=prompt, tenant_id=tenant_id, actor_role=actor_role, component_templates=persona_plan["recommendedTemplates"], ) # ── Step 2: Policy validation ───────────────────────────────────────── policy_errors = [] for component_plan in retrieval_plan.get("components", []): result = policy_svc.validate_retrieval_plan(component_plan, ctx) if not result.passed: policy_errors.extend(result.errors) if result.warnings: warnings.extend(result.warnings) if policy_errors: execution["status"] = "failed" execution["warnings"] = warnings + policy_errors execution["completedAt"] = _now() logger.warning( "ORCH policy_denial execution_id=%s actor=%s errors=%s", execution_id, actor_id, policy_errors, ) return execution execution["status"] = "executing" await self._persist_execution(execution) # ── Step 3: Build visualization plan (component descriptors) ────────── viz_plan = await self._build_visualization_plan( retrieval_plan=retrieval_plan, prompt=prompt, execution_id=execution_id, actor_id=actor_id, tenant_id=tenant_id, branch_id=branch_id, placement_mode=placement_mode, ctx=ctx, persona_plan=persona_plan, ) execution["visualizationPlan"] = viz_plan # ── Step 4: Commit revision ─────────────────────────────────────────── component_ids = [c["componentId"] for c in viz_plan.get("components", [])] execution["componentsCreated"] = component_ids # Commit a revision bump with the new components try: page = await canvas_service.get_page(page_id, tenant_id) if page: existing_comps = page.get("components", []) new_comps = existing_comps + viz_plan.get("components", []) revision = await canvas_service.commit_revision( page_id=page_id, tenant_id=tenant_id, actor_id=actor_id, commit_kind="prompt", commit_summary=f"Oracle: {prompt[:80]}", components=new_comps, execution_id=execution_id, idempotency_key=client_request_id, ) execution["headRevision"] = revision["revisionNumber"] except Exception as exc: logger.warning("ORCH revision_commit failed (non-fatal): %s", exc) warnings.append("Revision commit deferred — will retry on next sync.") execution["status"] = "completed" execution["summary"] = self._generate_summary(prompt, viz_plan) execution["completedAt"] = _now() execution["warnings"] = warnings await self._persist_execution(execution) return execution async def _build_visualization_plan( self, *, retrieval_plan: dict[str, Any], prompt: str, execution_id: str, actor_id: str, tenant_id: str, branch_id: str, placement_mode: str, ctx: PolicyContext, persona_plan: dict[str, Any], ) -> dict[str, Any]: """Converts a retrieval plan into a list of CanvasComponent descriptors.""" components = [ self._persona_text_canvas( execution_id=execution_id, actor_id=actor_id, branch_id=branch_id, prompt=prompt, persona_plan=persona_plan, ) ] base_order = 900 # Append after existing components component_plans = retrieval_plan.get("components", []) for i, plan in enumerate(component_plans): ctype = plan["suggestedType"] dataset = plan["dataset"] component_id = str(uuid.uuid4()) query_result = await data_access_gateway.execute_component_plan(plan, ctx, prompt) component_warnings = query_result.warnings mapped_type = self._map_type(ctype) data_rows = query_result.rows comp: dict[str, Any] = { "componentId": component_id, "type": mapped_type, "title": self._generate_title(prompt, ctype), "description": f"Generated from: \"{prompt[:80]}\"", "dataSourceDescriptor": { "descriptorId": str(uuid.uuid4()), "sourceType": "postgres", "connectorId": "velocity-core-postgres", "dataset": dataset, "authContextRef": f"authctx_{actor_id}_scope", "queryTemplate": plan.get("queryTemplate", f"SELECT * FROM {dataset} WHERE tenant_id = :tenant_id"), "queryParameters": plan.get("queryParameters", {"tenant_id": tenant_id}), "rowLimit": plan.get("rowLimit", 50), "privacyTier": plan.get("privacyTier", "standard"), "cachePolicy": {"mode": "ttl", "ttlSeconds": 120}, }, "visualizationParameters": self._default_viz_params(ctype, data_rows), "dataBindings": self._default_bindings(ctype), "version": 1, "lifecycleState": "active", "provenance": { "originType": "prompt_generated", "promptExecutionId": execution_id, "sourceBranchId": branch_id, "createdBy": actor_id, "createdAt": _now(), }, "renderingHints": self._rendering_hints(ctype), "layout": { "orderIndex": base_order + (i + 1) * 100, "sectionId": "sec_prompt_generated", "widthMode": "full" if ctype in ("pipeline_board", "table", "geo_map") else "half", "minHeightPx": 300, "stickyHeader": False, }, "accessControls": { "visibilityScope": "private", "allowedRoles": ["senior_broker", "sales_director", "marketing_operator", "data_steward", "compliance_reviewer", "platform_admin"], "redactionPolicy": "none", }, "styleSignature": { "theme": "velocity_glass", "paletteToken": "ocean_signal", "motionProfile": "calm_reveal", "density": "comfortable", "radiusScale": "lg", "typographyScale": "balanced", }, "validationState": { "schema": "pass", "policy": "pass", "a11y": "pass", "performance": "pass", "status": "validated", }, "auditLog": [f"aud_{execution_id}_create"], "dataRows": data_rows, } if component_warnings and not data_rows: comp = self._error_component( component_id=component_id, execution_id=execution_id, actor_id=actor_id, branch_id=branch_id, dataset=dataset, warnings=component_warnings, order_index=base_order + (i + 1) * 100, ) components.append(comp) return {"components": components} @staticmethod def _persona_text_canvas( *, execution_id: str, actor_id: str, branch_id: str, prompt: str, persona_plan: dict[str, Any], ) -> dict[str, Any]: recommended = ", ".join(persona_plan.get("recommendedTemplates", [])) or "no direct template matches" content = ( f"Oracle received: {prompt}\n\n" f"Reusable templates: {recommended}\n\n" "Execution policy: query live CRM data first, reuse matching templates, " "synthesize missing UI blocks, then dispatch the required ComfyUI-backed workflow." ) return { "componentId": str(uuid.uuid4()), "type": "textCanvas", "title": "Oracle Planning Notes", "description": "Persona-driven guidance generated before data-bound components.", "dataSourceDescriptor": { "descriptorId": str(uuid.uuid4()), "sourceType": "inline", "connectorId": "oracle-persona", "dataset": "oracle_persona_plan", "authContextRef": f"authctx_{actor_id}_scope", "queryTemplate": "", "queryParameters": {}, "rowLimit": 1, "privacyTier": "standard", }, "visualizationParameters": { "content": content, "widthMode": "full", "adjustableHeight": True, }, "dataBindings": {"dimensions": [], "measures": [], "series": [], "filters": []}, "version": 1, "lifecycleState": "active", "provenance": { "originType": "prompt_generated", "promptExecutionId": execution_id, "sourceBranchId": branch_id, "createdBy": actor_id, "createdAt": _now(), }, "renderingHints": {"estimatedHeightPx": 180, "skeletonVariant": "text", "virtualizationPriority": 4}, "layout": { "orderIndex": 910, "sectionId": "sec_prompt_generated", "widthMode": "full", "minHeightPx": 180, "stickyHeader": False, }, "accessControls": { "visibilityScope": "private", "allowedRoles": ["senior_broker", "sales_director", "marketing_operator", "data_steward", "compliance_reviewer", "platform_admin"], "redactionPolicy": "none", }, "styleSignature": { "theme": "velocity_glass", "paletteToken": "ocean_signal", "motionProfile": "calm_reveal", "density": "comfortable", "radiusScale": "lg", "typographyScale": "balanced", }, "validationState": { "schema": "pass", "policy": "pass", "a11y": "pass", "performance": "pass", "status": "validated", }, "auditLog": [f"aud_{execution_id}_persona"], "dataRows": [], } @staticmethod def _map_type(plan_type: str) -> str: mapping = { "pipeline_board": "pipelineBoard", "bar_chart": "barChart", "geo_map": "geoMap", "table": "table", "line_chart": "lineChart", "kpi_tile": "kpiTile", "activity_stream": "activityStream", } return mapping.get(plan_type, "barChart") @staticmethod def _generate_title(prompt: str, comp_type: str) -> str: labels = { "pipeline_board": "Pipeline View", "bar_chart": "Comparative Analysis", "geo_map": "Geographic Distribution", "table": "Performance Table", "line_chart": "Trend Analysis", "kpi_tile": "Key Metric", "activity_stream": "Activity Stream", } return labels.get(comp_type, "Oracle Canvas Component") @staticmethod def _default_viz_params(comp_type: str, rows: list[dict[str, Any]]) -> dict[str, Any]: defaults: dict[str, dict[str, Any]] = { "bar_chart": {"xAxis": "category", "yAxis": "value", "sort": "desc", "showLabels": True, "legend": False}, "line_chart": {"showPoints": True, "smooth": True}, "kpi_tile": { "label": rows[0].get("metric_label", "Result") if rows else "Result", "trend": str(rows[0].get("trend_value", "")) if rows else "", "comparisonLabel": rows[0].get("comparison_label", "") if rows else "", }, "geo_map": {"mapStyle": "dubai_district_heat", "intensityField": "lead_count", "interactive": True, "tooltipFields": ["district", "lead_count", "avg_qd_score"]}, "table": {"rankBy": "revenue_generated", "showTopBadge": True, "columns": ["name", "deals_closed", "revenue_generated"]}, "pipeline_board": {"showValue": True, "colorByStage": True}, "activity_stream": {"showUrgencyIndicator": True}, } return defaults.get(comp_type, {}) @staticmethod def _default_bindings(comp_type: str) -> dict[str, Any]: return {"dimensions": [], "measures": [], "series": [], "filters": []} @staticmethod def _rendering_hints(comp_type: str) -> dict[str, Any]: priority_map = { "pipeline_board": ("pipeline", 9), "bar_chart": ("chart", 8), "geo_map": ("map", 9), "table": ("table", 7), "line_chart": ("chart", 8), "kpi_tile": ("kpi", 6), "activity_stream": ("table", 8), } skeleton, priority = priority_map.get(comp_type, ("chart", 7)) height_map = { "pipeline_board": 400, "bar_chart": 320, "geo_map": 420, "table": 300, "line_chart": 320, "kpi_tile": 140, "activity_stream": 360, } return { "estimatedHeightPx": height_map.get(comp_type, 300), "skeletonVariant": skeleton, "virtualizationPriority": priority, } @staticmethod def _generate_summary(prompt: str, viz_plan: dict[str, Any]) -> str: count = len(viz_plan.get("components", [])) short_prompt = prompt[:60] + ("…" if len(prompt) > 60 else "") return f'Generated {count} component{"s" if count != 1 else ""} for: "{short_prompt}"' @staticmethod def _error_component( *, component_id: str, execution_id: str, actor_id: str, branch_id: str, dataset: str, warnings: list[str], order_index: int, ) -> dict[str, Any]: return { "componentId": component_id, "type": "errorNotice", "title": f"{dataset} unavailable", "description": "Oracle could not render live data for this component.", "dataSourceDescriptor": { "descriptorId": str(uuid.uuid4()), "sourceType": "postgres", "connectorId": "velocity-core-postgres", "dataset": dataset, "authContextRef": f"authctx_{actor_id}_scope", "queryTemplate": "", "queryParameters": {}, "rowLimit": 0, "privacyTier": "standard", }, "visualizationParameters": { "errorCode": "oracle_live_query_failed", "message": " | ".join(warnings[:2]), "severity": "warning", "retryable": True, }, "dataBindings": {"dimensions": [], "measures": [], "series": [], "filters": []}, "version": 1, "lifecycleState": "active", "provenance": { "originType": "prompt_generated", "promptExecutionId": execution_id, "sourceBranchId": branch_id, "createdBy": actor_id, "createdAt": _now(), }, "renderingHints": {"estimatedHeightPx": 140, "skeletonVariant": "generic", "virtualizationPriority": 5}, "layout": { "orderIndex": order_index, "sectionId": "sec_prompt_generated", "widthMode": "full", "minHeightPx": 140, "stickyHeader": False, }, "accessControls": { "visibilityScope": "private", "allowedRoles": ["senior_broker", "sales_director", "marketing_operator", "data_steward", "compliance_reviewer", "platform_admin"], "redactionPolicy": "none", }, "styleSignature": { "theme": "velocity_glass", "paletteToken": "ocean_signal", "motionProfile": "calm_reveal", "density": "comfortable", "radiusScale": "lg", "typographyScale": "balanced", }, "validationState": { "schema": "pass", "policy": "pass", "a11y": "pass", "performance": "pass", "status": "validated", }, "auditLog": [f"aud_{execution_id}_error"], "dataRows": [], } async def _call_nemoclaw( self, prompt: str, context: list[dict[str, str]], ctx: PolicyContext, ) -> dict[str, Any]: """ Calls the Nemoclaw hosted model endpoint. Raises on failure so the orchestrator can fall back to demo. """ import httpx # type: ignore async with httpx.AsyncClient(timeout=30.0) as client: resp = await client.post( f"{_NEMOCLAW_URL}/v1/oracle/plan", headers={"Authorization": f"Bearer {_NEMOCLAW_API_KEY}"}, json={ "prompt": prompt, "conversationContext": context, "tenantId": ctx.tenant_id, "actorRole": ctx.actor_role, "semanticModelVersion": "oracle_semantic_v2026_04_08_01", }, ) resp.raise_for_status() return resp.json() # type: ignore[no-any-return] async def get_execution(self, execution_id: str) -> dict[str, Any] | None: return _DEMO_EXECUTIONS.get(execution_id) async def _persist_execution(self, execution: dict[str, Any]) -> None: _DEMO_EXECUTIONS[execution["executionId"]] = execution if not _db_ready(): return assert asyncpg is not None conn = await asyncpg.connect(_DB_URL) try: await conn.execute( """ INSERT INTO oracle_prompt_executions ( execution_id, tenant_id, page_id, branch_id, actor_id, prompt, intent_class, status, model_runtime, semantic_model_version, retrieval_plan, visualization_plan, warnings, summary, components_created, client_request_id, created_at, completed_at ) VALUES ( $1::uuid, $2, $3::uuid, $4, $5, $6, $7, $8, $9, $10, $11::jsonb, $12::jsonb, $13::text[], $14, $15::text[], $16, $17::timestamptz, $18::timestamptz ) ON CONFLICT (execution_id) DO UPDATE SET status = EXCLUDED.status, retrieval_plan = EXCLUDED.retrieval_plan, visualization_plan = EXCLUDED.visualization_plan, warnings = EXCLUDED.warnings, summary = EXCLUDED.summary, components_created = EXCLUDED.components_created, completed_at = EXCLUDED.completed_at """, execution["executionId"], execution["tenantId"], execution["pageId"], execution["branchId"], execution["actorId"], execution["prompt"], execution["intentClass"], execution["status"], execution["modelRuntime"], execution["semanticModelVersion"], json.dumps(execution.get("retrievalPlan") or {}), json.dumps(execution.get("visualizationPlan") or {}), execution.get("warnings", []), execution.get("summary"), execution.get("componentsCreated", []), execution.get("clientRequestId"), execution["createdAt"], execution.get("completedAt"), ) finally: await conn.close() # ── Singleton ───────────────────────────────────────────────────────────────── prompt_orchestrator = PromptOrchestrator()