Files
Project_Velocity/backend/oracle/plan_verifier.py

236 lines
7.9 KiB
Python

"""
oracle/plan_verifier.py
Verify planned SQL before execution and optionally repair common semantic errors.
"""
from __future__ import annotations
import json
import logging
import re
from dataclasses import dataclass, field
from typing import Any
from .semantic_catalog import build_semantic_context_for_planner
logger = logging.getLogger(__name__)
_DESTRUCTIVE = re.compile(
r"\b(insert|update|delete|drop|alter|truncate|copy|create|grant|revoke|call|execute|do|merge)\b",
re.IGNORECASE,
)
_BAD_TIMESTAMP_PATTERNS: list[tuple[str, str]] = [
("edge_communication_events", "timestamp"),
("crm_property_interests", "last_discussed_at"),
("crm_property_interests", "last_interaction"),
]
_BAD_SCORE_PATTERNS: list[tuple[str, str]] = [
("crm_people", "engagement_score"),
("crm_leads", "engagement_score"),
("intel_interactions", "engagement_score"),
("crm_people", "qd_score"),
("crm_leads", "qd_score"),
]
_HALLUCINATED_COLUMNS: list[tuple[str, str]] = [
("intel_interactions", "broker_id"),
("intel_interactions", "sentiment"),
("crm_leads", "last_contacted_at"),
("crm_people", "last_contact"),
]
@dataclass
class VerificationViolation:
rule: str
detail: str
severity: str
@dataclass
class VerificationResult:
passed: bool
sql: str
original_sql: str
violations: list[VerificationViolation] = field(default_factory=list)
was_repaired: bool = False
repair_attempted: bool = False
repair_failed: bool = False
notes: list[str] = field(default_factory=list)
class PlanVerifier:
def verify(self, sql: str, prompt: str, detected_intents: list[str], row_limit: int) -> VerificationResult:
del prompt, detected_intents
violations: list[VerificationViolation] = []
sql_lower = sql.lower()
if _DESTRUCTIVE.search(sql):
violations.append(
VerificationViolation(
rule="destructive_dml",
detail="SQL contains a write or DDL statement.",
severity="blocking",
)
)
for table, column in _BAD_TIMESTAMP_PATTERNS:
if table in sql_lower and column in sql_lower:
violations.append(
VerificationViolation(
rule="deprecated_timestamp",
detail=(
f"SQL references {table}.{column}, which is sparse or deprecated. "
"Use intel_interactions.happened_at or read_last_contacted.last_contacted_at."
),
severity="blocking",
)
)
for table, column in _BAD_SCORE_PATTERNS:
if table in sql_lower and column in sql_lower:
violations.append(
VerificationViolation(
rule="wrong_score_column",
detail=(
f"SQL references {table}.{column}, which is not the QD source of truth. "
"Use intel_qd_scores.current_value."
),
severity="blocking",
)
)
for table, column in _HALLUCINATED_COLUMNS:
if table in sql_lower and column in sql_lower:
violations.append(
VerificationViolation(
rule="hallucinated_column",
detail=f"SQL references {table}.{column}, which does not exist in the live schema.",
severity="blocking",
)
)
if "limit" not in sql_lower:
violations.append(
VerificationViolation(
rule="missing_limit",
detail=f"SQL has no LIMIT clause; executor will enforce row cap {row_limit}.",
severity="warning",
)
)
if re.search(r"\bselect\s+\*\b", sql_lower) and sql_lower.count("join") > 1:
violations.append(
VerificationViolation(
rule="select_star_join",
detail="SELECT * with multiple JOINs may create noisy wide rows.",
severity="warning",
)
)
blocking = [violation for violation in violations if violation.severity == "blocking"]
return VerificationResult(
passed=len(blocking) == 0,
sql=sql,
original_sql=sql,
violations=violations,
)
async def verify_and_repair(
self,
sql: str,
prompt: str,
detected_intents: list[str],
row_limit: int,
llm_service: Any | None = None,
) -> VerificationResult:
result = self.verify(sql, prompt, detected_intents, row_limit)
if result.passed:
return result
blocking = [violation for violation in result.violations if violation.severity == "blocking"]
if not blocking:
return result
result.repair_attempted = True
if llm_service is None:
result.repair_failed = True
result.notes.append("No LLM service available for SQL repair.")
return result
try:
repaired_sql = await self._repair_sql(
sql=sql,
prompt=prompt,
violations=blocking,
detected_intents=detected_intents,
row_limit=row_limit,
llm_service=llm_service,
)
except Exception as exc:
logger.warning("plan_verifier repair failed: %s", exc)
result.repair_failed = True
result.notes.append(f"Repair failed: {exc}")
return result
recheck = self.verify(repaired_sql, prompt, detected_intents, row_limit)
recheck.original_sql = sql
recheck.was_repaired = True
recheck.repair_attempted = True
recheck.notes.append(
"Repaired violations: " + ", ".join(violation.rule for violation in blocking)
)
return recheck
async def _repair_sql(
self,
*,
sql: str,
prompt: str,
violations: list[VerificationViolation],
detected_intents: list[str],
row_limit: int,
llm_service: Any,
) -> str:
semantic_ctx = build_semantic_context_for_planner(detected_intents, max_concepts=4)
violation_text = "\n".join(f"- [{violation.rule}] {violation.detail}" for violation in violations)
response = await llm_service.chat(
provider_id="sglang",
model=None,
system_prompt=(
"You are Oracle's SQL repair agent. "
"Fix only the listed violations. Return strict JSON with key 'sql'."
),
messages=[
{
"role": "user",
"content": (
f"Original prompt: {prompt}\n\n"
f"Semantic catalog:\n{semantic_ctx}\n\n"
f"Violations:\n{violation_text}\n\n"
f"Broken SQL:\n{sql}\n\n"
f"Row cap: {row_limit}\n\n"
"Return JSON: {\"sql\": \"<corrected SQL>\"}"
),
}
],
temperature=0.0,
response_format="json",
metadata={"agent": "oracle_plan_verifier_repair"},
)
message = response.get("message") or {}
parsed = message.get("parsedJson")
if not isinstance(parsed, dict):
content = message.get("content") or "{}"
parsed = json.loads(content) if isinstance(content, str) else {}
repaired = str(parsed.get("sql") or "").strip()
if not repaired:
raise ValueError("Repair LLM returned empty SQL.")
return repaired
plan_verifier = PlanVerifier()