""" oracle/plan_verifier.py Verify planned SQL before execution and optionally repair common semantic errors. """ from __future__ import annotations import json import logging import re from dataclasses import dataclass, field from typing import Any from .semantic_catalog import build_semantic_context_for_planner logger = logging.getLogger(__name__) _DESTRUCTIVE = re.compile( r"\b(insert|update|delete|drop|alter|truncate|copy|create|grant|revoke|call|execute|do|merge)\b", re.IGNORECASE, ) _BAD_TIMESTAMP_PATTERNS: list[tuple[str, str]] = [ ("edge_communication_events", "timestamp"), ("crm_property_interests", "last_discussed_at"), ("crm_property_interests", "last_interaction"), ] _BAD_SCORE_PATTERNS: list[tuple[str, str]] = [ ("crm_people", "engagement_score"), ("crm_leads", "engagement_score"), ("intel_interactions", "engagement_score"), ("crm_people", "qd_score"), ("crm_leads", "qd_score"), ] _HALLUCINATED_COLUMNS: list[tuple[str, str]] = [ ("intel_interactions", "broker_id"), ("intel_interactions", "sentiment"), ("crm_leads", "last_contacted_at"), ("crm_people", "last_contact"), ] @dataclass class VerificationViolation: rule: str detail: str severity: str @dataclass class VerificationResult: passed: bool sql: str original_sql: str violations: list[VerificationViolation] = field(default_factory=list) was_repaired: bool = False repair_attempted: bool = False repair_failed: bool = False notes: list[str] = field(default_factory=list) class PlanVerifier: def verify(self, sql: str, prompt: str, detected_intents: list[str], row_limit: int) -> VerificationResult: del prompt, detected_intents violations: list[VerificationViolation] = [] sql_lower = sql.lower() if _DESTRUCTIVE.search(sql): violations.append( VerificationViolation( rule="destructive_dml", detail="SQL contains a write or DDL statement.", severity="blocking", ) ) for table, column in _BAD_TIMESTAMP_PATTERNS: if table in sql_lower and column in sql_lower: violations.append( VerificationViolation( rule="deprecated_timestamp", detail=( f"SQL references {table}.{column}, which is sparse or deprecated. " "Use intel_interactions.happened_at or read_last_contacted.last_contacted_at." ), severity="blocking", ) ) for table, column in _BAD_SCORE_PATTERNS: if table in sql_lower and column in sql_lower: violations.append( VerificationViolation( rule="wrong_score_column", detail=( f"SQL references {table}.{column}, which is not the QD source of truth. " "Use intel_qd_scores.current_value." ), severity="blocking", ) ) for table, column in _HALLUCINATED_COLUMNS: if table in sql_lower and column in sql_lower: violations.append( VerificationViolation( rule="hallucinated_column", detail=f"SQL references {table}.{column}, which does not exist in the live schema.", severity="blocking", ) ) if "limit" not in sql_lower: violations.append( VerificationViolation( rule="missing_limit", detail=f"SQL has no LIMIT clause; executor will enforce row cap {row_limit}.", severity="warning", ) ) if re.search(r"\bselect\s+\*\b", sql_lower) and sql_lower.count("join") > 1: violations.append( VerificationViolation( rule="select_star_join", detail="SELECT * with multiple JOINs may create noisy wide rows.", severity="warning", ) ) blocking = [violation for violation in violations if violation.severity == "blocking"] return VerificationResult( passed=len(blocking) == 0, sql=sql, original_sql=sql, violations=violations, ) async def verify_and_repair( self, sql: str, prompt: str, detected_intents: list[str], row_limit: int, llm_service: Any | None = None, ) -> VerificationResult: result = self.verify(sql, prompt, detected_intents, row_limit) if result.passed: return result blocking = [violation for violation in result.violations if violation.severity == "blocking"] if not blocking: return result result.repair_attempted = True if llm_service is None: result.repair_failed = True result.notes.append("No LLM service available for SQL repair.") return result try: repaired_sql = await self._repair_sql( sql=sql, prompt=prompt, violations=blocking, detected_intents=detected_intents, row_limit=row_limit, llm_service=llm_service, ) except Exception as exc: logger.warning("plan_verifier repair failed: %s", exc) result.repair_failed = True result.notes.append(f"Repair failed: {exc}") return result recheck = self.verify(repaired_sql, prompt, detected_intents, row_limit) recheck.original_sql = sql recheck.was_repaired = True recheck.repair_attempted = True recheck.notes.append( "Repaired violations: " + ", ".join(violation.rule for violation in blocking) ) return recheck async def _repair_sql( self, *, sql: str, prompt: str, violations: list[VerificationViolation], detected_intents: list[str], row_limit: int, llm_service: Any, ) -> str: semantic_ctx = build_semantic_context_for_planner(detected_intents, max_concepts=4) violation_text = "\n".join(f"- [{violation.rule}] {violation.detail}" for violation in violations) response = await llm_service.chat( provider_id="sglang", model=None, system_prompt=( "You are Oracle's SQL repair agent. " "Fix only the listed violations. Return strict JSON with key 'sql'." ), messages=[ { "role": "user", "content": ( f"Original prompt: {prompt}\n\n" f"Semantic catalog:\n{semantic_ctx}\n\n" f"Violations:\n{violation_text}\n\n" f"Broken SQL:\n{sql}\n\n" f"Row cap: {row_limit}\n\n" "Return JSON: {\"sql\": \"\"}" ), } ], temperature=0.0, response_format="json", metadata={"agent": "oracle_plan_verifier_repair"}, ) message = response.get("message") or {} parsed = message.get("parsedJson") if not isinstance(parsed, dict): content = message.get("content") or "{}" parsed = json.loads(content) if isinstance(content, str) else {} repaired = str(parsed.get("sql") or "").strip() if not repaired: raise ValueError("Repair LLM returned empty SQL.") return repaired plan_verifier = PlanVerifier()