#!/usr/bin/env python3 import asyncio, json, time, uuid, io, sys, os, logging from pathlib import Path from dataclasses import dataclass from typing import Optional, List import httpx import uvicorn from fastapi import FastAPI, UploadFile, File, HTTPException, Form, Request from fastapi.responses import JSONResponse, StreamingResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel ROOT_DIR = Path(__file__).resolve().parent for scripts_dir in (ROOT_DIR / "comfy_engine" / "scripts", ROOT_DIR / "scripts"): if scripts_dir.exists(): sys.path.insert(0, str(scripts_dir)) try: from gateway_auth import load_gateway_api_key, is_gateway_request_authorized except ImportError as exc: raise RuntimeError("Dream Weaver gateway_auth.py is required on PYTHONPATH") from exc try: from prompt_expander import expand_prompt, ROOM_CONTEXTS, ExpandedPrompt LLM_AVAILABLE = True except ImportError as exc: LLM_AVAILABLE = False logging.warning("prompt_expander unavailable; using deterministic fallback expansion: %s", exc) class ExpandedPrompt(BaseModel): style_name: str positive_prompt: str negative_prompt: str steps: int = 28 cfg: float = 7.0 denoise: float = 0.72 ROOM_CONTEXTS = {} def expand_prompt(keywords: List[str], room_type: str) -> ExpandedPrompt: pretty_room = room_type.replace("_", " ").strip() or "living room" pretty_keywords = ", ".join(keywords) if keywords else "modern, photorealistic" return ExpandedPrompt( style_name="Fallback Prompt Expansion", positive_prompt=( f"photorealistic premium {pretty_room} interior design, {pretty_keywords}, " "natural lighting, realistic materials, architect-grade composition" ), negative_prompt=( "worst quality, low quality, blurry, distorted perspective, " "people, watermark, text, duplicate objects" ), ) logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger("DreamWeaverGateway") COMFY = (os.environ.get("COMFYUI_URL") or os.environ.get("COMFY_URL") or "http://127.0.0.1:8188").rstrip("/") COMFY_URLS = [ item.strip().rstrip("/") for item in os.environ.get("COMFYUI_URLS", COMFY).split(",") if item.strip() ] COMFY_TLS_VERIFY = os.environ.get("COMFYUI_TLS_VERIFY", "true").strip().lower() not in {"0", "false", "no", "off"} GATEWAY_API_KEY = load_gateway_api_key() POLL_TIMEOUT_SECONDS = int(os.environ.get("DREAM_WEAVER_POLL_TIMEOUT_SECONDS", "1800")) POLL_INTERVAL_SECONDS = float(os.environ.get("DREAM_WEAVER_POLL_INTERVAL_SECONDS", "2")) INPUT_MEGAPIXELS = float(os.environ.get("DREAM_WEAVER_INPUT_MEGAPIXELS", "0.75")) MAX_RENDER_STEPS = int(os.environ.get("DREAM_WEAVER_MAX_STEPS", "18")) PREFERRED_CHECKPOINTS = [ "realvisxlV50_v50LightningBakedvae.safetensors", "realvisxlV50Lightning_v50Lightning.safetensors", ] app = FastAPI(title="Dream Weaver API v2", version="2.0.0") app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) jobs: dict = {} worker_locks: dict[str, asyncio.Lock] = {} assignment_lock = asyncio.Lock() ACTIVE_JOB_STATUSES = {"uploading", "processing"} @dataclass(frozen=True) class ComfyWorker: id: str url: str COMFY_WORKERS = [ ComfyWorker(id=f"comfy-{index}", url=url) for index, url in enumerate(COMFY_URLS) ] for worker in COMFY_WORKERS: worker_locks.setdefault(worker.id, asyncio.Lock()) def comfy_client(timeout: float = 30) -> httpx.AsyncClient: return httpx.AsyncClient(timeout=timeout, verify=COMFY_TLS_VERIFY, follow_redirects=True) def worker_by_id(worker_id: str) -> ComfyWorker: for worker in COMFY_WORKERS: if worker.id == worker_id: return worker raise HTTPException(status_code=500, detail=f"Dream Weaver worker {worker_id} is not configured") async def list_comfy_checkpoints(worker: Optional[ComfyWorker] = None) -> list[str]: worker = worker or COMFY_WORKERS[0] async with comfy_client(timeout=10) as client: response = await client.get(f"{worker.url}/models/checkpoints") response.raise_for_status() payload = response.json() if isinstance(payload, list): return [item for item in payload if isinstance(item, str)] return [] async def resolve_checkpoint(worker: ComfyWorker) -> str: checkpoints = await list_comfy_checkpoints(worker) if not checkpoints: raise HTTPException( status_code=503, detail=( "ComfyUI is online but has no checkpoint models installed. " "Hydrate RealVisXL into ComfyUI/models/checkpoints before generating." ), ) lower_lookup = {item.lower(): item for item in checkpoints} for preferred in PREFERRED_CHECKPOINTS: match = lower_lookup.get(preferred.lower()) if match: return match return checkpoints[0] def gateway_urls(job_id: str) -> dict: return { "poll_url": f"/dream-weaver/status/{job_id}", "result_url": f"/dream-weaver/result/{job_id}", } def ensure_gateway_auth(request: Request) -> None: if is_gateway_request_authorized(request.headers, GATEWAY_API_KEY): return raise HTTPException(status_code=401, detail="Dream Weaver gateway API key is required or invalid.") async def worker_queue_size(worker: ComfyWorker) -> int: async with comfy_client(timeout=5) as client: response = await client.get(f"{worker.url}/queue") response.raise_for_status() payload = response.json() running = payload.get("queue_running") if isinstance(payload, dict) else [] pending = payload.get("queue_pending") if isinstance(payload, dict) else [] return len(running or []) + len(pending or []) def local_worker_load(worker: ComfyWorker) -> int: return sum( 1 for job in jobs.values() if job.get("worker_id") == worker.id and job.get("status") in ACTIVE_JOB_STATUSES ) async def choose_worker() -> ComfyWorker: candidates: list[tuple[int, int, str, ComfyWorker]] = [] errors: list[str] = [] for worker in COMFY_WORKERS: try: checkpoints = await list_comfy_checkpoints(worker) if not checkpoints: errors.append(f"{worker.id} has no checkpoints") continue candidates.append((local_worker_load(worker), await worker_queue_size(worker), worker.id, worker)) except Exception as exc: errors.append(f"{worker.id} unhealthy: {exc}") if not candidates: detail = "; ".join(errors) if errors else "No ComfyUI workers are configured" raise HTTPException(status_code=503, detail=f"No healthy Dream Weaver workers. {detail}") candidates.sort(key=lambda item: (item[0], item[1], item[2])) return candidates[0][3] async def upload_to_comfy(worker: ComfyWorker, data: bytes, filename: str) -> str: async with comfy_client(timeout=30) as client: r = await client.post(f"{worker.url}/upload/image", files={"image": (filename, data, "image/jpeg")}, data={"overwrite": "true"}) r.raise_for_status() return r.json()["name"] def normalize_expanded_prompt(expanded: "ExpandedPrompt") -> "ExpandedPrompt": expanded.steps = max(6, min(int(expanded.steps or MAX_RENDER_STEPS), MAX_RENDER_STEPS)) expanded.cfg = max(3.0, min(float(expanded.cfg or 6.0), 7.0)) expanded.denoise = max(0.45, min(float(expanded.denoise or 0.65), 0.72)) return expanded def build_workflow(img_name: str, expanded: "ExpandedPrompt", ckpt_name: str) -> dict: expanded = normalize_expanded_prompt(expanded) return { "1": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": ckpt_name}}, "2": {"class_type": "LoadImage", "inputs": {"image": img_name, "upload": "image"}}, "9": {"class_type": "ImageScaleToTotalPixels", "inputs": {"image": ["2", 0], "upscale_method": "lanczos", "megapixels": INPUT_MEGAPIXELS, "resolution_steps": 8}}, "3": {"class_type": "CLIPTextEncode", "inputs": {"text": expanded.positive_prompt, "clip": ["1", 1]}}, "4": {"class_type": "CLIPTextEncode", "inputs": {"text": expanded.negative_prompt, "clip": ["1", 1]}}, "5": {"class_type": "VAEEncode", "inputs": {"pixels": ["9", 0], "vae": ["1", 2]}}, "6": {"class_type": "KSampler", "inputs": {"model": ["1", 0], "positive": ["3", 0], "negative": ["4", 0], "latent_image": ["5", 0], "seed": int(time.time()) % 999983, "steps": expanded.steps, "cfg": expanded.cfg, "sampler_name": "dpmpp_2m", "scheduler": "karras", "denoise": expanded.denoise}}, "7": {"class_type": "VAEDecode", "inputs": {"samples": ["6", 0], "vae": ["1", 2]}}, "8": {"class_type": "SaveImage", "inputs": {"images": ["7", 0], "filename_prefix": f"dw_{expanded.style_name.replace(' ', '_')[:30]}"}}, } async def queue_prompt(worker: ComfyWorker, workflow: dict) -> str: async with comfy_client(timeout=30) as client: r = await client.post(f"{worker.url}/prompt", json={"prompt": workflow, "client_id": str(uuid.uuid4())}) if r.status_code >= 400: detail = r.text try: detail = json.dumps(r.json()) except Exception: pass raise HTTPException(status_code=502, detail=f"ComfyUI rejected Dream Weaver workflow: {detail}") return r.json()["prompt_id"] def extract_comfy_error(history_entry: dict) -> Optional[str]: status = history_entry.get("status") if isinstance(history_entry, dict) else None if not isinstance(status, dict): return None if status.get("status_str") != "error": return None messages = status.get("messages") or [] for kind, payload in reversed(messages): if kind == "execution_error" and isinstance(payload, dict): node_type = payload.get("node_type") or payload.get("node_id") or "ComfyUI" message = payload.get("exception_message") or payload.get("exception_type") or "ComfyUI execution failed" return f"{node_type}: {message}" return "ComfyUI execution failed" async def poll_result(worker: ComfyWorker, prompt_id: str, timeout: int = POLL_TIMEOUT_SECONDS): start = time.time() async with comfy_client(timeout=10) as client: while time.time() - start < timeout: r = await client.get(f"{worker.url}/history/{prompt_id}") if r.status_code == 200: h = r.json().get(prompt_id, {}) err = extract_comfy_error(h) if err: return None, err imgs = [img for nd in h.get("outputs", {}).values() for img in nd.get("images", [])] if imgs: return imgs[0], None await asyncio.sleep(POLL_INTERVAL_SECONDS) return None, f"timeout after {timeout} seconds" async def background_poll(job_id: str, worker_id: str, prompt_id: str): worker = worker_by_id(worker_id) img, err = await poll_result(worker, prompt_id) if img: jobs[job_id].update({"status": "done", "output": img, "completed": time.time()}) else: jobs[job_id].update({"status": "error", "error": str(err), "completed": time.time()}) @app.get("/health") @app.get("/dream-weaver/health") async def health(): worker_health = [] for worker in COMFY_WORKERS: try: checkpoints = await list_comfy_checkpoints(worker) queue_size = await worker_queue_size(worker) worker_health.append({ "id": worker.id, "url": worker.url, "online": True, "checkpoint_ready": bool(checkpoints), "checkpoint_count": len(checkpoints), "queue_size": queue_size, "available_checkpoints": checkpoints[:12], }) except Exception as exc: worker_health.append({ "id": worker.id, "url": worker.url, "online": False, "checkpoint_ready": False, "checkpoint_count": 0, "queue_size": None, "error": str(exc), }) ready_workers = [worker for worker in worker_health if worker["online"] and worker["checkpoint_ready"]] checkpoints = ready_workers[0]["available_checkpoints"] if ready_workers else [] return { "status": "ok", "comfyui": bool(ready_workers), "comfyui_url": COMFY_WORKERS[0].url if COMFY_WORKERS else COMFY, "comfyui_urls": [worker.url for worker in COMFY_WORKERS], "checkpoint_ready": bool(ready_workers), "checkpoint_count": max((worker["checkpoint_count"] for worker in worker_health), default=0), "preferred_checkpoints": PREFERRED_CHECKPOINTS, "available_checkpoints": checkpoints[:12], "workers": worker_health, "ready_worker_count": len(ready_workers), "llm_expansion": LLM_AVAILABLE, "input_megapixels": INPUT_MEGAPIXELS, "max_render_steps": MAX_RENDER_STEPS, "poll_timeout_seconds": POLL_TIMEOUT_SECONDS, "version": "2.0.0", "auth_required": GATEWAY_API_KEY is not None, "auth_scheme": "x-dream-weaver-api-key" } @app.get("/dream-weaver/status/{job_id}") async def status(job_id: str, request: Request): ensure_gateway_auth(request) job = jobs.get(job_id) if not job: raise HTTPException(status_code=404, detail="Job not found") res = {k: v for k, v in job.items() if k != "output"} if "created" in job: res["elapsed_seconds"] = round(time.time() - float(job["created"]), 2) res["ready"] = job.get("status") == "done" if res["ready"]: res.update(gateway_urls(job_id)) return res @app.post("/dream-weaver") async def dream_weaver( request: Request, image: UploadFile = File(...), keywords: str = Form(default=""), room_type: str = Form(default="living_room") ): ensure_gateway_auth(request) job_id = str(uuid.uuid4()) jobs[job_id] = {"status": "uploading", "created": time.time()} data = await image.read() async with assignment_lock: worker = await choose_worker() jobs[job_id].update({"worker_id": worker.id, "worker_url": worker.url}) lock = worker_locks.setdefault(worker.id, asyncio.Lock()) async with lock: comfy_name = await upload_to_comfy(worker, data, f"dw_{job_id[:8]}.jpg") kw_list = [k.strip() for k in keywords.split(",") if k.strip()] expanded = await asyncio.to_thread(expand_prompt, keywords=kw_list, room_type=room_type) ckpt_name = await resolve_checkpoint(worker) jobs[job_id]["checkpoint"] = ckpt_name wf = build_workflow(comfy_name, expanded, ckpt_name) prompt_id = await queue_prompt(worker, wf) jobs[job_id].update({"status": "processing", "prompt_id": prompt_id}) asyncio.create_task(background_poll(job_id, worker.id, prompt_id)) return { "job_id": job_id, "status": "processing", **gateway_urls(job_id), } @app.get("/dream-weaver/result/{job_id}") async def result(job_id: str, request: Request): ensure_gateway_auth(request) job = jobs.get(job_id) if not job or job.get("status") != "done": raise HTTPException(status_code=404, detail="Result not ready") img = job.get("output") if not img: raise HTTPException(status_code=404, detail="Result not ready") worker = worker_by_id(job.get("worker_id", COMFY_WORKERS[0].id)) async with comfy_client(timeout=30) as client: response = await client.get( f"{worker.url}/view", params={ "filename": img["filename"], "subfolder": img.get("subfolder", ""), "type": img.get("type", "output"), }, ) response.raise_for_status() return StreamingResponse( io.BytesIO(response.content), media_type="image/png", headers={"Content-Disposition": f"attachment; filename=dreamweaver_{job_id[:8]}.png"}, ) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8082, log_level="info")