Files
Project_Velocity/backend/oracle/natural_db_agent.py

334 lines
14 KiB
Python

"""
Natural DB-first Oracle agent.
The LLM can plan arbitrary analytical SELECT statements over the full public
Velocity app schema. The executor enforces only a read-only SQL contract and a
UI row cap; write paths stay behind typed API endpoints.
"""
from __future__ import annotations
import json
import logging
import os
import re
from dataclasses import dataclass
from datetime import date, datetime
from decimal import Decimal
from typing import Any
from backend.services.runtime_llm_service import runtime_llm_service
try:
import asyncpg # type: ignore
except Exception: # pragma: no cover
asyncpg = None # type: ignore
logger = logging.getLogger(__name__)
DESTRUCTIVE_SQL = re.compile(
r"\b(insert|update|delete|drop|alter|truncate|copy|create|grant|revoke|call|execute|do|merge)\b",
re.IGNORECASE,
)
TABLE_REF_RE = re.compile(r"\b(?:from|join)\s+([a-zA-Z_][\w.]*)(?:\s|$)", re.IGNORECASE)
CTE_NAME_RE = re.compile(r"\b(?:with|,)\s*([a-zA-Z_][\w]*)\s+as\s*\(", re.IGNORECASE)
def _json_safe(value: Any) -> Any:
if isinstance(value, (datetime, date)):
return value.isoformat()
if isinstance(value, Decimal):
return float(value)
if isinstance(value, (list, tuple)):
return [_json_safe(v) for v in value]
if isinstance(value, dict):
return {str(k): _json_safe(v) for k, v in value.items()}
return value
def db_ready() -> bool:
if asyncpg is None:
return False
read_database_url = os.getenv("ORACLE_READ_DATABASE_URL", "")
if read_database_url and not read_database_url.startswith("PLACEHOLDER"):
return True
database_url = os.getenv("DATABASE_URL", "")
return bool(database_url and not database_url.startswith("PLACEHOLDER")) or all(
os.getenv(name) for name in ("VELOCITY_DB_NAME", "VELOCITY_DB_USER", "VELOCITY_DB_PASSWORD")
)
async def connect_db() -> Any:
if asyncpg is None:
raise RuntimeError("asyncpg is not installed.")
read_database_url = os.getenv("ORACLE_READ_DATABASE_URL", "")
if read_database_url and not read_database_url.startswith("PLACEHOLDER"):
return await asyncpg.connect(read_database_url)
if all(os.getenv(name) for name in ("VELOCITY_DB_READ_NAME", "VELOCITY_DB_READ_USER", "VELOCITY_DB_READ_PASSWORD")):
return await asyncpg.connect(
host=os.getenv("VELOCITY_DB_READ_HOST", os.getenv("VELOCITY_DB_HOST", "127.0.0.1")),
port=int(os.getenv("VELOCITY_DB_READ_PORT", os.getenv("VELOCITY_DB_PORT", "5432"))),
database=os.environ["VELOCITY_DB_READ_NAME"],
user=os.environ["VELOCITY_DB_READ_USER"],
password=os.environ["VELOCITY_DB_READ_PASSWORD"],
)
database_url = os.getenv("DATABASE_URL", "")
if database_url and not database_url.startswith("PLACEHOLDER"):
return await asyncpg.connect(database_url)
return await asyncpg.connect(
host=os.getenv("VELOCITY_DB_HOST", "127.0.0.1"),
port=int(os.getenv("VELOCITY_DB_PORT", "5432")),
database=os.environ["VELOCITY_DB_NAME"],
user=os.environ["VELOCITY_DB_USER"],
password=os.environ["VELOCITY_DB_PASSWORD"],
)
@dataclass
class NaturalQueryResult:
prompt: str
sql: str
title: str
summary: str
columns: list[str]
rows: list[dict[str, Any]]
row_count: int
source_tables: list[str]
component_type: str
warnings: list[str]
def as_dict(self) -> dict[str, Any]:
return {
"prompt": self.prompt,
"sql": self.sql,
"title": self.title,
"summary": self.summary,
"columns": self.columns,
"rows": self.rows,
"rowCount": self.row_count,
"sourceTables": self.source_tables,
"componentType": self.component_type,
"warnings": self.warnings,
}
def sanitize_sql(sql: str, row_limit: int) -> tuple[str, list[str], list[str]]:
warnings: list[str] = []
clean = re.sub(r"--.*?$|/\*.*?\*/", "", sql.strip(), flags=re.MULTILINE | re.DOTALL).strip().rstrip(";")
if not re.match(r"^(select|with)\b", clean, re.IGNORECASE):
raise ValueError("Oracle SQL agent only accepts SELECT or WITH queries.")
if DESTRUCTIVE_SQL.search(clean):
raise ValueError("Oracle SQL agent blocked non-read SQL.")
tables = []
for match in TABLE_REF_RE.finditer(clean):
table = match.group(1).split(".")[-1].strip('"').lower()
if table in {"lateral", "select"}:
continue
if table and table not in tables:
tables.append(table)
return clean, tables, warnings
def infer_component_type(prompt: str, columns: list[str], rows: list[dict[str, Any]]) -> str:
lower = prompt.lower()
if any(term in lower for term in ("timeline", "conversation", "whatsapp", "message", "call", "email", "history")):
return "activity_stream"
if len(rows) == 1 and len(columns) <= 5 and any(isinstance(rows[0].get(c), (int, float)) for c in columns):
return "kpi_tile"
if any(c.endswith("_at") or c in {"date", "when", "timestamp", "happened_at"} for c in columns):
if len(rows) > 1 and any(term in lower for term in ("trend", "over time", "timeseries")):
return "line_chart"
if any(term in lower for term in ("timeline", "activity", "last", "recent")):
return "activity_stream"
numeric_cols = [c for c in columns if rows and isinstance(rows[0].get(c), (int, float))]
if numeric_cols and any(term in lower for term in ("count", "compare", "distribution", "most", "top", "by ")):
return "bar_chart"
return "table"
def _looks_like_property_rollup_prompt(prompt: str) -> bool:
lower = prompt.lower()
property_terms = ("property", "properties", "project", "projects")
aggregate_terms = ("top", "most", "majority", "highest", "popular", "common")
interest_terms = ("interest", "interested", "liked", "preference", "preferences")
return (
any(term in lower for term in property_terms)
and any(term in lower for term in aggregate_terms)
and any(term in lower for term in interest_terms)
)
def title_from_prompt(prompt: str) -> str:
words = re.sub(r"\s+", " ", prompt.strip()).strip(" ?.!")
return words[:1].upper() + words[1:80] if words else "Oracle Query Result"
class NaturalDbAgent:
async def schema_catalog(self, conn: Any | None = None) -> dict[str, Any]:
own_conn = conn is None
if conn is None:
if not db_ready():
return {"tables": [], "available": False}
conn = await connect_db()
try:
table_names = await conn.fetch(
"""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public' AND table_type = 'BASE TABLE'
ORDER BY table_name
"""
)
public_tables = [row["table_name"] for row in table_names]
rows = await conn.fetch(
"""
SELECT c.table_name, c.column_name, c.data_type, c.udt_name, c.is_nullable
FROM information_schema.columns c
WHERE c.table_schema = 'public'
ORDER BY c.table_name, c.ordinal_position
"""
)
counts = {}
for table in public_tables:
exists = await conn.fetchval("SELECT to_regclass($1)", f"public.{table}")
counts[table] = None if not exists else int(await conn.fetchval(f'SELECT COUNT(*) FROM "{table}"'))
tables: dict[str, dict[str, Any]] = {}
for row in rows:
entry = tables.setdefault(row["table_name"], {"columns": [], "rowCount": counts.get(row["table_name"])})
entry["columns"].append({
"name": row["column_name"],
"dataType": row["data_type"],
"udtName": row["udt_name"],
"nullable": row["is_nullable"] == "YES",
})
return {"available": True, "tables": tables, "allowedTables": public_tables}
finally:
if own_conn:
await conn.close()
async def data_health(self, conn: Any | None = None) -> dict[str, Any]:
catalog = await self.schema_catalog(conn)
expected = {
"crm_people": 341,
"crm_leads": 250,
"crm_opportunities": 400,
"crm_property_interests": 400,
"intel_interactions": 1897,
"intel_messages": 6944,
"intel_calls": 478,
"intel_transcripts": 231,
"intel_emails": 149,
"intel_visits": 305,
"intel_reminders": 759,
"intel_extracted_facts": 1686,
"read_last_contacted": 250,
"read_next_best_action": 250,
}
tables = catalog.get("tables", {})
counts = {table: (meta or {}).get("rowCount") for table, meta in sorted(tables.items())}
return {
"counts": counts,
"expectedSyntheticV2Counts": expected,
"missingTables": [t for t, count in counts.items() if count is None],
"emptyTables": [t for t, count in counts.items() if count == 0],
"belowExpected": {t: {"expected": e, "actual": counts.get(t)} for t, e in expected.items() if (counts.get(t) or 0) < e},
}
async def execute_prompt(self, prompt: str, *, row_limit: int = 100, conn: Any | None = None) -> NaturalQueryResult:
if not prompt.strip():
raise ValueError("Prompt is required.")
own_conn = conn is None
if conn is None:
if not db_ready():
raise RuntimeError("Database unavailable for Oracle natural query.")
conn = await connect_db()
try:
catalog = await self.schema_catalog(conn)
plan = await self._plan_sql(prompt, catalog, row_limit)
return await self._run_plan(conn, prompt, plan, row_limit)
finally:
if own_conn:
await conn.close()
async def _run_plan(self, conn: Any, prompt: str, plan: dict[str, Any], row_limit: int) -> NaturalQueryResult:
raw_sql = str(plan.get("sql") or "").strip()
if not raw_sql:
raise RuntimeError("Natural SQL planner returned no SQL.")
sql, tables, warnings = sanitize_sql(raw_sql, row_limit)
try:
records = await conn.fetch(sql)
except Exception as exc:
raise RuntimeError(f"Natural SQL execution failed: {exc}") from exc
rows = [_json_safe(dict(record)) for record in records]
columns = list(rows[0].keys()) if rows else []
component_type = infer_component_type(prompt, columns, rows)
return NaturalQueryResult(
prompt=prompt,
sql=sql,
title=str(plan.get("title") or title_from_prompt(prompt)),
summary=str(plan.get("rationale") or f"SQL-backed Oracle result from {', '.join(tables) or 'Velocity CRM'}."),
columns=columns,
rows=rows,
row_count=len(rows),
source_tables=tables,
component_type=component_type,
warnings=warnings,
)
async def _plan_sql(self, prompt: str, catalog: dict[str, Any], row_limit: int) -> dict[str, Any]:
try:
providers = runtime_llm_service._provider_catalog()
except Exception:
providers = {}
if not providers:
raise RuntimeError("No runtime LLM providers are configured for Oracle natural planning.")
schema_brief = json.dumps(catalog.get("tables", {}), default=str)[:16000]
semantic_rules = """
Velocity SQL semantics:
- QD score means intel_qd_scores.current_value. Do not use crm_people.engagement_score, crm_leads.engagement_score, or intel_interactions.engagement_score as QD.
- For project/property scoped prompts such as "in Atri Surya Toron", "interested in", "for project", or "for property", use crm_property_interests as the primary scoping table.
- Prefer crm_property_interests.project_name for textual project matching. inventory_projects is optional for enrichment, not the primary client-to-project relationship.
- For client lists scoped to a project, join crm_people to crm_property_interests on person_id and filter project_name case-insensitively.
- For lowest/highest/best/worst QD prompts, sort on intel_qd_scores.current_value ASC/DESC as requested.
- Respect the user-requested cardinality exactly when possible. If the prompt says five/top 5/lowest 5, return LIMIT 5.
- When listing clients, include person identity fields from crm_people such as person_id, full_name, primary_phone, and primary_email.
- When aggregating top properties/projects, group by crm_property_interests.project_name and count DISTINCT person_id.
- You may use any table in the public schema that is relevant to the question.
- Use only read-only PostgreSQL SELECT/CTE queries.
"""
system = (
"You are Oracle's read-only PostgreSQL planner. Generate one useful SELECT or WITH query "
"for the user's CRM question. You have access to the full public schema. Return JSON with sql, title, rationale. "
"Never generate INSERT, UPDATE, DELETE, DDL, COPY, or permission statements."
)
try:
response = await runtime_llm_service.chat(
provider_id="sglang",
model=None,
system_prompt=system,
messages=[{
"role": "user",
"content": (
f"Schema:\n{schema_brief}\n\n"
f"Semantic rules:\n{semantic_rules}\n\n"
f"Question:\n{prompt}\n\n"
f"Row cap: {row_limit}\n\n"
"Return strict JSON with keys: sql, title, rationale."
),
}],
temperature=0.05,
response_format="json",
metadata={"agent": "oracle_natural_db_agent"},
)
message = response.get("message") or {}
parsed = message.get("parsedJson")
content = message.get("content") or "{}"
if not isinstance(parsed, dict):
parsed = json.loads(content) if isinstance(content, str) else content
if isinstance(parsed, dict) and parsed.get("sql"):
return parsed
except Exception as exc:
raise RuntimeError(f"Natural DB planner LLM failed: {exc}") from exc
raise RuntimeError("Natural DB planner returned no valid SQL.")
natural_db_agent = NaturalDbAgent()