Merge Conflicts (#41)
Co-authored-by: Sayan Datta <sayan@Sayans-MacBook-Air.local> Reviewed-on: #41
This commit was merged in pull request #41.
This commit is contained in:
@@ -19,12 +19,13 @@ import asyncio, json, time, uuid, io, sys, os, logging, traceback
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
import httpx
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, UploadFile, File, HTTPException, Form, BackgroundTasks
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
import httpx
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, UploadFile, File, HTTPException, Form, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from gateway_auth import load_gateway_api_key, is_gateway_request_authorized
|
||||
|
||||
# Add scripts dir to path so we can import prompt_expander
|
||||
SCRIPTS_DIR = Path(__file__).parent / "scripts"
|
||||
@@ -38,10 +39,48 @@ except ImportError:
|
||||
logging.warning("prompt_expander not found — LLM expansion disabled")
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger("DreamWeaverGateway")
|
||||
|
||||
COMFY = "http://127.0.0.1:8118"
|
||||
COMFY_ROOT = "/opt/dlami/nvme/ComfyUI"
|
||||
logger = logging.getLogger("DreamWeaverGateway")
|
||||
|
||||
COMFY = (os.environ.get("COMFYUI_URL") or os.environ.get("COMFY_URL") or "http://127.0.0.1:8118").rstrip("/")
|
||||
COMFY_TLS_VERIFY = os.environ.get("COMFYUI_TLS_VERIFY", "true").strip().lower() not in {"0", "false", "no", "off"}
|
||||
COMFY_ROOT = "/opt/dlami/nvme/ComfyUI"
|
||||
GATEWAY_API_KEY = load_gateway_api_key()
|
||||
PREFERRED_CHECKPOINTS = [
|
||||
"realvisxlV50_v50LightningBakedvae.safetensors",
|
||||
"realvisxlV50Lightning_v50Lightning.safetensors",
|
||||
]
|
||||
|
||||
|
||||
def comfy_client(timeout: float = 30) -> httpx.AsyncClient:
|
||||
return httpx.AsyncClient(timeout=timeout, verify=COMFY_TLS_VERIFY, follow_redirects=True)
|
||||
|
||||
|
||||
async def list_comfy_checkpoints() -> list[str]:
|
||||
async with comfy_client(timeout=10) as client:
|
||||
response = await client.get(f"{COMFY}/models/checkpoints")
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
if isinstance(payload, list):
|
||||
return [item for item in payload if isinstance(item, str)]
|
||||
return []
|
||||
|
||||
|
||||
async def resolve_checkpoint() -> str:
|
||||
checkpoints = await list_comfy_checkpoints()
|
||||
if not checkpoints:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"ComfyUI is online but has no checkpoint models installed. "
|
||||
"Hydrate RealVisXL into ComfyUI/models/checkpoints before generating."
|
||||
),
|
||||
)
|
||||
lower_lookup = {item.lower(): item for item in checkpoints}
|
||||
for preferred in PREFERRED_CHECKPOINTS:
|
||||
match = lower_lookup.get(preferred.lower())
|
||||
if match:
|
||||
return match
|
||||
return checkpoints[0]
|
||||
|
||||
app = FastAPI(
|
||||
title="Dream Weaver API v2",
|
||||
@@ -51,7 +90,13 @@ app = FastAPI(
|
||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
||||
|
||||
# In-memory job store (swap for Redis in production)
|
||||
jobs: dict = {}
|
||||
jobs: dict = {}
|
||||
|
||||
|
||||
def ensure_gateway_auth(request: Request) -> None:
|
||||
if is_gateway_request_authorized(request.headers, GATEWAY_API_KEY):
|
||||
return
|
||||
raise HTTPException(status_code=401, detail="Dream Weaver gateway API key is required or invalid.")
|
||||
|
||||
|
||||
# ─── Models ──────────────────────────────────────────────────────────────────
|
||||
@@ -73,8 +118,8 @@ class ExpandResponse(BaseModel):
|
||||
|
||||
|
||||
# ─── ComfyUI helpers ──────────────────────────────────────────────────────────
|
||||
async def upload_to_comfy(data: bytes, filename: str) -> str:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
async def upload_to_comfy(data: bytes, filename: str) -> str:
|
||||
async with comfy_client(timeout=30) as client:
|
||||
r = await client.post(f"{COMFY}/upload/image",
|
||||
files={"image": (filename, data, "image/jpeg")},
|
||||
data={"overwrite": "true"})
|
||||
@@ -82,11 +127,11 @@ async def upload_to_comfy(data: bytes, filename: str) -> str:
|
||||
return r.json()["name"]
|
||||
|
||||
|
||||
def build_workflow(img_name: str, expanded: "ExpandedPrompt") -> dict:
|
||||
def build_workflow(img_name: str, expanded: "ExpandedPrompt", ckpt_name: str) -> dict:
|
||||
"""Build ComfyUI API workflow from an ExpandedPrompt result."""
|
||||
return {
|
||||
"1": {"class_type": "CheckpointLoaderSimple",
|
||||
"inputs": {"ckpt_name": "realvisxlV50_v50LightningBakedvae.safetensors"}},
|
||||
"1": {"class_type": "CheckpointLoaderSimple",
|
||||
"inputs": {"ckpt_name": ckpt_name}},
|
||||
"2": {"class_type": "LoadImage",
|
||||
"inputs": {"image": img_name, "upload": "image"}},
|
||||
"3": {"class_type": "CLIPTextEncode", # Positive prompt
|
||||
@@ -114,17 +159,23 @@ def build_workflow(img_name: str, expanded: "ExpandedPrompt") -> dict:
|
||||
}
|
||||
|
||||
|
||||
async def queue_prompt(workflow: dict) -> str:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
r = await client.post(f"{COMFY}/prompt",
|
||||
json={"prompt": workflow, "client_id": str(uuid.uuid4())})
|
||||
r.raise_for_status()
|
||||
return r.json()["prompt_id"]
|
||||
async def queue_prompt(workflow: dict) -> str:
|
||||
async with comfy_client(timeout=30) as client:
|
||||
r = await client.post(f"{COMFY}/prompt",
|
||||
json={"prompt": workflow, "client_id": str(uuid.uuid4())})
|
||||
if r.status_code >= 400:
|
||||
detail = r.text
|
||||
try:
|
||||
detail = json.dumps(r.json())
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(status_code=502, detail=f"ComfyUI rejected Dream Weaver workflow: {detail}")
|
||||
return r.json()["prompt_id"]
|
||||
|
||||
|
||||
async def poll_result(prompt_id: str, timeout: int = 300):
|
||||
start = time.time()
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
async def poll_result(prompt_id: str, timeout: int = 300):
|
||||
start = time.time()
|
||||
async with comfy_client(timeout=10) as client:
|
||||
while time.time() - start < timeout:
|
||||
r = await client.get(f"{COMFY}/history/{prompt_id}")
|
||||
if r.status_code == 200:
|
||||
@@ -149,23 +200,36 @@ async def background_poll(job_id: str, prompt_id: str):
|
||||
|
||||
# ─── Endpoints ───────────────────────────────────────────────────────────────
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
comfy_ok = False
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5) as c:
|
||||
r = await c.get(f"{COMFY}/system_stats")
|
||||
comfy_ok = r.status_code == 200
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"status": "ok",
|
||||
"comfyui": comfy_ok,
|
||||
"gpu": "4x NVIDIA L4 (96GB VRAM)",
|
||||
"model": "RealVisXL V5.0 Lightning",
|
||||
"llm_expansion": LLM_AVAILABLE,
|
||||
"version": "2.0.0"
|
||||
}
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
comfy_ok = False
|
||||
checkpoints: list[str] = []
|
||||
try:
|
||||
async with comfy_client(timeout=5) as c:
|
||||
r = await c.get(f"{COMFY}/system_stats")
|
||||
comfy_ok = r.status_code == 200
|
||||
except Exception:
|
||||
pass
|
||||
if comfy_ok:
|
||||
try:
|
||||
checkpoints = await list_comfy_checkpoints()
|
||||
except Exception:
|
||||
checkpoints = []
|
||||
return {
|
||||
"status": "ok",
|
||||
"comfyui": comfy_ok,
|
||||
"gpu": "4x NVIDIA L4 (96GB VRAM)",
|
||||
"model": "RealVisXL V5.0 Lightning",
|
||||
"comfyui_url": COMFY,
|
||||
"checkpoint_ready": bool(checkpoints),
|
||||
"checkpoint_count": len(checkpoints),
|
||||
"preferred_checkpoints": PREFERRED_CHECKPOINTS,
|
||||
"available_checkpoints": checkpoints[:12],
|
||||
"llm_expansion": LLM_AVAILABLE,
|
||||
"version": "2.0.0",
|
||||
"auth_required": GATEWAY_API_KEY is not None,
|
||||
"auth_scheme": "x-dream-weaver-api-key"
|
||||
}
|
||||
|
||||
|
||||
@app.get("/room-types")
|
||||
@@ -185,8 +249,9 @@ async def room_types():
|
||||
}
|
||||
|
||||
|
||||
@app.post("/dream-weaver/expand", response_model=ExpandResponse)
|
||||
async def expand_endpoint(req: ExpandRequest):
|
||||
@app.post("/dream-weaver/expand", response_model=ExpandResponse)
|
||||
async def expand_endpoint(req: ExpandRequest, request: Request):
|
||||
ensure_gateway_auth(request)
|
||||
"""
|
||||
Preview the LLM-generated prompt WITHOUT submitting to ComfyUI.
|
||||
Use this to let the user review/edit the prompt before generating.
|
||||
@@ -228,8 +293,9 @@ async def expand_endpoint(req: ExpandRequest):
|
||||
|
||||
|
||||
@app.post("/dream-weaver")
|
||||
async def dream_weaver(
|
||||
image: UploadFile = File(...),
|
||||
async def dream_weaver(
|
||||
request: Request,
|
||||
image: UploadFile = File(...),
|
||||
# ── Dynamic keyword mode (new) ──
|
||||
keywords: str = Form(default=""), # comma-separated: "blue marble, gold, renaissance"
|
||||
room_type: str = Form(default="living_room"),
|
||||
@@ -239,7 +305,7 @@ async def dream_weaver(
|
||||
custom_negative: str = Form(default=""),
|
||||
denoise: float = Form(default=0.0), # 0.0 = use LLM recommendation
|
||||
cfg_scale: float = Form(default=0.0), # 0.0 = use LLM recommendation
|
||||
):
|
||||
):
|
||||
"""
|
||||
Submit a room photo for AI redesign using dynamic keyword → LLM → ComfyUI pipeline.
|
||||
|
||||
@@ -249,7 +315,8 @@ async def dream_weaver(
|
||||
|
||||
Returns job_id for async polling.
|
||||
"""
|
||||
job_id = str(uuid.uuid4())
|
||||
ensure_gateway_auth(request)
|
||||
job_id = str(uuid.uuid4())
|
||||
jobs[job_id] = {"status": "uploading", "created": time.time()}
|
||||
|
||||
try:
|
||||
@@ -312,9 +379,11 @@ async def dream_weaver(
|
||||
"room_type": room_type,
|
||||
})
|
||||
|
||||
# Submit workflow
|
||||
wf = build_workflow(comfy_name, expanded)
|
||||
prompt_id = await queue_prompt(wf)
|
||||
# Submit workflow
|
||||
ckpt_name = await resolve_checkpoint()
|
||||
jobs[job_id]["checkpoint"] = ckpt_name
|
||||
wf = build_workflow(comfy_name, expanded, ckpt_name)
|
||||
prompt_id = await queue_prompt(wf)
|
||||
jobs[job_id].update({"status": "processing", "prompt_id": prompt_id})
|
||||
|
||||
# Start background polling
|
||||
@@ -339,9 +408,10 @@ async def dream_weaver(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/dream-weaver/status/{job_id}")
|
||||
async def status(job_id: str):
|
||||
job = jobs.get(job_id)
|
||||
@app.get("/dream-weaver/status/{job_id}")
|
||||
async def status(job_id: str, request: Request):
|
||||
ensure_gateway_auth(request)
|
||||
job = jobs.get(job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
result = {k: v for k, v in job.items() if k != "output"}
|
||||
@@ -351,15 +421,16 @@ async def status(job_id: str):
|
||||
return result
|
||||
|
||||
|
||||
@app.get("/dream-weaver/result/{job_id}")
|
||||
async def result(job_id: str):
|
||||
job = jobs.get(job_id)
|
||||
@app.get("/dream-weaver/result/{job_id}")
|
||||
async def result(job_id: str, request: Request):
|
||||
ensure_gateway_auth(request)
|
||||
job = jobs.get(job_id)
|
||||
if not job or job.get("status") != "done":
|
||||
raise HTTPException(status_code=404, detail="Result not ready")
|
||||
img = job["output"]
|
||||
url = (f"{COMFY}/view?filename={img['filename']}"
|
||||
f"&subfolder={img.get('subfolder','')}&type={img.get('type','output')}")
|
||||
async with httpx.AsyncClient(timeout=30) as c:
|
||||
async with comfy_client(timeout=30) as c:
|
||||
r = await c.get(url)
|
||||
return StreamingResponse(
|
||||
io.BytesIO(r.content),
|
||||
@@ -369,19 +440,21 @@ async def result(job_id: str):
|
||||
|
||||
|
||||
@app.post("/dream-weaver/sync")
|
||||
async def dream_weaver_sync(
|
||||
image: UploadFile = File(...),
|
||||
async def dream_weaver_sync(
|
||||
request: Request,
|
||||
image: UploadFile = File(...),
|
||||
keywords: str = Form(default=""),
|
||||
room_type: str = Form(default="living_room"),
|
||||
additional_notes: str = Form(default=""),
|
||||
custom_positive: str = Form(default=""),
|
||||
custom_negative: str = Form(default=""),
|
||||
):
|
||||
):
|
||||
"""
|
||||
Blocking version — waits up to 120s and returns image bytes directly.
|
||||
Use for testing. Prefer async /dream-weaver for production.
|
||||
"""
|
||||
data = await image.read()
|
||||
ensure_gateway_auth(request)
|
||||
data = await image.read()
|
||||
filename = f"sync_{uuid.uuid4().hex[:8]}_{image.filename or 'room.jpg'}"
|
||||
comfy_name = await upload_to_comfy(data, filename)
|
||||
|
||||
@@ -404,14 +477,15 @@ async def dream_weaver_sync(
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Provide keywords or custom_positive")
|
||||
|
||||
wf = build_workflow(comfy_name, expanded)
|
||||
ckpt_name = await resolve_checkpoint()
|
||||
wf = build_workflow(comfy_name, expanded, ckpt_name)
|
||||
prompt_id = await queue_prompt(wf)
|
||||
img, err = await poll_result(prompt_id, timeout=120)
|
||||
if err:
|
||||
raise HTTPException(status_code=500, detail=str(err))
|
||||
url = (f"{COMFY}/view?filename={img['filename']}"
|
||||
f"&subfolder={img.get('subfolder','')}&type={img.get('type','output')}")
|
||||
async with httpx.AsyncClient(timeout=30) as c:
|
||||
async with comfy_client(timeout=30) as c:
|
||||
r = await c.get(url)
|
||||
return StreamingResponse(io.BytesIO(r.content), media_type="image/png",
|
||||
headers={"X-Style": expanded.style_name,
|
||||
|
||||
49
comfy_engine/scripts/gateway_auth.py
Normal file
49
comfy_engine/scripts/gateway_auth.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Mapping
|
||||
|
||||
|
||||
_SUPPORTED_ENV_KEYS = (
|
||||
"DREAM_WEAVER_GATEWAY_API_KEY",
|
||||
"DREAM_WEAVER_API_KEY",
|
||||
)
|
||||
|
||||
|
||||
def load_gateway_api_key(env: Mapping[str, str] | None = None) -> str | None:
|
||||
values = env if env is not None else os.environ
|
||||
for key in _SUPPORTED_ENV_KEYS:
|
||||
raw = values.get(key)
|
||||
if raw is None:
|
||||
continue
|
||||
trimmed = raw.strip()
|
||||
if trimmed:
|
||||
return trimmed
|
||||
return None
|
||||
|
||||
|
||||
def extract_gateway_api_key(headers: Mapping[str, str]) -> str | None:
|
||||
for header_name in ("x-dream-weaver-api-key", "x-api-key"):
|
||||
value = headers.get(header_name)
|
||||
if value:
|
||||
trimmed = value.strip()
|
||||
if trimmed:
|
||||
return trimmed
|
||||
|
||||
authorization = headers.get("authorization", "")
|
||||
if authorization.lower().startswith("bearer "):
|
||||
token = authorization[7:].strip()
|
||||
if token:
|
||||
return token
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def is_gateway_request_authorized(
|
||||
headers: Mapping[str, str],
|
||||
required_api_key: str | None,
|
||||
) -> bool:
|
||||
if required_api_key is None:
|
||||
return True
|
||||
presented = extract_gateway_api_key(headers)
|
||||
return presented == required_api_key
|
||||
Reference in New Issue
Block a user