99 lines
4.8 KiB
Python
99 lines
4.8 KiB
Python
#!/usr/bin/env python3
|
|
import asyncio, json, time, uuid, io, sys, os, logging
|
|
from pathlib import Path
|
|
from typing import Optional, List
|
|
import httpx
|
|
import uvicorn
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException, Form, BackgroundTasks
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
|
|
SCRIPTS_DIR = Path(__file__).parent / "scripts"
|
|
sys.path.insert(0, str(SCRIPTS_DIR))
|
|
|
|
try:
|
|
from prompt_expander import expand_prompt, expand_prompt_simple, ROOM_CONTEXTS, ExpandedPrompt
|
|
LLM_AVAILABLE = True
|
|
except ImportError:
|
|
LLM_AVAILABLE = False
|
|
logging.warning("prompt_expander not found — LLM expansion disabled")
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
|
logger = logging.getLogger("DreamWeaverGateway")
|
|
COMFY = "http://127.0.0.1:8188"
|
|
|
|
app = FastAPI(title="Dream Weaver API v2", version="2.0.0")
|
|
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
|
jobs: dict = {}
|
|
|
|
async def upload_to_comfy(data: bytes, filename: str) -> str:
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
r = await client.post(f"{COMFY}/upload/image", files={"image": (filename, data, "image/jpeg")}, data={"overwrite": "true"})
|
|
r.raise_for_status()
|
|
return r.json()["name"]
|
|
|
|
def build_workflow(img_name: str, expanded: "ExpandedPrompt") -> dict:
|
|
return {
|
|
"1": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "realvisxlV50_v50LightningBakedvae.safetensors"}},
|
|
"2": {"class_type": "LoadImage", "inputs": {"image": img_name, "upload": "image"}},
|
|
"3": {"class_type": "CLIPTextEncode", "inputs": {"text": expanded.positive_prompt, "clip": ["1", 1]}},
|
|
"4": {"class_type": "CLIPTextEncode", "inputs": {"text": expanded.negative_prompt, "clip": ["1", 1]}},
|
|
"5": {"class_type": "VAEEncode", "inputs": {"pixels": ["2", 0], "vae": ["1", 2]}},
|
|
"6": {"class_type": "KSampler", "inputs": {"model": ["1", 0], "positive": ["3", 0], "negative": ["4", 0], "latent_image": ["5", 0], "seed": int(time.time()) % 999983, "steps": expanded.steps, "cfg": expanded.cfg, "sampler_name": "dpmpp_2m", "scheduler": "karras", "denoise": expanded.denoise}},
|
|
"7": {"class_type": "VAEDecode", "inputs": {"samples": ["6", 0], "vae": ["1", 2]}},
|
|
"8": {"class_type": "SaveImage", "inputs": {"images": ["7", 0], "filename_prefix": f"dw_{expanded.style_name.replace(' ', '_')[:30]}"}},
|
|
}
|
|
|
|
async def queue_prompt(workflow: dict) -> str:
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
r = await client.post(f"{COMFY}/prompt", json={"prompt": workflow, "client_id": str(uuid.uuid4())})
|
|
r.raise_for_status()
|
|
return r.json()["prompt_id"]
|
|
|
|
async def poll_result(prompt_id: str, timeout: int = 300):
|
|
start = time.time()
|
|
async with httpx.AsyncClient(timeout=10) as client:
|
|
while time.time() - start < timeout:
|
|
r = await client.get(f"{COMFY}/history/{prompt_id}")
|
|
if r.status_code == 200:
|
|
h = r.json().get(prompt_id, {})
|
|
imgs = [img for nd in h.get("outputs", {}).values() for img in nd.get("images", [])]
|
|
if imgs: return imgs[0], None
|
|
await asyncio.sleep(2)
|
|
return None, "timeout"
|
|
|
|
async def background_poll(job_id: str, prompt_id: str):
|
|
img, err = await poll_result(prompt_id)
|
|
if img: jobs[job_id].update({"status": "done", "output": img, "completed": time.time()})
|
|
else: jobs[job_id].update({"status": "error", "error": str(err)})
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"status": "ok", "comfyui": True, "llm_expansion": LLM_AVAILABLE, "version": "2.0.0"}
|
|
|
|
@app.get("/dream-weaver/status/{job_id}")
|
|
async def status(job_id: str):
|
|
job = jobs.get(job_id)
|
|
if not job: raise HTTPException(status_code=404, detail="Job not found")
|
|
res = {k: v for k, v in job.items() if k != "output"}
|
|
res["ready"] = job.get("status") == "done"
|
|
return res
|
|
|
|
@app.post("/dream-weaver")
|
|
async def dream_weaver(image: UploadFile = File(...), keywords: str = Form(default=""), room_type: str = Form(default="living_room")):
|
|
job_id = str(uuid.uuid4())
|
|
jobs[job_id] = {"status": "uploading", "created": time.time()}
|
|
data = await image.read()
|
|
comfy_name = await upload_to_comfy(data, f"dw_{job_id[:8]}.jpg")
|
|
kw_list = [k.strip() for k in keywords.split(",") if k.strip()]
|
|
expanded = await asyncio.to_thread(expand_prompt, keywords=kw_list, room_type=room_type)
|
|
wf = build_workflow(comfy_name, expanded)
|
|
prompt_id = await queue_prompt(wf)
|
|
jobs[job_id].update({"status": "processing", "prompt_id": prompt_id})
|
|
asyncio.create_task(background_poll(job_id, prompt_id))
|
|
return {"job_id": job_id, "status": "processing"}
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run(app, host="0.0.0.0", port=8082, log_level="info")
|