Files
Project_Animatix/backend/app/services/orchestrator.py
2026-04-17 19:11:57 +05:30

326 lines
12 KiB
Python

import asyncio
import json
import logging
import subprocess
import uuid
from pathlib import Path
from typing import Iterable, Optional
from sqlalchemy.orm import Session
from app.db.session import SessionLocal
from app.models import Asset, Job, JobEvent, JobOutput
from app.services.comfy_client import comfy_client
from app.services.storage import asset_storage, output_storage
from app.services.workflow_binder import WorkflowBinder, select_template_name
logger = logging.getLogger(__name__)
VIDEO_EXTENSIONS = {".mp4", ".mov", ".webm", ".mkv", ".avi"}
MODEL_LOADER_INPUTS = {
"CLIPLoader": "clip_name",
"VAELoader": "vae_name",
"UNETLoader": "unet_name",
"LoraLoaderModelOnly": "lora_name",
}
def _add_event(db: Session, job_id: str, event_type: str, message: str, payload: dict | None = None) -> None:
event = JobEvent(
job_id=job_id,
event_type=event_type,
message=message,
payload_json=json.dumps(payload) if payload else None,
)
db.add(event)
db.commit()
def _set_status(db: Session, job: Job, status: str, error: str | None = None) -> None:
job.status = status
if error:
job.error_message = error
db.commit()
_add_event(db, job.id, "status_change", f"Job status -> {status}")
def _extract_history_error(history: dict) -> str | None:
status = history.get("status", {}) or {}
if status.get("status_str") == "error":
for message in status.get("messages", []) or []:
if not isinstance(message, (list, tuple)) or len(message) < 2:
continue
payload = message[1] or {}
if message[0] == "execution_error":
exception_message = payload.get("exception_message")
node_id = payload.get("node_id")
node_type = payload.get("node_type")
if exception_message and node_id and node_type:
return f"ComfyUI execution error on node {node_id} ({node_type}): {exception_message}"
if exception_message:
return f"ComfyUI execution error: {exception_message}"
return f"ComfyUI execution error: {payload}"
return "ComfyUI execution failed without a detailed error message."
if history.get("node_errors"):
return f"ComfyUI node validation failed: {history['node_errors']}"
return None
def _output_type_for_filename(filename: str) -> str:
return "video" if Path(filename).suffix.lower() in VIDEO_EXTENSIONS else "image"
def _required_model_values(workflow: dict) -> dict[str, set[str]]:
required: dict[str, set[str]] = {loader: set() for loader in MODEL_LOADER_INPUTS}
for node in workflow.values():
if not isinstance(node, dict):
continue
class_type = node.get("class_type")
input_name = MODEL_LOADER_INPUTS.get(class_type)
if not input_name:
continue
value = (node.get("inputs") or {}).get(input_name)
if isinstance(value, str) and value:
required[class_type].add(value)
return {loader: values for loader, values in required.items() if values}
async def _validate_runtime_models(workflow: dict) -> None:
required = _required_model_values(workflow)
if not required:
return
missing_by_loader: list[str] = []
for loader, expected_values in required.items():
object_info = await comfy_client.get_object_info(loader)
loader_input = MODEL_LOADER_INPUTS[loader]
available_raw = (((object_info.get("input") or {}).get("required") or {}).get(loader_input) or [[]])[0]
available = set(available_raw or [])
missing = sorted(value for value in expected_values if value not in available)
if missing:
missing_by_loader.append(f"{loader} missing {missing}; available={sorted(available)}")
if missing_by_loader:
raise RuntimeError(
"ComfyUI runtime is missing required Wan model files. "
+ " | ".join(missing_by_loader)
)
async def _get_history_with_fallback(prompt_id: str) -> dict:
history = await comfy_client.get_history(prompt_id)
if history:
return history
all_history = await comfy_client.get_history_all()
return all_history.get(prompt_id, {})
def _iter_history_files(node_output: dict) -> Iterable[dict]:
for video in node_output.get("videos", []) or []:
yield {
"filename": video["filename"],
"subfolder": video.get("subfolder", ""),
"folder_type": video.get("type", "output"),
"output_type": "video",
}
for image in node_output.get("images", []) or []:
filename = image["filename"]
yield {
"filename": filename,
"subfolder": image.get("subfolder", ""),
"folder_type": image.get("type", "output"),
"output_type": _output_type_for_filename(filename),
}
async def _collect_outputs_from_history(db: Session, job: Job, history: dict) -> int:
existing_paths = {output.file_path for output in job.outputs}
created = 0
for node_id, node_output in (history.get("outputs", {}) or {}).items():
for file_info in _iter_history_files(node_output):
fname = file_info["filename"]
data = await comfy_client.download_output(
fname,
file_info["subfolder"],
file_info["folder_type"],
)
rel_path = output_storage.save_bytes(data, job.id, fname)
if rel_path in existing_paths:
continue
poster_path = None
if file_info["output_type"] == "video":
try:
poster_fname = f"poster_{Path(fname).stem}.jpg"
poster_abs = str(output_storage.absolute_path(f"{job.id}/{poster_fname}"))
subprocess.run(
["ffmpeg", "-y", "-i", str(output_storage.absolute_path(rel_path)), "-vframes", "1", poster_abs],
capture_output=True,
timeout=30,
check=False,
)
if Path(poster_abs).exists():
poster_path = f"{job.id}/{poster_fname}"
except Exception:
pass
db.add(
JobOutput(
job_id=job.id,
output_type=file_info["output_type"],
file_path=rel_path,
poster_path=poster_path,
metadata_json=json.dumps({"node_id": node_id, "filename": fname}),
)
)
existing_paths.add(rel_path)
created += 1
if created:
db.commit()
return created
async def reconcile_job_outputs_if_missing(job_id: str) -> bool:
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
return False
if job.status != "completed" or not job.comfy_prompt_id or job.outputs:
return False
history = await _get_history_with_fallback(job.comfy_prompt_id)
if not history or _extract_history_error(history):
return False
created = await _collect_outputs_from_history(db, job, history)
if created:
_add_event(db, job.id, "outputs_reconciled", f"Recovered {created} output file(s) from ComfyUI history.")
return True
return False
finally:
db.close()
async def _upload_asset_to_comfy(db: Session, asset_id: Optional[str]) -> Optional[str]:
if not asset_id:
return None
asset = db.query(Asset).filter(Asset.id == asset_id).first()
if not asset:
raise ValueError(f"Asset {asset_id} not found")
if asset.is_trashed:
raise ValueError(f"Asset {asset.original_filename} is in trash")
return await comfy_client.upload_media(
str(asset_storage.absolute_path(asset.storage_path)),
asset.original_filename,
asset.asset_type,
)
def _validate_job(job: Job) -> list[str]:
errors = []
if not job.prompt or not job.prompt.strip():
errors.append("Prompt is required")
if not job.ground_truth_asset_id:
errors.append("Ground truth image is required")
if job.mode == "animate":
if job.submode not in ("move", "mix"):
errors.append("Animate mode requires submode 'move' or 'mix'")
elif job.mode == "audio":
if not job.audio_asset_id:
errors.append("Audio mode requires an audio file")
else:
errors.append("Unknown mode")
return errors
async def run_job(job_id: str) -> None:
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
return
_set_status(db, job, "validating")
errors = _validate_job(job)
if errors:
_set_status(db, job, "failed", "; ".join(errors))
return
_set_status(db, job, "uploading_assets")
gt_name = await _upload_asset_to_comfy(db, job.ground_truth_asset_id)
motion_name = await _upload_asset_to_comfy(db, job.motion_asset_id)
audio_name = await _upload_asset_to_comfy(db, job.audio_asset_id)
pose_name = await _upload_asset_to_comfy(db, job.pose_asset_id)
ref_names = []
if job.reference_asset_ids_json:
for ref_id in json.loads(job.reference_asset_ids_json):
uploaded = await _upload_asset_to_comfy(db, ref_id)
if uploaded:
ref_names.append(uploaded)
settings_dict = json.loads(job.settings_json) if job.settings_json else {}
binder = WorkflowBinder(select_template_name(job.mode, job.submode))
if "PLACEHOLDER" in binder.status.upper():
raise RuntimeError(
f"Workflow template '{select_template_name(job.mode, job.submode)}' is still a placeholder. "
"Replace it with the production ComfyUI export before running real generations."
)
raw_seed = settings_dict.get("seed", 0)
seed = raw_seed if isinstance(raw_seed, int) and raw_seed >= 0 else 0
params = {
"positive_prompt": job.prompt,
"negative_prompt": job.negative_prompt or "",
"ground_truth": gt_name,
"motion_video": motion_name,
"audio": audio_name,
"pose_sheet": pose_name,
"reference_image": ref_names[0] if ref_names else None,
"seed": seed,
"steps": settings_dict.get("steps", 20),
"cfg": settings_dict.get("cfg", 7.0),
}
workflow = binder.bind(params)
await _validate_runtime_models(workflow)
job.workflow_template_name = select_template_name(job.mode, job.submode)
job.workflow_template_version = binder.version
db.commit()
_set_status(db, job, "queued")
prompt_id = await comfy_client.submit_prompt(workflow, client_id=str(uuid.uuid4()))
job.comfy_prompt_id = prompt_id
db.commit()
_add_event(db, job.id, "submitted", f"ComfyUI prompt_id: {prompt_id}")
_set_status(db, job, "executing")
history = {}
for _ in range(360):
await asyncio.sleep(5)
history = await _get_history_with_fallback(prompt_id)
history_error = _extract_history_error(history)
if history_error:
_set_status(db, job, "failed", history_error)
return
if history.get("status", {}).get("completed"):
break
else:
_set_status(db, job, "failed", "Timed out waiting for ComfyUI")
return
_set_status(db, job, "collecting_outputs")
await _collect_outputs_from_history(db, job, history)
_set_status(db, job, "completed")
except Exception as exc:
logger.exception("Job %s failed: %s", job_id, exc)
job = db.query(Job).filter(Job.id == job_id).first()
if job:
_set_status(db, job, "failed", str(exc))
finally:
db.close()