from __future__ import annotations import asyncio import json import logging import os import uuid from dataclasses import dataclass from datetime import UTC, datetime from typing import Any import httpx logger = logging.getLogger("velocity.runtime_llm") SGLANG_BASE_URL = os.getenv( "SGLANG_BASE_URL", os.getenv("LLM_BASE_URL", os.getenv("OLLAMA_BASE_URL", "https://llm.desineuron.in")), ).rstrip("/") SGLANG_CHAT_URL = os.getenv("SGLANG_CHAT_URL", f"{SGLANG_BASE_URL}/v1/chat/completions") SGLANG_MODELS_URL = os.getenv("SGLANG_MODELS_URL", f"{SGLANG_BASE_URL}/v1/models") SGLANG_DEFAULT_MODEL = os.getenv( "SGLANG_MODEL", os.getenv("OLLAMA_MODEL", "qwen3.6:35b-a3b"), ) SGLANG_API_TOKEN = os.getenv("SGLANG_API_TOKEN", "") RUNTIME_LLM_TIMEOUT_S = float(os.getenv("RUNTIME_LLM_TIMEOUT_S", "90.0")) RUNTIME_LLM_CONCURRENCY = int(os.getenv("RUNTIME_LLM_BATCH_CONCURRENCY", "2")) def _utc_now() -> datetime: return datetime.now(UTC) def _utc_iso() -> str: return _utc_now().isoformat() @dataclass class RuntimeProvider: provider_id: str base_url: str chat_url: str default_model: str auth_token: str | None = None supports_batch: bool = True @property def headers(self) -> dict[str, str]: headers = {"Content-Type": "application/json"} if self.auth_token: headers["Authorization"] = f"Bearer {self.auth_token}" return headers class RuntimeLLMService: def __init__(self) -> None: self._jobs: dict[str, dict[str, Any]] = {} def _provider_catalog(self) -> list[RuntimeProvider]: if not SGLANG_CHAT_URL: return [] return [ RuntimeProvider( provider_id="sglang", base_url=SGLANG_BASE_URL, chat_url=SGLANG_CHAT_URL, default_model=SGLANG_DEFAULT_MODEL, auth_token=SGLANG_API_TOKEN or None, ) ] def get_provider(self, provider_id: str | None) -> RuntimeProvider: providers = {provider.provider_id: provider for provider in self._provider_catalog()} if provider_id in {"ollama", "nemoclaw"}: provider_id = "sglang" if provider_id: provider = providers.get(provider_id) if provider is None: raise ValueError(f"Unknown provider '{provider_id}'.") return provider if "sglang" in providers: return providers["sglang"] raise ValueError("No runtime LLM providers are configured.") async def list_providers(self) -> list[dict[str, Any]]: providers: list[dict[str, Any]] = [] for provider in self._provider_catalog(): models: list[str] = [provider.default_model] status = "offline" error: str | None = None try: async with httpx.AsyncClient(timeout=10.0) as client: response = await client.get(SGLANG_MODELS_URL, headers=provider.headers) response.raise_for_status() payload = response.json() models = [ str(item.get("id", "")).strip() for item in payload.get("data", []) if item.get("id") ] if provider.default_model not in models: models.insert(0, provider.default_model) status = "online" except Exception as exc: # pragma: no cover - network/runtime dependent error = str(exc) providers.append( { "id": provider.provider_id, "status": status, "baseUrl": provider.base_url, "defaultModel": provider.default_model, "models": models, "supportsBatch": provider.supports_batch, "error": error, } ) return providers async def chat( self, *, provider_id: str | None, model: str | None, system_prompt: str | None, messages: list[dict[str, str]], temperature: float = 0.2, response_format: str | None = None, metadata: dict[str, Any] | None = None, ) -> dict[str, Any]: provider = self.get_provider(provider_id) selected_model = model or provider.default_model prepared_messages = list(messages) if system_prompt: prepared_messages = [{"role": "system", "content": system_prompt}] + prepared_messages payload: dict[str, Any] = { "model": selected_model, "messages": prepared_messages, "temperature": temperature, } if response_format == "json": payload["response_format"] = {"type": "json_object"} async with httpx.AsyncClient(timeout=RUNTIME_LLM_TIMEOUT_S) as client: response = await client.post(provider.chat_url, json=payload, headers=provider.headers) response.raise_for_status() body = response.json() choice = (body.get("choices") or [{}])[0] message = choice.get("message") or {} content = message.get("content") text = self._extract_text(content) parsed_json: dict[str, Any] | None = None if response_format == "json": try: parsed_json = json.loads(text) if text else {} except json.JSONDecodeError: parsed_json = None return { "provider": provider.provider_id, "model": selected_model, "message": { "role": "assistant", "content": text, "parsedJson": parsed_json, }, "usage": body.get("usage"), "metadata": metadata or {}, "completedAt": _utc_iso(), } async def submit_batch( self, *, provider_id: str | None, model: str | None, job_type: str, items: list[dict[str, Any]], metadata: dict[str, Any] | None, pool: Any | None = None, actor_id: str | None = None, ) -> dict[str, Any]: provider = self.get_provider(provider_id) selected_model = model or provider.default_model job_id = str(uuid.uuid4()) created_at = _utc_iso() normalized_items = [ { "request_id": str(item.get("request_id") or f"item_{idx+1}"), "messages": item.get("messages") or [], "system_prompt": item.get("system_prompt"), "temperature": float(item.get("temperature", 0.2)), "response_format": item.get("response_format"), "metadata": item.get("metadata") or {}, } for idx, item in enumerate(items) ] job_record = { "job_id": job_id, "provider": provider.provider_id, "model": selected_model, "job_type": job_type, "status": "queued", "submitted_count": len(normalized_items), "completed_count": 0, "failed_count": 0, "metadata": metadata or {}, "items": normalized_items, "results": [], "created_at": created_at, "updated_at": created_at, "started_at": None, "completed_at": None, "actor_id": actor_id, } self._jobs[job_id] = job_record await self._persist_job(job_record, pool=pool) asyncio.create_task(self._run_batch(job_id, pool=pool)) return { "job_id": job_id, "status": job_record["status"], "provider": provider.provider_id, "model": selected_model, "submitted_count": len(normalized_items), "created_at": created_at, } async def _run_batch(self, job_id: str, *, pool: Any | None = None) -> None: job = self._jobs.get(job_id) if not job: return job["status"] = "running" job["started_at"] = _utc_iso() job["updated_at"] = _utc_iso() await self._persist_job(job, pool=pool) semaphore = asyncio.Semaphore(RUNTIME_LLM_CONCURRENCY) async def _execute_item(item: dict[str, Any]) -> dict[str, Any]: async with semaphore: try: response = await self.chat( provider_id=job["provider"], model=job["model"], system_prompt=item.get("system_prompt"), messages=item.get("messages") or [], temperature=float(item.get("temperature", 0.2)), response_format=item.get("response_format"), metadata=item.get("metadata") or {}, ) return { "request_id": item["request_id"], "status": "completed", "response": response, "error": None, } except Exception as exc: # pragma: no cover - network/runtime dependent logger.error("runtime_llm batch item failed job=%s request=%s error=%s", job_id, item["request_id"], exc) return { "request_id": item["request_id"], "status": "failed", "response": None, "error": str(exc), } results = await asyncio.gather(*[_execute_item(item) for item in job["items"]]) job["results"] = results job["completed_count"] = sum(1 for result in results if result["status"] == "completed") job["failed_count"] = sum(1 for result in results if result["status"] == "failed") job["status"] = "completed" if job["failed_count"] == 0 else ("failed" if job["completed_count"] == 0 else "completed_with_errors") job["completed_at"] = _utc_iso() job["updated_at"] = _utc_iso() await self._persist_job(job, pool=pool) async def get_job(self, job_id: str, *, pool: Any | None = None) -> dict[str, Any] | None: if job_id in self._jobs: return self._jobs[job_id] if pool is not None: loaded = await self._load_job_from_db(job_id, pool=pool) if loaded: self._jobs[job_id] = loaded return loaded return None async def list_job_results(self, job_id: str, *, pool: Any | None = None) -> list[dict[str, Any]] | None: job = await self.get_job(job_id, pool=pool) if not job: return None return list(job.get("results") or []) async def _persist_job(self, job: dict[str, Any], *, pool: Any | None = None) -> None: if pool is None: return async with pool.acquire() as conn: await conn.execute( """ INSERT INTO workflow_agent_runs ( run_id, agent_name, trigger_type, trigger_ref, input_payload, output_payload, status, duration_ms, error_detail, started_at, completed_at ) VALUES ( $1::uuid, 'runtime_llm', $2, $3, $4::jsonb, $5::jsonb, $6, $7, $8, $9::timestamptz, $10::timestamptz ) ON CONFLICT (run_id) DO UPDATE SET input_payload = EXCLUDED.input_payload, output_payload = EXCLUDED.output_payload, status = EXCLUDED.status, duration_ms = EXCLUDED.duration_ms, error_detail = EXCLUDED.error_detail, started_at = EXCLUDED.started_at, completed_at = EXCLUDED.completed_at """, job["job_id"], job["job_type"], job.get("actor_id"), json.dumps( { "provider": job["provider"], "model": job["model"], "metadata": job.get("metadata") or {}, "items": job.get("items") or [], } ), json.dumps( { "results": job.get("results") or [], "submitted_count": job.get("submitted_count", 0), "completed_count": job.get("completed_count", 0), "failed_count": job.get("failed_count", 0), "created_at": job.get("created_at"), "updated_at": job.get("updated_at"), } ), job["status"], self._duration_ms(job.get("started_at"), job.get("completed_at")), self._job_error_detail(job), job.get("started_at"), job.get("completed_at"), ) async def _load_job_from_db(self, job_id: str, *, pool: Any) -> dict[str, Any] | None: async with pool.acquire() as conn: row = await conn.fetchrow( """ SELECT run_id::text AS job_id, trigger_type AS job_type, trigger_ref AS actor_id, input_payload, output_payload, status, started_at, completed_at FROM workflow_agent_runs WHERE run_id = $1::uuid AND agent_name = 'runtime_llm' """, job_id, ) if not row: return None input_payload = dict(row["input_payload"] or {}) output_payload = dict(row["output_payload"] or {}) return { "job_id": row["job_id"], "provider": input_payload.get("provider"), "model": input_payload.get("model"), "job_type": row["job_type"], "status": row["status"], "submitted_count": int(output_payload.get("submitted_count", len(input_payload.get("items") or []))), "completed_count": int(output_payload.get("completed_count", 0)), "failed_count": int(output_payload.get("failed_count", 0)), "metadata": input_payload.get("metadata") or {}, "items": input_payload.get("items") or [], "results": output_payload.get("results") or [], "created_at": output_payload.get("created_at") or (row["started_at"].isoformat() if row["started_at"] else None), "updated_at": output_payload.get("updated_at") or (row["completed_at"].isoformat() if row["completed_at"] else None), "started_at": row["started_at"].isoformat() if row["started_at"] else None, "completed_at": row["completed_at"].isoformat() if row["completed_at"] else None, "actor_id": row["actor_id"], } @staticmethod def _extract_text(content: Any) -> str: if isinstance(content, str): return content if isinstance(content, list): parts: list[str] = [] for part in content: if isinstance(part, dict): text = part.get("text") if isinstance(text, str): parts.append(text) return "\n".join(parts).strip() return str(content or "") @staticmethod def _duration_ms(started_at: str | None, completed_at: str | None) -> int | None: if not started_at or not completed_at: return None try: start = datetime.fromisoformat(started_at.replace("Z", "+00:00")) end = datetime.fromisoformat(completed_at.replace("Z", "+00:00")) except ValueError: return None return max(0, int((end - start).total_seconds() * 1000)) @staticmethod def _job_error_detail(job: dict[str, Any]) -> str | None: failed = [result for result in job.get("results") or [] if result.get("status") == "failed"] if not failed: return None return "; ".join(f"{item.get('request_id')}: {item.get('error')}" for item in failed[:5]) runtime_llm_service = RuntimeLLMService()