355 lines
12 KiB
Python
355 lines
12 KiB
Python
"""
|
|
backend/services/nemoclaw_client.py - NemoClaw inference client.
|
|
|
|
Production path:
|
|
1. Shared SGLang / OpenAI-compatible coding runtime.
|
|
|
|
Compatibility:
|
|
- Legacy NEMOCLAW_* env names are still honored.
|
|
- Legacy OLLAMA_BASE_URL can still seed the base URL, but Ollama is no longer
|
|
a production fallback path.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Optional
|
|
|
|
import httpx
|
|
|
|
logger = logging.getLogger("velocity.nemoclaw")
|
|
|
|
NEMOCLAW_TIMEOUT = float(os.getenv("NEMOCLAW_TIMEOUT_S", "45.0"))
|
|
NEMOCLAW_TEMPERATURE = float(os.getenv("NEMOCLAW_TEMPERATURE", "0.2"))
|
|
|
|
SGLANG_BASE_URL = os.getenv(
|
|
"SGLANG_BASE_URL",
|
|
os.getenv(
|
|
"NEMOCLAW_BASE_URL",
|
|
os.getenv("LLM_BASE_URL", os.getenv("OLLAMA_BASE_URL", "https://llm.desineuron.in")),
|
|
),
|
|
).rstrip("/")
|
|
SGLANG_CHAT_URL = os.getenv(
|
|
"SGLANG_CHAT_URL",
|
|
os.getenv("NEMOCLAW_CHAT_URL", f"{SGLANG_BASE_URL}/v1/chat/completions"),
|
|
)
|
|
SGLANG_MODELS_URL = os.getenv("SGLANG_MODELS_URL", f"{SGLANG_BASE_URL}/v1/models")
|
|
SGLANG_MODEL = os.getenv(
|
|
"SGLANG_MODEL",
|
|
os.getenv("NEMOCLAW_MODEL", os.getenv("OLLAMA_MODEL", "qwen3.6:35b-a3b")),
|
|
)
|
|
SGLANG_API_TOKEN = os.getenv("SGLANG_API_TOKEN", os.getenv("NEMOCLAW_API_TOKEN", ""))
|
|
|
|
_PROMPT_DIR = os.getenv("NEMOCLAW_PROMPT_DIR", "/opt/dlami/nvme/nemoclaw/prompts")
|
|
|
|
|
|
def _load_system_prompt(name: str) -> str:
|
|
local_fallback = os.path.join(
|
|
os.path.dirname(__file__), "..", "nemoclaw_prompts", f"{name}.md"
|
|
)
|
|
for path in (os.path.join(_PROMPT_DIR, f"{name}.md"), local_fallback):
|
|
try:
|
|
with open(path, encoding="utf-8") as handle:
|
|
return "\n".join(
|
|
line for line in handle.read().splitlines() if not line.startswith("#")
|
|
).strip()
|
|
except FileNotFoundError:
|
|
continue
|
|
logger.warning("Prompt '%s' not found, using inline fallback.", name)
|
|
return _PROMPTS.get(name, "")
|
|
|
|
|
|
_PROMPTS = {
|
|
"qd_calculator": (
|
|
"You are a behavioral intelligence analyst for a luxury real estate sales platform.\n"
|
|
"Compute a Quantum Dynamics score between 1 and 100 using blend shapes, CRM context, "
|
|
"and the active scene label when present.\n"
|
|
'Respond with JSON only: {"qd_score": <int>, "reasoning": "<one sentence>", "confidence": <float>}'
|
|
),
|
|
"lead_tagger": (
|
|
"You are a lead intelligence analyst. Classify a real estate lead as HNI or NRI.\n"
|
|
'Respond with JSON only: {"tags_to_add": [...], "tags_to_remove": []}'
|
|
),
|
|
"cctv_profiler": (
|
|
"You are a visitor profiling analyst for a luxury real estate development CCTV system.\n"
|
|
'Respond with JSON only: {"wealth_indicator": "HNI"|"standard"|"unknown", '
|
|
'"vehicle_class": "luxury"|"standard"|"unknown", "tags_to_add": [...], "notes": "<string>"}'
|
|
),
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class QDResult:
|
|
qd_score: int
|
|
reasoning: str
|
|
confidence: float
|
|
|
|
|
|
@dataclass
|
|
class TagResult:
|
|
tags_to_add: list[str] = field(default_factory=list)
|
|
tags_to_remove: list[str] = field(default_factory=list)
|
|
|
|
|
|
@dataclass
|
|
class CCTVProfileResult:
|
|
wealth_indicator: str
|
|
vehicle_class: str
|
|
tags_to_add: list[str] = field(default_factory=list)
|
|
notes: str = ""
|
|
|
|
|
|
async def _attempt_chat(
|
|
*,
|
|
label: str,
|
|
url: str,
|
|
model: str,
|
|
system_content: str,
|
|
user_content: str,
|
|
timeout: float,
|
|
headers: dict[str, str],
|
|
) -> dict:
|
|
payload = {
|
|
"model": model,
|
|
"messages": [
|
|
{"role": "system", "content": system_content},
|
|
{"role": "user", "content": user_content},
|
|
],
|
|
"temperature": NEMOCLAW_TEMPERATURE,
|
|
"response_format": {"type": "json_object"},
|
|
"max_tokens": 1024,
|
|
}
|
|
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
response = await client.post(url, json=payload, headers=headers)
|
|
response.raise_for_status()
|
|
body = response.json()
|
|
raw_content = body["choices"][0]["message"]["content"]
|
|
logger.debug("NemoClaw response via %s: %s", label, raw_content[:200])
|
|
return _parse_model_response(raw_content)
|
|
|
|
|
|
def _extract_text(raw_content: object) -> str:
|
|
if isinstance(raw_content, str):
|
|
return raw_content
|
|
if isinstance(raw_content, list):
|
|
parts: list[str] = []
|
|
for item in raw_content:
|
|
if isinstance(item, dict):
|
|
text = item.get("text")
|
|
if isinstance(text, str):
|
|
parts.append(text)
|
|
return "\n".join(parts).strip()
|
|
return str(raw_content)
|
|
|
|
|
|
def _parse_model_response(raw_content: object) -> dict:
|
|
text = _extract_text(raw_content).strip()
|
|
if not text:
|
|
return {}
|
|
try:
|
|
return json.loads(text)
|
|
except json.JSONDecodeError:
|
|
start = text.find("{")
|
|
end = text.rfind("}")
|
|
if start != -1 and end != -1 and end > start:
|
|
candidate = text[start : end + 1]
|
|
try:
|
|
return json.loads(candidate)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
parsed: dict[str, object] = {}
|
|
int_match = re.search(r'"qd_score"\s*:\s*(\d+)', text)
|
|
if int_match:
|
|
parsed["qd_score"] = int(int_match.group(1))
|
|
conf_match = re.search(r'"confidence"\s*:\s*([0-9]*\.?[0-9]+)', text)
|
|
if conf_match:
|
|
parsed["confidence"] = float(conf_match.group(1))
|
|
reason_match = re.search(r'"reasoning"\s*:\s*"([^"]*)"', text)
|
|
if reason_match:
|
|
parsed["reasoning"] = reason_match.group(1)
|
|
wealth_match = re.search(r'"wealth_indicator"\s*:\s*"([^"]*)"', text)
|
|
if wealth_match:
|
|
parsed["wealth_indicator"] = wealth_match.group(1)
|
|
vehicle_match = re.search(r'"vehicle_class"\s*:\s*"([^"]*)"', text)
|
|
if vehicle_match:
|
|
parsed["vehicle_class"] = vehicle_match.group(1)
|
|
notes_match = re.search(r'"notes"\s*:\s*"([^"]*)"', text)
|
|
if notes_match:
|
|
parsed["notes"] = notes_match.group(1)
|
|
tags_match = re.search(r'"tags_to_add"\s*:\s*\[(.*?)\]', text, flags=re.S)
|
|
if tags_match:
|
|
parsed["tags_to_add"] = re.findall(r'"([^"]+)"', tags_match.group(1))
|
|
remove_tags_match = re.search(r'"tags_to_remove"\s*:\s*\[(.*?)\]', text, flags=re.S)
|
|
if remove_tags_match:
|
|
parsed["tags_to_remove"] = re.findall(r'"([^"]+)"', remove_tags_match.group(1))
|
|
if parsed:
|
|
logger.warning("Recovered partial NemoClaw JSON payload from malformed model output.")
|
|
return parsed
|
|
raise json.JSONDecodeError("Unable to parse model JSON", text, 0)
|
|
|
|
|
|
async def _nemoclaw_chat(
|
|
system_content: str,
|
|
user_content: str,
|
|
timeout: float = NEMOCLAW_TIMEOUT,
|
|
) -> dict:
|
|
if not SGLANG_CHAT_URL:
|
|
raise RuntimeError(
|
|
"No NemoClaw inference endpoint is configured. Set SGLANG_BASE_URL or NEMOCLAW_BASE_URL."
|
|
)
|
|
|
|
headers = {"Content-Type": "application/json"}
|
|
if SGLANG_API_TOKEN:
|
|
headers["Authorization"] = f"Bearer {SGLANG_API_TOKEN}"
|
|
|
|
t_start = time.monotonic()
|
|
try:
|
|
result = await _attempt_chat(
|
|
label="sglang",
|
|
url=SGLANG_CHAT_URL,
|
|
model=SGLANG_MODEL,
|
|
system_content=system_content,
|
|
user_content=user_content,
|
|
timeout=timeout,
|
|
headers=headers,
|
|
)
|
|
logger.info(
|
|
"NemoClaw inference via sglang model=%s elapsed=%.2fs",
|
|
SGLANG_MODEL,
|
|
time.monotonic() - t_start,
|
|
)
|
|
return result
|
|
except (httpx.ConnectError, httpx.TimeoutException) as exc:
|
|
raise RuntimeError(f"NemoClaw SGLang endpoint unreachable: {exc}") from exc
|
|
except httpx.HTTPStatusError as exc:
|
|
raise RuntimeError(
|
|
f"NemoClaw SGLang HTTP {exc.response.status_code}: {exc.response.text[:300]}"
|
|
) from exc
|
|
except (KeyError, IndexError, TypeError, json.JSONDecodeError) as exc:
|
|
raise RuntimeError(f"NemoClaw SGLang returned invalid JSON: {exc}") from exc
|
|
|
|
|
|
async def score_qd(
|
|
*,
|
|
lead_id: str,
|
|
batch_id: str,
|
|
blend_shapes: dict[str, float],
|
|
video_ts_ms: int,
|
|
scene_label: Optional[str] = None,
|
|
crm_context: dict,
|
|
current_qd_score: Optional[int] = None,
|
|
) -> QDResult:
|
|
system_prompt = _load_system_prompt("qd_calculator")
|
|
user_content = json.dumps(
|
|
{
|
|
"lead_id": lead_id,
|
|
"batch_id": batch_id,
|
|
"video_ts_ms": video_ts_ms,
|
|
"scene_label": scene_label,
|
|
"current_qd_score": current_qd_score,
|
|
"crm_context": crm_context,
|
|
"blend_shapes": blend_shapes,
|
|
},
|
|
indent=2,
|
|
)
|
|
data = await _nemoclaw_chat(system_prompt, user_content)
|
|
raw_score = int(data.get("qd_score", current_qd_score or 50))
|
|
return QDResult(
|
|
qd_score=max(1, min(100, raw_score)),
|
|
reasoning=str(data.get("reasoning", "")),
|
|
confidence=float(data.get("confidence", 0.7)),
|
|
)
|
|
|
|
|
|
async def tag_lead(
|
|
*,
|
|
lead_id: str,
|
|
phone: str,
|
|
budget: Optional[str],
|
|
message_text: str,
|
|
) -> TagResult:
|
|
system_prompt = _load_system_prompt("lead_tagger")
|
|
user_content = (
|
|
f"Lead ID: {lead_id}\n"
|
|
f"Phone: {phone}\n"
|
|
f"Budget indicator: {budget or 'unknown'}\n"
|
|
f"First message: {message_text}"
|
|
)
|
|
try:
|
|
data = await _nemoclaw_chat(system_prompt, user_content)
|
|
except Exception as exc:
|
|
logger.error("Lead tagging failed for %s: %s", lead_id, exc)
|
|
return TagResult()
|
|
return TagResult(
|
|
tags_to_add=data.get("tags_to_add", []),
|
|
tags_to_remove=data.get("tags_to_remove", []),
|
|
)
|
|
|
|
|
|
async def profile_cctv_visitor(
|
|
*,
|
|
license_plate: Optional[str],
|
|
zone: str,
|
|
face_description: Optional[str] = None,
|
|
vehicle_description: Optional[str] = None,
|
|
) -> CCTVProfileResult:
|
|
system_prompt = _load_system_prompt("cctv_profiler")
|
|
user_content = json.dumps(
|
|
{
|
|
"license_plate": license_plate,
|
|
"zone": zone,
|
|
"face_description": face_description,
|
|
"vehicle_description": vehicle_description,
|
|
},
|
|
indent=2,
|
|
)
|
|
try:
|
|
data = await _nemoclaw_chat(system_prompt, user_content, timeout=20.0)
|
|
except Exception as exc:
|
|
logger.error("CCTV profiling failed (zone=%s): %s", zone, exc)
|
|
return CCTVProfileResult(wealth_indicator="unknown", vehicle_class="unknown")
|
|
return CCTVProfileResult(
|
|
wealth_indicator=data.get("wealth_indicator", "unknown"),
|
|
vehicle_class=data.get("vehicle_class", "unknown"),
|
|
tags_to_add=data.get("tags_to_add", []),
|
|
notes=data.get("notes", ""),
|
|
)
|
|
|
|
|
|
async def health_check() -> dict:
|
|
headers = {"Content-Type": "application/json"}
|
|
if SGLANG_API_TOKEN:
|
|
headers["Authorization"] = f"Bearer {SGLANG_API_TOKEN}"
|
|
|
|
results: dict[str, str] = {
|
|
"model": SGLANG_MODEL,
|
|
"primary_url": SGLANG_CHAT_URL,
|
|
"models_url": SGLANG_MODELS_URL,
|
|
}
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=5.0) as client:
|
|
models_response = await client.get(SGLANG_MODELS_URL, headers=headers)
|
|
models_response.raise_for_status()
|
|
chat_response = await client.post(
|
|
SGLANG_CHAT_URL,
|
|
json={
|
|
"model": SGLANG_MODEL,
|
|
"messages": [{"role": "user", "content": "ping"}],
|
|
"max_tokens": 5,
|
|
},
|
|
headers=headers,
|
|
)
|
|
chat_response.raise_for_status()
|
|
results["sglang"] = "ok"
|
|
except Exception as exc:
|
|
results["sglang"] = f"error: {exc}"
|
|
|
|
return results
|