Files
Project_Velocity/dw_gateway_v2_min.py

243 lines
9.9 KiB
Python

#!/usr/bin/env python3
import asyncio, json, time, uuid, io, sys, os, logging
from pathlib import Path
from typing import Optional, List
import httpx
import uvicorn
from fastapi import FastAPI, UploadFile, File, HTTPException, Form, Request
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
ROOT_DIR = Path(__file__).resolve().parent
SCRIPTS_DIR = ROOT_DIR / "scripts"
if not SCRIPTS_DIR.exists():
SCRIPTS_DIR = ROOT_DIR / "comfy_engine" / "scripts"
sys.path.insert(0, str(SCRIPTS_DIR))
try:
from prompt_expander import expand_prompt, expand_prompt_simple, ROOM_CONTEXTS, ExpandedPrompt
from gateway_auth import load_gateway_api_key, is_gateway_request_authorized
LLM_AVAILABLE = True
except ImportError:
LLM_AVAILABLE = False
logging.warning("prompt_expander not found — LLM expansion disabled")
class ExpandedPrompt(BaseModel):
style_name: str
positive_prompt: str
negative_prompt: str
steps: int = 28
cfg: float = 7.0
denoise: float = 0.72
ROOM_CONTEXTS = {}
def expand_prompt(keywords: List[str], room_type: str) -> ExpandedPrompt:
pretty_room = room_type.replace("_", " ").strip() or "living room"
pretty_keywords = ", ".join(keywords) if keywords else "modern, photorealistic"
return ExpandedPrompt(
style_name="Fallback Prompt Expansion",
positive_prompt=(
f"photorealistic premium {pretty_room} interior design, {pretty_keywords}, "
"natural lighting, realistic materials, architect-grade composition"
),
negative_prompt=(
"worst quality, low quality, blurry, distorted perspective, "
"people, watermark, text, duplicate objects"
),
)
expand_prompt_simple = expand_prompt
from gateway_auth import load_gateway_api_key, is_gateway_request_authorized
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger("DreamWeaverGateway")
COMFY = (os.environ.get("COMFYUI_URL") or os.environ.get("COMFY_URL") or "http://127.0.0.1:8188").rstrip("/")
COMFY_TLS_VERIFY = os.environ.get("COMFYUI_TLS_VERIFY", "true").strip().lower() not in {"0", "false", "no", "off"}
GATEWAY_API_KEY = load_gateway_api_key()
PREFERRED_CHECKPOINTS = [
"realvisxlV50_v50LightningBakedvae.safetensors",
"realvisxlV50Lightning_v50Lightning.safetensors",
]
app = FastAPI(title="Dream Weaver API v2", version="2.0.0")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
jobs: dict = {}
def comfy_client(timeout: float = 30) -> httpx.AsyncClient:
return httpx.AsyncClient(timeout=timeout, verify=COMFY_TLS_VERIFY, follow_redirects=True)
async def list_comfy_checkpoints() -> list[str]:
async with comfy_client(timeout=10) as client:
response = await client.get(f"{COMFY}/models/checkpoints")
response.raise_for_status()
payload = response.json()
if isinstance(payload, list):
return [item for item in payload if isinstance(item, str)]
return []
async def resolve_checkpoint() -> str:
checkpoints = await list_comfy_checkpoints()
if not checkpoints:
raise HTTPException(
status_code=503,
detail=(
"ComfyUI is online but has no checkpoint models installed. "
"Hydrate RealVisXL into ComfyUI/models/checkpoints before generating."
),
)
lower_lookup = {item.lower(): item for item in checkpoints}
for preferred in PREFERRED_CHECKPOINTS:
match = lower_lookup.get(preferred.lower())
if match:
return match
return checkpoints[0]
def gateway_urls(job_id: str) -> dict:
return {
"poll_url": f"/dream-weaver/status/{job_id}",
"result_url": f"/dream-weaver/result/{job_id}",
}
def ensure_gateway_auth(request: Request) -> None:
if is_gateway_request_authorized(request.headers, GATEWAY_API_KEY):
return
raise HTTPException(status_code=401, detail="Dream Weaver gateway API key is required or invalid.")
async def upload_to_comfy(data: bytes, filename: str) -> str:
async with comfy_client(timeout=30) as client:
r = await client.post(f"{COMFY}/upload/image", files={"image": (filename, data, "image/jpeg")}, data={"overwrite": "true"})
r.raise_for_status()
return r.json()["name"]
def build_workflow(img_name: str, expanded: "ExpandedPrompt", ckpt_name: str) -> dict:
return {
"1": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": ckpt_name}},
"2": {"class_type": "LoadImage", "inputs": {"image": img_name, "upload": "image"}},
"3": {"class_type": "CLIPTextEncode", "inputs": {"text": expanded.positive_prompt, "clip": ["1", 1]}},
"4": {"class_type": "CLIPTextEncode", "inputs": {"text": expanded.negative_prompt, "clip": ["1", 1]}},
"5": {"class_type": "VAEEncode", "inputs": {"pixels": ["2", 0], "vae": ["1", 2]}},
"6": {"class_type": "KSampler", "inputs": {"model": ["1", 0], "positive": ["3", 0], "negative": ["4", 0], "latent_image": ["5", 0], "seed": int(time.time()) % 999983, "steps": expanded.steps, "cfg": expanded.cfg, "sampler_name": "dpmpp_2m", "scheduler": "karras", "denoise": expanded.denoise}},
"7": {"class_type": "VAEDecode", "inputs": {"samples": ["6", 0], "vae": ["1", 2]}},
"8": {"class_type": "SaveImage", "inputs": {"images": ["7", 0], "filename_prefix": f"dw_{expanded.style_name.replace(' ', '_')[:30]}"}},
}
async def queue_prompt(workflow: dict) -> str:
async with comfy_client(timeout=30) as client:
r = await client.post(f"{COMFY}/prompt", json={"prompt": workflow, "client_id": str(uuid.uuid4())})
if r.status_code >= 400:
detail = r.text
try:
detail = json.dumps(r.json())
except Exception:
pass
raise HTTPException(status_code=502, detail=f"ComfyUI rejected Dream Weaver workflow: {detail}")
return r.json()["prompt_id"]
async def poll_result(prompt_id: str, timeout: int = 300):
start = time.time()
async with comfy_client(timeout=10) as client:
while time.time() - start < timeout:
r = await client.get(f"{COMFY}/history/{prompt_id}")
if r.status_code == 200:
h = r.json().get(prompt_id, {})
imgs = [img for nd in h.get("outputs", {}).values() for img in nd.get("images", [])]
if imgs: return imgs[0], None
await asyncio.sleep(2)
return None, "timeout"
async def background_poll(job_id: str, prompt_id: str):
img, err = await poll_result(prompt_id)
if img: jobs[job_id].update({"status": "done", "output": img, "completed": time.time()})
else: jobs[job_id].update({"status": "error", "error": str(err)})
@app.get("/health")
async def health():
checkpoints: list[str] = []
try:
checkpoints = await list_comfy_checkpoints()
except Exception:
checkpoints = []
return {
"status": "ok",
"comfyui": True,
"comfyui_url": COMFY,
"checkpoint_ready": bool(checkpoints),
"checkpoint_count": len(checkpoints),
"preferred_checkpoints": PREFERRED_CHECKPOINTS,
"available_checkpoints": checkpoints[:12],
"llm_expansion": LLM_AVAILABLE,
"version": "2.0.0",
"auth_required": GATEWAY_API_KEY is not None,
"auth_scheme": "x-dream-weaver-api-key"
}
@app.get("/dream-weaver/status/{job_id}")
async def status(job_id: str, request: Request):
ensure_gateway_auth(request)
job = jobs.get(job_id)
if not job: raise HTTPException(status_code=404, detail="Job not found")
res = {k: v for k, v in job.items() if k != "output"}
res["ready"] = job.get("status") == "done"
if res["ready"]:
res.update(gateway_urls(job_id))
return res
@app.post("/dream-weaver")
async def dream_weaver(
request: Request,
image: UploadFile = File(...),
keywords: str = Form(default=""),
room_type: str = Form(default="living_room")
):
ensure_gateway_auth(request)
job_id = str(uuid.uuid4())
jobs[job_id] = {"status": "uploading", "created": time.time()}
data = await image.read()
comfy_name = await upload_to_comfy(data, f"dw_{job_id[:8]}.jpg")
kw_list = [k.strip() for k in keywords.split(",") if k.strip()]
expanded = await asyncio.to_thread(expand_prompt, keywords=kw_list, room_type=room_type)
ckpt_name = await resolve_checkpoint()
jobs[job_id]["checkpoint"] = ckpt_name
wf = build_workflow(comfy_name, expanded, ckpt_name)
prompt_id = await queue_prompt(wf)
jobs[job_id].update({"status": "processing", "prompt_id": prompt_id})
asyncio.create_task(background_poll(job_id, prompt_id))
return {
"job_id": job_id,
"status": "processing",
**gateway_urls(job_id),
}
@app.get("/dream-weaver/result/{job_id}")
async def result(job_id: str, request: Request):
ensure_gateway_auth(request)
job = jobs.get(job_id)
if not job or job.get("status") != "done":
raise HTTPException(status_code=404, detail="Result not ready")
img = job.get("output")
if not img:
raise HTTPException(status_code=404, detail="Result not ready")
async with comfy_client(timeout=30) as client:
response = await client.get(
f"{COMFY}/view",
params={
"filename": img["filename"],
"subfolder": img.get("subfolder", ""),
"type": img.get("type", "output"),
},
)
response.raise_for_status()
return StreamingResponse(
io.BytesIO(response.content),
media_type="image/png",
headers={"Content-Disposition": f"attachment; filename=dreamweaver_{job_id[:8]}.png"},
)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8082, log_level="info")