Files
Velocity-OS/media-engine/gateway.py

374 lines
16 KiB
Python

#!/usr/bin/env python3
import asyncio, json, time, uuid, io, sys, os, logging
from pathlib import Path
from dataclasses import dataclass
from typing import Optional, List
import httpx
import uvicorn
from fastapi import FastAPI, UploadFile, File, HTTPException, Form, Request
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
ROOT_DIR = Path(__file__).resolve().parent
for scripts_dir in (ROOT_DIR / "comfy_engine" / "scripts", ROOT_DIR / "scripts"):
if scripts_dir.exists():
sys.path.insert(0, str(scripts_dir))
try:
from gateway_auth import load_gateway_api_key, is_gateway_request_authorized
except ImportError as exc:
raise RuntimeError("Dream Weaver gateway_auth.py is required on PYTHONPATH") from exc
try:
from prompt_expander import expand_prompt, ROOM_CONTEXTS, ExpandedPrompt
LLM_AVAILABLE = True
except ImportError as exc:
LLM_AVAILABLE = False
logging.warning("prompt_expander unavailable; using deterministic fallback expansion: %s", exc)
class ExpandedPrompt(BaseModel):
style_name: str
positive_prompt: str
negative_prompt: str
steps: int = 28
cfg: float = 7.0
denoise: float = 0.72
ROOM_CONTEXTS = {}
def expand_prompt(keywords: List[str], room_type: str) -> ExpandedPrompt:
pretty_room = room_type.replace("_", " ").strip() or "living room"
pretty_keywords = ", ".join(keywords) if keywords else "modern, photorealistic"
return ExpandedPrompt(
style_name="Fallback Prompt Expansion",
positive_prompt=(
f"photorealistic premium {pretty_room} interior design, {pretty_keywords}, "
"natural lighting, realistic materials, architect-grade composition"
),
negative_prompt=(
"worst quality, low quality, blurry, distorted perspective, "
"people, watermark, text, duplicate objects"
),
)
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("DreamWeaverGateway")
COMFY = (os.environ.get("COMFYUI_URL") or os.environ.get("COMFY_URL") or "http://127.0.0.1:8188").rstrip("/")
COMFY_URLS = [
item.strip().rstrip("/")
for item in os.environ.get("COMFYUI_URLS", COMFY).split(",")
if item.strip()
]
COMFY_TLS_VERIFY = os.environ.get("COMFYUI_TLS_VERIFY", "true").strip().lower() not in {"0", "false", "no", "off"}
GATEWAY_API_KEY = load_gateway_api_key()
POLL_TIMEOUT_SECONDS = int(os.environ.get("DREAM_WEAVER_POLL_TIMEOUT_SECONDS", "1800"))
POLL_INTERVAL_SECONDS = float(os.environ.get("DREAM_WEAVER_POLL_INTERVAL_SECONDS", "2"))
INPUT_MEGAPIXELS = float(os.environ.get("DREAM_WEAVER_INPUT_MEGAPIXELS", "0.75"))
MAX_RENDER_STEPS = int(os.environ.get("DREAM_WEAVER_MAX_STEPS", "18"))
PREFERRED_CHECKPOINTS = [
"realvisxlV50_v50LightningBakedvae.safetensors",
"realvisxlV50Lightning_v50Lightning.safetensors",
]
app = FastAPI(title="Dream Weaver API v2", version="2.0.0")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
jobs: dict = {}
worker_locks: dict[str, asyncio.Lock] = {}
assignment_lock = asyncio.Lock()
ACTIVE_JOB_STATUSES = {"uploading", "processing"}
@dataclass(frozen=True)
class ComfyWorker:
id: str
url: str
COMFY_WORKERS = [
ComfyWorker(id=f"comfy-{index}", url=url)
for index, url in enumerate(COMFY_URLS)
]
for worker in COMFY_WORKERS:
worker_locks.setdefault(worker.id, asyncio.Lock())
def comfy_client(timeout: float = 30) -> httpx.AsyncClient:
return httpx.AsyncClient(timeout=timeout, verify=COMFY_TLS_VERIFY, follow_redirects=True)
def worker_by_id(worker_id: str) -> ComfyWorker:
for worker in COMFY_WORKERS:
if worker.id == worker_id:
return worker
raise HTTPException(status_code=500, detail=f"Dream Weaver worker {worker_id} is not configured")
async def list_comfy_checkpoints(worker: Optional[ComfyWorker] = None) -> list[str]:
worker = worker or COMFY_WORKERS[0]
async with comfy_client(timeout=10) as client:
response = await client.get(f"{worker.url}/models/checkpoints")
response.raise_for_status()
payload = response.json()
if isinstance(payload, list):
return [item for item in payload if isinstance(item, str)]
return []
async def resolve_checkpoint(worker: ComfyWorker) -> str:
checkpoints = await list_comfy_checkpoints(worker)
if not checkpoints:
raise HTTPException(
status_code=503,
detail=(
"ComfyUI is online but has no checkpoint models installed. "
"Hydrate RealVisXL into ComfyUI/models/checkpoints before generating."
),
)
lower_lookup = {item.lower(): item for item in checkpoints}
for preferred in PREFERRED_CHECKPOINTS:
match = lower_lookup.get(preferred.lower())
if match:
return match
return checkpoints[0]
def gateway_urls(job_id: str) -> dict:
return {
"poll_url": f"/dream-weaver/status/{job_id}",
"result_url": f"/dream-weaver/result/{job_id}",
}
def ensure_gateway_auth(request: Request) -> None:
if is_gateway_request_authorized(request.headers, GATEWAY_API_KEY):
return
raise HTTPException(status_code=401, detail="Dream Weaver gateway API key is required or invalid.")
async def worker_queue_size(worker: ComfyWorker) -> int:
async with comfy_client(timeout=5) as client:
response = await client.get(f"{worker.url}/queue")
response.raise_for_status()
payload = response.json()
running = payload.get("queue_running") if isinstance(payload, dict) else []
pending = payload.get("queue_pending") if isinstance(payload, dict) else []
return len(running or []) + len(pending or [])
def local_worker_load(worker: ComfyWorker) -> int:
return sum(
1
for job in jobs.values()
if job.get("worker_id") == worker.id and job.get("status") in ACTIVE_JOB_STATUSES
)
async def choose_worker() -> ComfyWorker:
candidates: list[tuple[int, int, str, ComfyWorker]] = []
errors: list[str] = []
for worker in COMFY_WORKERS:
try:
checkpoints = await list_comfy_checkpoints(worker)
if not checkpoints:
errors.append(f"{worker.id} has no checkpoints")
continue
candidates.append((local_worker_load(worker), await worker_queue_size(worker), worker.id, worker))
except Exception as exc:
errors.append(f"{worker.id} unhealthy: {exc}")
if not candidates:
detail = "; ".join(errors) if errors else "No ComfyUI workers are configured"
raise HTTPException(status_code=503, detail=f"No healthy Dream Weaver workers. {detail}")
candidates.sort(key=lambda item: (item[0], item[1], item[2]))
return candidates[0][3]
async def upload_to_comfy(worker: ComfyWorker, data: bytes, filename: str) -> str:
async with comfy_client(timeout=30) as client:
r = await client.post(f"{worker.url}/upload/image", files={"image": (filename, data, "image/jpeg")}, data={"overwrite": "true"})
r.raise_for_status()
return r.json()["name"]
def normalize_expanded_prompt(expanded: "ExpandedPrompt") -> "ExpandedPrompt":
expanded.steps = max(6, min(int(expanded.steps or MAX_RENDER_STEPS), MAX_RENDER_STEPS))
expanded.cfg = max(3.0, min(float(expanded.cfg or 6.0), 7.0))
expanded.denoise = max(0.45, min(float(expanded.denoise or 0.65), 0.72))
return expanded
def build_workflow(img_name: str, expanded: "ExpandedPrompt", ckpt_name: str) -> dict:
expanded = normalize_expanded_prompt(expanded)
return {
"1": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": ckpt_name}},
"2": {"class_type": "LoadImage", "inputs": {"image": img_name, "upload": "image"}},
"9": {"class_type": "ImageScaleToTotalPixels", "inputs": {"image": ["2", 0], "upscale_method": "lanczos", "megapixels": INPUT_MEGAPIXELS, "resolution_steps": 8}},
"3": {"class_type": "CLIPTextEncode", "inputs": {"text": expanded.positive_prompt, "clip": ["1", 1]}},
"4": {"class_type": "CLIPTextEncode", "inputs": {"text": expanded.negative_prompt, "clip": ["1", 1]}},
"5": {"class_type": "VAEEncode", "inputs": {"pixels": ["9", 0], "vae": ["1", 2]}},
"6": {"class_type": "KSampler", "inputs": {"model": ["1", 0], "positive": ["3", 0], "negative": ["4", 0], "latent_image": ["5", 0], "seed": int(time.time()) % 999983, "steps": expanded.steps, "cfg": expanded.cfg, "sampler_name": "dpmpp_2m", "scheduler": "karras", "denoise": expanded.denoise}},
"7": {"class_type": "VAEDecode", "inputs": {"samples": ["6", 0], "vae": ["1", 2]}},
"8": {"class_type": "SaveImage", "inputs": {"images": ["7", 0], "filename_prefix": f"dw_{expanded.style_name.replace(' ', '_')[:30]}"}},
}
async def queue_prompt(worker: ComfyWorker, workflow: dict) -> str:
async with comfy_client(timeout=30) as client:
r = await client.post(f"{worker.url}/prompt", json={"prompt": workflow, "client_id": str(uuid.uuid4())})
if r.status_code >= 400:
detail = r.text
try:
detail = json.dumps(r.json())
except Exception:
pass
raise HTTPException(status_code=502, detail=f"ComfyUI rejected Dream Weaver workflow: {detail}")
return r.json()["prompt_id"]
def extract_comfy_error(history_entry: dict) -> Optional[str]:
status = history_entry.get("status") if isinstance(history_entry, dict) else None
if not isinstance(status, dict):
return None
if status.get("status_str") != "error":
return None
messages = status.get("messages") or []
for kind, payload in reversed(messages):
if kind == "execution_error" and isinstance(payload, dict):
node_type = payload.get("node_type") or payload.get("node_id") or "ComfyUI"
message = payload.get("exception_message") or payload.get("exception_type") or "ComfyUI execution failed"
return f"{node_type}: {message}"
return "ComfyUI execution failed"
async def poll_result(worker: ComfyWorker, prompt_id: str, timeout: int = POLL_TIMEOUT_SECONDS):
start = time.time()
async with comfy_client(timeout=10) as client:
while time.time() - start < timeout:
r = await client.get(f"{worker.url}/history/{prompt_id}")
if r.status_code == 200:
h = r.json().get(prompt_id, {})
err = extract_comfy_error(h)
if err:
return None, err
imgs = [img for nd in h.get("outputs", {}).values() for img in nd.get("images", [])]
if imgs: return imgs[0], None
await asyncio.sleep(POLL_INTERVAL_SECONDS)
return None, f"timeout after {timeout} seconds"
async def background_poll(job_id: str, worker_id: str, prompt_id: str):
worker = worker_by_id(worker_id)
img, err = await poll_result(worker, prompt_id)
if img:
jobs[job_id].update({"status": "done", "output": img, "completed": time.time()})
else:
jobs[job_id].update({"status": "error", "error": str(err), "completed": time.time()})
@app.get("/health")
@app.get("/dream-weaver/health")
async def health():
worker_health = []
for worker in COMFY_WORKERS:
try:
checkpoints = await list_comfy_checkpoints(worker)
queue_size = await worker_queue_size(worker)
worker_health.append({
"id": worker.id,
"url": worker.url,
"online": True,
"checkpoint_ready": bool(checkpoints),
"checkpoint_count": len(checkpoints),
"queue_size": queue_size,
"available_checkpoints": checkpoints[:12],
})
except Exception as exc:
worker_health.append({
"id": worker.id,
"url": worker.url,
"online": False,
"checkpoint_ready": False,
"checkpoint_count": 0,
"queue_size": None,
"error": str(exc),
})
ready_workers = [worker for worker in worker_health if worker["online"] and worker["checkpoint_ready"]]
checkpoints = ready_workers[0]["available_checkpoints"] if ready_workers else []
return {
"status": "ok",
"comfyui": bool(ready_workers),
"comfyui_url": COMFY_WORKERS[0].url if COMFY_WORKERS else COMFY,
"comfyui_urls": [worker.url for worker in COMFY_WORKERS],
"checkpoint_ready": bool(ready_workers),
"checkpoint_count": max((worker["checkpoint_count"] for worker in worker_health), default=0),
"preferred_checkpoints": PREFERRED_CHECKPOINTS,
"available_checkpoints": checkpoints[:12],
"workers": worker_health,
"ready_worker_count": len(ready_workers),
"llm_expansion": LLM_AVAILABLE,
"input_megapixels": INPUT_MEGAPIXELS,
"max_render_steps": MAX_RENDER_STEPS,
"poll_timeout_seconds": POLL_TIMEOUT_SECONDS,
"version": "2.0.0",
"auth_required": GATEWAY_API_KEY is not None,
"auth_scheme": "x-dream-weaver-api-key"
}
@app.get("/dream-weaver/status/{job_id}")
async def status(job_id: str, request: Request):
ensure_gateway_auth(request)
job = jobs.get(job_id)
if not job: raise HTTPException(status_code=404, detail="Job not found")
res = {k: v for k, v in job.items() if k != "output"}
if "created" in job:
res["elapsed_seconds"] = round(time.time() - float(job["created"]), 2)
res["ready"] = job.get("status") == "done"
if res["ready"]:
res.update(gateway_urls(job_id))
return res
@app.post("/dream-weaver")
async def dream_weaver(
request: Request,
image: UploadFile = File(...),
keywords: str = Form(default=""),
room_type: str = Form(default="living_room")
):
ensure_gateway_auth(request)
job_id = str(uuid.uuid4())
jobs[job_id] = {"status": "uploading", "created": time.time()}
data = await image.read()
async with assignment_lock:
worker = await choose_worker()
jobs[job_id].update({"worker_id": worker.id, "worker_url": worker.url})
lock = worker_locks.setdefault(worker.id, asyncio.Lock())
async with lock:
comfy_name = await upload_to_comfy(worker, data, f"dw_{job_id[:8]}.jpg")
kw_list = [k.strip() for k in keywords.split(",") if k.strip()]
expanded = await asyncio.to_thread(expand_prompt, keywords=kw_list, room_type=room_type)
ckpt_name = await resolve_checkpoint(worker)
jobs[job_id]["checkpoint"] = ckpt_name
wf = build_workflow(comfy_name, expanded, ckpt_name)
prompt_id = await queue_prompt(worker, wf)
jobs[job_id].update({"status": "processing", "prompt_id": prompt_id})
asyncio.create_task(background_poll(job_id, worker.id, prompt_id))
return {
"job_id": job_id,
"status": "processing",
**gateway_urls(job_id),
}
@app.get("/dream-weaver/result/{job_id}")
async def result(job_id: str, request: Request):
ensure_gateway_auth(request)
job = jobs.get(job_id)
if not job or job.get("status") != "done":
raise HTTPException(status_code=404, detail="Result not ready")
img = job.get("output")
if not img:
raise HTTPException(status_code=404, detail="Result not ready")
worker = worker_by_id(job.get("worker_id", COMFY_WORKERS[0].id))
async with comfy_client(timeout=30) as client:
response = await client.get(
f"{worker.url}/view",
params={
"filename": img["filename"],
"subfolder": img.get("subfolder", ""),
"type": img.get("type", "output"),
},
)
response.raise_for_status()
return StreamingResponse(
io.BytesIO(response.content),
media_type="image/png",
headers={"Content-Disposition": f"attachment; filename=dreamweaver_{job_id[:8]}.png"},
)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8082, log_level="info")