forked from sagnik/Velocity-OS
Initial commit: Velocity-OS migration
This commit is contained in:
373
media-engine/gateway.py
Normal file
373
media-engine/gateway.py
Normal file
@@ -0,0 +1,373 @@
|
||||
#!/usr/bin/env python3
|
||||
import asyncio, json, time, uuid, io, sys, os, logging
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List
|
||||
import httpx
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, UploadFile, File, HTTPException, Form, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
ROOT_DIR = Path(__file__).resolve().parent
|
||||
for scripts_dir in (ROOT_DIR / "comfy_engine" / "scripts", ROOT_DIR / "scripts"):
|
||||
if scripts_dir.exists():
|
||||
sys.path.insert(0, str(scripts_dir))
|
||||
|
||||
try:
|
||||
from gateway_auth import load_gateway_api_key, is_gateway_request_authorized
|
||||
except ImportError as exc:
|
||||
raise RuntimeError("Dream Weaver gateway_auth.py is required on PYTHONPATH") from exc
|
||||
|
||||
try:
|
||||
from prompt_expander import expand_prompt, ROOM_CONTEXTS, ExpandedPrompt
|
||||
LLM_AVAILABLE = True
|
||||
except ImportError as exc:
|
||||
LLM_AVAILABLE = False
|
||||
logging.warning("prompt_expander unavailable; using deterministic fallback expansion: %s", exc)
|
||||
|
||||
class ExpandedPrompt(BaseModel):
|
||||
style_name: str
|
||||
positive_prompt: str
|
||||
negative_prompt: str
|
||||
steps: int = 28
|
||||
cfg: float = 7.0
|
||||
denoise: float = 0.72
|
||||
|
||||
ROOM_CONTEXTS = {}
|
||||
|
||||
def expand_prompt(keywords: List[str], room_type: str) -> ExpandedPrompt:
|
||||
pretty_room = room_type.replace("_", " ").strip() or "living room"
|
||||
pretty_keywords = ", ".join(keywords) if keywords else "modern, photorealistic"
|
||||
return ExpandedPrompt(
|
||||
style_name="Fallback Prompt Expansion",
|
||||
positive_prompt=(
|
||||
f"photorealistic premium {pretty_room} interior design, {pretty_keywords}, "
|
||||
"natural lighting, realistic materials, architect-grade composition"
|
||||
),
|
||||
negative_prompt=(
|
||||
"worst quality, low quality, blurry, distorted perspective, "
|
||||
"people, watermark, text, duplicate objects"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger("DreamWeaverGateway")
|
||||
COMFY = (os.environ.get("COMFYUI_URL") or os.environ.get("COMFY_URL") or "http://127.0.0.1:8188").rstrip("/")
|
||||
COMFY_URLS = [
|
||||
item.strip().rstrip("/")
|
||||
for item in os.environ.get("COMFYUI_URLS", COMFY).split(",")
|
||||
if item.strip()
|
||||
]
|
||||
COMFY_TLS_VERIFY = os.environ.get("COMFYUI_TLS_VERIFY", "true").strip().lower() not in {"0", "false", "no", "off"}
|
||||
GATEWAY_API_KEY = load_gateway_api_key()
|
||||
POLL_TIMEOUT_SECONDS = int(os.environ.get("DREAM_WEAVER_POLL_TIMEOUT_SECONDS", "1800"))
|
||||
POLL_INTERVAL_SECONDS = float(os.environ.get("DREAM_WEAVER_POLL_INTERVAL_SECONDS", "2"))
|
||||
INPUT_MEGAPIXELS = float(os.environ.get("DREAM_WEAVER_INPUT_MEGAPIXELS", "0.75"))
|
||||
MAX_RENDER_STEPS = int(os.environ.get("DREAM_WEAVER_MAX_STEPS", "18"))
|
||||
PREFERRED_CHECKPOINTS = [
|
||||
"realvisxlV50_v50LightningBakedvae.safetensors",
|
||||
"realvisxlV50Lightning_v50Lightning.safetensors",
|
||||
]
|
||||
|
||||
app = FastAPI(title="Dream Weaver API v2", version="2.0.0")
|
||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
||||
jobs: dict = {}
|
||||
worker_locks: dict[str, asyncio.Lock] = {}
|
||||
assignment_lock = asyncio.Lock()
|
||||
ACTIVE_JOB_STATUSES = {"uploading", "processing"}
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ComfyWorker:
|
||||
id: str
|
||||
url: str
|
||||
|
||||
COMFY_WORKERS = [
|
||||
ComfyWorker(id=f"comfy-{index}", url=url)
|
||||
for index, url in enumerate(COMFY_URLS)
|
||||
]
|
||||
for worker in COMFY_WORKERS:
|
||||
worker_locks.setdefault(worker.id, asyncio.Lock())
|
||||
|
||||
def comfy_client(timeout: float = 30) -> httpx.AsyncClient:
|
||||
return httpx.AsyncClient(timeout=timeout, verify=COMFY_TLS_VERIFY, follow_redirects=True)
|
||||
|
||||
def worker_by_id(worker_id: str) -> ComfyWorker:
|
||||
for worker in COMFY_WORKERS:
|
||||
if worker.id == worker_id:
|
||||
return worker
|
||||
raise HTTPException(status_code=500, detail=f"Dream Weaver worker {worker_id} is not configured")
|
||||
|
||||
async def list_comfy_checkpoints(worker: Optional[ComfyWorker] = None) -> list[str]:
|
||||
worker = worker or COMFY_WORKERS[0]
|
||||
async with comfy_client(timeout=10) as client:
|
||||
response = await client.get(f"{worker.url}/models/checkpoints")
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
if isinstance(payload, list):
|
||||
return [item for item in payload if isinstance(item, str)]
|
||||
return []
|
||||
|
||||
async def resolve_checkpoint(worker: ComfyWorker) -> str:
|
||||
checkpoints = await list_comfy_checkpoints(worker)
|
||||
if not checkpoints:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"ComfyUI is online but has no checkpoint models installed. "
|
||||
"Hydrate RealVisXL into ComfyUI/models/checkpoints before generating."
|
||||
),
|
||||
)
|
||||
lower_lookup = {item.lower(): item for item in checkpoints}
|
||||
for preferred in PREFERRED_CHECKPOINTS:
|
||||
match = lower_lookup.get(preferred.lower())
|
||||
if match:
|
||||
return match
|
||||
return checkpoints[0]
|
||||
|
||||
def gateway_urls(job_id: str) -> dict:
|
||||
return {
|
||||
"poll_url": f"/dream-weaver/status/{job_id}",
|
||||
"result_url": f"/dream-weaver/result/{job_id}",
|
||||
}
|
||||
|
||||
def ensure_gateway_auth(request: Request) -> None:
|
||||
if is_gateway_request_authorized(request.headers, GATEWAY_API_KEY):
|
||||
return
|
||||
raise HTTPException(status_code=401, detail="Dream Weaver gateway API key is required or invalid.")
|
||||
|
||||
async def worker_queue_size(worker: ComfyWorker) -> int:
|
||||
async with comfy_client(timeout=5) as client:
|
||||
response = await client.get(f"{worker.url}/queue")
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
running = payload.get("queue_running") if isinstance(payload, dict) else []
|
||||
pending = payload.get("queue_pending") if isinstance(payload, dict) else []
|
||||
return len(running or []) + len(pending or [])
|
||||
|
||||
def local_worker_load(worker: ComfyWorker) -> int:
|
||||
return sum(
|
||||
1
|
||||
for job in jobs.values()
|
||||
if job.get("worker_id") == worker.id and job.get("status") in ACTIVE_JOB_STATUSES
|
||||
)
|
||||
|
||||
async def choose_worker() -> ComfyWorker:
|
||||
candidates: list[tuple[int, int, str, ComfyWorker]] = []
|
||||
errors: list[str] = []
|
||||
for worker in COMFY_WORKERS:
|
||||
try:
|
||||
checkpoints = await list_comfy_checkpoints(worker)
|
||||
if not checkpoints:
|
||||
errors.append(f"{worker.id} has no checkpoints")
|
||||
continue
|
||||
candidates.append((local_worker_load(worker), await worker_queue_size(worker), worker.id, worker))
|
||||
except Exception as exc:
|
||||
errors.append(f"{worker.id} unhealthy: {exc}")
|
||||
if not candidates:
|
||||
detail = "; ".join(errors) if errors else "No ComfyUI workers are configured"
|
||||
raise HTTPException(status_code=503, detail=f"No healthy Dream Weaver workers. {detail}")
|
||||
candidates.sort(key=lambda item: (item[0], item[1], item[2]))
|
||||
return candidates[0][3]
|
||||
|
||||
async def upload_to_comfy(worker: ComfyWorker, data: bytes, filename: str) -> str:
|
||||
async with comfy_client(timeout=30) as client:
|
||||
r = await client.post(f"{worker.url}/upload/image", files={"image": (filename, data, "image/jpeg")}, data={"overwrite": "true"})
|
||||
r.raise_for_status()
|
||||
return r.json()["name"]
|
||||
|
||||
def normalize_expanded_prompt(expanded: "ExpandedPrompt") -> "ExpandedPrompt":
|
||||
expanded.steps = max(6, min(int(expanded.steps or MAX_RENDER_STEPS), MAX_RENDER_STEPS))
|
||||
expanded.cfg = max(3.0, min(float(expanded.cfg or 6.0), 7.0))
|
||||
expanded.denoise = max(0.45, min(float(expanded.denoise or 0.65), 0.72))
|
||||
return expanded
|
||||
|
||||
def build_workflow(img_name: str, expanded: "ExpandedPrompt", ckpt_name: str) -> dict:
|
||||
expanded = normalize_expanded_prompt(expanded)
|
||||
return {
|
||||
"1": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": ckpt_name}},
|
||||
"2": {"class_type": "LoadImage", "inputs": {"image": img_name, "upload": "image"}},
|
||||
"9": {"class_type": "ImageScaleToTotalPixels", "inputs": {"image": ["2", 0], "upscale_method": "lanczos", "megapixels": INPUT_MEGAPIXELS, "resolution_steps": 8}},
|
||||
"3": {"class_type": "CLIPTextEncode", "inputs": {"text": expanded.positive_prompt, "clip": ["1", 1]}},
|
||||
"4": {"class_type": "CLIPTextEncode", "inputs": {"text": expanded.negative_prompt, "clip": ["1", 1]}},
|
||||
"5": {"class_type": "VAEEncode", "inputs": {"pixels": ["9", 0], "vae": ["1", 2]}},
|
||||
"6": {"class_type": "KSampler", "inputs": {"model": ["1", 0], "positive": ["3", 0], "negative": ["4", 0], "latent_image": ["5", 0], "seed": int(time.time()) % 999983, "steps": expanded.steps, "cfg": expanded.cfg, "sampler_name": "dpmpp_2m", "scheduler": "karras", "denoise": expanded.denoise}},
|
||||
"7": {"class_type": "VAEDecode", "inputs": {"samples": ["6", 0], "vae": ["1", 2]}},
|
||||
"8": {"class_type": "SaveImage", "inputs": {"images": ["7", 0], "filename_prefix": f"dw_{expanded.style_name.replace(' ', '_')[:30]}"}},
|
||||
}
|
||||
|
||||
async def queue_prompt(worker: ComfyWorker, workflow: dict) -> str:
|
||||
async with comfy_client(timeout=30) as client:
|
||||
r = await client.post(f"{worker.url}/prompt", json={"prompt": workflow, "client_id": str(uuid.uuid4())})
|
||||
if r.status_code >= 400:
|
||||
detail = r.text
|
||||
try:
|
||||
detail = json.dumps(r.json())
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(status_code=502, detail=f"ComfyUI rejected Dream Weaver workflow: {detail}")
|
||||
return r.json()["prompt_id"]
|
||||
|
||||
def extract_comfy_error(history_entry: dict) -> Optional[str]:
|
||||
status = history_entry.get("status") if isinstance(history_entry, dict) else None
|
||||
if not isinstance(status, dict):
|
||||
return None
|
||||
if status.get("status_str") != "error":
|
||||
return None
|
||||
messages = status.get("messages") or []
|
||||
for kind, payload in reversed(messages):
|
||||
if kind == "execution_error" and isinstance(payload, dict):
|
||||
node_type = payload.get("node_type") or payload.get("node_id") or "ComfyUI"
|
||||
message = payload.get("exception_message") or payload.get("exception_type") or "ComfyUI execution failed"
|
||||
return f"{node_type}: {message}"
|
||||
return "ComfyUI execution failed"
|
||||
|
||||
async def poll_result(worker: ComfyWorker, prompt_id: str, timeout: int = POLL_TIMEOUT_SECONDS):
|
||||
start = time.time()
|
||||
async with comfy_client(timeout=10) as client:
|
||||
while time.time() - start < timeout:
|
||||
r = await client.get(f"{worker.url}/history/{prompt_id}")
|
||||
if r.status_code == 200:
|
||||
h = r.json().get(prompt_id, {})
|
||||
err = extract_comfy_error(h)
|
||||
if err:
|
||||
return None, err
|
||||
imgs = [img for nd in h.get("outputs", {}).values() for img in nd.get("images", [])]
|
||||
if imgs: return imgs[0], None
|
||||
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||||
return None, f"timeout after {timeout} seconds"
|
||||
|
||||
async def background_poll(job_id: str, worker_id: str, prompt_id: str):
|
||||
worker = worker_by_id(worker_id)
|
||||
img, err = await poll_result(worker, prompt_id)
|
||||
if img:
|
||||
jobs[job_id].update({"status": "done", "output": img, "completed": time.time()})
|
||||
else:
|
||||
jobs[job_id].update({"status": "error", "error": str(err), "completed": time.time()})
|
||||
|
||||
@app.get("/health")
|
||||
@app.get("/dream-weaver/health")
|
||||
async def health():
|
||||
worker_health = []
|
||||
for worker in COMFY_WORKERS:
|
||||
try:
|
||||
checkpoints = await list_comfy_checkpoints(worker)
|
||||
queue_size = await worker_queue_size(worker)
|
||||
worker_health.append({
|
||||
"id": worker.id,
|
||||
"url": worker.url,
|
||||
"online": True,
|
||||
"checkpoint_ready": bool(checkpoints),
|
||||
"checkpoint_count": len(checkpoints),
|
||||
"queue_size": queue_size,
|
||||
"available_checkpoints": checkpoints[:12],
|
||||
})
|
||||
except Exception as exc:
|
||||
worker_health.append({
|
||||
"id": worker.id,
|
||||
"url": worker.url,
|
||||
"online": False,
|
||||
"checkpoint_ready": False,
|
||||
"checkpoint_count": 0,
|
||||
"queue_size": None,
|
||||
"error": str(exc),
|
||||
})
|
||||
ready_workers = [worker for worker in worker_health if worker["online"] and worker["checkpoint_ready"]]
|
||||
checkpoints = ready_workers[0]["available_checkpoints"] if ready_workers else []
|
||||
return {
|
||||
"status": "ok",
|
||||
"comfyui": bool(ready_workers),
|
||||
"comfyui_url": COMFY_WORKERS[0].url if COMFY_WORKERS else COMFY,
|
||||
"comfyui_urls": [worker.url for worker in COMFY_WORKERS],
|
||||
"checkpoint_ready": bool(ready_workers),
|
||||
"checkpoint_count": max((worker["checkpoint_count"] for worker in worker_health), default=0),
|
||||
"preferred_checkpoints": PREFERRED_CHECKPOINTS,
|
||||
"available_checkpoints": checkpoints[:12],
|
||||
"workers": worker_health,
|
||||
"ready_worker_count": len(ready_workers),
|
||||
"llm_expansion": LLM_AVAILABLE,
|
||||
"input_megapixels": INPUT_MEGAPIXELS,
|
||||
"max_render_steps": MAX_RENDER_STEPS,
|
||||
"poll_timeout_seconds": POLL_TIMEOUT_SECONDS,
|
||||
"version": "2.0.0",
|
||||
"auth_required": GATEWAY_API_KEY is not None,
|
||||
"auth_scheme": "x-dream-weaver-api-key"
|
||||
}
|
||||
|
||||
@app.get("/dream-weaver/status/{job_id}")
|
||||
async def status(job_id: str, request: Request):
|
||||
ensure_gateway_auth(request)
|
||||
job = jobs.get(job_id)
|
||||
if not job: raise HTTPException(status_code=404, detail="Job not found")
|
||||
res = {k: v for k, v in job.items() if k != "output"}
|
||||
if "created" in job:
|
||||
res["elapsed_seconds"] = round(time.time() - float(job["created"]), 2)
|
||||
res["ready"] = job.get("status") == "done"
|
||||
if res["ready"]:
|
||||
res.update(gateway_urls(job_id))
|
||||
return res
|
||||
|
||||
@app.post("/dream-weaver")
|
||||
async def dream_weaver(
|
||||
request: Request,
|
||||
image: UploadFile = File(...),
|
||||
keywords: str = Form(default=""),
|
||||
room_type: str = Form(default="living_room")
|
||||
):
|
||||
ensure_gateway_auth(request)
|
||||
job_id = str(uuid.uuid4())
|
||||
jobs[job_id] = {"status": "uploading", "created": time.time()}
|
||||
data = await image.read()
|
||||
async with assignment_lock:
|
||||
worker = await choose_worker()
|
||||
jobs[job_id].update({"worker_id": worker.id, "worker_url": worker.url})
|
||||
lock = worker_locks.setdefault(worker.id, asyncio.Lock())
|
||||
async with lock:
|
||||
comfy_name = await upload_to_comfy(worker, data, f"dw_{job_id[:8]}.jpg")
|
||||
kw_list = [k.strip() for k in keywords.split(",") if k.strip()]
|
||||
expanded = await asyncio.to_thread(expand_prompt, keywords=kw_list, room_type=room_type)
|
||||
ckpt_name = await resolve_checkpoint(worker)
|
||||
jobs[job_id]["checkpoint"] = ckpt_name
|
||||
wf = build_workflow(comfy_name, expanded, ckpt_name)
|
||||
prompt_id = await queue_prompt(worker, wf)
|
||||
jobs[job_id].update({"status": "processing", "prompt_id": prompt_id})
|
||||
asyncio.create_task(background_poll(job_id, worker.id, prompt_id))
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"status": "processing",
|
||||
**gateway_urls(job_id),
|
||||
}
|
||||
|
||||
@app.get("/dream-weaver/result/{job_id}")
|
||||
async def result(job_id: str, request: Request):
|
||||
ensure_gateway_auth(request)
|
||||
job = jobs.get(job_id)
|
||||
if not job or job.get("status") != "done":
|
||||
raise HTTPException(status_code=404, detail="Result not ready")
|
||||
|
||||
img = job.get("output")
|
||||
if not img:
|
||||
raise HTTPException(status_code=404, detail="Result not ready")
|
||||
|
||||
worker = worker_by_id(job.get("worker_id", COMFY_WORKERS[0].id))
|
||||
async with comfy_client(timeout=30) as client:
|
||||
response = await client.get(
|
||||
f"{worker.url}/view",
|
||||
params={
|
||||
"filename": img["filename"],
|
||||
"subfolder": img.get("subfolder", ""),
|
||||
"type": img.get("type", "output"),
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(response.content),
|
||||
media_type="image/png",
|
||||
headers={"Content-Disposition": f"attachment; filename=dreamweaver_{job_id[:8]}.png"},
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8082, log_level="info")
|
||||
Reference in New Issue
Block a user