69 lines
2.4 KiB
Python
69 lines
2.4 KiB
Python
import copy
|
|
import json
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import Any, Dict, Optional
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
WORKFLOWS_ROOT = Path(__file__).parents[3] / "workflows"
|
|
_REGISTRY: Dict[str, Path] = {}
|
|
|
|
|
|
def _discover() -> None:
|
|
_REGISTRY.clear()
|
|
for path in WORKFLOWS_ROOT.rglob("*.json"):
|
|
try:
|
|
with open(path, encoding="utf-8") as handle:
|
|
data = json.load(handle)
|
|
meta = data.get("__animatrix_meta__", {})
|
|
_REGISTRY[meta.get("name") or path.stem] = path
|
|
except Exception as exc:
|
|
logger.warning("Could not load workflow %s: %s", path, exc)
|
|
|
|
|
|
_discover()
|
|
|
|
|
|
def select_template_name(mode: str, submode: Optional[str], model_preset: Optional[str] = None) -> str:
|
|
if mode == "animate":
|
|
if model_preset == "wan22-a14b-anime-style":
|
|
if (submode or "move") != "move":
|
|
raise ValueError("Anime Style preset is currently supported only for Animate / Move.")
|
|
return "wan22_animate_move_anime_style"
|
|
return f"wan22_animate_{submode or 'move'}"
|
|
if mode == "audio":
|
|
return "wan22_s2v"
|
|
raise ValueError(f"Unknown mode: {mode}")
|
|
|
|
|
|
class WorkflowBinder:
|
|
def __init__(self, template_name: str):
|
|
if template_name not in _REGISTRY:
|
|
_discover()
|
|
if template_name not in _REGISTRY:
|
|
raise FileNotFoundError(
|
|
f"Workflow template '{template_name}' not found in {WORKFLOWS_ROOT}. Available: {list(_REGISTRY.keys())}"
|
|
)
|
|
with open(_REGISTRY[template_name], encoding="utf-8") as handle:
|
|
self._raw = json.load(handle)
|
|
self._meta = self._raw.get("__animatrix_meta__", {})
|
|
self._param_nodes = self._meta.get("param_nodes", {})
|
|
self.version = self._meta.get("version", "unknown")
|
|
self.status = self._meta.get("status", "")
|
|
|
|
def bind(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
|
workflow = copy.deepcopy(self._raw)
|
|
workflow.pop("__animatrix_meta__", None)
|
|
for param_key, value in params.items():
|
|
if value is None:
|
|
continue
|
|
node_spec = self._param_nodes.get(param_key)
|
|
if not node_spec:
|
|
continue
|
|
node_id = str(node_spec["node_id"])
|
|
input_name = node_spec["input"]
|
|
if node_id in workflow:
|
|
workflow[node_id]["inputs"][input_name] = value
|
|
return workflow
|