""" oracle/plan_verifier.py Verify planned SQL before execution and optionally repair common semantic errors. """ from __future__ import annotations import json import logging import re from dataclasses import dataclass, field from typing import Any from .semantic_catalog import VALID_QD_SCORE_TYPES, build_semantic_context_for_planner logger = logging.getLogger(__name__) _DESTRUCTIVE = re.compile( r"\b(insert|update|delete|drop|alter|truncate|copy|create|grant|revoke|call|execute|do|merge)\b", re.IGNORECASE, ) _BAD_TIMESTAMP_PATTERNS: list[tuple[str, str]] = [ ("edge_communication_events", "timestamp"), ("crm_property_interests", "last_discussed_at"), ("crm_property_interests", "last_interaction"), ] _BAD_SCORE_PATTERNS: list[tuple[str, str]] = [ ("crm_people", "engagement_score"), ("crm_leads", "engagement_score"), ("intel_interactions", "engagement_score"), ("crm_people", "qd_score"), ("crm_leads", "qd_score"), ] _HALLUCINATED_COLUMNS: list[tuple[str, str]] = [ ("intel_interactions", "broker_id"), ("intel_interactions", "sentiment"), ("crm_leads", "last_contacted_at"), ("crm_people", "last_contact"), ("read_last_contacted", "last_contacted_at"), ("read_last_contacted", "days_since_last_contact"), ("read_last_contacted", "staleness_label"), ] _CONTACT_INTENTS = {"last_contacted", "timeline"} def _extract_limit_from_prompt(prompt: str, default: int) -> int: lowered = prompt.lower() numeric_match = re.search(r"\b(?:top|last|latest|recent|first|show|which|give me)\s+(\d{1,4})\b", lowered) if numeric_match: return max(1, min(int(numeric_match.group(1)), default)) words = { "one": 1, "two": 2, "three": 3, "four": 4, "five": 5, "six": 6, "seven": 7, "eight": 8, "nine": 9, "ten": 10, "eleven": 11, "twelve": 12, "fifteen": 15, "twenty": 20, } word_match = re.search( r"\b(?:top|last|latest|recent|first|show|which|give me)\s+" r"(one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve|fifteen|twenty)\b", lowered, ) if word_match: return max(1, min(words[word_match.group(1)], default)) return default def _canonical_qd_sql(prompt: str, row_limit: int) -> str: limit = _extract_limit_from_prompt(prompt, row_limit) lowered = prompt.lower() direction = "ASC" if any(token in lowered for token in ("lowest", "least", "bottom", "weakest")) else "DESC" project_filter = "" project_join = "" project_match = re.search(r"\bin\s+([A-Za-z0-9][A-Za-z0-9 .&'-]{2,80})(?:\?|$)", prompt) if project_match: project_name = project_match.group(1).strip() if not re.search(r"\b(last|month|months|week|weeks|day|days|year|years)\b", project_name, re.IGNORECASE): project_join = "JOIN crm_property_interests pi ON pi.person_id = p.person_id " escaped = project_name.replace("'", "''") project_filter = f"AND pi.project_name ILIKE '%{escaped}%' " return ( "SELECT p.full_name, p.primary_email, p.primary_phone, " "q.current_value AS qd_score, q.score_type, q.computed_at " "FROM intel_qd_scores q " "JOIN crm_people p ON p.person_id = q.person_id " f"{project_join}" "WHERE q.score_type = 'overall' " f"{project_filter}" f"ORDER BY q.current_value {direction} " f"LIMIT {limit}" ) def _canonical_recent_contact_sql(prompt: str, row_limit: int) -> str: limit = _extract_limit_from_prompt(prompt, row_limit) interval = "3 months" lowered = prompt.lower() interval_match = re.search(r"\b(?:last|past|recent)\s+(\d{1,3})\s+(day|days|week|weeks|month|months|year|years)\b", lowered) if interval_match: count, unit = interval_match.groups() interval = f"{int(count)} {unit}" return ( "SELECT p.full_name, p.primary_email, p.primary_phone, " "lc.last_contact_at, lc.last_channel, lc.days_since_contact, " "q.current_value AS qd_score " "FROM read_last_contacted lc " "JOIN crm_people p ON p.person_id = lc.person_id " "LEFT JOIN intel_qd_scores q ON q.person_id = p.person_id AND q.score_type = 'overall' " f"WHERE lc.last_contact_at >= NOW() - INTERVAL '{interval}' " "ORDER BY q.current_value DESC NULLS LAST, lc.last_contact_at DESC " f"LIMIT {limit}" ) def _semantic_rule_repair( *, prompt: str, detected_intents: list[str], row_limit: int, violations: list[VerificationViolation], ) -> str | None: violation_rules = {violation.rule for violation in violations} if "qd_score" in detected_intents and violation_rules.intersection({"wrong_score_column", "impossible_score_type"}): return _canonical_qd_sql(prompt, row_limit) if set(detected_intents).intersection(_CONTACT_INTENTS) and violation_rules.intersection( {"deprecated_timestamp", "hallucinated_column"} ): return _canonical_recent_contact_sql(prompt, row_limit) return None def _extract_score_type_literals(sql: str) -> list[str]: literals: list[str] = [] eq_pattern = re.compile( r"(?:\b\w+\.)?score_type\s*=\s*'([^']+)'", re.IGNORECASE, ) in_pattern = re.compile( r"(?:\b\w+\.)?score_type\s+in\s*\(([^)]*)\)", re.IGNORECASE | re.DOTALL, ) literals.extend(match.group(1) for match in eq_pattern.finditer(sql)) for match in in_pattern.finditer(sql): literals.extend(re.findall(r"'([^']+)'", match.group(1))) return literals def _references_table(sql_lower: str, table: str) -> bool: return bool(re.search(rf"\b(?:from|join)\s+(?:public\.)?{re.escape(table)}\b", sql_lower)) def _aliases_for_table(sql: str, table: str) -> set[str]: aliases = {table} pattern = re.compile( rf"\b(?:from|join)\s+(?:public\.)?{re.escape(table)}(?:\s+(?:as\s+)?([a-zA-Z_][a-zA-Z0-9_]*))?", re.IGNORECASE, ) for match in pattern.finditer(sql): alias = match.group(1) if alias and alias.lower() not in {"on", "where", "join", "left", "right", "inner", "outer", "full", "cross"}: aliases.add(alias) return aliases def _references_column(sql: str, sql_lower: str, table: str, column: str) -> bool: if not _references_table(sql_lower, table): return False for alias in _aliases_for_table(sql, table): qualified = re.compile(rf"\b{re.escape(alias)}\.{re.escape(column)}\b", re.IGNORECASE) if qualified.search(sql): return True return False @dataclass class VerificationViolation: rule: str detail: str severity: str @dataclass class VerificationResult: passed: bool sql: str original_sql: str violations: list[VerificationViolation] = field(default_factory=list) was_repaired: bool = False repair_attempted: bool = False repair_failed: bool = False notes: list[str] = field(default_factory=list) class PlanVerifier: def verify(self, sql: str, prompt: str, detected_intents: list[str], row_limit: int) -> VerificationResult: del prompt violations: list[VerificationViolation] = [] sql_lower = sql.lower() intent_set = set(detected_intents) if _DESTRUCTIVE.search(sql): violations.append( VerificationViolation( rule="destructive_dml", detail="SQL contains a write or DDL statement.", severity="blocking", ) ) for table, column in _BAD_TIMESTAMP_PATTERNS: if intent_set.intersection(_CONTACT_INTENTS) and _references_column(sql, sql_lower, table, column): violations.append( VerificationViolation( rule="deprecated_timestamp", detail=( f"SQL references {table}.{column}, which is sparse or deprecated. " "Use intel_interactions.happened_at or read_last_contacted.last_contact_at." ), severity="blocking", ) ) valid_score_types = {value.lower() for value in VALID_QD_SCORE_TYPES} for literal in _extract_score_type_literals(sql): if literal.lower() not in valid_score_types: violations.append( VerificationViolation( rule="impossible_score_type", detail=( f"SQL filters intel_qd_scores.score_type with impossible value '{literal}'. " "Valid values are: " + ", ".join(VALID_QD_SCORE_TYPES) + ". " "For generic QD prompts, use score_type = 'overall'." ), severity="blocking", ) ) for table, column in _BAD_SCORE_PATTERNS: if _references_column(sql, sql_lower, table, column): violations.append( VerificationViolation( rule="wrong_score_column", detail=( f"SQL references {table}.{column}, which is not the QD source of truth. " "Use intel_qd_scores.current_value." ), severity="blocking", ) ) for table, column in _HALLUCINATED_COLUMNS: if _references_column(sql, sql_lower, table, column): violations.append( VerificationViolation( rule="hallucinated_column", detail=f"SQL references {table}.{column}, which does not exist in the live schema.", severity="blocking", ) ) if "limit" not in sql_lower: violations.append( VerificationViolation( rule="missing_limit", detail=f"SQL has no LIMIT clause; executor will enforce row cap {row_limit}.", severity="warning", ) ) if re.search(r"\bselect\s+\*\b", sql_lower) and sql_lower.count("join") > 1: violations.append( VerificationViolation( rule="select_star_join", detail="SELECT * with multiple JOINs may create noisy wide rows.", severity="warning", ) ) blocking = [violation for violation in violations if violation.severity == "blocking"] return VerificationResult( passed=len(blocking) == 0, sql=sql, original_sql=sql, violations=violations, ) async def verify_and_repair( self, sql: str, prompt: str, detected_intents: list[str], row_limit: int, llm_service: Any | None = None, ) -> VerificationResult: result = self.verify(sql, prompt, detected_intents, row_limit) if result.passed: return result blocking = [violation for violation in result.violations if violation.severity == "blocking"] if not blocking: return result result.repair_attempted = True if llm_service is None: result.repair_failed = True result.notes.append("No LLM service available for SQL repair.") return result try: repaired_sql = await self._repair_sql( sql=sql, prompt=prompt, violations=blocking, detected_intents=detected_intents, row_limit=row_limit, llm_service=llm_service, ) except Exception as exc: logger.warning("plan_verifier repair failed: %s", exc) result.repair_failed = True result.notes.append(f"Repair failed: {exc}") return result recheck = self.verify(repaired_sql, prompt, detected_intents, row_limit) recheck.original_sql = sql recheck.was_repaired = True recheck.repair_attempted = True recheck.notes.append( "Repaired violations: " + ", ".join(violation.rule for violation in blocking) ) if not recheck.passed: semantic_repair = _semantic_rule_repair( prompt=prompt, detected_intents=detected_intents, row_limit=row_limit, violations=blocking, ) if semantic_repair: semantic_recheck = self.verify(semantic_repair, prompt, detected_intents, row_limit) semantic_recheck.original_sql = sql semantic_recheck.was_repaired = True semantic_recheck.repair_attempted = True semantic_recheck.notes.append( "Semantic rule repair applied: " + ", ".join(violation.rule for violation in blocking) ) return semantic_recheck return recheck async def _repair_sql( self, *, sql: str, prompt: str, violations: list[VerificationViolation], detected_intents: list[str], row_limit: int, llm_service: Any, ) -> str: semantic_ctx = build_semantic_context_for_planner(detected_intents, max_concepts=4) violation_text = "\n".join(f"- [{violation.rule}] {violation.detail}" for violation in violations) hard_rules = ( "Hard repair rules:\n" "- crm_people is identity only. It has no QD score source-of-truth column.\n" "- For QD score prompts, use intel_qd_scores.current_value and join crm_people on person_id.\n" "- Valid intel_qd_scores.score_type values are: " + ", ".join(VALID_QD_SCORE_TYPES) + ".\n" "- Never use score_type = 'QD'. For generic QD prompts use score_type = 'overall'.\n" "- For recent contact prompts, use read_last_contacted.last_contact_at or intel_interactions.happened_at.\n" "- Never use edge_communication_events.timestamp or crm_property_interests.last_discussed_at for contact recency." ) canonical_examples = ( "Canonical repair examples:\n" "Generic QD ranking:\n" "SELECT p.full_name, p.primary_email, p.primary_phone, q.current_value AS qd_score, q.score_type, q.computed_at " "FROM intel_qd_scores q JOIN crm_people p ON p.person_id = q.person_id " "WHERE q.score_type = 'overall' ORDER BY q.current_value DESC LIMIT 8;\n" "Recent contact ranking:\n" "SELECT p.full_name, p.primary_email, lc.last_contact_at, lc.last_channel, q.current_value AS qd_score " "FROM read_last_contacted lc JOIN crm_people p ON p.person_id = lc.person_id " "LEFT JOIN intel_qd_scores q ON q.person_id = p.person_id AND q.score_type = 'overall' " "WHERE lc.last_contact_at >= NOW() - INTERVAL '3 months' " "ORDER BY q.current_value DESC NULLS LAST LIMIT 10;" ) response = await llm_service.chat( provider_id="sglang", model=None, system_prompt=( "You are Oracle's SQL repair agent. " "Fix only the listed violations. Return strict JSON with key 'sql'." ), messages=[ { "role": "user", "content": ( f"Original prompt: {prompt}\n\n" f"Semantic catalog:\n{semantic_ctx}\n\n" f"{hard_rules}\n\n" f"{canonical_examples}\n\n" f"Violations:\n{violation_text}\n\n" f"Broken SQL:\n{sql}\n\n" f"Row cap: {row_limit}\n\n" "Return JSON: {\"sql\": \"\"}" ), } ], temperature=0.0, response_format="json", metadata={"agent": "oracle_plan_verifier_repair"}, ) message = response.get("message") or {} parsed = message.get("parsedJson") if not isinstance(parsed, dict): content = message.get("content") or "{}" parsed = json.loads(content) if isinstance(content, str) else {} repaired = str(parsed.get("sql") or "").strip() if not repaired: raise ValueError("Repair LLM returned empty SQL.") return repaired plan_verifier = PlanVerifier()