forked from sagnik/Project_Velocity
Co-authored-by: Sagnik <sagnik7896@gmail.com> Reviewed-on: sagnik/Project_Velocity#31
141 lines
4.8 KiB
Python
141 lines
4.8 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
|
from pydantic import BaseModel, Field
|
|
|
|
from backend.auth.dependencies import UserPrincipal, get_current_user
|
|
from backend.services.runtime_llm_service import runtime_llm_service
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class ChatMessage(BaseModel):
|
|
role: str = Field(..., pattern="^(system|user|assistant)$")
|
|
content: str = Field(..., min_length=1)
|
|
|
|
|
|
class RuntimeChatRequest(BaseModel):
|
|
provider: str | None = None
|
|
model: str | None = None
|
|
system_prompt: str | None = None
|
|
messages: list[ChatMessage]
|
|
temperature: float = Field(default=0.2, ge=0.0, le=2.0)
|
|
response_format: str | None = Field(default=None, pattern="^(json|text)$")
|
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
class BatchItemRequest(BaseModel):
|
|
request_id: str
|
|
messages: list[ChatMessage]
|
|
system_prompt: str | None = None
|
|
temperature: float = Field(default=0.2, ge=0.0, le=2.0)
|
|
response_format: str | None = Field(default=None, pattern="^(json|text)$")
|
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
|
class RuntimeBatchRequest(BaseModel):
|
|
provider: str | None = None
|
|
model: str | None = None
|
|
job_type: str = Field(..., min_length=1, max_length=128)
|
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
items: list[BatchItemRequest] = Field(..., min_length=1, max_length=128)
|
|
|
|
|
|
def _normalize_user(user: UserPrincipal) -> dict[str, str]:
|
|
return {
|
|
"user_id": user.user_id,
|
|
"role": user.role,
|
|
}
|
|
|
|
|
|
@router.get("/providers", summary="List configured runtime LLM providers and models")
|
|
async def list_runtime_providers(_: UserPrincipal = Depends(get_current_user)) -> dict:
|
|
return {"status": "ok", "data": await runtime_llm_service.list_providers()}
|
|
|
|
|
|
@router.post("/chat", summary="Execute a single runtime LLM chat completion")
|
|
async def runtime_chat(
|
|
payload: RuntimeChatRequest,
|
|
user: UserPrincipal = Depends(get_current_user),
|
|
) -> dict:
|
|
response = await runtime_llm_service.chat(
|
|
provider_id=payload.provider,
|
|
model=payload.model,
|
|
system_prompt=payload.system_prompt,
|
|
messages=[message.model_dump() for message in payload.messages],
|
|
temperature=payload.temperature,
|
|
response_format=payload.response_format,
|
|
metadata={
|
|
**payload.metadata,
|
|
"requested_by": _normalize_user(user),
|
|
},
|
|
)
|
|
return {"status": "ok", "data": response}
|
|
|
|
|
|
@router.post("/batch", status_code=status.HTTP_202_ACCEPTED, summary="Submit a persisted runtime LLM batch job")
|
|
async def runtime_batch(
|
|
payload: RuntimeBatchRequest,
|
|
request: Request,
|
|
user: UserPrincipal = Depends(get_current_user),
|
|
) -> dict:
|
|
pool = getattr(request.app.state, "db_pool", None)
|
|
result = await runtime_llm_service.submit_batch(
|
|
provider_id=payload.provider,
|
|
model=payload.model,
|
|
job_type=payload.job_type,
|
|
items=[item.model_dump() for item in payload.items],
|
|
metadata={
|
|
**payload.metadata,
|
|
"requested_by": _normalize_user(user),
|
|
},
|
|
pool=pool,
|
|
actor_id=user.user_id,
|
|
)
|
|
return {"status": "ok", "data": result}
|
|
|
|
|
|
@router.get("/jobs/{job_id}", summary="Get runtime LLM batch job status")
|
|
async def get_runtime_job(
|
|
job_id: str,
|
|
request: Request,
|
|
_: UserPrincipal = Depends(get_current_user),
|
|
) -> dict:
|
|
pool = getattr(request.app.state, "db_pool", None)
|
|
job = await runtime_llm_service.get_job(job_id, pool=pool)
|
|
if not job:
|
|
raise HTTPException(status_code=404, detail=f"Runtime LLM job '{job_id}' not found.")
|
|
|
|
return {
|
|
"status": "ok",
|
|
"data": {
|
|
"job_id": job["job_id"],
|
|
"status": job["status"],
|
|
"provider": job["provider"],
|
|
"model": job["model"],
|
|
"job_type": job["job_type"],
|
|
"submitted_count": job["submitted_count"],
|
|
"completed_count": job["completed_count"],
|
|
"failed_count": job["failed_count"],
|
|
"created_at": job["created_at"],
|
|
"started_at": job["started_at"],
|
|
"completed_at": job["completed_at"],
|
|
"metadata": job.get("metadata") or {},
|
|
},
|
|
}
|
|
|
|
|
|
@router.get("/jobs/{job_id}/results", summary="Get runtime LLM batch job item results")
|
|
async def get_runtime_job_results(
|
|
job_id: str,
|
|
request: Request,
|
|
_: UserPrincipal = Depends(get_current_user),
|
|
) -> dict:
|
|
pool = getattr(request.app.state, "db_pool", None)
|
|
results = await runtime_llm_service.list_job_results(job_id, pool=pool)
|
|
if results is None:
|
|
raise HTTPException(status_code=404, detail=f"Runtime LLM job '{job_id}' not found.")
|
|
return {"status": "ok", "data": results, "meta": {"count": len(results)}}
|