forked from sagnik/Project_Velocity
236 lines
7.9 KiB
Python
236 lines
7.9 KiB
Python
"""
|
|
oracle/plan_verifier.py
|
|
|
|
Verify planned SQL before execution and optionally repair common semantic errors.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from .semantic_catalog import build_semantic_context_for_planner
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_DESTRUCTIVE = re.compile(
|
|
r"\b(insert|update|delete|drop|alter|truncate|copy|create|grant|revoke|call|execute|do|merge)\b",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
_BAD_TIMESTAMP_PATTERNS: list[tuple[str, str]] = [
|
|
("edge_communication_events", "timestamp"),
|
|
("crm_property_interests", "last_discussed_at"),
|
|
("crm_property_interests", "last_interaction"),
|
|
]
|
|
|
|
_BAD_SCORE_PATTERNS: list[tuple[str, str]] = [
|
|
("crm_people", "engagement_score"),
|
|
("crm_leads", "engagement_score"),
|
|
("intel_interactions", "engagement_score"),
|
|
("crm_people", "qd_score"),
|
|
("crm_leads", "qd_score"),
|
|
]
|
|
|
|
_HALLUCINATED_COLUMNS: list[tuple[str, str]] = [
|
|
("intel_interactions", "broker_id"),
|
|
("intel_interactions", "sentiment"),
|
|
("crm_leads", "last_contacted_at"),
|
|
("crm_people", "last_contact"),
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class VerificationViolation:
|
|
rule: str
|
|
detail: str
|
|
severity: str
|
|
|
|
|
|
@dataclass
|
|
class VerificationResult:
|
|
passed: bool
|
|
sql: str
|
|
original_sql: str
|
|
violations: list[VerificationViolation] = field(default_factory=list)
|
|
was_repaired: bool = False
|
|
repair_attempted: bool = False
|
|
repair_failed: bool = False
|
|
notes: list[str] = field(default_factory=list)
|
|
|
|
|
|
class PlanVerifier:
|
|
def verify(self, sql: str, prompt: str, detected_intents: list[str], row_limit: int) -> VerificationResult:
|
|
del prompt, detected_intents
|
|
violations: list[VerificationViolation] = []
|
|
sql_lower = sql.lower()
|
|
|
|
if _DESTRUCTIVE.search(sql):
|
|
violations.append(
|
|
VerificationViolation(
|
|
rule="destructive_dml",
|
|
detail="SQL contains a write or DDL statement.",
|
|
severity="blocking",
|
|
)
|
|
)
|
|
|
|
for table, column in _BAD_TIMESTAMP_PATTERNS:
|
|
if table in sql_lower and column in sql_lower:
|
|
violations.append(
|
|
VerificationViolation(
|
|
rule="deprecated_timestamp",
|
|
detail=(
|
|
f"SQL references {table}.{column}, which is sparse or deprecated. "
|
|
"Use intel_interactions.happened_at or read_last_contacted.last_contacted_at."
|
|
),
|
|
severity="blocking",
|
|
)
|
|
)
|
|
|
|
for table, column in _BAD_SCORE_PATTERNS:
|
|
if table in sql_lower and column in sql_lower:
|
|
violations.append(
|
|
VerificationViolation(
|
|
rule="wrong_score_column",
|
|
detail=(
|
|
f"SQL references {table}.{column}, which is not the QD source of truth. "
|
|
"Use intel_qd_scores.current_value."
|
|
),
|
|
severity="blocking",
|
|
)
|
|
)
|
|
|
|
for table, column in _HALLUCINATED_COLUMNS:
|
|
if table in sql_lower and column in sql_lower:
|
|
violations.append(
|
|
VerificationViolation(
|
|
rule="hallucinated_column",
|
|
detail=f"SQL references {table}.{column}, which does not exist in the live schema.",
|
|
severity="blocking",
|
|
)
|
|
)
|
|
|
|
if "limit" not in sql_lower:
|
|
violations.append(
|
|
VerificationViolation(
|
|
rule="missing_limit",
|
|
detail=f"SQL has no LIMIT clause; executor will enforce row cap {row_limit}.",
|
|
severity="warning",
|
|
)
|
|
)
|
|
|
|
if re.search(r"\bselect\s+\*\b", sql_lower) and sql_lower.count("join") > 1:
|
|
violations.append(
|
|
VerificationViolation(
|
|
rule="select_star_join",
|
|
detail="SELECT * with multiple JOINs may create noisy wide rows.",
|
|
severity="warning",
|
|
)
|
|
)
|
|
|
|
blocking = [violation for violation in violations if violation.severity == "blocking"]
|
|
return VerificationResult(
|
|
passed=len(blocking) == 0,
|
|
sql=sql,
|
|
original_sql=sql,
|
|
violations=violations,
|
|
)
|
|
|
|
async def verify_and_repair(
|
|
self,
|
|
sql: str,
|
|
prompt: str,
|
|
detected_intents: list[str],
|
|
row_limit: int,
|
|
llm_service: Any | None = None,
|
|
) -> VerificationResult:
|
|
result = self.verify(sql, prompt, detected_intents, row_limit)
|
|
if result.passed:
|
|
return result
|
|
|
|
blocking = [violation for violation in result.violations if violation.severity == "blocking"]
|
|
if not blocking:
|
|
return result
|
|
|
|
result.repair_attempted = True
|
|
if llm_service is None:
|
|
result.repair_failed = True
|
|
result.notes.append("No LLM service available for SQL repair.")
|
|
return result
|
|
|
|
try:
|
|
repaired_sql = await self._repair_sql(
|
|
sql=sql,
|
|
prompt=prompt,
|
|
violations=blocking,
|
|
detected_intents=detected_intents,
|
|
row_limit=row_limit,
|
|
llm_service=llm_service,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("plan_verifier repair failed: %s", exc)
|
|
result.repair_failed = True
|
|
result.notes.append(f"Repair failed: {exc}")
|
|
return result
|
|
|
|
recheck = self.verify(repaired_sql, prompt, detected_intents, row_limit)
|
|
recheck.original_sql = sql
|
|
recheck.was_repaired = True
|
|
recheck.repair_attempted = True
|
|
recheck.notes.append(
|
|
"Repaired violations: " + ", ".join(violation.rule for violation in blocking)
|
|
)
|
|
return recheck
|
|
|
|
async def _repair_sql(
|
|
self,
|
|
*,
|
|
sql: str,
|
|
prompt: str,
|
|
violations: list[VerificationViolation],
|
|
detected_intents: list[str],
|
|
row_limit: int,
|
|
llm_service: Any,
|
|
) -> str:
|
|
semantic_ctx = build_semantic_context_for_planner(detected_intents, max_concepts=4)
|
|
violation_text = "\n".join(f"- [{violation.rule}] {violation.detail}" for violation in violations)
|
|
|
|
response = await llm_service.chat(
|
|
provider_id="sglang",
|
|
model=None,
|
|
system_prompt=(
|
|
"You are Oracle's SQL repair agent. "
|
|
"Fix only the listed violations. Return strict JSON with key 'sql'."
|
|
),
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": (
|
|
f"Original prompt: {prompt}\n\n"
|
|
f"Semantic catalog:\n{semantic_ctx}\n\n"
|
|
f"Violations:\n{violation_text}\n\n"
|
|
f"Broken SQL:\n{sql}\n\n"
|
|
f"Row cap: {row_limit}\n\n"
|
|
"Return JSON: {\"sql\": \"<corrected SQL>\"}"
|
|
),
|
|
}
|
|
],
|
|
temperature=0.0,
|
|
response_format="json",
|
|
metadata={"agent": "oracle_plan_verifier_repair"},
|
|
)
|
|
message = response.get("message") or {}
|
|
parsed = message.get("parsedJson")
|
|
if not isinstance(parsed, dict):
|
|
content = message.get("content") or "{}"
|
|
parsed = json.loads(content) if isinstance(content, str) else {}
|
|
repaired = str(parsed.get("sql") or "").strip()
|
|
if not repaired:
|
|
raise ValueError("Repair LLM returned empty SQL.")
|
|
return repaired
|
|
|
|
|
|
plan_verifier = PlanVerifier()
|