fix: Oracle Canvas JSON Component Generation planning and orchestration logic
This commit is contained in:
235
backend/oracle/plan_verifier.py
Normal file
235
backend/oracle/plan_verifier.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
oracle/plan_verifier.py
|
||||
|
||||
Verify planned SQL before execution and optionally repair common semantic errors.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from .semantic_catalog import build_semantic_context_for_planner
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DESTRUCTIVE = re.compile(
|
||||
r"\b(insert|update|delete|drop|alter|truncate|copy|create|grant|revoke|call|execute|do|merge)\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
_BAD_TIMESTAMP_PATTERNS: list[tuple[str, str]] = [
|
||||
("edge_communication_events", "timestamp"),
|
||||
("crm_property_interests", "last_discussed_at"),
|
||||
("crm_property_interests", "last_interaction"),
|
||||
]
|
||||
|
||||
_BAD_SCORE_PATTERNS: list[tuple[str, str]] = [
|
||||
("crm_people", "engagement_score"),
|
||||
("crm_leads", "engagement_score"),
|
||||
("intel_interactions", "engagement_score"),
|
||||
("crm_people", "qd_score"),
|
||||
("crm_leads", "qd_score"),
|
||||
]
|
||||
|
||||
_HALLUCINATED_COLUMNS: list[tuple[str, str]] = [
|
||||
("intel_interactions", "broker_id"),
|
||||
("intel_interactions", "sentiment"),
|
||||
("crm_leads", "last_contacted_at"),
|
||||
("crm_people", "last_contact"),
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class VerificationViolation:
|
||||
rule: str
|
||||
detail: str
|
||||
severity: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class VerificationResult:
|
||||
passed: bool
|
||||
sql: str
|
||||
original_sql: str
|
||||
violations: list[VerificationViolation] = field(default_factory=list)
|
||||
was_repaired: bool = False
|
||||
repair_attempted: bool = False
|
||||
repair_failed: bool = False
|
||||
notes: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class PlanVerifier:
|
||||
def verify(self, sql: str, prompt: str, detected_intents: list[str], row_limit: int) -> VerificationResult:
|
||||
del prompt, detected_intents
|
||||
violations: list[VerificationViolation] = []
|
||||
sql_lower = sql.lower()
|
||||
|
||||
if _DESTRUCTIVE.search(sql):
|
||||
violations.append(
|
||||
VerificationViolation(
|
||||
rule="destructive_dml",
|
||||
detail="SQL contains a write or DDL statement.",
|
||||
severity="blocking",
|
||||
)
|
||||
)
|
||||
|
||||
for table, column in _BAD_TIMESTAMP_PATTERNS:
|
||||
if table in sql_lower and column in sql_lower:
|
||||
violations.append(
|
||||
VerificationViolation(
|
||||
rule="deprecated_timestamp",
|
||||
detail=(
|
||||
f"SQL references {table}.{column}, which is sparse or deprecated. "
|
||||
"Use intel_interactions.happened_at or read_last_contacted.last_contacted_at."
|
||||
),
|
||||
severity="blocking",
|
||||
)
|
||||
)
|
||||
|
||||
for table, column in _BAD_SCORE_PATTERNS:
|
||||
if table in sql_lower and column in sql_lower:
|
||||
violations.append(
|
||||
VerificationViolation(
|
||||
rule="wrong_score_column",
|
||||
detail=(
|
||||
f"SQL references {table}.{column}, which is not the QD source of truth. "
|
||||
"Use intel_qd_scores.current_value."
|
||||
),
|
||||
severity="blocking",
|
||||
)
|
||||
)
|
||||
|
||||
for table, column in _HALLUCINATED_COLUMNS:
|
||||
if table in sql_lower and column in sql_lower:
|
||||
violations.append(
|
||||
VerificationViolation(
|
||||
rule="hallucinated_column",
|
||||
detail=f"SQL references {table}.{column}, which does not exist in the live schema.",
|
||||
severity="blocking",
|
||||
)
|
||||
)
|
||||
|
||||
if "limit" not in sql_lower:
|
||||
violations.append(
|
||||
VerificationViolation(
|
||||
rule="missing_limit",
|
||||
detail=f"SQL has no LIMIT clause; executor will enforce row cap {row_limit}.",
|
||||
severity="warning",
|
||||
)
|
||||
)
|
||||
|
||||
if re.search(r"\bselect\s+\*\b", sql_lower) and sql_lower.count("join") > 1:
|
||||
violations.append(
|
||||
VerificationViolation(
|
||||
rule="select_star_join",
|
||||
detail="SELECT * with multiple JOINs may create noisy wide rows.",
|
||||
severity="warning",
|
||||
)
|
||||
)
|
||||
|
||||
blocking = [violation for violation in violations if violation.severity == "blocking"]
|
||||
return VerificationResult(
|
||||
passed=len(blocking) == 0,
|
||||
sql=sql,
|
||||
original_sql=sql,
|
||||
violations=violations,
|
||||
)
|
||||
|
||||
async def verify_and_repair(
|
||||
self,
|
||||
sql: str,
|
||||
prompt: str,
|
||||
detected_intents: list[str],
|
||||
row_limit: int,
|
||||
llm_service: Any | None = None,
|
||||
) -> VerificationResult:
|
||||
result = self.verify(sql, prompt, detected_intents, row_limit)
|
||||
if result.passed:
|
||||
return result
|
||||
|
||||
blocking = [violation for violation in result.violations if violation.severity == "blocking"]
|
||||
if not blocking:
|
||||
return result
|
||||
|
||||
result.repair_attempted = True
|
||||
if llm_service is None:
|
||||
result.repair_failed = True
|
||||
result.notes.append("No LLM service available for SQL repair.")
|
||||
return result
|
||||
|
||||
try:
|
||||
repaired_sql = await self._repair_sql(
|
||||
sql=sql,
|
||||
prompt=prompt,
|
||||
violations=blocking,
|
||||
detected_intents=detected_intents,
|
||||
row_limit=row_limit,
|
||||
llm_service=llm_service,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("plan_verifier repair failed: %s", exc)
|
||||
result.repair_failed = True
|
||||
result.notes.append(f"Repair failed: {exc}")
|
||||
return result
|
||||
|
||||
recheck = self.verify(repaired_sql, prompt, detected_intents, row_limit)
|
||||
recheck.original_sql = sql
|
||||
recheck.was_repaired = True
|
||||
recheck.repair_attempted = True
|
||||
recheck.notes.append(
|
||||
"Repaired violations: " + ", ".join(violation.rule for violation in blocking)
|
||||
)
|
||||
return recheck
|
||||
|
||||
async def _repair_sql(
|
||||
self,
|
||||
*,
|
||||
sql: str,
|
||||
prompt: str,
|
||||
violations: list[VerificationViolation],
|
||||
detected_intents: list[str],
|
||||
row_limit: int,
|
||||
llm_service: Any,
|
||||
) -> str:
|
||||
semantic_ctx = build_semantic_context_for_planner(detected_intents, max_concepts=4)
|
||||
violation_text = "\n".join(f"- [{violation.rule}] {violation.detail}" for violation in violations)
|
||||
|
||||
response = await llm_service.chat(
|
||||
provider_id="sglang",
|
||||
model=None,
|
||||
system_prompt=(
|
||||
"You are Oracle's SQL repair agent. "
|
||||
"Fix only the listed violations. Return strict JSON with key 'sql'."
|
||||
),
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Original prompt: {prompt}\n\n"
|
||||
f"Semantic catalog:\n{semantic_ctx}\n\n"
|
||||
f"Violations:\n{violation_text}\n\n"
|
||||
f"Broken SQL:\n{sql}\n\n"
|
||||
f"Row cap: {row_limit}\n\n"
|
||||
"Return JSON: {\"sql\": \"<corrected SQL>\"}"
|
||||
),
|
||||
}
|
||||
],
|
||||
temperature=0.0,
|
||||
response_format="json",
|
||||
metadata={"agent": "oracle_plan_verifier_repair"},
|
||||
)
|
||||
message = response.get("message") or {}
|
||||
parsed = message.get("parsedJson")
|
||||
if not isinstance(parsed, dict):
|
||||
content = message.get("content") or "{}"
|
||||
parsed = json.loads(content) if isinstance(content, str) else {}
|
||||
repaired = str(parsed.get("sql") or "").strip()
|
||||
if not repaired:
|
||||
raise ValueError("Repair LLM returned empty SQL.")
|
||||
return repaired
|
||||
|
||||
|
||||
plan_verifier = PlanVerifier()
|
||||
Reference in New Issue
Block a user