forked from sagnik/Project_Velocity
feat: Ipad app features and Dream Weaver for Velocity WebOS
This commit is contained in:
@@ -4,38 +4,116 @@ from pathlib import Path
|
||||
from typing import Optional, List
|
||||
import httpx
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, UploadFile, File, HTTPException, Form, BackgroundTasks
|
||||
from fastapi import FastAPI, UploadFile, File, HTTPException, Form, Request
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
|
||||
SCRIPTS_DIR = Path(__file__).parent / "scripts"
|
||||
ROOT_DIR = Path(__file__).resolve().parent
|
||||
SCRIPTS_DIR = ROOT_DIR / "scripts"
|
||||
if not SCRIPTS_DIR.exists():
|
||||
SCRIPTS_DIR = ROOT_DIR / "comfy_engine" / "scripts"
|
||||
sys.path.insert(0, str(SCRIPTS_DIR))
|
||||
|
||||
try:
|
||||
from prompt_expander import expand_prompt, expand_prompt_simple, ROOM_CONTEXTS, ExpandedPrompt
|
||||
from gateway_auth import load_gateway_api_key, is_gateway_request_authorized
|
||||
LLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
LLM_AVAILABLE = False
|
||||
logging.warning("prompt_expander not found — LLM expansion disabled")
|
||||
|
||||
class ExpandedPrompt(BaseModel):
|
||||
style_name: str
|
||||
positive_prompt: str
|
||||
negative_prompt: str
|
||||
steps: int = 28
|
||||
cfg: float = 7.0
|
||||
denoise: float = 0.72
|
||||
|
||||
ROOM_CONTEXTS = {}
|
||||
|
||||
def expand_prompt(keywords: List[str], room_type: str) -> ExpandedPrompt:
|
||||
pretty_room = room_type.replace("_", " ").strip() or "living room"
|
||||
pretty_keywords = ", ".join(keywords) if keywords else "modern, photorealistic"
|
||||
return ExpandedPrompt(
|
||||
style_name="Fallback Prompt Expansion",
|
||||
positive_prompt=(
|
||||
f"photorealistic premium {pretty_room} interior design, {pretty_keywords}, "
|
||||
"natural lighting, realistic materials, architect-grade composition"
|
||||
),
|
||||
negative_prompt=(
|
||||
"worst quality, low quality, blurry, distorted perspective, "
|
||||
"people, watermark, text, duplicate objects"
|
||||
),
|
||||
)
|
||||
|
||||
expand_prompt_simple = expand_prompt
|
||||
from gateway_auth import load_gateway_api_key, is_gateway_request_authorized
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger("DreamWeaverGateway")
|
||||
COMFY = "http://127.0.0.1:8188"
|
||||
COMFY = (os.environ.get("COMFYUI_URL") or os.environ.get("COMFY_URL") or "http://127.0.0.1:8188").rstrip("/")
|
||||
COMFY_TLS_VERIFY = os.environ.get("COMFYUI_TLS_VERIFY", "true").strip().lower() not in {"0", "false", "no", "off"}
|
||||
GATEWAY_API_KEY = load_gateway_api_key()
|
||||
PREFERRED_CHECKPOINTS = [
|
||||
"realvisxlV50_v50LightningBakedvae.safetensors",
|
||||
"realvisxlV50Lightning_v50Lightning.safetensors",
|
||||
]
|
||||
|
||||
app = FastAPI(title="Dream Weaver API v2", version="2.0.0")
|
||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
||||
jobs: dict = {}
|
||||
|
||||
def comfy_client(timeout: float = 30) -> httpx.AsyncClient:
|
||||
return httpx.AsyncClient(timeout=timeout, verify=COMFY_TLS_VERIFY, follow_redirects=True)
|
||||
|
||||
async def list_comfy_checkpoints() -> list[str]:
|
||||
async with comfy_client(timeout=10) as client:
|
||||
response = await client.get(f"{COMFY}/models/checkpoints")
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
if isinstance(payload, list):
|
||||
return [item for item in payload if isinstance(item, str)]
|
||||
return []
|
||||
|
||||
async def resolve_checkpoint() -> str:
|
||||
checkpoints = await list_comfy_checkpoints()
|
||||
if not checkpoints:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"ComfyUI is online but has no checkpoint models installed. "
|
||||
"Hydrate RealVisXL into ComfyUI/models/checkpoints before generating."
|
||||
),
|
||||
)
|
||||
lower_lookup = {item.lower(): item for item in checkpoints}
|
||||
for preferred in PREFERRED_CHECKPOINTS:
|
||||
match = lower_lookup.get(preferred.lower())
|
||||
if match:
|
||||
return match
|
||||
return checkpoints[0]
|
||||
|
||||
def gateway_urls(job_id: str) -> dict:
|
||||
return {
|
||||
"poll_url": f"/dream-weaver/status/{job_id}",
|
||||
"result_url": f"/dream-weaver/result/{job_id}",
|
||||
}
|
||||
|
||||
def ensure_gateway_auth(request: Request) -> None:
|
||||
if is_gateway_request_authorized(request.headers, GATEWAY_API_KEY):
|
||||
return
|
||||
raise HTTPException(status_code=401, detail="Dream Weaver gateway API key is required or invalid.")
|
||||
|
||||
async def upload_to_comfy(data: bytes, filename: str) -> str:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
async with comfy_client(timeout=30) as client:
|
||||
r = await client.post(f"{COMFY}/upload/image", files={"image": (filename, data, "image/jpeg")}, data={"overwrite": "true"})
|
||||
r.raise_for_status()
|
||||
return r.json()["name"]
|
||||
|
||||
def build_workflow(img_name: str, expanded: "ExpandedPrompt") -> dict:
|
||||
def build_workflow(img_name: str, expanded: "ExpandedPrompt", ckpt_name: str) -> dict:
|
||||
return {
|
||||
"1": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": "realvisxlV50_v50LightningBakedvae.safetensors"}},
|
||||
"1": {"class_type": "CheckpointLoaderSimple", "inputs": {"ckpt_name": ckpt_name}},
|
||||
"2": {"class_type": "LoadImage", "inputs": {"image": img_name, "upload": "image"}},
|
||||
"3": {"class_type": "CLIPTextEncode", "inputs": {"text": expanded.positive_prompt, "clip": ["1", 1]}},
|
||||
"4": {"class_type": "CLIPTextEncode", "inputs": {"text": expanded.negative_prompt, "clip": ["1", 1]}},
|
||||
@@ -46,14 +124,20 @@ def build_workflow(img_name: str, expanded: "ExpandedPrompt") -> dict:
|
||||
}
|
||||
|
||||
async def queue_prompt(workflow: dict) -> str:
|
||||
async with httpx.AsyncClient(timeout=30) as client:
|
||||
async with comfy_client(timeout=30) as client:
|
||||
r = await client.post(f"{COMFY}/prompt", json={"prompt": workflow, "client_id": str(uuid.uuid4())})
|
||||
r.raise_for_status()
|
||||
if r.status_code >= 400:
|
||||
detail = r.text
|
||||
try:
|
||||
detail = json.dumps(r.json())
|
||||
except Exception:
|
||||
pass
|
||||
raise HTTPException(status_code=502, detail=f"ComfyUI rejected Dream Weaver workflow: {detail}")
|
||||
return r.json()["prompt_id"]
|
||||
|
||||
async def poll_result(prompt_id: str, timeout: int = 300):
|
||||
start = time.time()
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
async with comfy_client(timeout=10) as client:
|
||||
while time.time() - start < timeout:
|
||||
r = await client.get(f"{COMFY}/history/{prompt_id}")
|
||||
if r.status_code == 200:
|
||||
@@ -70,29 +154,89 @@ async def background_poll(job_id: str, prompt_id: str):
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok", "comfyui": True, "llm_expansion": LLM_AVAILABLE, "version": "2.0.0"}
|
||||
checkpoints: list[str] = []
|
||||
try:
|
||||
checkpoints = await list_comfy_checkpoints()
|
||||
except Exception:
|
||||
checkpoints = []
|
||||
return {
|
||||
"status": "ok",
|
||||
"comfyui": True,
|
||||
"comfyui_url": COMFY,
|
||||
"checkpoint_ready": bool(checkpoints),
|
||||
"checkpoint_count": len(checkpoints),
|
||||
"preferred_checkpoints": PREFERRED_CHECKPOINTS,
|
||||
"available_checkpoints": checkpoints[:12],
|
||||
"llm_expansion": LLM_AVAILABLE,
|
||||
"version": "2.0.0",
|
||||
"auth_required": GATEWAY_API_KEY is not None,
|
||||
"auth_scheme": "x-dream-weaver-api-key"
|
||||
}
|
||||
|
||||
@app.get("/dream-weaver/status/{job_id}")
|
||||
async def status(job_id: str):
|
||||
async def status(job_id: str, request: Request):
|
||||
ensure_gateway_auth(request)
|
||||
job = jobs.get(job_id)
|
||||
if not job: raise HTTPException(status_code=404, detail="Job not found")
|
||||
res = {k: v for k, v in job.items() if k != "output"}
|
||||
res["ready"] = job.get("status") == "done"
|
||||
if res["ready"]:
|
||||
res.update(gateway_urls(job_id))
|
||||
return res
|
||||
|
||||
@app.post("/dream-weaver")
|
||||
async def dream_weaver(image: UploadFile = File(...), keywords: str = Form(default=""), room_type: str = Form(default="living_room")):
|
||||
async def dream_weaver(
|
||||
request: Request,
|
||||
image: UploadFile = File(...),
|
||||
keywords: str = Form(default=""),
|
||||
room_type: str = Form(default="living_room")
|
||||
):
|
||||
ensure_gateway_auth(request)
|
||||
job_id = str(uuid.uuid4())
|
||||
jobs[job_id] = {"status": "uploading", "created": time.time()}
|
||||
data = await image.read()
|
||||
comfy_name = await upload_to_comfy(data, f"dw_{job_id[:8]}.jpg")
|
||||
kw_list = [k.strip() for k in keywords.split(",") if k.strip()]
|
||||
expanded = await asyncio.to_thread(expand_prompt, keywords=kw_list, room_type=room_type)
|
||||
wf = build_workflow(comfy_name, expanded)
|
||||
ckpt_name = await resolve_checkpoint()
|
||||
jobs[job_id]["checkpoint"] = ckpt_name
|
||||
wf = build_workflow(comfy_name, expanded, ckpt_name)
|
||||
prompt_id = await queue_prompt(wf)
|
||||
jobs[job_id].update({"status": "processing", "prompt_id": prompt_id})
|
||||
asyncio.create_task(background_poll(job_id, prompt_id))
|
||||
return {"job_id": job_id, "status": "processing"}
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"status": "processing",
|
||||
**gateway_urls(job_id),
|
||||
}
|
||||
|
||||
@app.get("/dream-weaver/result/{job_id}")
|
||||
async def result(job_id: str, request: Request):
|
||||
ensure_gateway_auth(request)
|
||||
job = jobs.get(job_id)
|
||||
if not job or job.get("status") != "done":
|
||||
raise HTTPException(status_code=404, detail="Result not ready")
|
||||
|
||||
img = job.get("output")
|
||||
if not img:
|
||||
raise HTTPException(status_code=404, detail="Result not ready")
|
||||
|
||||
async with comfy_client(timeout=30) as client:
|
||||
response = await client.get(
|
||||
f"{COMFY}/view",
|
||||
params={
|
||||
"filename": img["filename"],
|
||||
"subfolder": img.get("subfolder", ""),
|
||||
"type": img.get("type", "output"),
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return StreamingResponse(
|
||||
io.BytesIO(response.content),
|
||||
media_type="image/png",
|
||||
headers={"Content-Disposition": f"attachment; filename=dreamweaver_{job_id[:8]}.png"},
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8082, log_level="info")
|
||||
|
||||
Reference in New Issue
Block a user