444 lines
16 KiB
Python
444 lines
16 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import uuid
|
|
from dataclasses import dataclass
|
|
from datetime import UTC, datetime
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
logger = logging.getLogger("velocity.runtime_llm")
|
|
|
|
SGLANG_BASE_URL = os.getenv(
|
|
"SGLANG_BASE_URL",
|
|
os.getenv("LLM_BASE_URL", os.getenv("OLLAMA_BASE_URL", "https://llm.desineuron.in")),
|
|
).rstrip("/")
|
|
SGLANG_CHAT_URL = os.getenv("SGLANG_CHAT_URL", f"{SGLANG_BASE_URL}/v1/chat/completions")
|
|
SGLANG_MODELS_URL = os.getenv("SGLANG_MODELS_URL", f"{SGLANG_BASE_URL}/v1/models")
|
|
SGLANG_DEFAULT_MODEL = os.getenv(
|
|
"SGLANG_MODEL",
|
|
os.getenv("OLLAMA_MODEL", "qwen3.6:35b-a3b"),
|
|
)
|
|
SGLANG_API_TOKEN = os.getenv("SGLANG_API_TOKEN", "")
|
|
|
|
RUNTIME_LLM_TIMEOUT_S = float(os.getenv("RUNTIME_LLM_TIMEOUT_S", "90.0"))
|
|
RUNTIME_LLM_CONCURRENCY = int(os.getenv("RUNTIME_LLM_BATCH_CONCURRENCY", "2"))
|
|
|
|
|
|
def _utc_now() -> datetime:
|
|
return datetime.now(UTC)
|
|
|
|
|
|
def _utc_iso() -> str:
|
|
return _utc_now().isoformat()
|
|
|
|
|
|
@dataclass
|
|
class RuntimeProvider:
|
|
provider_id: str
|
|
base_url: str
|
|
chat_url: str
|
|
default_model: str
|
|
auth_token: str | None = None
|
|
supports_batch: bool = True
|
|
|
|
@property
|
|
def headers(self) -> dict[str, str]:
|
|
headers = {"Content-Type": "application/json"}
|
|
if self.auth_token:
|
|
headers["Authorization"] = f"Bearer {self.auth_token}"
|
|
return headers
|
|
|
|
|
|
class RuntimeLLMService:
|
|
def __init__(self) -> None:
|
|
self._jobs: dict[str, dict[str, Any]] = {}
|
|
|
|
def _provider_catalog(self) -> list[RuntimeProvider]:
|
|
if not SGLANG_CHAT_URL:
|
|
return []
|
|
return [
|
|
RuntimeProvider(
|
|
provider_id="sglang",
|
|
base_url=SGLANG_BASE_URL,
|
|
chat_url=SGLANG_CHAT_URL,
|
|
default_model=SGLANG_DEFAULT_MODEL,
|
|
auth_token=SGLANG_API_TOKEN or None,
|
|
)
|
|
]
|
|
|
|
def get_provider(self, provider_id: str | None) -> RuntimeProvider:
|
|
providers = {provider.provider_id: provider for provider in self._provider_catalog()}
|
|
if provider_id in {"ollama", "nemoclaw"}:
|
|
provider_id = "sglang"
|
|
if provider_id:
|
|
provider = providers.get(provider_id)
|
|
if provider is None:
|
|
raise ValueError(f"Unknown provider '{provider_id}'.")
|
|
return provider
|
|
|
|
if "sglang" in providers:
|
|
return providers["sglang"]
|
|
raise ValueError("No runtime LLM providers are configured.")
|
|
|
|
async def list_providers(self) -> list[dict[str, Any]]:
|
|
providers: list[dict[str, Any]] = []
|
|
for provider in self._provider_catalog():
|
|
models: list[str] = [provider.default_model]
|
|
status = "offline"
|
|
error: str | None = None
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
response = await client.get(SGLANG_MODELS_URL, headers=provider.headers)
|
|
response.raise_for_status()
|
|
payload = response.json()
|
|
models = [
|
|
str(item.get("id", "")).strip()
|
|
for item in payload.get("data", [])
|
|
if item.get("id")
|
|
]
|
|
if provider.default_model not in models:
|
|
models.insert(0, provider.default_model)
|
|
status = "online"
|
|
except Exception as exc: # pragma: no cover - network/runtime dependent
|
|
error = str(exc)
|
|
|
|
providers.append(
|
|
{
|
|
"id": provider.provider_id,
|
|
"status": status,
|
|
"baseUrl": provider.base_url,
|
|
"defaultModel": provider.default_model,
|
|
"models": models,
|
|
"supportsBatch": provider.supports_batch,
|
|
"error": error,
|
|
}
|
|
)
|
|
return providers
|
|
|
|
async def chat(
|
|
self,
|
|
*,
|
|
provider_id: str | None,
|
|
model: str | None,
|
|
system_prompt: str | None,
|
|
messages: list[dict[str, str]],
|
|
temperature: float = 0.2,
|
|
response_format: str | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> dict[str, Any]:
|
|
provider = self.get_provider(provider_id)
|
|
selected_model = model or provider.default_model
|
|
prepared_messages = list(messages)
|
|
if system_prompt:
|
|
prepared_messages = [{"role": "system", "content": system_prompt}] + prepared_messages
|
|
|
|
payload: dict[str, Any] = {
|
|
"model": selected_model,
|
|
"messages": prepared_messages,
|
|
"temperature": temperature,
|
|
}
|
|
if response_format == "json":
|
|
payload["response_format"] = {"type": "json_object"}
|
|
|
|
async with httpx.AsyncClient(timeout=RUNTIME_LLM_TIMEOUT_S) as client:
|
|
response = await client.post(provider.chat_url, json=payload, headers=provider.headers)
|
|
response.raise_for_status()
|
|
body = response.json()
|
|
choice = (body.get("choices") or [{}])[0]
|
|
message = choice.get("message") or {}
|
|
content = message.get("content")
|
|
text = self._extract_text(content)
|
|
parsed_json: dict[str, Any] | None = None
|
|
if response_format == "json":
|
|
try:
|
|
parsed_json = json.loads(text) if text else {}
|
|
except json.JSONDecodeError:
|
|
parsed_json = None
|
|
|
|
return {
|
|
"provider": provider.provider_id,
|
|
"model": selected_model,
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": text,
|
|
"parsedJson": parsed_json,
|
|
},
|
|
"usage": body.get("usage"),
|
|
"metadata": metadata or {},
|
|
"completedAt": _utc_iso(),
|
|
}
|
|
|
|
async def submit_batch(
|
|
self,
|
|
*,
|
|
provider_id: str | None,
|
|
model: str | None,
|
|
job_type: str,
|
|
items: list[dict[str, Any]],
|
|
metadata: dict[str, Any] | None,
|
|
pool: Any | None = None,
|
|
actor_id: str | None = None,
|
|
) -> dict[str, Any]:
|
|
provider = self.get_provider(provider_id)
|
|
selected_model = model or provider.default_model
|
|
job_id = str(uuid.uuid4())
|
|
created_at = _utc_iso()
|
|
normalized_items = [
|
|
{
|
|
"request_id": str(item.get("request_id") or f"item_{idx+1}"),
|
|
"messages": item.get("messages") or [],
|
|
"system_prompt": item.get("system_prompt"),
|
|
"temperature": float(item.get("temperature", 0.2)),
|
|
"response_format": item.get("response_format"),
|
|
"metadata": item.get("metadata") or {},
|
|
}
|
|
for idx, item in enumerate(items)
|
|
]
|
|
|
|
job_record = {
|
|
"job_id": job_id,
|
|
"provider": provider.provider_id,
|
|
"model": selected_model,
|
|
"job_type": job_type,
|
|
"status": "queued",
|
|
"submitted_count": len(normalized_items),
|
|
"completed_count": 0,
|
|
"failed_count": 0,
|
|
"metadata": metadata or {},
|
|
"items": normalized_items,
|
|
"results": [],
|
|
"created_at": created_at,
|
|
"updated_at": created_at,
|
|
"started_at": None,
|
|
"completed_at": None,
|
|
"actor_id": actor_id,
|
|
}
|
|
self._jobs[job_id] = job_record
|
|
await self._persist_job(job_record, pool=pool)
|
|
asyncio.create_task(self._run_batch(job_id, pool=pool))
|
|
return {
|
|
"job_id": job_id,
|
|
"status": job_record["status"],
|
|
"provider": provider.provider_id,
|
|
"model": selected_model,
|
|
"submitted_count": len(normalized_items),
|
|
"created_at": created_at,
|
|
}
|
|
|
|
async def _run_batch(self, job_id: str, *, pool: Any | None = None) -> None:
|
|
job = self._jobs.get(job_id)
|
|
if not job:
|
|
return
|
|
|
|
job["status"] = "running"
|
|
job["started_at"] = _utc_iso()
|
|
job["updated_at"] = _utc_iso()
|
|
await self._persist_job(job, pool=pool)
|
|
|
|
semaphore = asyncio.Semaphore(RUNTIME_LLM_CONCURRENCY)
|
|
|
|
async def _execute_item(item: dict[str, Any]) -> dict[str, Any]:
|
|
async with semaphore:
|
|
try:
|
|
response = await self.chat(
|
|
provider_id=job["provider"],
|
|
model=job["model"],
|
|
system_prompt=item.get("system_prompt"),
|
|
messages=item.get("messages") or [],
|
|
temperature=float(item.get("temperature", 0.2)),
|
|
response_format=item.get("response_format"),
|
|
metadata=item.get("metadata") or {},
|
|
)
|
|
return {
|
|
"request_id": item["request_id"],
|
|
"status": "completed",
|
|
"response": response,
|
|
"error": None,
|
|
}
|
|
except Exception as exc: # pragma: no cover - network/runtime dependent
|
|
logger.error("runtime_llm batch item failed job=%s request=%s error=%s", job_id, item["request_id"], exc)
|
|
return {
|
|
"request_id": item["request_id"],
|
|
"status": "failed",
|
|
"response": None,
|
|
"error": str(exc),
|
|
}
|
|
|
|
results = await asyncio.gather(*[_execute_item(item) for item in job["items"]])
|
|
job["results"] = results
|
|
job["completed_count"] = sum(1 for result in results if result["status"] == "completed")
|
|
job["failed_count"] = sum(1 for result in results if result["status"] == "failed")
|
|
job["status"] = "completed" if job["failed_count"] == 0 else ("failed" if job["completed_count"] == 0 else "completed_with_errors")
|
|
job["completed_at"] = _utc_iso()
|
|
job["updated_at"] = _utc_iso()
|
|
await self._persist_job(job, pool=pool)
|
|
|
|
async def get_job(self, job_id: str, *, pool: Any | None = None) -> dict[str, Any] | None:
|
|
if job_id in self._jobs:
|
|
return self._jobs[job_id]
|
|
if pool is not None:
|
|
loaded = await self._load_job_from_db(job_id, pool=pool)
|
|
if loaded:
|
|
self._jobs[job_id] = loaded
|
|
return loaded
|
|
return None
|
|
|
|
async def list_job_results(self, job_id: str, *, pool: Any | None = None) -> list[dict[str, Any]] | None:
|
|
job = await self.get_job(job_id, pool=pool)
|
|
if not job:
|
|
return None
|
|
return list(job.get("results") or [])
|
|
|
|
async def _persist_job(self, job: dict[str, Any], *, pool: Any | None = None) -> None:
|
|
if pool is None:
|
|
return
|
|
async with pool.acquire() as conn:
|
|
await conn.execute(
|
|
"""
|
|
INSERT INTO workflow_agent_runs (
|
|
run_id,
|
|
agent_name,
|
|
trigger_type,
|
|
trigger_ref,
|
|
input_payload,
|
|
output_payload,
|
|
status,
|
|
duration_ms,
|
|
error_detail,
|
|
started_at,
|
|
completed_at
|
|
)
|
|
VALUES (
|
|
$1::uuid,
|
|
'runtime_llm',
|
|
$2,
|
|
$3,
|
|
$4::jsonb,
|
|
$5::jsonb,
|
|
$6,
|
|
$7,
|
|
$8,
|
|
$9::timestamptz,
|
|
$10::timestamptz
|
|
)
|
|
ON CONFLICT (run_id)
|
|
DO UPDATE SET
|
|
input_payload = EXCLUDED.input_payload,
|
|
output_payload = EXCLUDED.output_payload,
|
|
status = EXCLUDED.status,
|
|
duration_ms = EXCLUDED.duration_ms,
|
|
error_detail = EXCLUDED.error_detail,
|
|
started_at = EXCLUDED.started_at,
|
|
completed_at = EXCLUDED.completed_at
|
|
""",
|
|
job["job_id"],
|
|
job["job_type"],
|
|
job.get("actor_id"),
|
|
json.dumps(
|
|
{
|
|
"provider": job["provider"],
|
|
"model": job["model"],
|
|
"metadata": job.get("metadata") or {},
|
|
"items": job.get("items") or [],
|
|
}
|
|
),
|
|
json.dumps(
|
|
{
|
|
"results": job.get("results") or [],
|
|
"submitted_count": job.get("submitted_count", 0),
|
|
"completed_count": job.get("completed_count", 0),
|
|
"failed_count": job.get("failed_count", 0),
|
|
"created_at": job.get("created_at"),
|
|
"updated_at": job.get("updated_at"),
|
|
}
|
|
),
|
|
job["status"],
|
|
self._duration_ms(job.get("started_at"), job.get("completed_at")),
|
|
self._job_error_detail(job),
|
|
job.get("started_at"),
|
|
job.get("completed_at"),
|
|
)
|
|
|
|
async def _load_job_from_db(self, job_id: str, *, pool: Any) -> dict[str, Any] | None:
|
|
async with pool.acquire() as conn:
|
|
row = await conn.fetchrow(
|
|
"""
|
|
SELECT
|
|
run_id::text AS job_id,
|
|
trigger_type AS job_type,
|
|
trigger_ref AS actor_id,
|
|
input_payload,
|
|
output_payload,
|
|
status,
|
|
started_at,
|
|
completed_at
|
|
FROM workflow_agent_runs
|
|
WHERE run_id = $1::uuid AND agent_name = 'runtime_llm'
|
|
""",
|
|
job_id,
|
|
)
|
|
if not row:
|
|
return None
|
|
|
|
input_payload = dict(row["input_payload"] or {})
|
|
output_payload = dict(row["output_payload"] or {})
|
|
return {
|
|
"job_id": row["job_id"],
|
|
"provider": input_payload.get("provider"),
|
|
"model": input_payload.get("model"),
|
|
"job_type": row["job_type"],
|
|
"status": row["status"],
|
|
"submitted_count": int(output_payload.get("submitted_count", len(input_payload.get("items") or []))),
|
|
"completed_count": int(output_payload.get("completed_count", 0)),
|
|
"failed_count": int(output_payload.get("failed_count", 0)),
|
|
"metadata": input_payload.get("metadata") or {},
|
|
"items": input_payload.get("items") or [],
|
|
"results": output_payload.get("results") or [],
|
|
"created_at": output_payload.get("created_at") or (row["started_at"].isoformat() if row["started_at"] else None),
|
|
"updated_at": output_payload.get("updated_at") or (row["completed_at"].isoformat() if row["completed_at"] else None),
|
|
"started_at": row["started_at"].isoformat() if row["started_at"] else None,
|
|
"completed_at": row["completed_at"].isoformat() if row["completed_at"] else None,
|
|
"actor_id": row["actor_id"],
|
|
}
|
|
|
|
@staticmethod
|
|
def _extract_text(content: Any) -> str:
|
|
if isinstance(content, str):
|
|
return content
|
|
if isinstance(content, list):
|
|
parts: list[str] = []
|
|
for part in content:
|
|
if isinstance(part, dict):
|
|
text = part.get("text")
|
|
if isinstance(text, str):
|
|
parts.append(text)
|
|
return "\n".join(parts).strip()
|
|
return str(content or "")
|
|
|
|
@staticmethod
|
|
def _duration_ms(started_at: str | None, completed_at: str | None) -> int | None:
|
|
if not started_at or not completed_at:
|
|
return None
|
|
try:
|
|
start = datetime.fromisoformat(started_at.replace("Z", "+00:00"))
|
|
end = datetime.fromisoformat(completed_at.replace("Z", "+00:00"))
|
|
except ValueError:
|
|
return None
|
|
return max(0, int((end - start).total_seconds() * 1000))
|
|
|
|
@staticmethod
|
|
def _job_error_detail(job: dict[str, Any]) -> str | None:
|
|
failed = [result for result in job.get("results") or [] if result.get("status") == "failed"]
|
|
if not failed:
|
|
return None
|
|
return "; ".join(f"{item.get('request_id')}: {item.get('error')}" for item in failed[:5])
|
|
|
|
|
|
runtime_llm_service = RuntimeLLMService()
|