forked from sagnik/Project_Velocity
437 lines
16 KiB
Python
437 lines
16 KiB
Python
"""
|
|
oracle/plan_verifier.py
|
|
|
|
Verify planned SQL before execution and optionally repair common semantic errors.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import re
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from .semantic_catalog import VALID_QD_SCORE_TYPES, build_semantic_context_for_planner
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_DESTRUCTIVE = re.compile(
|
|
r"\b(insert|update|delete|drop|alter|truncate|copy|create|grant|revoke|call|execute|do|merge)\b",
|
|
re.IGNORECASE,
|
|
)
|
|
|
|
_BAD_TIMESTAMP_PATTERNS: list[tuple[str, str]] = [
|
|
("edge_communication_events", "timestamp"),
|
|
("crm_property_interests", "last_discussed_at"),
|
|
("crm_property_interests", "last_interaction"),
|
|
]
|
|
|
|
_BAD_SCORE_PATTERNS: list[tuple[str, str]] = [
|
|
("crm_people", "engagement_score"),
|
|
("crm_leads", "engagement_score"),
|
|
("intel_interactions", "engagement_score"),
|
|
("crm_people", "qd_score"),
|
|
("crm_leads", "qd_score"),
|
|
]
|
|
|
|
_HALLUCINATED_COLUMNS: list[tuple[str, str]] = [
|
|
("intel_interactions", "broker_id"),
|
|
("intel_interactions", "sentiment"),
|
|
("crm_leads", "last_contacted_at"),
|
|
("crm_people", "last_contact"),
|
|
("read_last_contacted", "last_contacted_at"),
|
|
("read_last_contacted", "days_since_last_contact"),
|
|
("read_last_contacted", "staleness_label"),
|
|
]
|
|
|
|
_CONTACT_INTENTS = {"last_contacted", "timeline"}
|
|
|
|
|
|
def _extract_limit_from_prompt(prompt: str, default: int) -> int:
|
|
lowered = prompt.lower()
|
|
numeric_match = re.search(r"\b(?:top|last|latest|recent|first|show|which|give me)\s+(\d{1,4})\b", lowered)
|
|
if numeric_match:
|
|
return max(1, min(int(numeric_match.group(1)), default))
|
|
words = {
|
|
"one": 1,
|
|
"two": 2,
|
|
"three": 3,
|
|
"four": 4,
|
|
"five": 5,
|
|
"six": 6,
|
|
"seven": 7,
|
|
"eight": 8,
|
|
"nine": 9,
|
|
"ten": 10,
|
|
"eleven": 11,
|
|
"twelve": 12,
|
|
"fifteen": 15,
|
|
"twenty": 20,
|
|
}
|
|
word_match = re.search(
|
|
r"\b(?:top|last|latest|recent|first|show|which|give me)\s+"
|
|
r"(one|two|three|four|five|six|seven|eight|nine|ten|eleven|twelve|fifteen|twenty)\b",
|
|
lowered,
|
|
)
|
|
if word_match:
|
|
return max(1, min(words[word_match.group(1)], default))
|
|
return default
|
|
|
|
|
|
def _canonical_qd_sql(prompt: str, row_limit: int) -> str:
|
|
limit = _extract_limit_from_prompt(prompt, row_limit)
|
|
lowered = prompt.lower()
|
|
direction = "ASC" if any(token in lowered for token in ("lowest", "least", "bottom", "weakest")) else "DESC"
|
|
project_filter = ""
|
|
project_join = ""
|
|
project_match = re.search(r"\bin\s+([A-Za-z0-9][A-Za-z0-9 .&'-]{2,80})(?:\?|$)", prompt)
|
|
if project_match:
|
|
project_name = project_match.group(1).strip()
|
|
if not re.search(r"\b(last|month|months|week|weeks|day|days|year|years)\b", project_name, re.IGNORECASE):
|
|
project_join = "JOIN crm_property_interests pi ON pi.person_id = p.person_id "
|
|
escaped = project_name.replace("'", "''")
|
|
project_filter = f"AND pi.project_name ILIKE '%{escaped}%' "
|
|
return (
|
|
"SELECT p.full_name, p.primary_email, p.primary_phone, "
|
|
"q.current_value AS qd_score, q.score_type, q.computed_at "
|
|
"FROM intel_qd_scores q "
|
|
"JOIN crm_people p ON p.person_id = q.person_id "
|
|
f"{project_join}"
|
|
"WHERE q.score_type = 'overall' "
|
|
f"{project_filter}"
|
|
f"ORDER BY q.current_value {direction} "
|
|
f"LIMIT {limit}"
|
|
)
|
|
|
|
|
|
def _canonical_recent_contact_sql(prompt: str, row_limit: int) -> str:
|
|
limit = _extract_limit_from_prompt(prompt, row_limit)
|
|
interval = "3 months"
|
|
lowered = prompt.lower()
|
|
interval_match = re.search(r"\b(?:last|past|recent)\s+(\d{1,3})\s+(day|days|week|weeks|month|months|year|years)\b", lowered)
|
|
if interval_match:
|
|
count, unit = interval_match.groups()
|
|
interval = f"{int(count)} {unit}"
|
|
return (
|
|
"SELECT p.full_name, p.primary_email, p.primary_phone, "
|
|
"lc.last_contact_at, lc.last_channel, lc.days_since_contact, "
|
|
"q.current_value AS qd_score "
|
|
"FROM read_last_contacted lc "
|
|
"JOIN crm_people p ON p.person_id = lc.person_id "
|
|
"LEFT JOIN intel_qd_scores q ON q.person_id = p.person_id AND q.score_type = 'overall' "
|
|
f"WHERE lc.last_contact_at >= NOW() - INTERVAL '{interval}' "
|
|
"ORDER BY q.current_value DESC NULLS LAST, lc.last_contact_at DESC "
|
|
f"LIMIT {limit}"
|
|
)
|
|
|
|
|
|
def _semantic_rule_repair(
|
|
*,
|
|
prompt: str,
|
|
detected_intents: list[str],
|
|
row_limit: int,
|
|
violations: list[VerificationViolation],
|
|
) -> str | None:
|
|
violation_rules = {violation.rule for violation in violations}
|
|
if "qd_score" in detected_intents and violation_rules.intersection({"wrong_score_column", "impossible_score_type"}):
|
|
return _canonical_qd_sql(prompt, row_limit)
|
|
if set(detected_intents).intersection(_CONTACT_INTENTS) and violation_rules.intersection(
|
|
{"deprecated_timestamp", "hallucinated_column"}
|
|
):
|
|
return _canonical_recent_contact_sql(prompt, row_limit)
|
|
return None
|
|
|
|
|
|
def _extract_score_type_literals(sql: str) -> list[str]:
|
|
literals: list[str] = []
|
|
eq_pattern = re.compile(
|
|
r"(?:\b\w+\.)?score_type\s*=\s*'([^']+)'",
|
|
re.IGNORECASE,
|
|
)
|
|
in_pattern = re.compile(
|
|
r"(?:\b\w+\.)?score_type\s+in\s*\(([^)]*)\)",
|
|
re.IGNORECASE | re.DOTALL,
|
|
)
|
|
literals.extend(match.group(1) for match in eq_pattern.finditer(sql))
|
|
for match in in_pattern.finditer(sql):
|
|
literals.extend(re.findall(r"'([^']+)'", match.group(1)))
|
|
return literals
|
|
|
|
|
|
def _references_table(sql_lower: str, table: str) -> bool:
|
|
return bool(re.search(rf"\b(?:from|join)\s+(?:public\.)?{re.escape(table)}\b", sql_lower))
|
|
|
|
|
|
def _aliases_for_table(sql: str, table: str) -> set[str]:
|
|
aliases = {table}
|
|
pattern = re.compile(
|
|
rf"\b(?:from|join)\s+(?:public\.)?{re.escape(table)}(?:\s+(?:as\s+)?([a-zA-Z_][a-zA-Z0-9_]*))?",
|
|
re.IGNORECASE,
|
|
)
|
|
for match in pattern.finditer(sql):
|
|
alias = match.group(1)
|
|
if alias and alias.lower() not in {"on", "where", "join", "left", "right", "inner", "outer", "full", "cross"}:
|
|
aliases.add(alias)
|
|
return aliases
|
|
|
|
|
|
def _references_column(sql: str, sql_lower: str, table: str, column: str) -> bool:
|
|
if not _references_table(sql_lower, table):
|
|
return False
|
|
for alias in _aliases_for_table(sql, table):
|
|
qualified = re.compile(rf"\b{re.escape(alias)}\.{re.escape(column)}\b", re.IGNORECASE)
|
|
if qualified.search(sql):
|
|
return True
|
|
return False
|
|
|
|
|
|
@dataclass
|
|
class VerificationViolation:
|
|
rule: str
|
|
detail: str
|
|
severity: str
|
|
|
|
|
|
@dataclass
|
|
class VerificationResult:
|
|
passed: bool
|
|
sql: str
|
|
original_sql: str
|
|
violations: list[VerificationViolation] = field(default_factory=list)
|
|
was_repaired: bool = False
|
|
repair_attempted: bool = False
|
|
repair_failed: bool = False
|
|
notes: list[str] = field(default_factory=list)
|
|
|
|
|
|
class PlanVerifier:
|
|
def verify(self, sql: str, prompt: str, detected_intents: list[str], row_limit: int) -> VerificationResult:
|
|
del prompt
|
|
violations: list[VerificationViolation] = []
|
|
sql_lower = sql.lower()
|
|
intent_set = set(detected_intents)
|
|
|
|
if _DESTRUCTIVE.search(sql):
|
|
violations.append(
|
|
VerificationViolation(
|
|
rule="destructive_dml",
|
|
detail="SQL contains a write or DDL statement.",
|
|
severity="blocking",
|
|
)
|
|
)
|
|
|
|
for table, column in _BAD_TIMESTAMP_PATTERNS:
|
|
if intent_set.intersection(_CONTACT_INTENTS) and _references_column(sql, sql_lower, table, column):
|
|
violations.append(
|
|
VerificationViolation(
|
|
rule="deprecated_timestamp",
|
|
detail=(
|
|
f"SQL references {table}.{column}, which is sparse or deprecated. "
|
|
"Use intel_interactions.happened_at or read_last_contacted.last_contact_at."
|
|
),
|
|
severity="blocking",
|
|
)
|
|
)
|
|
|
|
valid_score_types = {value.lower() for value in VALID_QD_SCORE_TYPES}
|
|
for literal in _extract_score_type_literals(sql):
|
|
if literal.lower() not in valid_score_types:
|
|
violations.append(
|
|
VerificationViolation(
|
|
rule="impossible_score_type",
|
|
detail=(
|
|
f"SQL filters intel_qd_scores.score_type with impossible value '{literal}'. "
|
|
"Valid values are: " + ", ".join(VALID_QD_SCORE_TYPES) + ". "
|
|
"For generic QD prompts, use score_type = 'overall'."
|
|
),
|
|
severity="blocking",
|
|
)
|
|
)
|
|
|
|
for table, column in _BAD_SCORE_PATTERNS:
|
|
if _references_column(sql, sql_lower, table, column):
|
|
violations.append(
|
|
VerificationViolation(
|
|
rule="wrong_score_column",
|
|
detail=(
|
|
f"SQL references {table}.{column}, which is not the QD source of truth. "
|
|
"Use intel_qd_scores.current_value."
|
|
),
|
|
severity="blocking",
|
|
)
|
|
)
|
|
|
|
for table, column in _HALLUCINATED_COLUMNS:
|
|
if _references_column(sql, sql_lower, table, column):
|
|
violations.append(
|
|
VerificationViolation(
|
|
rule="hallucinated_column",
|
|
detail=f"SQL references {table}.{column}, which does not exist in the live schema.",
|
|
severity="blocking",
|
|
)
|
|
)
|
|
|
|
if "limit" not in sql_lower:
|
|
violations.append(
|
|
VerificationViolation(
|
|
rule="missing_limit",
|
|
detail=f"SQL has no LIMIT clause; executor will enforce row cap {row_limit}.",
|
|
severity="warning",
|
|
)
|
|
)
|
|
|
|
if re.search(r"\bselect\s+\*\b", sql_lower) and sql_lower.count("join") > 1:
|
|
violations.append(
|
|
VerificationViolation(
|
|
rule="select_star_join",
|
|
detail="SELECT * with multiple JOINs may create noisy wide rows.",
|
|
severity="warning",
|
|
)
|
|
)
|
|
|
|
blocking = [violation for violation in violations if violation.severity == "blocking"]
|
|
return VerificationResult(
|
|
passed=len(blocking) == 0,
|
|
sql=sql,
|
|
original_sql=sql,
|
|
violations=violations,
|
|
)
|
|
|
|
async def verify_and_repair(
|
|
self,
|
|
sql: str,
|
|
prompt: str,
|
|
detected_intents: list[str],
|
|
row_limit: int,
|
|
llm_service: Any | None = None,
|
|
) -> VerificationResult:
|
|
result = self.verify(sql, prompt, detected_intents, row_limit)
|
|
if result.passed:
|
|
return result
|
|
|
|
blocking = [violation for violation in result.violations if violation.severity == "blocking"]
|
|
if not blocking:
|
|
return result
|
|
|
|
result.repair_attempted = True
|
|
if llm_service is None:
|
|
result.repair_failed = True
|
|
result.notes.append("No LLM service available for SQL repair.")
|
|
return result
|
|
|
|
try:
|
|
repaired_sql = await self._repair_sql(
|
|
sql=sql,
|
|
prompt=prompt,
|
|
violations=blocking,
|
|
detected_intents=detected_intents,
|
|
row_limit=row_limit,
|
|
llm_service=llm_service,
|
|
)
|
|
except Exception as exc:
|
|
logger.warning("plan_verifier repair failed: %s", exc)
|
|
result.repair_failed = True
|
|
result.notes.append(f"Repair failed: {exc}")
|
|
return result
|
|
|
|
recheck = self.verify(repaired_sql, prompt, detected_intents, row_limit)
|
|
recheck.original_sql = sql
|
|
recheck.was_repaired = True
|
|
recheck.repair_attempted = True
|
|
recheck.notes.append(
|
|
"Repaired violations: " + ", ".join(violation.rule for violation in blocking)
|
|
)
|
|
if not recheck.passed:
|
|
semantic_repair = _semantic_rule_repair(
|
|
prompt=prompt,
|
|
detected_intents=detected_intents,
|
|
row_limit=row_limit,
|
|
violations=blocking,
|
|
)
|
|
if semantic_repair:
|
|
semantic_recheck = self.verify(semantic_repair, prompt, detected_intents, row_limit)
|
|
semantic_recheck.original_sql = sql
|
|
semantic_recheck.was_repaired = True
|
|
semantic_recheck.repair_attempted = True
|
|
semantic_recheck.notes.append(
|
|
"Semantic rule repair applied: " + ", ".join(violation.rule for violation in blocking)
|
|
)
|
|
return semantic_recheck
|
|
return recheck
|
|
|
|
async def _repair_sql(
|
|
self,
|
|
*,
|
|
sql: str,
|
|
prompt: str,
|
|
violations: list[VerificationViolation],
|
|
detected_intents: list[str],
|
|
row_limit: int,
|
|
llm_service: Any,
|
|
) -> str:
|
|
semantic_ctx = build_semantic_context_for_planner(detected_intents, max_concepts=4)
|
|
violation_text = "\n".join(f"- [{violation.rule}] {violation.detail}" for violation in violations)
|
|
hard_rules = (
|
|
"Hard repair rules:\n"
|
|
"- crm_people is identity only. It has no QD score source-of-truth column.\n"
|
|
"- For QD score prompts, use intel_qd_scores.current_value and join crm_people on person_id.\n"
|
|
"- Valid intel_qd_scores.score_type values are: "
|
|
+ ", ".join(VALID_QD_SCORE_TYPES)
|
|
+ ".\n"
|
|
"- Never use score_type = 'QD'. For generic QD prompts use score_type = 'overall'.\n"
|
|
"- For recent contact prompts, use read_last_contacted.last_contact_at or intel_interactions.happened_at.\n"
|
|
"- Never use edge_communication_events.timestamp or crm_property_interests.last_discussed_at for contact recency."
|
|
)
|
|
canonical_examples = (
|
|
"Canonical repair examples:\n"
|
|
"Generic QD ranking:\n"
|
|
"SELECT p.full_name, p.primary_email, p.primary_phone, q.current_value AS qd_score, q.score_type, q.computed_at "
|
|
"FROM intel_qd_scores q JOIN crm_people p ON p.person_id = q.person_id "
|
|
"WHERE q.score_type = 'overall' ORDER BY q.current_value DESC LIMIT 8;\n"
|
|
"Recent contact ranking:\n"
|
|
"SELECT p.full_name, p.primary_email, lc.last_contact_at, lc.last_channel, q.current_value AS qd_score "
|
|
"FROM read_last_contacted lc JOIN crm_people p ON p.person_id = lc.person_id "
|
|
"LEFT JOIN intel_qd_scores q ON q.person_id = p.person_id AND q.score_type = 'overall' "
|
|
"WHERE lc.last_contact_at >= NOW() - INTERVAL '3 months' "
|
|
"ORDER BY q.current_value DESC NULLS LAST LIMIT 10;"
|
|
)
|
|
|
|
response = await llm_service.chat(
|
|
provider_id="sglang",
|
|
model=None,
|
|
system_prompt=(
|
|
"You are Oracle's SQL repair agent. "
|
|
"Fix only the listed violations. Return strict JSON with key 'sql'."
|
|
),
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": (
|
|
f"Original prompt: {prompt}\n\n"
|
|
f"Semantic catalog:\n{semantic_ctx}\n\n"
|
|
f"{hard_rules}\n\n"
|
|
f"{canonical_examples}\n\n"
|
|
f"Violations:\n{violation_text}\n\n"
|
|
f"Broken SQL:\n{sql}\n\n"
|
|
f"Row cap: {row_limit}\n\n"
|
|
"Return JSON: {\"sql\": \"<corrected SQL>\"}"
|
|
),
|
|
}
|
|
],
|
|
temperature=0.0,
|
|
response_format="json",
|
|
metadata={"agent": "oracle_plan_verifier_repair"},
|
|
)
|
|
message = response.get("message") or {}
|
|
parsed = message.get("parsedJson")
|
|
if not isinstance(parsed, dict):
|
|
content = message.get("content") or "{}"
|
|
parsed = json.loads(content) if isinstance(content, str) else {}
|
|
repaired = str(parsed.get("sql") or "").strip()
|
|
if not repaired:
|
|
raise ValueError("Repair LLM returned empty SQL.")
|
|
return repaired
|
|
|
|
|
|
plan_verifier = PlanVerifier()
|