forked from sagnik/Project_Velocity
Co-authored-by: Sayan Datta <sayan@Sayans-MacBook-Air.local> Reviewed-on: sagnik/Project_Velocity#41
243 lines
9.9 KiB
Python
243 lines
9.9 KiB
Python
#!/usr/bin/env python3
|
|
import asyncio, json, time, uuid, io, sys, os, logging
|
|
from pathlib import Path
|
|
from typing import Optional, List
|
|
import httpx
|
|
import uvicorn
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException, Form, Request
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
|
|
ROOT_DIR = Path(__file__).resolve().parent
|
|
SCRIPTS_DIR = ROOT_DIR / "scripts"
|
|
if not SCRIPTS_DIR.exists():
|
|
SCRIPTS_DIR = ROOT_DIR / "comfy_engine" / "scripts"
|
|
sys.path.insert(0, str(SCRIPTS_DIR))
|
|
|
|
try:
|
|
from prompt_expander import expand_prompt, expand_prompt_simple, ROOM_CONTEXTS, ExpandedPrompt
|
|
from gateway_auth import load_gateway_api_key, is_gateway_request_authorized
|
|
LLM_AVAILABLE = True
|
|
except ImportError:
|
|
LLM_AVAILABLE = False
|
|
logging.warning("prompt_expander not found — LLM expansion disabled")
|
|
|
|
class ExpandedPrompt(BaseModel):
|
|
style_name: str
|
|
positive_prompt: str
|
|
negative_prompt: str
|
|
steps: int = 28
|
|
cfg: float = 7.0
|
|
denoise: float = 0.72
|
|
|
|
ROOM_CONTEXTS = {}
|
|
|
|
def expand_prompt(keywords: List[str], room_type: str) -> ExpandedPrompt:
|
|
pretty_room = room_type.replace("_", " ").strip() or "living room"
|
|
pretty_keywords = ", ".join(keywords) if keywords else "modern, photorealistic"
|
|
return ExpandedPrompt(
|
|
style_name="Fallback Prompt Expansion",
|
|
positive_prompt=(
|
|
f"photorealistic premium {pretty_room} interior design, {pretty_keywords}, "
|
|
"natural lighting, realistic materials, architect-grade composition"
|
|
),
|
|
negative_prompt=(
|
|
"worst quality, low quality, blurry, distorted perspective, "
|
|
"people, watermark, text, duplicate objects"
|
|
),
|
|
)
|
|
|
|
expand_prompt_simple = expand_prompt
|
|
from gateway_auth import load_gateway_api_key, is_gateway_request_authorized
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
|
logger = logging.getLogger("DreamWeaverGateway")
|
|
COMFY = (os.environ.get("COMFYUI_URL") or os.environ.get("COMFY_URL") or "http://127.0.0.1:8188").rstrip("/")
|
|
COMFY_TLS_VERIFY = os.environ.get("COMFYUI_TLS_VERIFY", "true").strip().lower() not in {"0", "false", "no", "off"}
|
|
GATEWAY_API_KEY = load_gateway_api_key()
|
|
PREFERRED_CHECKPOINTS = [
|
|
"realvisxlV50_v50LightningBakedvae.safetensors",
|
|
"realvisxlV50Lightning_v50Lightning.safetensors",
|
|
]
|
|
|
|
app = FastAPI(title="Dream Weaver API v2", version="2.0.0")
|
|
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
|
jobs: dict = {}
|
|
|
|
def comfy_client(timeout: float = 30) -> httpx.AsyncClient:
|
|
return httpx.AsyncClient(timeout=timeout, verify=COMFY_TLS_VERIFY, follow_redirects=True)
|
|
|
|
async def list_comfy_checkpoints() -> list[str]:
|
|
async with comfy_client(timeout=10) as client:
|
|
response = await client.get(f"{COMFY}/models/checkpoints")
|
|
response.raise_for_status()
|
|
payload = response.json()
|
|
if isinstance(payload, list):
|
|
return [item for item in payload if isinstance(item, str)]
|
|
return []
|
|
|
|
async def resolve_checkpoint() -> str:
|
|
checkpoints = await list_comfy_checkpoints()
|
|
if not checkpoints:
|
|
raise HTTPException(
|
|
status_code=503,
|
|
detail=(
|
|
"ComfyUI is online but has no checkpoint models installed. "
|
|
"Hydrate RealVisXL into ComfyUI/models/checkpoints before generating."
|
|
),
|
|
)
|
|
lower_lookup = {item.lower(): item for item in checkpoints}
|
|
for preferred in PREFERRED_CHECKPOINTS:
|
|
match = lower_lookup.get(preferred.lower())
|
|
if match:
|
|
return match
|
|
return checkpoints[0]
|
|
|
|
def gateway_urls(job_id: str) -> dict:
|
|
return {
|
|
"poll_url": f"/dream-weaver/status/{job_id}",
|
|
"result_url": f"/dream-weaver/result/{job_id}",
|
|
}
|
|
|
|
def ensure_gateway_auth(request: Request) -> None:
|
|
if is_gateway_request_authorized(request.headers, GATEWAY_API_KEY):
|
|
return
|
|
raise HTTPException(status_code=401, detail="Dream Weaver gateway API key is required or invalid.")
|
|
|
|
async def upload_to_comfy(data: bytes, filename: str) -> str:
|
|
async with comfy_client(timeout=30) as client:
|
|
r = await client.post(f"{COMFY}/upload/image", files={"image": (filename, data, "image/jpeg")}, data={"overwrite": "true"})
|
|
r.raise_for_status()
|
|
return r.json()["name"]
|
|
|
|
def build_workflow(img_name: str, expanded: "ExpandedPrompt", ckpt_name: str) -> dict:
|
|
return {
|
|
"1": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": ckpt_name}},
|
|
"2": {"class_type": "LoadImage", "inputs": {"image": img_name, "upload": "image"}},
|
|
"3": {"class_type": "CLIPTextEncode", "inputs": {"text": expanded.positive_prompt, "clip": ["1", 1]}},
|
|
"4": {"class_type": "CLIPTextEncode", "inputs": {"text": expanded.negative_prompt, "clip": ["1", 1]}},
|
|
"5": {"class_type": "VAEEncode", "inputs": {"pixels": ["2", 0], "vae": ["1", 2]}},
|
|
"6": {"class_type": "KSampler", "inputs": {"model": ["1", 0], "positive": ["3", 0], "negative": ["4", 0], "latent_image": ["5", 0], "seed": int(time.time()) % 999983, "steps": expanded.steps, "cfg": expanded.cfg, "sampler_name": "dpmpp_2m", "scheduler": "karras", "denoise": expanded.denoise}},
|
|
"7": {"class_type": "VAEDecode", "inputs": {"samples": ["6", 0], "vae": ["1", 2]}},
|
|
"8": {"class_type": "SaveImage", "inputs": {"images": ["7", 0], "filename_prefix": f"dw_{expanded.style_name.replace(' ', '_')[:30]}"}},
|
|
}
|
|
|
|
async def queue_prompt(workflow: dict) -> str:
|
|
async with comfy_client(timeout=30) as client:
|
|
r = await client.post(f"{COMFY}/prompt", json={"prompt": workflow, "client_id": str(uuid.uuid4())})
|
|
if r.status_code >= 400:
|
|
detail = r.text
|
|
try:
|
|
detail = json.dumps(r.json())
|
|
except Exception:
|
|
pass
|
|
raise HTTPException(status_code=502, detail=f"ComfyUI rejected Dream Weaver workflow: {detail}")
|
|
return r.json()["prompt_id"]
|
|
|
|
async def poll_result(prompt_id: str, timeout: int = 300):
|
|
start = time.time()
|
|
async with comfy_client(timeout=10) as client:
|
|
while time.time() - start < timeout:
|
|
r = await client.get(f"{COMFY}/history/{prompt_id}")
|
|
if r.status_code == 200:
|
|
h = r.json().get(prompt_id, {})
|
|
imgs = [img for nd in h.get("outputs", {}).values() for img in nd.get("images", [])]
|
|
if imgs: return imgs[0], None
|
|
await asyncio.sleep(2)
|
|
return None, "timeout"
|
|
|
|
async def background_poll(job_id: str, prompt_id: str):
|
|
img, err = await poll_result(prompt_id)
|
|
if img: jobs[job_id].update({"status": "done", "output": img, "completed": time.time()})
|
|
else: jobs[job_id].update({"status": "error", "error": str(err)})
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
checkpoints: list[str] = []
|
|
try:
|
|
checkpoints = await list_comfy_checkpoints()
|
|
except Exception:
|
|
checkpoints = []
|
|
return {
|
|
"status": "ok",
|
|
"comfyui": True,
|
|
"comfyui_url": COMFY,
|
|
"checkpoint_ready": bool(checkpoints),
|
|
"checkpoint_count": len(checkpoints),
|
|
"preferred_checkpoints": PREFERRED_CHECKPOINTS,
|
|
"available_checkpoints": checkpoints[:12],
|
|
"llm_expansion": LLM_AVAILABLE,
|
|
"version": "2.0.0",
|
|
"auth_required": GATEWAY_API_KEY is not None,
|
|
"auth_scheme": "x-dream-weaver-api-key"
|
|
}
|
|
|
|
@app.get("/dream-weaver/status/{job_id}")
|
|
async def status(job_id: str, request: Request):
|
|
ensure_gateway_auth(request)
|
|
job = jobs.get(job_id)
|
|
if not job: raise HTTPException(status_code=404, detail="Job not found")
|
|
res = {k: v for k, v in job.items() if k != "output"}
|
|
res["ready"] = job.get("status") == "done"
|
|
if res["ready"]:
|
|
res.update(gateway_urls(job_id))
|
|
return res
|
|
|
|
@app.post("/dream-weaver")
|
|
async def dream_weaver(
|
|
request: Request,
|
|
image: UploadFile = File(...),
|
|
keywords: str = Form(default=""),
|
|
room_type: str = Form(default="living_room")
|
|
):
|
|
ensure_gateway_auth(request)
|
|
job_id = str(uuid.uuid4())
|
|
jobs[job_id] = {"status": "uploading", "created": time.time()}
|
|
data = await image.read()
|
|
comfy_name = await upload_to_comfy(data, f"dw_{job_id[:8]}.jpg")
|
|
kw_list = [k.strip() for k in keywords.split(",") if k.strip()]
|
|
expanded = await asyncio.to_thread(expand_prompt, keywords=kw_list, room_type=room_type)
|
|
ckpt_name = await resolve_checkpoint()
|
|
jobs[job_id]["checkpoint"] = ckpt_name
|
|
wf = build_workflow(comfy_name, expanded, ckpt_name)
|
|
prompt_id = await queue_prompt(wf)
|
|
jobs[job_id].update({"status": "processing", "prompt_id": prompt_id})
|
|
asyncio.create_task(background_poll(job_id, prompt_id))
|
|
return {
|
|
"job_id": job_id,
|
|
"status": "processing",
|
|
**gateway_urls(job_id),
|
|
}
|
|
|
|
@app.get("/dream-weaver/result/{job_id}")
|
|
async def result(job_id: str, request: Request):
|
|
ensure_gateway_auth(request)
|
|
job = jobs.get(job_id)
|
|
if not job or job.get("status") != "done":
|
|
raise HTTPException(status_code=404, detail="Result not ready")
|
|
|
|
img = job.get("output")
|
|
if not img:
|
|
raise HTTPException(status_code=404, detail="Result not ready")
|
|
|
|
async with comfy_client(timeout=30) as client:
|
|
response = await client.get(
|
|
f"{COMFY}/view",
|
|
params={
|
|
"filename": img["filename"],
|
|
"subfolder": img.get("subfolder", ""),
|
|
"type": img.get("type", "output"),
|
|
},
|
|
)
|
|
response.raise_for_status()
|
|
|
|
return StreamingResponse(
|
|
io.BytesIO(response.content),
|
|
media_type="image/png",
|
|
headers={"Content-Disposition": f"attachment; filename=dreamweaver_{job_id[:8]}.png"},
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run(app, host="0.0.0.0", port=8082, log_level="info")
|