Files
Project_Velocity/backend/services/runtime_llm_service.py
2026-04-23 01:20:21 +05:30

444 lines
16 KiB
Python

from __future__ import annotations
import asyncio
import json
import logging
import os
import uuid
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any
import httpx
logger = logging.getLogger("velocity.runtime_llm")
SGLANG_BASE_URL = os.getenv(
"SGLANG_BASE_URL",
os.getenv("LLM_BASE_URL", os.getenv("OLLAMA_BASE_URL", "https://llm.desineuron.in")),
).rstrip("/")
SGLANG_CHAT_URL = os.getenv("SGLANG_CHAT_URL", f"{SGLANG_BASE_URL}/v1/chat/completions")
SGLANG_MODELS_URL = os.getenv("SGLANG_MODELS_URL", f"{SGLANG_BASE_URL}/v1/models")
SGLANG_DEFAULT_MODEL = os.getenv(
"SGLANG_MODEL",
os.getenv("OLLAMA_MODEL", "qwen3.6:35b-a3b"),
)
SGLANG_API_TOKEN = os.getenv("SGLANG_API_TOKEN", "")
RUNTIME_LLM_TIMEOUT_S = float(os.getenv("RUNTIME_LLM_TIMEOUT_S", "90.0"))
RUNTIME_LLM_CONCURRENCY = int(os.getenv("RUNTIME_LLM_BATCH_CONCURRENCY", "2"))
def _utc_now() -> datetime:
return datetime.now(UTC)
def _utc_iso() -> str:
return _utc_now().isoformat()
@dataclass
class RuntimeProvider:
provider_id: str
base_url: str
chat_url: str
default_model: str
auth_token: str | None = None
supports_batch: bool = True
@property
def headers(self) -> dict[str, str]:
headers = {"Content-Type": "application/json"}
if self.auth_token:
headers["Authorization"] = f"Bearer {self.auth_token}"
return headers
class RuntimeLLMService:
def __init__(self) -> None:
self._jobs: dict[str, dict[str, Any]] = {}
def _provider_catalog(self) -> list[RuntimeProvider]:
if not SGLANG_CHAT_URL:
return []
return [
RuntimeProvider(
provider_id="sglang",
base_url=SGLANG_BASE_URL,
chat_url=SGLANG_CHAT_URL,
default_model=SGLANG_DEFAULT_MODEL,
auth_token=SGLANG_API_TOKEN or None,
)
]
def get_provider(self, provider_id: str | None) -> RuntimeProvider:
providers = {provider.provider_id: provider for provider in self._provider_catalog()}
if provider_id in {"ollama", "nemoclaw"}:
provider_id = "sglang"
if provider_id:
provider = providers.get(provider_id)
if provider is None:
raise ValueError(f"Unknown provider '{provider_id}'.")
return provider
if "sglang" in providers:
return providers["sglang"]
raise ValueError("No runtime LLM providers are configured.")
async def list_providers(self) -> list[dict[str, Any]]:
providers: list[dict[str, Any]] = []
for provider in self._provider_catalog():
models: list[str] = [provider.default_model]
status = "offline"
error: str | None = None
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(SGLANG_MODELS_URL, headers=provider.headers)
response.raise_for_status()
payload = response.json()
models = [
str(item.get("id", "")).strip()
for item in payload.get("data", [])
if item.get("id")
]
if provider.default_model not in models:
models.insert(0, provider.default_model)
status = "online"
except Exception as exc: # pragma: no cover - network/runtime dependent
error = str(exc)
providers.append(
{
"id": provider.provider_id,
"status": status,
"baseUrl": provider.base_url,
"defaultModel": provider.default_model,
"models": models,
"supportsBatch": provider.supports_batch,
"error": error,
}
)
return providers
async def chat(
self,
*,
provider_id: str | None,
model: str | None,
system_prompt: str | None,
messages: list[dict[str, str]],
temperature: float = 0.2,
response_format: str | None = None,
metadata: dict[str, Any] | None = None,
) -> dict[str, Any]:
provider = self.get_provider(provider_id)
selected_model = model or provider.default_model
prepared_messages = list(messages)
if system_prompt:
prepared_messages = [{"role": "system", "content": system_prompt}] + prepared_messages
payload: dict[str, Any] = {
"model": selected_model,
"messages": prepared_messages,
"temperature": temperature,
}
if response_format == "json":
payload["response_format"] = {"type": "json_object"}
async with httpx.AsyncClient(timeout=RUNTIME_LLM_TIMEOUT_S) as client:
response = await client.post(provider.chat_url, json=payload, headers=provider.headers)
response.raise_for_status()
body = response.json()
choice = (body.get("choices") or [{}])[0]
message = choice.get("message") or {}
content = message.get("content")
text = self._extract_text(content)
parsed_json: dict[str, Any] | None = None
if response_format == "json":
try:
parsed_json = json.loads(text) if text else {}
except json.JSONDecodeError:
parsed_json = None
return {
"provider": provider.provider_id,
"model": selected_model,
"message": {
"role": "assistant",
"content": text,
"parsedJson": parsed_json,
},
"usage": body.get("usage"),
"metadata": metadata or {},
"completedAt": _utc_iso(),
}
async def submit_batch(
self,
*,
provider_id: str | None,
model: str | None,
job_type: str,
items: list[dict[str, Any]],
metadata: dict[str, Any] | None,
pool: Any | None = None,
actor_id: str | None = None,
) -> dict[str, Any]:
provider = self.get_provider(provider_id)
selected_model = model or provider.default_model
job_id = str(uuid.uuid4())
created_at = _utc_iso()
normalized_items = [
{
"request_id": str(item.get("request_id") or f"item_{idx+1}"),
"messages": item.get("messages") or [],
"system_prompt": item.get("system_prompt"),
"temperature": float(item.get("temperature", 0.2)),
"response_format": item.get("response_format"),
"metadata": item.get("metadata") or {},
}
for idx, item in enumerate(items)
]
job_record = {
"job_id": job_id,
"provider": provider.provider_id,
"model": selected_model,
"job_type": job_type,
"status": "queued",
"submitted_count": len(normalized_items),
"completed_count": 0,
"failed_count": 0,
"metadata": metadata or {},
"items": normalized_items,
"results": [],
"created_at": created_at,
"updated_at": created_at,
"started_at": None,
"completed_at": None,
"actor_id": actor_id,
}
self._jobs[job_id] = job_record
await self._persist_job(job_record, pool=pool)
asyncio.create_task(self._run_batch(job_id, pool=pool))
return {
"job_id": job_id,
"status": job_record["status"],
"provider": provider.provider_id,
"model": selected_model,
"submitted_count": len(normalized_items),
"created_at": created_at,
}
async def _run_batch(self, job_id: str, *, pool: Any | None = None) -> None:
job = self._jobs.get(job_id)
if not job:
return
job["status"] = "running"
job["started_at"] = _utc_iso()
job["updated_at"] = _utc_iso()
await self._persist_job(job, pool=pool)
semaphore = asyncio.Semaphore(RUNTIME_LLM_CONCURRENCY)
async def _execute_item(item: dict[str, Any]) -> dict[str, Any]:
async with semaphore:
try:
response = await self.chat(
provider_id=job["provider"],
model=job["model"],
system_prompt=item.get("system_prompt"),
messages=item.get("messages") or [],
temperature=float(item.get("temperature", 0.2)),
response_format=item.get("response_format"),
metadata=item.get("metadata") or {},
)
return {
"request_id": item["request_id"],
"status": "completed",
"response": response,
"error": None,
}
except Exception as exc: # pragma: no cover - network/runtime dependent
logger.error("runtime_llm batch item failed job=%s request=%s error=%s", job_id, item["request_id"], exc)
return {
"request_id": item["request_id"],
"status": "failed",
"response": None,
"error": str(exc),
}
results = await asyncio.gather(*[_execute_item(item) for item in job["items"]])
job["results"] = results
job["completed_count"] = sum(1 for result in results if result["status"] == "completed")
job["failed_count"] = sum(1 for result in results if result["status"] == "failed")
job["status"] = "completed" if job["failed_count"] == 0 else ("failed" if job["completed_count"] == 0 else "completed_with_errors")
job["completed_at"] = _utc_iso()
job["updated_at"] = _utc_iso()
await self._persist_job(job, pool=pool)
async def get_job(self, job_id: str, *, pool: Any | None = None) -> dict[str, Any] | None:
if job_id in self._jobs:
return self._jobs[job_id]
if pool is not None:
loaded = await self._load_job_from_db(job_id, pool=pool)
if loaded:
self._jobs[job_id] = loaded
return loaded
return None
async def list_job_results(self, job_id: str, *, pool: Any | None = None) -> list[dict[str, Any]] | None:
job = await self.get_job(job_id, pool=pool)
if not job:
return None
return list(job.get("results") or [])
async def _persist_job(self, job: dict[str, Any], *, pool: Any | None = None) -> None:
if pool is None:
return
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO workflow_agent_runs (
run_id,
agent_name,
trigger_type,
trigger_ref,
input_payload,
output_payload,
status,
duration_ms,
error_detail,
started_at,
completed_at
)
VALUES (
$1::uuid,
'runtime_llm',
$2,
$3,
$4::jsonb,
$5::jsonb,
$6,
$7,
$8,
$9::timestamptz,
$10::timestamptz
)
ON CONFLICT (run_id)
DO UPDATE SET
input_payload = EXCLUDED.input_payload,
output_payload = EXCLUDED.output_payload,
status = EXCLUDED.status,
duration_ms = EXCLUDED.duration_ms,
error_detail = EXCLUDED.error_detail,
started_at = EXCLUDED.started_at,
completed_at = EXCLUDED.completed_at
""",
job["job_id"],
job["job_type"],
job.get("actor_id"),
json.dumps(
{
"provider": job["provider"],
"model": job["model"],
"metadata": job.get("metadata") or {},
"items": job.get("items") or [],
}
),
json.dumps(
{
"results": job.get("results") or [],
"submitted_count": job.get("submitted_count", 0),
"completed_count": job.get("completed_count", 0),
"failed_count": job.get("failed_count", 0),
"created_at": job.get("created_at"),
"updated_at": job.get("updated_at"),
}
),
job["status"],
self._duration_ms(job.get("started_at"), job.get("completed_at")),
self._job_error_detail(job),
job.get("started_at"),
job.get("completed_at"),
)
async def _load_job_from_db(self, job_id: str, *, pool: Any) -> dict[str, Any] | None:
async with pool.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT
run_id::text AS job_id,
trigger_type AS job_type,
trigger_ref AS actor_id,
input_payload,
output_payload,
status,
started_at,
completed_at
FROM workflow_agent_runs
WHERE run_id = $1::uuid AND agent_name = 'runtime_llm'
""",
job_id,
)
if not row:
return None
input_payload = dict(row["input_payload"] or {})
output_payload = dict(row["output_payload"] or {})
return {
"job_id": row["job_id"],
"provider": input_payload.get("provider"),
"model": input_payload.get("model"),
"job_type": row["job_type"],
"status": row["status"],
"submitted_count": int(output_payload.get("submitted_count", len(input_payload.get("items") or []))),
"completed_count": int(output_payload.get("completed_count", 0)),
"failed_count": int(output_payload.get("failed_count", 0)),
"metadata": input_payload.get("metadata") or {},
"items": input_payload.get("items") or [],
"results": output_payload.get("results") or [],
"created_at": output_payload.get("created_at") or (row["started_at"].isoformat() if row["started_at"] else None),
"updated_at": output_payload.get("updated_at") or (row["completed_at"].isoformat() if row["completed_at"] else None),
"started_at": row["started_at"].isoformat() if row["started_at"] else None,
"completed_at": row["completed_at"].isoformat() if row["completed_at"] else None,
"actor_id": row["actor_id"],
}
@staticmethod
def _extract_text(content: Any) -> str:
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for part in content:
if isinstance(part, dict):
text = part.get("text")
if isinstance(text, str):
parts.append(text)
return "\n".join(parts).strip()
return str(content or "")
@staticmethod
def _duration_ms(started_at: str | None, completed_at: str | None) -> int | None:
if not started_at or not completed_at:
return None
try:
start = datetime.fromisoformat(started_at.replace("Z", "+00:00"))
end = datetime.fromisoformat(completed_at.replace("Z", "+00:00"))
except ValueError:
return None
return max(0, int((end - start).total_seconds() * 1000))
@staticmethod
def _job_error_detail(job: dict[str, Any]) -> str | None:
failed = [result for result in job.get("results") or [] if result.get("status") == "failed"]
if not failed:
return None
return "; ".join(f"{item.get('request_id')}: {item.get('error')}" for item in failed[:5])
runtime_llm_service = RuntimeLLMService()