374 lines
16 KiB
Python
374 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
import asyncio, json, time, uuid, io, sys, os, logging
|
|
from pathlib import Path
|
|
from dataclasses import dataclass
|
|
from typing import Optional, List
|
|
import httpx
|
|
import uvicorn
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException, Form, Request
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
|
|
ROOT_DIR = Path(__file__).resolve().parent
|
|
for scripts_dir in (ROOT_DIR / "comfy_engine" / "scripts", ROOT_DIR / "scripts"):
|
|
if scripts_dir.exists():
|
|
sys.path.insert(0, str(scripts_dir))
|
|
|
|
try:
|
|
from gateway_auth import load_gateway_api_key, is_gateway_request_authorized
|
|
except ImportError as exc:
|
|
raise RuntimeError("Dream Weaver gateway_auth.py is required on PYTHONPATH") from exc
|
|
|
|
try:
|
|
from prompt_expander import expand_prompt, ROOM_CONTEXTS, ExpandedPrompt
|
|
LLM_AVAILABLE = True
|
|
except ImportError as exc:
|
|
LLM_AVAILABLE = False
|
|
logging.warning("prompt_expander unavailable; using deterministic fallback expansion: %s", exc)
|
|
|
|
class ExpandedPrompt(BaseModel):
|
|
style_name: str
|
|
positive_prompt: str
|
|
negative_prompt: str
|
|
steps: int = 28
|
|
cfg: float = 7.0
|
|
denoise: float = 0.72
|
|
|
|
ROOM_CONTEXTS = {}
|
|
|
|
def expand_prompt(keywords: List[str], room_type: str) -> ExpandedPrompt:
|
|
pretty_room = room_type.replace("_", " ").strip() or "living room"
|
|
pretty_keywords = ", ".join(keywords) if keywords else "modern, photorealistic"
|
|
return ExpandedPrompt(
|
|
style_name="Fallback Prompt Expansion",
|
|
positive_prompt=(
|
|
f"photorealistic premium {pretty_room} interior design, {pretty_keywords}, "
|
|
"natural lighting, realistic materials, architect-grade composition"
|
|
),
|
|
negative_prompt=(
|
|
"worst quality, low quality, blurry, distorted perspective, "
|
|
"people, watermark, text, duplicate objects"
|
|
),
|
|
)
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
|
logger = logging.getLogger("DreamWeaverGateway")
|
|
COMFY = (os.environ.get("COMFYUI_URL") or os.environ.get("COMFY_URL") or "http://127.0.0.1:8188").rstrip("/")
|
|
COMFY_URLS = [
|
|
item.strip().rstrip("/")
|
|
for item in os.environ.get("COMFYUI_URLS", COMFY).split(",")
|
|
if item.strip()
|
|
]
|
|
COMFY_TLS_VERIFY = os.environ.get("COMFYUI_TLS_VERIFY", "true").strip().lower() not in {"0", "false", "no", "off"}
|
|
GATEWAY_API_KEY = load_gateway_api_key()
|
|
POLL_TIMEOUT_SECONDS = int(os.environ.get("DREAM_WEAVER_POLL_TIMEOUT_SECONDS", "1800"))
|
|
POLL_INTERVAL_SECONDS = float(os.environ.get("DREAM_WEAVER_POLL_INTERVAL_SECONDS", "2"))
|
|
INPUT_MEGAPIXELS = float(os.environ.get("DREAM_WEAVER_INPUT_MEGAPIXELS", "0.75"))
|
|
MAX_RENDER_STEPS = int(os.environ.get("DREAM_WEAVER_MAX_STEPS", "18"))
|
|
PREFERRED_CHECKPOINTS = [
|
|
"realvisxlV50_v50LightningBakedvae.safetensors",
|
|
"realvisxlV50Lightning_v50Lightning.safetensors",
|
|
]
|
|
|
|
app = FastAPI(title="Dream Weaver API v2", version="2.0.0")
|
|
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
|
jobs: dict = {}
|
|
worker_locks: dict[str, asyncio.Lock] = {}
|
|
assignment_lock = asyncio.Lock()
|
|
ACTIVE_JOB_STATUSES = {"uploading", "processing"}
|
|
|
|
@dataclass(frozen=True)
|
|
class ComfyWorker:
|
|
id: str
|
|
url: str
|
|
|
|
COMFY_WORKERS = [
|
|
ComfyWorker(id=f"comfy-{index}", url=url)
|
|
for index, url in enumerate(COMFY_URLS)
|
|
]
|
|
for worker in COMFY_WORKERS:
|
|
worker_locks.setdefault(worker.id, asyncio.Lock())
|
|
|
|
def comfy_client(timeout: float = 30) -> httpx.AsyncClient:
|
|
return httpx.AsyncClient(timeout=timeout, verify=COMFY_TLS_VERIFY, follow_redirects=True)
|
|
|
|
def worker_by_id(worker_id: str) -> ComfyWorker:
|
|
for worker in COMFY_WORKERS:
|
|
if worker.id == worker_id:
|
|
return worker
|
|
raise HTTPException(status_code=500, detail=f"Dream Weaver worker {worker_id} is not configured")
|
|
|
|
async def list_comfy_checkpoints(worker: Optional[ComfyWorker] = None) -> list[str]:
|
|
worker = worker or COMFY_WORKERS[0]
|
|
async with comfy_client(timeout=10) as client:
|
|
response = await client.get(f"{worker.url}/models/checkpoints")
|
|
response.raise_for_status()
|
|
payload = response.json()
|
|
if isinstance(payload, list):
|
|
return [item for item in payload if isinstance(item, str)]
|
|
return []
|
|
|
|
async def resolve_checkpoint(worker: ComfyWorker) -> str:
|
|
checkpoints = await list_comfy_checkpoints(worker)
|
|
if not checkpoints:
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail=(
|
|
"ComfyUI is online but has no checkpoint models installed. "
|
|
"Hydrate RealVisXL into ComfyUI/models/checkpoints before generating."
|
|
),
|
|
)
|
|
lower_lookup = {item.lower(): item for item in checkpoints}
|
|
for preferred in PREFERRED_CHECKPOINTS:
|
|
match = lower_lookup.get(preferred.lower())
|
|
if match:
|
|
return match
|
|
return checkpoints[0]
|
|
|
|
def gateway_urls(job_id: str) -> dict:
|
|
return {
|
|
"poll_url": f"/dream-weaver/status/{job_id}",
|
|
"result_url": f"/dream-weaver/result/{job_id}",
|
|
}
|
|
|
|
def ensure_gateway_auth(request: Request) -> None:
|
|
if is_gateway_request_authorized(request.headers, GATEWAY_API_KEY):
|
|
return
|
|
raise HTTPException(status_code=401, detail="Dream Weaver gateway API key is required or invalid.")
|
|
|
|
async def worker_queue_size(worker: ComfyWorker) -> int:
|
|
async with comfy_client(timeout=5) as client:
|
|
response = await client.get(f"{worker.url}/queue")
|
|
response.raise_for_status()
|
|
payload = response.json()
|
|
running = payload.get("queue_running") if isinstance(payload, dict) else []
|
|
pending = payload.get("queue_pending") if isinstance(payload, dict) else []
|
|
return len(running or []) + len(pending or [])
|
|
|
|
def local_worker_load(worker: ComfyWorker) -> int:
|
|
return sum(
|
|
1
|
|
for job in jobs.values()
|
|
if job.get("worker_id") == worker.id and job.get("status") in ACTIVE_JOB_STATUSES
|
|
)
|
|
|
|
async def choose_worker() -> ComfyWorker:
|
|
candidates: list[tuple[int, int, str, ComfyWorker]] = []
|
|
errors: list[str] = []
|
|
for worker in COMFY_WORKERS:
|
|
try:
|
|
checkpoints = await list_comfy_checkpoints(worker)
|
|
if not checkpoints:
|
|
errors.append(f"{worker.id} has no checkpoints")
|
|
continue
|
|
candidates.append((local_worker_load(worker), await worker_queue_size(worker), worker.id, worker))
|
|
except Exception as exc:
|
|
errors.append(f"{worker.id} unhealthy: {exc}")
|
|
if not candidates:
|
|
detail = "; ".join(errors) if errors else "No ComfyUI workers are configured"
|
|
raise HTTPException(status_code=503, detail=f"No healthy Dream Weaver workers. {detail}")
|
|
candidates.sort(key=lambda item: (item[0], item[1], item[2]))
|
|
return candidates[0][3]
|
|
|
|
async def upload_to_comfy(worker: ComfyWorker, data: bytes, filename: str) -> str:
|
|
async with comfy_client(timeout=30) as client:
|
|
r = await client.post(f"{worker.url}/upload/image", files={"image": (filename, data, "image/jpeg")}, data={"overwrite": "true"})
|
|
r.raise_for_status()
|
|
return r.json()["name"]
|
|
|
|
def normalize_expanded_prompt(expanded: "ExpandedPrompt") -> "ExpandedPrompt":
|
|
expanded.steps = max(6, min(int(expanded.steps or MAX_RENDER_STEPS), MAX_RENDER_STEPS))
|
|
expanded.cfg = max(3.0, min(float(expanded.cfg or 6.0), 7.0))
|
|
expanded.denoise = max(0.45, min(float(expanded.denoise or 0.65), 0.72))
|
|
return expanded
|
|
|
|
def build_workflow(img_name: str, expanded: "ExpandedPrompt", ckpt_name: str) -> dict:
|
|
expanded = normalize_expanded_prompt(expanded)
|
|
return {
|
|
"1": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": ckpt_name}},
|
|
"2": {"class_type": "LoadImage", "inputs": {"image": img_name, "upload": "image"}},
|
|
"9": {"class_type": "ImageScaleToTotalPixels", "inputs": {"image": ["2", 0], "upscale_method": "lanczos", "megapixels": INPUT_MEGAPIXELS, "resolution_steps": 8}},
|
|
"3": {"class_type": "CLIPTextEncode", "inputs": {"text": expanded.positive_prompt, "clip": ["1", 1]}},
|
|
"4": {"class_type": "CLIPTextEncode", "inputs": {"text": expanded.negative_prompt, "clip": ["1", 1]}},
|
|
"5": {"class_type": "VAEEncode", "inputs": {"pixels": ["9", 0], "vae": ["1", 2]}},
|
|
"6": {"class_type": "KSampler", "inputs": {"model": ["1", 0], "positive": ["3", 0], "negative": ["4", 0], "latent_image": ["5", 0], "seed": int(time.time()) % 999983, "steps": expanded.steps, "cfg": expanded.cfg, "sampler_name": "dpmpp_2m", "scheduler": "karras", "denoise": expanded.denoise}},
|
|
"7": {"class_type": "VAEDecode", "inputs": {"samples": ["6", 0], "vae": ["1", 2]}},
|
|
"8": {"class_type": "SaveImage", "inputs": {"images": ["7", 0], "filename_prefix": f"dw_{expanded.style_name.replace(' ', '_')[:30]}"}},
|
|
}
|
|
|
|
async def queue_prompt(worker: ComfyWorker, workflow: dict) -> str:
|
|
async with comfy_client(timeout=30) as client:
|
|
r = await client.post(f"{worker.url}/prompt", json={"prompt": workflow, "client_id": str(uuid.uuid4())})
|
|
if r.status_code >= 400:
|
|
detail = r.text
|
|
try:
|
|
detail = json.dumps(r.json())
|
|
except Exception:
|
|
pass
|
|
raise HTTPException(status_code=502, detail=f"ComfyUI rejected Dream Weaver workflow: {detail}")
|
|
return r.json()["prompt_id"]
|
|
|
|
def extract_comfy_error(history_entry: dict) -> Optional[str]:
|
|
status = history_entry.get("status") if isinstance(history_entry, dict) else None
|
|
if not isinstance(status, dict):
|
|
return None
|
|
if status.get("status_str") != "error":
|
|
return None
|
|
messages = status.get("messages") or []
|
|
for kind, payload in reversed(messages):
|
|
if kind == "execution_error" and isinstance(payload, dict):
|
|
node_type = payload.get("node_type") or payload.get("node_id") or "ComfyUI"
|
|
message = payload.get("exception_message") or payload.get("exception_type") or "ComfyUI execution failed"
|
|
return f"{node_type}: {message}"
|
|
return "ComfyUI execution failed"
|
|
|
|
async def poll_result(worker: ComfyWorker, prompt_id: str, timeout: int = POLL_TIMEOUT_SECONDS):
|
|
start = time.time()
|
|
async with comfy_client(timeout=10) as client:
|
|
while time.time() - start < timeout:
|
|
r = await client.get(f"{worker.url}/history/{prompt_id}")
|
|
if r.status_code == 200:
|
|
h = r.json().get(prompt_id, {})
|
|
err = extract_comfy_error(h)
|
|
if err:
|
|
return None, err
|
|
imgs = [img for nd in h.get("outputs", {}).values() for img in nd.get("images", [])]
|
|
if imgs: return imgs[0], None
|
|
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
|
return None, f"timeout after {timeout} seconds"
|
|
|
|
async def background_poll(job_id: str, worker_id: str, prompt_id: str):
|
|
worker = worker_by_id(worker_id)
|
|
img, err = await poll_result(worker, prompt_id)
|
|
if img:
|
|
jobs[job_id].update({"status": "done", "output": img, "completed": time.time()})
|
|
else:
|
|
jobs[job_id].update({"status": "error", "error": str(err), "completed": time.time()})
|
|
|
|
@app.get("/health")
|
|
@app.get("/dream-weaver/health")
|
|
async def health():
|
|
worker_health = []
|
|
for worker in COMFY_WORKERS:
|
|
try:
|
|
checkpoints = await list_comfy_checkpoints(worker)
|
|
queue_size = await worker_queue_size(worker)
|
|
worker_health.append({
|
|
"id": worker.id,
|
|
"url": worker.url,
|
|
"online": True,
|
|
"checkpoint_ready": bool(checkpoints),
|
|
"checkpoint_count": len(checkpoints),
|
|
"queue_size": queue_size,
|
|
"available_checkpoints": checkpoints[:12],
|
|
})
|
|
except Exception as exc:
|
|
worker_health.append({
|
|
"id": worker.id,
|
|
"url": worker.url,
|
|
"online": False,
|
|
"checkpoint_ready": False,
|
|
"checkpoint_count": 0,
|
|
"queue_size": None,
|
|
"error": str(exc),
|
|
})
|
|
ready_workers = [worker for worker in worker_health if worker["online"] and worker["checkpoint_ready"]]
|
|
checkpoints = ready_workers[0]["available_checkpoints"] if ready_workers else []
|
|
return {
|
|
"status": "ok",
|
|
"comfyui": bool(ready_workers),
|
|
"comfyui_url": COMFY_WORKERS[0].url if COMFY_WORKERS else COMFY,
|
|
"comfyui_urls": [worker.url for worker in COMFY_WORKERS],
|
|
"checkpoint_ready": bool(ready_workers),
|
|
"checkpoint_count": max((worker["checkpoint_count"] for worker in worker_health), default=0),
|
|
"preferred_checkpoints": PREFERRED_CHECKPOINTS,
|
|
"available_checkpoints": checkpoints[:12],
|
|
"workers": worker_health,
|
|
"ready_worker_count": len(ready_workers),
|
|
"llm_expansion": LLM_AVAILABLE,
|
|
"input_megapixels": INPUT_MEGAPIXELS,
|
|
"max_render_steps": MAX_RENDER_STEPS,
|
|
"poll_timeout_seconds": POLL_TIMEOUT_SECONDS,
|
|
"version": "2.0.0",
|
|
"auth_required": GATEWAY_API_KEY is not None,
|
|
"auth_scheme": "x-dream-weaver-api-key"
|
|
}
|
|
|
|
@app.get("/dream-weaver/status/{job_id}")
|
|
async def status(job_id: str, request: Request):
|
|
ensure_gateway_auth(request)
|
|
job = jobs.get(job_id)
|
|
if not job: raise HTTPException(status_code=404, detail="Job not found")
|
|
res = {k: v for k, v in job.items() if k != "output"}
|
|
if "created" in job:
|
|
res["elapsed_seconds"] = round(time.time() - float(job["created"]), 2)
|
|
res["ready"] = job.get("status") == "done"
|
|
if res["ready"]:
|
|
res.update(gateway_urls(job_id))
|
|
return res
|
|
|
|
@app.post("/dream-weaver")
|
|
async def dream_weaver(
|
|
request: Request,
|
|
image: UploadFile = File(...),
|
|
keywords: str = Form(default=""),
|
|
room_type: str = Form(default="living_room")
|
|
):
|
|
ensure_gateway_auth(request)
|
|
job_id = str(uuid.uuid4())
|
|
jobs[job_id] = {"status": "uploading", "created": time.time()}
|
|
data = await image.read()
|
|
async with assignment_lock:
|
|
worker = await choose_worker()
|
|
jobs[job_id].update({"worker_id": worker.id, "worker_url": worker.url})
|
|
lock = worker_locks.setdefault(worker.id, asyncio.Lock())
|
|
async with lock:
|
|
comfy_name = await upload_to_comfy(worker, data, f"dw_{job_id[:8]}.jpg")
|
|
kw_list = [k.strip() for k in keywords.split(",") if k.strip()]
|
|
expanded = await asyncio.to_thread(expand_prompt, keywords=kw_list, room_type=room_type)
|
|
ckpt_name = await resolve_checkpoint(worker)
|
|
jobs[job_id]["checkpoint"] = ckpt_name
|
|
wf = build_workflow(comfy_name, expanded, ckpt_name)
|
|
prompt_id = await queue_prompt(worker, wf)
|
|
jobs[job_id].update({"status": "processing", "prompt_id": prompt_id})
|
|
asyncio.create_task(background_poll(job_id, worker.id, prompt_id))
|
|
return {
|
|
"job_id": job_id,
|
|
"status": "processing",
|
|
**gateway_urls(job_id),
|
|
}
|
|
|
|
@app.get("/dream-weaver/result/{job_id}")
|
|
async def result(job_id: str, request: Request):
|
|
ensure_gateway_auth(request)
|
|
job = jobs.get(job_id)
|
|
if not job or job.get("status") != "done":
|
|
raise HTTPException(status_code=404, detail="Result not ready")
|
|
|
|
img = job.get("output")
|
|
if not img:
|
|
raise HTTPException(status_code=404, detail="Result not ready")
|
|
|
|
worker = worker_by_id(job.get("worker_id", COMFY_WORKERS[0].id))
|
|
async with comfy_client(timeout=30) as client:
|
|
response = await client.get(
|
|
f"{worker.url}/view",
|
|
params={
|
|
"filename": img["filename"],
|
|
"subfolder": img.get("subfolder", ""),
|
|
"type": img.get("type", "output"),
|
|
},
|
|
)
|
|
response.raise_for_status()
|
|
|
|
return StreamingResponse(
|
|
io.BytesIO(response.content),
|
|
media_type="image/png",
|
|
headers={"Content-Disposition": f"attachment; filename=dreamweaver_{job_id[:8]}.png"},
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run(app, host="0.0.0.0", port=8082, log_level="info")
|