Initial Animatrix import
This commit is contained in:
128
backend/app/services/comfy_client.py
Normal file
128
backend/app/services/comfy_client.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import httpx
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ComfyClient:
|
||||
def __init__(self, base_url: str | None = None):
|
||||
self.base_url = (base_url or settings.COMFYUI_BASE_URL).rstrip("/")
|
||||
self._client = httpx.AsyncClient(timeout=120.0)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._client.aclose()
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
for endpoint in ("/system_stats", "/"):
|
||||
try:
|
||||
response = await self._client.get(f"{self.base_url}{endpoint}")
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning("ComfyUI health check failed at %s: %s", endpoint, exc)
|
||||
return False
|
||||
|
||||
async def upload_image(self, file_path: str, filename: str) -> str:
|
||||
with open(file_path, "rb") as handle:
|
||||
files = {"image": (filename, handle, "application/octet-stream")}
|
||||
response = await self._client.post(f"{self.base_url}/upload/image", files=files)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("name", filename)
|
||||
|
||||
async def upload_media(self, file_path: str, filename: str, media_type: str) -> str:
|
||||
endpoint = {
|
||||
"image": "/upload/image",
|
||||
"pose_sheet": "/upload/image",
|
||||
"video": "/upload/video",
|
||||
"audio": "/upload/audio",
|
||||
}.get(media_type)
|
||||
field_name = {
|
||||
"image": "image",
|
||||
"pose_sheet": "image",
|
||||
"video": "video",
|
||||
"audio": "audio",
|
||||
}.get(media_type)
|
||||
|
||||
if not endpoint or not field_name:
|
||||
raise ValueError(f"Unsupported ComfyUI upload media type: {media_type}")
|
||||
|
||||
mime_type = "application/octet-stream"
|
||||
suffix = Path(filename).suffix.lower()
|
||||
if media_type in ("image", "pose_sheet"):
|
||||
mime_type = {
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".png": "image/png",
|
||||
".webp": "image/webp",
|
||||
}.get(suffix, mime_type)
|
||||
elif media_type == "video":
|
||||
mime_type = {
|
||||
".mp4": "video/mp4",
|
||||
".webm": "video/webm",
|
||||
".mov": "video/quicktime",
|
||||
}.get(suffix, mime_type)
|
||||
elif media_type == "audio":
|
||||
mime_type = {
|
||||
".mp3": "audio/mpeg",
|
||||
".mp4": "audio/mp4",
|
||||
".wav": "audio/wav",
|
||||
".ogg": "audio/ogg",
|
||||
}.get(suffix, mime_type)
|
||||
|
||||
with open(file_path, "rb") as handle:
|
||||
files = {field_name: (filename, handle, mime_type)}
|
||||
response = await self._client.post(f"{self.base_url}{endpoint}", files=files)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
return data.get("name", filename)
|
||||
|
||||
async def submit_prompt(self, workflow: Dict[str, Any], client_id: str | None = None) -> str:
|
||||
payload: Dict[str, Any] = {"prompt": workflow}
|
||||
if client_id:
|
||||
payload["client_id"] = client_id
|
||||
response = await self._client.post(f"{self.base_url}/prompt", json=payload)
|
||||
if response.is_error:
|
||||
detail = response.text
|
||||
raise RuntimeError(f"ComfyUI prompt submission failed ({response.status_code}): {detail}")
|
||||
data = response.json()
|
||||
prompt_id = data.get("prompt_id")
|
||||
if not prompt_id:
|
||||
raise RuntimeError(f"No prompt_id returned by ComfyUI: {data}")
|
||||
return prompt_id
|
||||
|
||||
async def get_history(self, prompt_id: str) -> Dict[str, Any]:
|
||||
response = await self._client.get(f"{self.base_url}/history/{prompt_id}")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get(prompt_id, {})
|
||||
|
||||
async def get_history_all(self) -> Dict[str, Any]:
|
||||
response = await self._client.get(f"{self.base_url}/history")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def get_queue(self) -> Dict[str, Any]:
|
||||
response = await self._client.get(f"{self.base_url}/queue")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def get_object_info(self, node_name: str) -> Dict[str, Any]:
|
||||
response = await self._client.get(f"{self.base_url}/object_info/{node_name}")
|
||||
response.raise_for_status()
|
||||
return response.json().get(node_name, {})
|
||||
|
||||
async def download_output(self, filename: str, subfolder: str = "", folder_type: str = "output") -> bytes:
|
||||
params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||
response = await self._client.get(f"{self.base_url}/view", params=params)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
|
||||
comfy_client = ComfyClient()
|
||||
325
backend/app/services/orchestrator.py
Normal file
325
backend/app/services/orchestrator.py
Normal file
@@ -0,0 +1,325 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.session import SessionLocal
|
||||
from app.models import Asset, Job, JobEvent, JobOutput
|
||||
from app.services.comfy_client import comfy_client
|
||||
from app.services.storage import asset_storage, output_storage
|
||||
from app.services.workflow_binder import WorkflowBinder, select_template_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
VIDEO_EXTENSIONS = {".mp4", ".mov", ".webm", ".mkv", ".avi"}
|
||||
MODEL_LOADER_INPUTS = {
|
||||
"CLIPLoader": "clip_name",
|
||||
"VAELoader": "vae_name",
|
||||
"UNETLoader": "unet_name",
|
||||
"LoraLoaderModelOnly": "lora_name",
|
||||
}
|
||||
|
||||
|
||||
def _add_event(db: Session, job_id: str, event_type: str, message: str, payload: dict | None = None) -> None:
|
||||
event = JobEvent(
|
||||
job_id=job_id,
|
||||
event_type=event_type,
|
||||
message=message,
|
||||
payload_json=json.dumps(payload) if payload else None,
|
||||
)
|
||||
db.add(event)
|
||||
db.commit()
|
||||
|
||||
|
||||
def _set_status(db: Session, job: Job, status: str, error: str | None = None) -> None:
|
||||
job.status = status
|
||||
if error:
|
||||
job.error_message = error
|
||||
db.commit()
|
||||
_add_event(db, job.id, "status_change", f"Job status -> {status}")
|
||||
|
||||
|
||||
def _extract_history_error(history: dict) -> str | None:
|
||||
status = history.get("status", {}) or {}
|
||||
if status.get("status_str") == "error":
|
||||
for message in status.get("messages", []) or []:
|
||||
if not isinstance(message, (list, tuple)) or len(message) < 2:
|
||||
continue
|
||||
payload = message[1] or {}
|
||||
if message[0] == "execution_error":
|
||||
exception_message = payload.get("exception_message")
|
||||
node_id = payload.get("node_id")
|
||||
node_type = payload.get("node_type")
|
||||
if exception_message and node_id and node_type:
|
||||
return f"ComfyUI execution error on node {node_id} ({node_type}): {exception_message}"
|
||||
if exception_message:
|
||||
return f"ComfyUI execution error: {exception_message}"
|
||||
return f"ComfyUI execution error: {payload}"
|
||||
return "ComfyUI execution failed without a detailed error message."
|
||||
|
||||
if history.get("node_errors"):
|
||||
return f"ComfyUI node validation failed: {history['node_errors']}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _output_type_for_filename(filename: str) -> str:
|
||||
return "video" if Path(filename).suffix.lower() in VIDEO_EXTENSIONS else "image"
|
||||
|
||||
|
||||
def _required_model_values(workflow: dict) -> dict[str, set[str]]:
|
||||
required: dict[str, set[str]] = {loader: set() for loader in MODEL_LOADER_INPUTS}
|
||||
for node in workflow.values():
|
||||
if not isinstance(node, dict):
|
||||
continue
|
||||
class_type = node.get("class_type")
|
||||
input_name = MODEL_LOADER_INPUTS.get(class_type)
|
||||
if not input_name:
|
||||
continue
|
||||
value = (node.get("inputs") or {}).get(input_name)
|
||||
if isinstance(value, str) and value:
|
||||
required[class_type].add(value)
|
||||
return {loader: values for loader, values in required.items() if values}
|
||||
|
||||
|
||||
async def _validate_runtime_models(workflow: dict) -> None:
|
||||
required = _required_model_values(workflow)
|
||||
if not required:
|
||||
return
|
||||
|
||||
missing_by_loader: list[str] = []
|
||||
for loader, expected_values in required.items():
|
||||
object_info = await comfy_client.get_object_info(loader)
|
||||
loader_input = MODEL_LOADER_INPUTS[loader]
|
||||
available_raw = (((object_info.get("input") or {}).get("required") or {}).get(loader_input) or [[]])[0]
|
||||
available = set(available_raw or [])
|
||||
missing = sorted(value for value in expected_values if value not in available)
|
||||
if missing:
|
||||
missing_by_loader.append(f"{loader} missing {missing}; available={sorted(available)}")
|
||||
|
||||
if missing_by_loader:
|
||||
raise RuntimeError(
|
||||
"ComfyUI runtime is missing required Wan model files. "
|
||||
+ " | ".join(missing_by_loader)
|
||||
)
|
||||
|
||||
|
||||
async def _get_history_with_fallback(prompt_id: str) -> dict:
|
||||
history = await comfy_client.get_history(prompt_id)
|
||||
if history:
|
||||
return history
|
||||
all_history = await comfy_client.get_history_all()
|
||||
return all_history.get(prompt_id, {})
|
||||
|
||||
|
||||
def _iter_history_files(node_output: dict) -> Iterable[dict]:
|
||||
for video in node_output.get("videos", []) or []:
|
||||
yield {
|
||||
"filename": video["filename"],
|
||||
"subfolder": video.get("subfolder", ""),
|
||||
"folder_type": video.get("type", "output"),
|
||||
"output_type": "video",
|
||||
}
|
||||
|
||||
for image in node_output.get("images", []) or []:
|
||||
filename = image["filename"]
|
||||
yield {
|
||||
"filename": filename,
|
||||
"subfolder": image.get("subfolder", ""),
|
||||
"folder_type": image.get("type", "output"),
|
||||
"output_type": _output_type_for_filename(filename),
|
||||
}
|
||||
|
||||
|
||||
async def _collect_outputs_from_history(db: Session, job: Job, history: dict) -> int:
|
||||
existing_paths = {output.file_path for output in job.outputs}
|
||||
created = 0
|
||||
|
||||
for node_id, node_output in (history.get("outputs", {}) or {}).items():
|
||||
for file_info in _iter_history_files(node_output):
|
||||
fname = file_info["filename"]
|
||||
data = await comfy_client.download_output(
|
||||
fname,
|
||||
file_info["subfolder"],
|
||||
file_info["folder_type"],
|
||||
)
|
||||
rel_path = output_storage.save_bytes(data, job.id, fname)
|
||||
if rel_path in existing_paths:
|
||||
continue
|
||||
|
||||
poster_path = None
|
||||
if file_info["output_type"] == "video":
|
||||
try:
|
||||
poster_fname = f"poster_{Path(fname).stem}.jpg"
|
||||
poster_abs = str(output_storage.absolute_path(f"{job.id}/{poster_fname}"))
|
||||
subprocess.run(
|
||||
["ffmpeg", "-y", "-i", str(output_storage.absolute_path(rel_path)), "-vframes", "1", poster_abs],
|
||||
capture_output=True,
|
||||
timeout=30,
|
||||
check=False,
|
||||
)
|
||||
if Path(poster_abs).exists():
|
||||
poster_path = f"{job.id}/{poster_fname}"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
db.add(
|
||||
JobOutput(
|
||||
job_id=job.id,
|
||||
output_type=file_info["output_type"],
|
||||
file_path=rel_path,
|
||||
poster_path=poster_path,
|
||||
metadata_json=json.dumps({"node_id": node_id, "filename": fname}),
|
||||
)
|
||||
)
|
||||
existing_paths.add(rel_path)
|
||||
created += 1
|
||||
|
||||
if created:
|
||||
db.commit()
|
||||
|
||||
return created
|
||||
|
||||
|
||||
async def reconcile_job_outputs_if_missing(job_id: str) -> bool:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if not job:
|
||||
return False
|
||||
if job.status != "completed" or not job.comfy_prompt_id or job.outputs:
|
||||
return False
|
||||
|
||||
history = await _get_history_with_fallback(job.comfy_prompt_id)
|
||||
if not history or _extract_history_error(history):
|
||||
return False
|
||||
|
||||
created = await _collect_outputs_from_history(db, job, history)
|
||||
if created:
|
||||
_add_event(db, job.id, "outputs_reconciled", f"Recovered {created} output file(s) from ComfyUI history.")
|
||||
return True
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
async def _upload_asset_to_comfy(db: Session, asset_id: Optional[str]) -> Optional[str]:
|
||||
if not asset_id:
|
||||
return None
|
||||
asset = db.query(Asset).filter(Asset.id == asset_id).first()
|
||||
if not asset:
|
||||
raise ValueError(f"Asset {asset_id} not found")
|
||||
if asset.is_trashed:
|
||||
raise ValueError(f"Asset {asset.original_filename} is in trash")
|
||||
return await comfy_client.upload_media(
|
||||
str(asset_storage.absolute_path(asset.storage_path)),
|
||||
asset.original_filename,
|
||||
asset.asset_type,
|
||||
)
|
||||
|
||||
|
||||
def _validate_job(job: Job) -> list[str]:
|
||||
errors = []
|
||||
if not job.prompt or not job.prompt.strip():
|
||||
errors.append("Prompt is required")
|
||||
if not job.ground_truth_asset_id:
|
||||
errors.append("Ground truth image is required")
|
||||
if job.mode == "animate":
|
||||
if job.submode not in ("move", "mix"):
|
||||
errors.append("Animate mode requires submode 'move' or 'mix'")
|
||||
elif job.mode == "audio":
|
||||
if not job.audio_asset_id:
|
||||
errors.append("Audio mode requires an audio file")
|
||||
else:
|
||||
errors.append("Unknown mode")
|
||||
return errors
|
||||
|
||||
|
||||
async def run_job(job_id: str) -> None:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if not job:
|
||||
return
|
||||
|
||||
_set_status(db, job, "validating")
|
||||
errors = _validate_job(job)
|
||||
if errors:
|
||||
_set_status(db, job, "failed", "; ".join(errors))
|
||||
return
|
||||
|
||||
_set_status(db, job, "uploading_assets")
|
||||
gt_name = await _upload_asset_to_comfy(db, job.ground_truth_asset_id)
|
||||
motion_name = await _upload_asset_to_comfy(db, job.motion_asset_id)
|
||||
audio_name = await _upload_asset_to_comfy(db, job.audio_asset_id)
|
||||
pose_name = await _upload_asset_to_comfy(db, job.pose_asset_id)
|
||||
ref_names = []
|
||||
if job.reference_asset_ids_json:
|
||||
for ref_id in json.loads(job.reference_asset_ids_json):
|
||||
uploaded = await _upload_asset_to_comfy(db, ref_id)
|
||||
if uploaded:
|
||||
ref_names.append(uploaded)
|
||||
|
||||
settings_dict = json.loads(job.settings_json) if job.settings_json else {}
|
||||
binder = WorkflowBinder(select_template_name(job.mode, job.submode))
|
||||
if "PLACEHOLDER" in binder.status.upper():
|
||||
raise RuntimeError(
|
||||
f"Workflow template '{select_template_name(job.mode, job.submode)}' is still a placeholder. "
|
||||
"Replace it with the production ComfyUI export before running real generations."
|
||||
)
|
||||
raw_seed = settings_dict.get("seed", 0)
|
||||
seed = raw_seed if isinstance(raw_seed, int) and raw_seed >= 0 else 0
|
||||
|
||||
params = {
|
||||
"positive_prompt": job.prompt,
|
||||
"negative_prompt": job.negative_prompt or "",
|
||||
"ground_truth": gt_name,
|
||||
"motion_video": motion_name,
|
||||
"audio": audio_name,
|
||||
"pose_sheet": pose_name,
|
||||
"reference_image": ref_names[0] if ref_names else None,
|
||||
"seed": seed,
|
||||
"steps": settings_dict.get("steps", 20),
|
||||
"cfg": settings_dict.get("cfg", 7.0),
|
||||
}
|
||||
workflow = binder.bind(params)
|
||||
await _validate_runtime_models(workflow)
|
||||
job.workflow_template_name = select_template_name(job.mode, job.submode)
|
||||
job.workflow_template_version = binder.version
|
||||
db.commit()
|
||||
|
||||
_set_status(db, job, "queued")
|
||||
prompt_id = await comfy_client.submit_prompt(workflow, client_id=str(uuid.uuid4()))
|
||||
job.comfy_prompt_id = prompt_id
|
||||
db.commit()
|
||||
_add_event(db, job.id, "submitted", f"ComfyUI prompt_id: {prompt_id}")
|
||||
|
||||
_set_status(db, job, "executing")
|
||||
history = {}
|
||||
for _ in range(360):
|
||||
await asyncio.sleep(5)
|
||||
history = await _get_history_with_fallback(prompt_id)
|
||||
history_error = _extract_history_error(history)
|
||||
if history_error:
|
||||
_set_status(db, job, "failed", history_error)
|
||||
return
|
||||
if history.get("status", {}).get("completed"):
|
||||
break
|
||||
else:
|
||||
_set_status(db, job, "failed", "Timed out waiting for ComfyUI")
|
||||
return
|
||||
|
||||
_set_status(db, job, "collecting_outputs")
|
||||
await _collect_outputs_from_history(db, job, history)
|
||||
_set_status(db, job, "completed")
|
||||
except Exception as exc:
|
||||
logger.exception("Job %s failed: %s", job_id, exc)
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if job:
|
||||
_set_status(db, job, "failed", str(exc))
|
||||
finally:
|
||||
db.close()
|
||||
118
backend/app/services/storage.py
Normal file
118
backend/app/services/storage.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import subprocess
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import aiofiles
|
||||
from fastapi import UploadFile
|
||||
from PIL import Image
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class LocalStorageService:
|
||||
def __init__(self, root: str):
|
||||
self.root = Path(root)
|
||||
self.root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def save_upload(self, upload: UploadFile, subfolder: str) -> tuple[str, int]:
|
||||
dest_dir = self.root / subfolder
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
ext = Path(upload.filename or "file").suffix
|
||||
filename = f"{uuid.uuid4()}{ext}"
|
||||
dest_path = dest_dir / filename
|
||||
content = await upload.read()
|
||||
async with aiofiles.open(dest_path, "wb") as handle:
|
||||
await handle.write(content)
|
||||
return str(dest_path.relative_to(self.root)).replace("\\", "/"), len(content)
|
||||
|
||||
def save_bytes(self, data: bytes, subfolder: str, filename: str) -> str:
|
||||
dest_dir = self.root / subfolder
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest_path = dest_dir / filename
|
||||
with open(dest_path, "wb") as handle:
|
||||
handle.write(data)
|
||||
return str(dest_path.relative_to(self.root)).replace("\\", "/")
|
||||
|
||||
def absolute_path(self, relative_path: str) -> Path:
|
||||
return self.root / relative_path
|
||||
|
||||
def delete_relative_path(self, relative_path: Optional[str]) -> None:
|
||||
if not relative_path:
|
||||
return
|
||||
abs_path = self.absolute_path(relative_path)
|
||||
try:
|
||||
if abs_path.exists():
|
||||
abs_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def generate_thumbnail(self, image_path: str, thumb_subfolder: str) -> Optional[str]:
|
||||
try:
|
||||
abs_path = self.absolute_path(image_path)
|
||||
with Image.open(abs_path) as img:
|
||||
img.thumbnail((400, 400))
|
||||
thumb_dir = self.root / thumb_subfolder
|
||||
thumb_dir.mkdir(parents=True, exist_ok=True)
|
||||
thumb_name = f"thumb_{Path(image_path).stem}.jpg"
|
||||
thumb_path = thumb_dir / thumb_name
|
||||
img.convert("RGB").save(thumb_path, "JPEG", quality=80)
|
||||
return str(thumb_path.relative_to(self.root)).replace("\\", "/")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def generate_video_thumbnail(self, video_path: str, thumb_subfolder: str) -> Optional[str]:
|
||||
abs_path = self.absolute_path(video_path)
|
||||
thumb_dir = self.root / thumb_subfolder
|
||||
thumb_dir.mkdir(parents=True, exist_ok=True)
|
||||
thumb_name = f"thumb_{Path(video_path).stem}.jpg"
|
||||
thumb_path = thumb_dir / thumb_name
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(abs_path),
|
||||
"-ss",
|
||||
"00:00:00.500",
|
||||
"-vframes",
|
||||
"1",
|
||||
str(thumb_path),
|
||||
],
|
||||
capture_output=True,
|
||||
timeout=30,
|
||||
check=False,
|
||||
)
|
||||
if thumb_path.exists():
|
||||
return str(thumb_path.relative_to(self.root)).replace("\\", "/")
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
def detect_duration_seconds(self, relative_path: str) -> Optional[float]:
|
||||
abs_path = self.absolute_path(relative_path)
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-show_entries",
|
||||
"format=duration",
|
||||
"-of",
|
||||
"default=noprint_wrappers=1:nokey=1",
|
||||
str(abs_path),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=20,
|
||||
check=True,
|
||||
)
|
||||
return round(float(result.stdout.strip()), 3)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
asset_storage = LocalStorageService(settings.ASSET_STORAGE_ROOT)
|
||||
output_storage = LocalStorageService(settings.OUTPUT_STORAGE_ROOT)
|
||||
64
backend/app/services/workflow_binder.py
Normal file
64
backend/app/services/workflow_binder.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WORKFLOWS_ROOT = Path(__file__).parents[3] / "workflows"
|
||||
_REGISTRY: Dict[str, Path] = {}
|
||||
|
||||
|
||||
def _discover() -> None:
|
||||
_REGISTRY.clear()
|
||||
for path in WORKFLOWS_ROOT.rglob("*.json"):
|
||||
try:
|
||||
with open(path, encoding="utf-8") as handle:
|
||||
data = json.load(handle)
|
||||
meta = data.get("__animatrix_meta__", {})
|
||||
_REGISTRY[meta.get("name") or path.stem] = path
|
||||
except Exception as exc:
|
||||
logger.warning("Could not load workflow %s: %s", path, exc)
|
||||
|
||||
|
||||
_discover()
|
||||
|
||||
|
||||
def select_template_name(mode: str, submode: Optional[str]) -> str:
|
||||
if mode == "animate":
|
||||
return f"wan22_animate_{submode or 'move'}"
|
||||
if mode == "audio":
|
||||
return "wan22_s2v"
|
||||
raise ValueError(f"Unknown mode: {mode}")
|
||||
|
||||
|
||||
class WorkflowBinder:
|
||||
def __init__(self, template_name: str):
|
||||
if template_name not in _REGISTRY:
|
||||
_discover()
|
||||
if template_name not in _REGISTRY:
|
||||
raise FileNotFoundError(
|
||||
f"Workflow template '{template_name}' not found in {WORKFLOWS_ROOT}. Available: {list(_REGISTRY.keys())}"
|
||||
)
|
||||
with open(_REGISTRY[template_name], encoding="utf-8") as handle:
|
||||
self._raw = json.load(handle)
|
||||
self._meta = self._raw.get("__animatrix_meta__", {})
|
||||
self._param_nodes = self._meta.get("param_nodes", {})
|
||||
self.version = self._meta.get("version", "unknown")
|
||||
self.status = self._meta.get("status", "")
|
||||
|
||||
def bind(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
workflow = copy.deepcopy(self._raw)
|
||||
workflow.pop("__animatrix_meta__", None)
|
||||
for param_key, value in params.items():
|
||||
if value is None:
|
||||
continue
|
||||
node_spec = self._param_nodes.get(param_key)
|
||||
if not node_spec:
|
||||
continue
|
||||
node_id = str(node_spec["node_id"])
|
||||
input_name = node_spec["input"]
|
||||
if node_id in workflow:
|
||||
workflow[node_id]["inputs"][input_name] = value
|
||||
return workflow
|
||||
Reference in New Issue
Block a user