326 lines
12 KiB
Python
326 lines
12 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import subprocess
|
|
import uuid
|
|
from pathlib import Path
|
|
from typing import Iterable, Optional
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.db.session import SessionLocal
|
|
from app.models import Asset, Job, JobEvent, JobOutput
|
|
from app.services.comfy_client import comfy_client
|
|
from app.services.storage import asset_storage, output_storage
|
|
from app.services.workflow_binder import WorkflowBinder, select_template_name
|
|
|
|
logger = logging.getLogger(__name__)
|
|
VIDEO_EXTENSIONS = {".mp4", ".mov", ".webm", ".mkv", ".avi"}
|
|
MODEL_LOADER_INPUTS = {
|
|
"CLIPLoader": "clip_name",
|
|
"VAELoader": "vae_name",
|
|
"UNETLoader": "unet_name",
|
|
"LoraLoaderModelOnly": "lora_name",
|
|
}
|
|
|
|
|
|
def _add_event(db: Session, job_id: str, event_type: str, message: str, payload: dict | None = None) -> None:
|
|
event = JobEvent(
|
|
job_id=job_id,
|
|
event_type=event_type,
|
|
message=message,
|
|
payload_json=json.dumps(payload) if payload else None,
|
|
)
|
|
db.add(event)
|
|
db.commit()
|
|
|
|
|
|
def _set_status(db: Session, job: Job, status: str, error: str | None = None) -> None:
|
|
job.status = status
|
|
if error:
|
|
job.error_message = error
|
|
db.commit()
|
|
_add_event(db, job.id, "status_change", f"Job status -> {status}")
|
|
|
|
|
|
def _extract_history_error(history: dict) -> str | None:
|
|
status = history.get("status", {}) or {}
|
|
if status.get("status_str") == "error":
|
|
for message in status.get("messages", []) or []:
|
|
if not isinstance(message, (list, tuple)) or len(message) < 2:
|
|
continue
|
|
payload = message[1] or {}
|
|
if message[0] == "execution_error":
|
|
exception_message = payload.get("exception_message")
|
|
node_id = payload.get("node_id")
|
|
node_type = payload.get("node_type")
|
|
if exception_message and node_id and node_type:
|
|
return f"ComfyUI execution error on node {node_id} ({node_type}): {exception_message}"
|
|
if exception_message:
|
|
return f"ComfyUI execution error: {exception_message}"
|
|
return f"ComfyUI execution error: {payload}"
|
|
return "ComfyUI execution failed without a detailed error message."
|
|
|
|
if history.get("node_errors"):
|
|
return f"ComfyUI node validation failed: {history['node_errors']}"
|
|
|
|
return None
|
|
|
|
|
|
def _output_type_for_filename(filename: str) -> str:
|
|
return "video" if Path(filename).suffix.lower() in VIDEO_EXTENSIONS else "image"
|
|
|
|
|
|
def _required_model_values(workflow: dict) -> dict[str, set[str]]:
|
|
required: dict[str, set[str]] = {loader: set() for loader in MODEL_LOADER_INPUTS}
|
|
for node in workflow.values():
|
|
if not isinstance(node, dict):
|
|
continue
|
|
class_type = node.get("class_type")
|
|
input_name = MODEL_LOADER_INPUTS.get(class_type)
|
|
if not input_name:
|
|
continue
|
|
value = (node.get("inputs") or {}).get(input_name)
|
|
if isinstance(value, str) and value:
|
|
required[class_type].add(value)
|
|
return {loader: values for loader, values in required.items() if values}
|
|
|
|
|
|
async def _validate_runtime_models(workflow: dict) -> None:
|
|
required = _required_model_values(workflow)
|
|
if not required:
|
|
return
|
|
|
|
missing_by_loader: list[str] = []
|
|
for loader, expected_values in required.items():
|
|
object_info = await comfy_client.get_object_info(loader)
|
|
loader_input = MODEL_LOADER_INPUTS[loader]
|
|
available_raw = (((object_info.get("input") or {}).get("required") or {}).get(loader_input) or [[]])[0]
|
|
available = set(available_raw or [])
|
|
missing = sorted(value for value in expected_values if value not in available)
|
|
if missing:
|
|
missing_by_loader.append(f"{loader} missing {missing}; available={sorted(available)}")
|
|
|
|
if missing_by_loader:
|
|
raise RuntimeError(
|
|
"ComfyUI runtime is missing required Wan model files. "
|
|
+ " | ".join(missing_by_loader)
|
|
)
|
|
|
|
|
|
async def _get_history_with_fallback(prompt_id: str) -> dict:
|
|
history = await comfy_client.get_history(prompt_id)
|
|
if history:
|
|
return history
|
|
all_history = await comfy_client.get_history_all()
|
|
return all_history.get(prompt_id, {})
|
|
|
|
|
|
def _iter_history_files(node_output: dict) -> Iterable[dict]:
|
|
for video in node_output.get("videos", []) or []:
|
|
yield {
|
|
"filename": video["filename"],
|
|
"subfolder": video.get("subfolder", ""),
|
|
"folder_type": video.get("type", "output"),
|
|
"output_type": "video",
|
|
}
|
|
|
|
for image in node_output.get("images", []) or []:
|
|
filename = image["filename"]
|
|
yield {
|
|
"filename": filename,
|
|
"subfolder": image.get("subfolder", ""),
|
|
"folder_type": image.get("type", "output"),
|
|
"output_type": _output_type_for_filename(filename),
|
|
}
|
|
|
|
|
|
async def _collect_outputs_from_history(db: Session, job: Job, history: dict) -> int:
|
|
existing_paths = {output.file_path for output in job.outputs}
|
|
created = 0
|
|
|
|
for node_id, node_output in (history.get("outputs", {}) or {}).items():
|
|
for file_info in _iter_history_files(node_output):
|
|
fname = file_info["filename"]
|
|
data = await comfy_client.download_output(
|
|
fname,
|
|
file_info["subfolder"],
|
|
file_info["folder_type"],
|
|
)
|
|
rel_path = output_storage.save_bytes(data, job.id, fname)
|
|
if rel_path in existing_paths:
|
|
continue
|
|
|
|
poster_path = None
|
|
if file_info["output_type"] == "video":
|
|
try:
|
|
poster_fname = f"poster_{Path(fname).stem}.jpg"
|
|
poster_abs = str(output_storage.absolute_path(f"{job.id}/{poster_fname}"))
|
|
subprocess.run(
|
|
["ffmpeg", "-y", "-i", str(output_storage.absolute_path(rel_path)), "-vframes", "1", poster_abs],
|
|
capture_output=True,
|
|
timeout=30,
|
|
check=False,
|
|
)
|
|
if Path(poster_abs).exists():
|
|
poster_path = f"{job.id}/{poster_fname}"
|
|
except Exception:
|
|
pass
|
|
|
|
db.add(
|
|
JobOutput(
|
|
job_id=job.id,
|
|
output_type=file_info["output_type"],
|
|
file_path=rel_path,
|
|
poster_path=poster_path,
|
|
metadata_json=json.dumps({"node_id": node_id, "filename": fname}),
|
|
)
|
|
)
|
|
existing_paths.add(rel_path)
|
|
created += 1
|
|
|
|
if created:
|
|
db.commit()
|
|
|
|
return created
|
|
|
|
|
|
async def reconcile_job_outputs_if_missing(job_id: str) -> bool:
|
|
db = SessionLocal()
|
|
try:
|
|
job = db.query(Job).filter(Job.id == job_id).first()
|
|
if not job:
|
|
return False
|
|
if job.status != "completed" or not job.comfy_prompt_id or job.outputs:
|
|
return False
|
|
|
|
history = await _get_history_with_fallback(job.comfy_prompt_id)
|
|
if not history or _extract_history_error(history):
|
|
return False
|
|
|
|
created = await _collect_outputs_from_history(db, job, history)
|
|
if created:
|
|
_add_event(db, job.id, "outputs_reconciled", f"Recovered {created} output file(s) from ComfyUI history.")
|
|
return True
|
|
return False
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
async def _upload_asset_to_comfy(db: Session, asset_id: Optional[str]) -> Optional[str]:
|
|
if not asset_id:
|
|
return None
|
|
asset = db.query(Asset).filter(Asset.id == asset_id).first()
|
|
if not asset:
|
|
raise ValueError(f"Asset {asset_id} not found")
|
|
if asset.is_trashed:
|
|
raise ValueError(f"Asset {asset.original_filename} is in trash")
|
|
return await comfy_client.upload_media(
|
|
str(asset_storage.absolute_path(asset.storage_path)),
|
|
asset.original_filename,
|
|
asset.asset_type,
|
|
)
|
|
|
|
|
|
def _validate_job(job: Job) -> list[str]:
|
|
errors = []
|
|
if not job.prompt or not job.prompt.strip():
|
|
errors.append("Prompt is required")
|
|
if not job.ground_truth_asset_id:
|
|
errors.append("Ground truth image is required")
|
|
if job.mode == "animate":
|
|
if job.submode not in ("move", "mix"):
|
|
errors.append("Animate mode requires submode 'move' or 'mix'")
|
|
elif job.mode == "audio":
|
|
if not job.audio_asset_id:
|
|
errors.append("Audio mode requires an audio file")
|
|
else:
|
|
errors.append("Unknown mode")
|
|
return errors
|
|
|
|
|
|
async def run_job(job_id: str) -> None:
|
|
db = SessionLocal()
|
|
try:
|
|
job = db.query(Job).filter(Job.id == job_id).first()
|
|
if not job:
|
|
return
|
|
|
|
_set_status(db, job, "validating")
|
|
errors = _validate_job(job)
|
|
if errors:
|
|
_set_status(db, job, "failed", "; ".join(errors))
|
|
return
|
|
|
|
_set_status(db, job, "uploading_assets")
|
|
gt_name = await _upload_asset_to_comfy(db, job.ground_truth_asset_id)
|
|
motion_name = await _upload_asset_to_comfy(db, job.motion_asset_id)
|
|
audio_name = await _upload_asset_to_comfy(db, job.audio_asset_id)
|
|
pose_name = await _upload_asset_to_comfy(db, job.pose_asset_id)
|
|
ref_names = []
|
|
if job.reference_asset_ids_json:
|
|
for ref_id in json.loads(job.reference_asset_ids_json):
|
|
uploaded = await _upload_asset_to_comfy(db, ref_id)
|
|
if uploaded:
|
|
ref_names.append(uploaded)
|
|
|
|
settings_dict = json.loads(job.settings_json) if job.settings_json else {}
|
|
binder = WorkflowBinder(select_template_name(job.mode, job.submode))
|
|
if "PLACEHOLDER" in binder.status.upper():
|
|
raise RuntimeError(
|
|
f"Workflow template '{select_template_name(job.mode, job.submode)}' is still a placeholder. "
|
|
"Replace it with the production ComfyUI export before running real generations."
|
|
)
|
|
raw_seed = settings_dict.get("seed", 0)
|
|
seed = raw_seed if isinstance(raw_seed, int) and raw_seed >= 0 else 0
|
|
|
|
params = {
|
|
"positive_prompt": job.prompt,
|
|
"negative_prompt": job.negative_prompt or "",
|
|
"ground_truth": gt_name,
|
|
"motion_video": motion_name,
|
|
"audio": audio_name,
|
|
"pose_sheet": pose_name,
|
|
"reference_image": ref_names[0] if ref_names else None,
|
|
"seed": seed,
|
|
"steps": settings_dict.get("steps", 20),
|
|
"cfg": settings_dict.get("cfg", 7.0),
|
|
}
|
|
workflow = binder.bind(params)
|
|
await _validate_runtime_models(workflow)
|
|
job.workflow_template_name = select_template_name(job.mode, job.submode)
|
|
job.workflow_template_version = binder.version
|
|
db.commit()
|
|
|
|
_set_status(db, job, "queued")
|
|
prompt_id = await comfy_client.submit_prompt(workflow, client_id=str(uuid.uuid4()))
|
|
job.comfy_prompt_id = prompt_id
|
|
db.commit()
|
|
_add_event(db, job.id, "submitted", f"ComfyUI prompt_id: {prompt_id}")
|
|
|
|
_set_status(db, job, "executing")
|
|
history = {}
|
|
for _ in range(360):
|
|
await asyncio.sleep(5)
|
|
history = await _get_history_with_fallback(prompt_id)
|
|
history_error = _extract_history_error(history)
|
|
if history_error:
|
|
_set_status(db, job, "failed", history_error)
|
|
return
|
|
if history.get("status", {}).get("completed"):
|
|
break
|
|
else:
|
|
_set_status(db, job, "failed", "Timed out waiting for ComfyUI")
|
|
return
|
|
|
|
_set_status(db, job, "collecting_outputs")
|
|
await _collect_outputs_from_history(db, job, history)
|
|
_set_status(db, job, "completed")
|
|
except Exception as exc:
|
|
logger.exception("Job %s failed: %s", job_id, exc)
|
|
job = db.query(Job).filter(Job.id == job_id).first()
|
|
if job:
|
|
_set_status(db, job, "failed", str(exc))
|
|
finally:
|
|
db.close()
|