129 lines
4.8 KiB
Python
129 lines
4.8 KiB
Python
import logging
|
|
from pathlib import Path
|
|
from typing import Any, Dict
|
|
|
|
import httpx
|
|
|
|
from app.core.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ComfyClient:
|
|
def __init__(self, base_url: str | None = None):
|
|
self.base_url = (base_url or settings.COMFYUI_BASE_URL).rstrip("/")
|
|
self._client = httpx.AsyncClient(timeout=120.0)
|
|
|
|
async def close(self) -> None:
|
|
await self._client.aclose()
|
|
|
|
async def health_check(self) -> bool:
|
|
for endpoint in ("/system_stats", "/"):
|
|
try:
|
|
response = await self._client.get(f"{self.base_url}{endpoint}")
|
|
if response.status_code == 200:
|
|
return True
|
|
except Exception as exc:
|
|
logger.warning("ComfyUI health check failed at %s: %s", endpoint, exc)
|
|
return False
|
|
|
|
async def upload_image(self, file_path: str, filename: str) -> str:
|
|
with open(file_path, "rb") as handle:
|
|
files = {"image": (filename, handle, "application/octet-stream")}
|
|
response = await self._client.post(f"{self.base_url}/upload/image", files=files)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return data.get("name", filename)
|
|
|
|
async def upload_media(self, file_path: str, filename: str, media_type: str) -> str:
|
|
endpoint = {
|
|
"image": "/upload/image",
|
|
"pose_sheet": "/upload/image",
|
|
"video": "/upload/video",
|
|
"audio": "/upload/audio",
|
|
}.get(media_type)
|
|
field_name = {
|
|
"image": "image",
|
|
"pose_sheet": "image",
|
|
"video": "video",
|
|
"audio": "audio",
|
|
}.get(media_type)
|
|
|
|
if not endpoint or not field_name:
|
|
raise ValueError(f"Unsupported ComfyUI upload media type: {media_type}")
|
|
|
|
mime_type = "application/octet-stream"
|
|
suffix = Path(filename).suffix.lower()
|
|
if media_type in ("image", "pose_sheet"):
|
|
mime_type = {
|
|
".jpg": "image/jpeg",
|
|
".jpeg": "image/jpeg",
|
|
".png": "image/png",
|
|
".webp": "image/webp",
|
|
}.get(suffix, mime_type)
|
|
elif media_type == "video":
|
|
mime_type = {
|
|
".mp4": "video/mp4",
|
|
".webm": "video/webm",
|
|
".mov": "video/quicktime",
|
|
}.get(suffix, mime_type)
|
|
elif media_type == "audio":
|
|
mime_type = {
|
|
".mp3": "audio/mpeg",
|
|
".mp4": "audio/mp4",
|
|
".wav": "audio/wav",
|
|
".ogg": "audio/ogg",
|
|
}.get(suffix, mime_type)
|
|
|
|
with open(file_path, "rb") as handle:
|
|
files = {field_name: (filename, handle, mime_type)}
|
|
response = await self._client.post(f"{self.base_url}{endpoint}", files=files)
|
|
response.raise_for_status()
|
|
|
|
data = response.json()
|
|
return data.get("name", filename)
|
|
|
|
async def submit_prompt(self, workflow: Dict[str, Any], client_id: str | None = None) -> str:
|
|
payload: Dict[str, Any] = {"prompt": workflow}
|
|
if client_id:
|
|
payload["client_id"] = client_id
|
|
response = await self._client.post(f"{self.base_url}/prompt", json=payload)
|
|
if response.is_error:
|
|
detail = response.text
|
|
raise RuntimeError(f"ComfyUI prompt submission failed ({response.status_code}): {detail}")
|
|
data = response.json()
|
|
prompt_id = data.get("prompt_id")
|
|
if not prompt_id:
|
|
raise RuntimeError(f"No prompt_id returned by ComfyUI: {data}")
|
|
return prompt_id
|
|
|
|
async def get_history(self, prompt_id: str) -> Dict[str, Any]:
|
|
response = await self._client.get(f"{self.base_url}/history/{prompt_id}")
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return data.get(prompt_id, {})
|
|
|
|
async def get_history_all(self) -> Dict[str, Any]:
|
|
response = await self._client.get(f"{self.base_url}/history")
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
async def get_queue(self) -> Dict[str, Any]:
|
|
response = await self._client.get(f"{self.base_url}/queue")
|
|
response.raise_for_status()
|
|
return response.json()
|
|
|
|
async def get_object_info(self, node_name: str) -> Dict[str, Any]:
|
|
response = await self._client.get(f"{self.base_url}/object_info/{node_name}")
|
|
response.raise_for_status()
|
|
return response.json().get(node_name, {})
|
|
|
|
async def download_output(self, filename: str, subfolder: str = "", folder_type: str = "output") -> bytes:
|
|
params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
|
response = await self._client.get(f"{self.base_url}/view", params=params)
|
|
response.raise_for_status()
|
|
return response.content
|
|
|
|
|
|
comfy_client = ComfyClient()
|