Initial Animatrix import
This commit is contained in:
325
backend/app/services/orchestrator.py
Normal file
325
backend/app/services/orchestrator.py
Normal file
@@ -0,0 +1,325 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.session import SessionLocal
|
||||
from app.models import Asset, Job, JobEvent, JobOutput
|
||||
from app.services.comfy_client import comfy_client
|
||||
from app.services.storage import asset_storage, output_storage
|
||||
from app.services.workflow_binder import WorkflowBinder, select_template_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
VIDEO_EXTENSIONS = {".mp4", ".mov", ".webm", ".mkv", ".avi"}
|
||||
MODEL_LOADER_INPUTS = {
|
||||
"CLIPLoader": "clip_name",
|
||||
"VAELoader": "vae_name",
|
||||
"UNETLoader": "unet_name",
|
||||
"LoraLoaderModelOnly": "lora_name",
|
||||
}
|
||||
|
||||
|
||||
def _add_event(db: Session, job_id: str, event_type: str, message: str, payload: dict | None = None) -> None:
|
||||
event = JobEvent(
|
||||
job_id=job_id,
|
||||
event_type=event_type,
|
||||
message=message,
|
||||
payload_json=json.dumps(payload) if payload else None,
|
||||
)
|
||||
db.add(event)
|
||||
db.commit()
|
||||
|
||||
|
||||
def _set_status(db: Session, job: Job, status: str, error: str | None = None) -> None:
|
||||
job.status = status
|
||||
if error:
|
||||
job.error_message = error
|
||||
db.commit()
|
||||
_add_event(db, job.id, "status_change", f"Job status -> {status}")
|
||||
|
||||
|
||||
def _extract_history_error(history: dict) -> str | None:
|
||||
status = history.get("status", {}) or {}
|
||||
if status.get("status_str") == "error":
|
||||
for message in status.get("messages", []) or []:
|
||||
if not isinstance(message, (list, tuple)) or len(message) < 2:
|
||||
continue
|
||||
payload = message[1] or {}
|
||||
if message[0] == "execution_error":
|
||||
exception_message = payload.get("exception_message")
|
||||
node_id = payload.get("node_id")
|
||||
node_type = payload.get("node_type")
|
||||
if exception_message and node_id and node_type:
|
||||
return f"ComfyUI execution error on node {node_id} ({node_type}): {exception_message}"
|
||||
if exception_message:
|
||||
return f"ComfyUI execution error: {exception_message}"
|
||||
return f"ComfyUI execution error: {payload}"
|
||||
return "ComfyUI execution failed without a detailed error message."
|
||||
|
||||
if history.get("node_errors"):
|
||||
return f"ComfyUI node validation failed: {history['node_errors']}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _output_type_for_filename(filename: str) -> str:
|
||||
return "video" if Path(filename).suffix.lower() in VIDEO_EXTENSIONS else "image"
|
||||
|
||||
|
||||
def _required_model_values(workflow: dict) -> dict[str, set[str]]:
|
||||
required: dict[str, set[str]] = {loader: set() for loader in MODEL_LOADER_INPUTS}
|
||||
for node in workflow.values():
|
||||
if not isinstance(node, dict):
|
||||
continue
|
||||
class_type = node.get("class_type")
|
||||
input_name = MODEL_LOADER_INPUTS.get(class_type)
|
||||
if not input_name:
|
||||
continue
|
||||
value = (node.get("inputs") or {}).get(input_name)
|
||||
if isinstance(value, str) and value:
|
||||
required[class_type].add(value)
|
||||
return {loader: values for loader, values in required.items() if values}
|
||||
|
||||
|
||||
async def _validate_runtime_models(workflow: dict) -> None:
|
||||
required = _required_model_values(workflow)
|
||||
if not required:
|
||||
return
|
||||
|
||||
missing_by_loader: list[str] = []
|
||||
for loader, expected_values in required.items():
|
||||
object_info = await comfy_client.get_object_info(loader)
|
||||
loader_input = MODEL_LOADER_INPUTS[loader]
|
||||
available_raw = (((object_info.get("input") or {}).get("required") or {}).get(loader_input) or [[]])[0]
|
||||
available = set(available_raw or [])
|
||||
missing = sorted(value for value in expected_values if value not in available)
|
||||
if missing:
|
||||
missing_by_loader.append(f"{loader} missing {missing}; available={sorted(available)}")
|
||||
|
||||
if missing_by_loader:
|
||||
raise RuntimeError(
|
||||
"ComfyUI runtime is missing required Wan model files. "
|
||||
+ " | ".join(missing_by_loader)
|
||||
)
|
||||
|
||||
|
||||
async def _get_history_with_fallback(prompt_id: str) -> dict:
|
||||
history = await comfy_client.get_history(prompt_id)
|
||||
if history:
|
||||
return history
|
||||
all_history = await comfy_client.get_history_all()
|
||||
return all_history.get(prompt_id, {})
|
||||
|
||||
|
||||
def _iter_history_files(node_output: dict) -> Iterable[dict]:
|
||||
for video in node_output.get("videos", []) or []:
|
||||
yield {
|
||||
"filename": video["filename"],
|
||||
"subfolder": video.get("subfolder", ""),
|
||||
"folder_type": video.get("type", "output"),
|
||||
"output_type": "video",
|
||||
}
|
||||
|
||||
for image in node_output.get("images", []) or []:
|
||||
filename = image["filename"]
|
||||
yield {
|
||||
"filename": filename,
|
||||
"subfolder": image.get("subfolder", ""),
|
||||
"folder_type": image.get("type", "output"),
|
||||
"output_type": _output_type_for_filename(filename),
|
||||
}
|
||||
|
||||
|
||||
async def _collect_outputs_from_history(db: Session, job: Job, history: dict) -> int:
|
||||
existing_paths = {output.file_path for output in job.outputs}
|
||||
created = 0
|
||||
|
||||
for node_id, node_output in (history.get("outputs", {}) or {}).items():
|
||||
for file_info in _iter_history_files(node_output):
|
||||
fname = file_info["filename"]
|
||||
data = await comfy_client.download_output(
|
||||
fname,
|
||||
file_info["subfolder"],
|
||||
file_info["folder_type"],
|
||||
)
|
||||
rel_path = output_storage.save_bytes(data, job.id, fname)
|
||||
if rel_path in existing_paths:
|
||||
continue
|
||||
|
||||
poster_path = None
|
||||
if file_info["output_type"] == "video":
|
||||
try:
|
||||
poster_fname = f"poster_{Path(fname).stem}.jpg"
|
||||
poster_abs = str(output_storage.absolute_path(f"{job.id}/{poster_fname}"))
|
||||
subprocess.run(
|
||||
["ffmpeg", "-y", "-i", str(output_storage.absolute_path(rel_path)), "-vframes", "1", poster_abs],
|
||||
capture_output=True,
|
||||
timeout=30,
|
||||
check=False,
|
||||
)
|
||||
if Path(poster_abs).exists():
|
||||
poster_path = f"{job.id}/{poster_fname}"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
db.add(
|
||||
JobOutput(
|
||||
job_id=job.id,
|
||||
output_type=file_info["output_type"],
|
||||
file_path=rel_path,
|
||||
poster_path=poster_path,
|
||||
metadata_json=json.dumps({"node_id": node_id, "filename": fname}),
|
||||
)
|
||||
)
|
||||
existing_paths.add(rel_path)
|
||||
created += 1
|
||||
|
||||
if created:
|
||||
db.commit()
|
||||
|
||||
return created
|
||||
|
||||
|
||||
async def reconcile_job_outputs_if_missing(job_id: str) -> bool:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if not job:
|
||||
return False
|
||||
if job.status != "completed" or not job.comfy_prompt_id or job.outputs:
|
||||
return False
|
||||
|
||||
history = await _get_history_with_fallback(job.comfy_prompt_id)
|
||||
if not history or _extract_history_error(history):
|
||||
return False
|
||||
|
||||
created = await _collect_outputs_from_history(db, job, history)
|
||||
if created:
|
||||
_add_event(db, job.id, "outputs_reconciled", f"Recovered {created} output file(s) from ComfyUI history.")
|
||||
return True
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
async def _upload_asset_to_comfy(db: Session, asset_id: Optional[str]) -> Optional[str]:
|
||||
if not asset_id:
|
||||
return None
|
||||
asset = db.query(Asset).filter(Asset.id == asset_id).first()
|
||||
if not asset:
|
||||
raise ValueError(f"Asset {asset_id} not found")
|
||||
if asset.is_trashed:
|
||||
raise ValueError(f"Asset {asset.original_filename} is in trash")
|
||||
return await comfy_client.upload_media(
|
||||
str(asset_storage.absolute_path(asset.storage_path)),
|
||||
asset.original_filename,
|
||||
asset.asset_type,
|
||||
)
|
||||
|
||||
|
||||
def _validate_job(job: Job) -> list[str]:
|
||||
errors = []
|
||||
if not job.prompt or not job.prompt.strip():
|
||||
errors.append("Prompt is required")
|
||||
if not job.ground_truth_asset_id:
|
||||
errors.append("Ground truth image is required")
|
||||
if job.mode == "animate":
|
||||
if job.submode not in ("move", "mix"):
|
||||
errors.append("Animate mode requires submode 'move' or 'mix'")
|
||||
elif job.mode == "audio":
|
||||
if not job.audio_asset_id:
|
||||
errors.append("Audio mode requires an audio file")
|
||||
else:
|
||||
errors.append("Unknown mode")
|
||||
return errors
|
||||
|
||||
|
||||
async def run_job(job_id: str) -> None:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if not job:
|
||||
return
|
||||
|
||||
_set_status(db, job, "validating")
|
||||
errors = _validate_job(job)
|
||||
if errors:
|
||||
_set_status(db, job, "failed", "; ".join(errors))
|
||||
return
|
||||
|
||||
_set_status(db, job, "uploading_assets")
|
||||
gt_name = await _upload_asset_to_comfy(db, job.ground_truth_asset_id)
|
||||
motion_name = await _upload_asset_to_comfy(db, job.motion_asset_id)
|
||||
audio_name = await _upload_asset_to_comfy(db, job.audio_asset_id)
|
||||
pose_name = await _upload_asset_to_comfy(db, job.pose_asset_id)
|
||||
ref_names = []
|
||||
if job.reference_asset_ids_json:
|
||||
for ref_id in json.loads(job.reference_asset_ids_json):
|
||||
uploaded = await _upload_asset_to_comfy(db, ref_id)
|
||||
if uploaded:
|
||||
ref_names.append(uploaded)
|
||||
|
||||
settings_dict = json.loads(job.settings_json) if job.settings_json else {}
|
||||
binder = WorkflowBinder(select_template_name(job.mode, job.submode))
|
||||
if "PLACEHOLDER" in binder.status.upper():
|
||||
raise RuntimeError(
|
||||
f"Workflow template '{select_template_name(job.mode, job.submode)}' is still a placeholder. "
|
||||
"Replace it with the production ComfyUI export before running real generations."
|
||||
)
|
||||
raw_seed = settings_dict.get("seed", 0)
|
||||
seed = raw_seed if isinstance(raw_seed, int) and raw_seed >= 0 else 0
|
||||
|
||||
params = {
|
||||
"positive_prompt": job.prompt,
|
||||
"negative_prompt": job.negative_prompt or "",
|
||||
"ground_truth": gt_name,
|
||||
"motion_video": motion_name,
|
||||
"audio": audio_name,
|
||||
"pose_sheet": pose_name,
|
||||
"reference_image": ref_names[0] if ref_names else None,
|
||||
"seed": seed,
|
||||
"steps": settings_dict.get("steps", 20),
|
||||
"cfg": settings_dict.get("cfg", 7.0),
|
||||
}
|
||||
workflow = binder.bind(params)
|
||||
await _validate_runtime_models(workflow)
|
||||
job.workflow_template_name = select_template_name(job.mode, job.submode)
|
||||
job.workflow_template_version = binder.version
|
||||
db.commit()
|
||||
|
||||
_set_status(db, job, "queued")
|
||||
prompt_id = await comfy_client.submit_prompt(workflow, client_id=str(uuid.uuid4()))
|
||||
job.comfy_prompt_id = prompt_id
|
||||
db.commit()
|
||||
_add_event(db, job.id, "submitted", f"ComfyUI prompt_id: {prompt_id}")
|
||||
|
||||
_set_status(db, job, "executing")
|
||||
history = {}
|
||||
for _ in range(360):
|
||||
await asyncio.sleep(5)
|
||||
history = await _get_history_with_fallback(prompt_id)
|
||||
history_error = _extract_history_error(history)
|
||||
if history_error:
|
||||
_set_status(db, job, "failed", history_error)
|
||||
return
|
||||
if history.get("status", {}).get("completed"):
|
||||
break
|
||||
else:
|
||||
_set_status(db, job, "failed", "Timed out waiting for ComfyUI")
|
||||
return
|
||||
|
||||
_set_status(db, job, "collecting_outputs")
|
||||
await _collect_outputs_from_history(db, job, history)
|
||||
_set_status(db, job, "completed")
|
||||
except Exception as exc:
|
||||
logger.exception("Job %s failed: %s", job_id, exc)
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if job:
|
||||
_set_status(db, job, "failed", str(exc))
|
||||
finally:
|
||||
db.close()
|
||||
Reference in New Issue
Block a user