Initial Animatrix import

This commit is contained in:
Sagnik
2026-04-17 19:11:57 +05:30
commit c7994d17a9
60 changed files with 8516 additions and 0 deletions

11
backend/.env.example Normal file
View File

@@ -0,0 +1,11 @@
SECRET_KEY=change-me
DATABASE_URL=sqlite:///./animatrix.db
ASSET_STORAGE_ROOT=./storage/assets
OUTPUT_STORAGE_ROOT=./storage/outputs
COMFYUI_BASE_URL=https://comfy.desineuron.in
# Set to the public HTTPS origin in production so generated backend URLs and cookie policy stay correct.
# Example production value:
# BACKEND_BASE_URL=https://animatrix.desineuron.in
BACKEND_BASE_URL=http://localhost:8000
CORS_ORIGINS=http://localhost:3000,https://animatrix.desineuron.in
ACCESS_TOKEN_EXPIRE_MINUTES=10080

1
backend/app/__init__.py Normal file
View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,28 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from app.core.deps import get_current_user
from app.db.session import get_db
from app.models import Job, User
from app.services.comfy_client import comfy_client
router = APIRouter(prefix="/api/admin", tags=["admin"])
@router.get("/health")
async def health(_: User = Depends(get_current_user)):
return {"api": "ok", "comfyui": await comfy_client.health_check()}
@router.get("/queue")
async def queue(_: User = Depends(get_current_user)):
return await comfy_client.get_queue()
@router.get("/jobs-summary")
def jobs_summary(db: Session = Depends(get_db), _: User = Depends(get_current_user)):
total = db.query(Job).count()
active = db.query(Job).filter(Job.status.in_(["validating", "uploading_assets", "queued", "executing", "collecting_outputs"])).count()
completed = db.query(Job).filter(Job.status == "completed").count()
failed = db.query(Job).filter(Job.status == "failed").count()
return {"total": total, "active": active, "completed": completed, "failed": failed}

View File

@@ -0,0 +1,120 @@
from datetime import datetime, timedelta, timezone
from typing import List, Optional
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
from sqlalchemy.orm import Session
from app.core.deps import get_current_user
from app.db.session import get_db
from app.models import Asset, User
from app.schemas import AssetResponse, AssetTrashRequest
from app.services.storage import asset_storage
router = APIRouter(prefix="/api/assets", tags=["assets"])
ALLOWED_TYPES = {
"image": ["image/jpeg", "image/png", "image/webp"],
"video": ["video/mp4", "video/webm", "video/quicktime"],
"audio": ["audio/mpeg", "audio/mp4", "audio/wav", "audio/ogg", "audio/x-wav"],
"pose_sheet": ["image/jpeg", "image/png", "image/webp"],
}
MAX_SIZE_BYTES = 500 * 1024 * 1024
@router.post("/upload", response_model=AssetResponse, status_code=201)
async def upload_asset(
file: UploadFile = File(...),
asset_type: str = Form(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
if asset_type not in ALLOWED_TYPES:
raise HTTPException(400, f"asset_type must be one of {list(ALLOWED_TYPES.keys())}")
mime = file.content_type or ""
if mime not in ALLOWED_TYPES[asset_type]:
raise HTTPException(400, f"Unsupported mime type {mime} for {asset_type}")
subfolder = f"{current_user.id}/{asset_type}"
storage_path, size_bytes = await asset_storage.save_upload(file, subfolder)
if size_bytes > MAX_SIZE_BYTES:
raise HTTPException(413, "File too large (max 500MB)")
thumbnail_path = None
width = height = None
duration_seconds = None
if asset_type in ("image", "pose_sheet"):
thumbnail_path = asset_storage.generate_thumbnail(storage_path, f"{current_user.id}/thumbs")
try:
from PIL import Image
abs_path = asset_storage.absolute_path(storage_path)
with Image.open(abs_path) as image:
width, height = image.size
except Exception:
pass
elif asset_type == "video":
thumbnail_path = asset_storage.generate_video_thumbnail(storage_path, f"{current_user.id}/thumbs")
duration_seconds = asset_storage.detect_duration_seconds(storage_path)
else:
duration_seconds = asset_storage.detect_duration_seconds(storage_path)
asset = Asset(
owner_id=current_user.id,
asset_type=asset_type,
mime_type=mime,
original_filename=file.filename or "upload",
storage_path=storage_path,
thumbnail_path=thumbnail_path,
size_bytes=size_bytes,
width=width,
height=height,
duration_seconds=duration_seconds,
)
db.add(asset)
db.commit()
db.refresh(asset)
return asset
@router.get("/", response_model=List[AssetResponse])
def list_assets(
asset_type: Optional[str] = None,
include_trashed: bool = False,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
query = db.query(Asset).filter(Asset.owner_id == current_user.id)
if not include_trashed:
query = query.filter(Asset.is_trashed.is_(False))
if asset_type:
query = query.filter(Asset.asset_type == asset_type)
return query.order_by(Asset.created_at.desc()).all()
@router.post("/trash", status_code=200)
def move_assets_to_trash(
payload: AssetTrashRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
if not payload.asset_ids:
raise HTTPException(400, "No asset ids provided")
assets = (
db.query(Asset)
.filter(Asset.owner_id == current_user.id, Asset.id.in_(payload.asset_ids))
.all()
)
if not assets:
raise HTTPException(404, "No matching assets found")
delete_after_at = datetime.now(timezone.utc) + timedelta(days=30)
for asset in assets:
asset.is_trashed = True
asset.delete_after_at = delete_after_at
db.commit()
return {
"moved_to_trash": len(assets),
"delete_after_at": delete_after_at.isoformat(),
}

View File

@@ -0,0 +1,61 @@
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.deps import get_current_user
from app.core.security import create_access_token, hash_password, verify_password
from app.db.session import get_db
from app.models import User
from app.schemas import LoginRequest, RegisterRequest, UserResponse
router = APIRouter(prefix="/api/auth", tags=["auth"])
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
def register(payload: RegisterRequest, db: Session = Depends(get_db)):
existing = db.query(User).filter(User.email == payload.email).first()
if existing:
raise HTTPException(status_code=400, detail="Email already registered")
user = User(email=payload.email, password_hash=hash_password(payload.password))
db.add(user)
db.commit()
db.refresh(user)
return user
def _is_secure_request(request: Request) -> bool:
forwarded_proto = request.headers.get("x-forwarded-proto", "")
if "https" in forwarded_proto.lower():
return True
if request.url.scheme == "https":
return True
return settings.BACKEND_BASE_URL.startswith("https://")
@router.post("/login")
def login(payload: LoginRequest, request: Request, response: Response, db: Session = Depends(get_db)):
user = db.query(User).filter(User.email == payload.email).first()
if not user or not verify_password(payload.password, user.password_hash):
raise HTTPException(status_code=401, detail="Invalid credentials")
token = create_access_token(subject=user.id)
response.set_cookie(
key="access_token",
value=token,
httponly=True,
samesite="lax",
secure=_is_secure_request(request),
max_age=60 * 60 * 24 * 7,
)
return {"message": "Logged in", "user": UserResponse.model_validate(user)}
@router.post("/logout")
def logout(response: Response):
response.delete_cookie("access_token")
return {"message": "Logged out"}
@router.get("/me", response_model=UserResponse)
def me(current_user: User = Depends(get_current_user)):
return current_user

View File

@@ -0,0 +1,118 @@
import json
from typing import List, Optional
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session, selectinload
from app.core.deps import get_current_user
from app.db.session import get_db
from app.models import Asset, Job, JobOutput, User
from app.schemas import JobCreateRequest, JobListResponse, JobResponse
from app.services.orchestrator import reconcile_job_outputs_if_missing, run_job
from app.services.storage import output_storage
router = APIRouter(prefix="/api/jobs", tags=["jobs"])
@router.post("/", response_model=JobResponse, status_code=201)
async def create_job(
payload: JobCreateRequest,
background_tasks: BackgroundTasks,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
def assert_owns(asset_id: Optional[str], label: str):
if asset_id:
asset = (
db.query(Asset)
.filter(Asset.id == asset_id, Asset.owner_id == current_user.id, Asset.is_trashed.is_(False))
.first()
)
if not asset:
raise HTTPException(400, f"{label} asset not found or not owned by user")
assert_owns(payload.ground_truth_asset_id, "ground_truth")
assert_owns(payload.motion_asset_id, "motion")
assert_owns(payload.audio_asset_id, "audio")
assert_owns(payload.pose_asset_id, "pose_sheet")
for ref_id in payload.reference_asset_ids or []:
assert_owns(ref_id, f"reference {ref_id}")
job = Job(
owner_id=current_user.id,
mode=payload.mode,
submode=payload.submode,
prompt=payload.prompt,
negative_prompt=payload.negative_prompt,
status="created",
ground_truth_asset_id=payload.ground_truth_asset_id,
motion_asset_id=payload.motion_asset_id,
audio_asset_id=payload.audio_asset_id,
pose_asset_id=payload.pose_asset_id,
reference_asset_ids_json=json.dumps(payload.reference_asset_ids) if payload.reference_asset_ids else None,
settings_json=json.dumps(payload.settings) if payload.settings else None,
)
db.add(job)
db.commit()
db.refresh(job)
background_tasks.add_task(run_job, job.id)
return job
@router.get("/", response_model=List[JobListResponse])
def list_jobs(
mode: Optional[str] = None,
status: Optional[str] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
query = db.query(Job).filter(Job.owner_id == current_user.id)
query = query.options(selectinload(Job.outputs))
if mode:
query = query.filter(Job.mode == mode)
if status:
query = query.filter(Job.status == status)
return query.order_by(Job.created_at.desc()).limit(100).all()
@router.get("/{job_id}", response_model=JobResponse)
async def get_job(job_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
job = (
db.query(Job)
.options(selectinload(Job.outputs), selectinload(Job.events))
.filter(Job.id == job_id, Job.owner_id == current_user.id)
.first()
)
if not job:
raise HTTPException(404, "Job not found")
if job.status == "completed" and not job.outputs and job.comfy_prompt_id:
await reconcile_job_outputs_if_missing(job.id)
db.expire_all()
job = (
db.query(Job)
.options(selectinload(Job.outputs), selectinload(Job.events))
.filter(Job.id == job_id, Job.owner_id == current_user.id)
.first()
)
return job
@router.get("/{job_id}/outputs/{output_id}/download")
def download_output(
job_id: str,
output_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
job = db.query(Job).filter(Job.id == job_id, Job.owner_id == current_user.id).first()
if not job:
raise HTTPException(404, "Job not found")
output = db.query(JobOutput).filter(JobOutput.id == output_id, JobOutput.job_id == job_id).first()
if not output:
raise HTTPException(404, "Output not found")
abs_path = output_storage.absolute_path(output.file_path)
if not abs_path.exists():
raise HTTPException(404, "Output file not found")
return FileResponse(str(abs_path))

View File

@@ -0,0 +1,25 @@
from typing import List
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
SECRET_KEY: str = "dev-secret-change-in-production"
DATABASE_URL: str = "sqlite:///./animatrix.db"
ASSET_STORAGE_ROOT: str = "./storage/assets"
OUTPUT_STORAGE_ROOT: str = "./storage/outputs"
COMFYUI_BASE_URL: str = "https://comfy.desineuron.in"
BACKEND_BASE_URL: str = "http://localhost:8000"
CORS_ORIGINS: str = "http://localhost:3000"
ACCESS_TOKEN_EXPIRE_MINUTES: int = 10080
@property
def cors_origins_list(self) -> List[str]:
return [o.strip() for o in self.CORS_ORIGINS.split(",") if o.strip()]
class Config:
env_file = ".env"
extra = "ignore"
settings = Settings()

30
backend/app/core/deps.py Normal file
View File

@@ -0,0 +1,30 @@
from typing import Optional
from fastapi import Cookie, Depends, HTTPException, status
from sqlalchemy.orm import Session
from app.core.security import decode_access_token
from app.db.session import get_db
from app.models import User
def get_current_user(
access_token: Optional[str] = Cookie(default=None),
db: Session = Depends(get_db),
) -> User:
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
)
if not access_token:
raise credentials_exception
user_id = decode_access_token(access_token)
if not user_id:
raise credentials_exception
user = db.query(User).filter(User.id == user_id, User.is_active.is_(True)).first()
if not user:
raise credentials_exception
return user

View File

@@ -0,0 +1,34 @@
from datetime import datetime, timedelta, timezone
from typing import Optional
from jose import JWTError, jwt
from passlib.context import CryptContext
from app.core.config import settings
pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
ALGORITHM = "HS256"
def hash_password(password: str) -> str:
return pwd_context.hash(password)
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def create_access_token(subject: str, expires_delta: Optional[timedelta] = None) -> str:
expire = datetime.now(timezone.utc) + (
expires_delta or timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
)
payload = {"sub": subject, "exp": expire}
return jwt.encode(payload, settings.SECRET_KEY, algorithm=ALGORITHM)
def decode_access_token(token: str) -> Optional[str]:
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
return payload.get("sub")
except JWTError:
return None

63
backend/app/db/init_db.py Normal file
View File

@@ -0,0 +1,63 @@
from datetime import datetime, timezone
from pathlib import Path
from sqlalchemy import text
from app.core.config import settings
from app.db.session import Base, engine
from app.models import Asset, Job, JobEvent, JobOutput, User # noqa: F401
from app.services.storage import asset_storage
def init_db() -> None:
Path(settings.ASSET_STORAGE_ROOT).mkdir(parents=True, exist_ok=True)
Path(settings.OUTPUT_STORAGE_ROOT).mkdir(parents=True, exist_ok=True)
Base.metadata.create_all(bind=engine)
_migrate_assets_table()
_cleanup_expired_trashed_assets()
def _migrate_assets_table() -> None:
with engine.begin() as conn:
columns = {
row[1]
for row in conn.execute(text("PRAGMA table_info(assets)")).fetchall()
}
if "is_trashed" not in columns:
conn.execute(text("ALTER TABLE assets ADD COLUMN is_trashed BOOLEAN NOT NULL DEFAULT 0"))
if "delete_after_at" not in columns:
conn.execute(text("ALTER TABLE assets ADD COLUMN delete_after_at DATETIME NULL"))
def _cleanup_expired_trashed_assets() -> None:
now = datetime.now(timezone.utc).isoformat()
with engine.begin() as conn:
rows = conn.execute(
text(
"""
SELECT id, storage_path, thumbnail_path
FROM assets
WHERE is_trashed = 1
AND delete_after_at IS NOT NULL
AND delete_after_at <= :now
"""
),
{"now": now},
).fetchall()
for _, storage_path, thumbnail_path in rows:
asset_storage.delete_relative_path(storage_path)
asset_storage.delete_relative_path(thumbnail_path)
if rows:
conn.execute(
text(
"""
DELETE FROM assets
WHERE is_trashed = 1
AND delete_after_at IS NOT NULL
AND delete_after_at <= :now
"""
),
{"now": now},
)

23
backend/app/db/session.py Normal file
View File

@@ -0,0 +1,23 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import DeclarativeBase, sessionmaker
from app.core.config import settings
connect_args = {}
if settings.DATABASE_URL.startswith("sqlite"):
connect_args = {"check_same_thread": False}
engine = create_engine(settings.DATABASE_URL, connect_args=connect_args)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
class Base(DeclarativeBase):
pass
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()

32
backend/app/main.py Normal file
View File

@@ -0,0 +1,32 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from app.api.routes import admin, assets, auth, jobs
from app.core.config import settings
from app.db.init_db import init_db
app = FastAPI(title="Animatrix API", version="0.1.0")
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
init_db()
app.include_router(auth.router)
app.include_router(assets.router)
app.include_router(jobs.router)
app.include_router(admin.router)
app.mount("/storage/assets", StaticFiles(directory=settings.ASSET_STORAGE_ROOT), name="assets")
app.mount("/storage/outputs", StaticFiles(directory=settings.OUTPUT_STORAGE_ROOT), name="outputs")
@app.get("/health")
def health() -> dict[str, str]:
return {"status": "ok"}

View File

@@ -0,0 +1,3 @@
from app.models.models import Asset, Job, JobEvent, JobOutput, User
__all__ = ["User", "Asset", "Job", "JobOutput", "JobEvent"]

View File

@@ -0,0 +1,106 @@
import uuid
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy import Boolean, DateTime, Float, ForeignKey, Integer, String, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db.session import Base
def utcnow():
return datetime.now(timezone.utc)
def new_uuid():
return str(uuid.uuid4())
class User(Base):
__tablename__ = "users"
id: Mapped[str] = mapped_column(String, primary_key=True, default=new_uuid)
email: Mapped[str] = mapped_column(String, unique=True, nullable=False, index=True)
password_hash: Mapped[str] = mapped_column(String, nullable=False)
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow, onupdate=utcnow)
assets: Mapped[list["Asset"]] = relationship("Asset", back_populates="owner", cascade="all, delete-orphan")
jobs: Mapped[list["Job"]] = relationship("Job", back_populates="owner", cascade="all, delete-orphan")
class Asset(Base):
__tablename__ = "assets"
id: Mapped[str] = mapped_column(String, primary_key=True, default=new_uuid)
owner_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"), nullable=False, index=True)
asset_type: Mapped[str] = mapped_column(String, nullable=False)
mime_type: Mapped[str] = mapped_column(String, nullable=False)
original_filename: Mapped[str] = mapped_column(String, nullable=False)
storage_path: Mapped[str] = mapped_column(String, nullable=False)
thumbnail_path: Mapped[Optional[str]] = mapped_column(String, nullable=True)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
width: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
height: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
duration_seconds: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
is_trashed: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False, index=True)
delete_after_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow)
owner: Mapped["User"] = relationship("User", back_populates="assets")
class Job(Base):
__tablename__ = "jobs"
id: Mapped[str] = mapped_column(String, primary_key=True, default=new_uuid)
owner_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"), nullable=False, index=True)
mode: Mapped[str] = mapped_column(String, nullable=False)
submode: Mapped[Optional[str]] = mapped_column(String, nullable=True)
prompt: Mapped[str] = mapped_column(Text, nullable=False)
negative_prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
status: Mapped[str] = mapped_column(String, nullable=False, default="created", index=True)
comfy_prompt_id: Mapped[Optional[str]] = mapped_column(String, nullable=True)
workflow_template_name: Mapped[Optional[str]] = mapped_column(String, nullable=True)
workflow_template_version: Mapped[Optional[str]] = mapped_column(String, nullable=True)
settings_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
ground_truth_asset_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("assets.id"), nullable=True)
motion_asset_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("assets.id"), nullable=True)
audio_asset_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("assets.id"), nullable=True)
reference_asset_ids_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
pose_asset_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("assets.id"), nullable=True)
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow, onupdate=utcnow)
owner: Mapped["User"] = relationship("User", back_populates="jobs")
outputs: Mapped[list["JobOutput"]] = relationship("JobOutput", back_populates="job", cascade="all, delete-orphan")
events: Mapped[list["JobEvent"]] = relationship("JobEvent", back_populates="job", cascade="all, delete-orphan")
class JobOutput(Base):
__tablename__ = "job_outputs"
id: Mapped[str] = mapped_column(String, primary_key=True, default=new_uuid)
job_id: Mapped[str] = mapped_column(String, ForeignKey("jobs.id"), nullable=False, index=True)
output_type: Mapped[str] = mapped_column(String, nullable=False)
file_path: Mapped[str] = mapped_column(String, nullable=False)
poster_path: Mapped[Optional[str]] = mapped_column(String, nullable=True)
metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow)
job: Mapped["Job"] = relationship("Job", back_populates="outputs")
class JobEvent(Base):
__tablename__ = "job_events"
id: Mapped[str] = mapped_column(String, primary_key=True, default=new_uuid)
job_id: Mapped[str] = mapped_column(String, ForeignKey("jobs.id"), nullable=False, index=True)
event_type: Mapped[str] = mapped_column(String, nullable=False)
message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
payload_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow)
job: Mapped["Job"] = relationship("Job", back_populates="events")

View File

@@ -0,0 +1,25 @@
from app.schemas.schemas import (
AssetResponse,
AssetTrashRequest,
JobCreateRequest,
JobEventResponse,
JobListResponse,
JobOutputResponse,
JobResponse,
LoginRequest,
RegisterRequest,
UserResponse,
)
__all__ = [
"RegisterRequest",
"LoginRequest",
"UserResponse",
"AssetResponse",
"AssetTrashRequest",
"JobCreateRequest",
"JobOutputResponse",
"JobEventResponse",
"JobResponse",
"JobListResponse",
]

View File

@@ -0,0 +1,143 @@
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel, EmailStr, field_validator
class RegisterRequest(BaseModel):
email: EmailStr
password: str
@field_validator("password")
@classmethod
def password_min_length(cls, v: str) -> str:
if len(v) < 8:
raise ValueError("Password must be at least 8 characters")
return v
class LoginRequest(BaseModel):
email: EmailStr
password: str
class UserResponse(BaseModel):
id: str
email: str
created_at: datetime
class Config:
from_attributes = True
class AssetResponse(BaseModel):
id: str
asset_type: str
mime_type: str
original_filename: str
storage_path: str
size_bytes: int
width: Optional[int] = None
height: Optional[int] = None
duration_seconds: Optional[float] = None
thumbnail_path: Optional[str] = None
is_trashed: bool = False
delete_after_at: Optional[datetime] = None
created_at: datetime
class Config:
from_attributes = True
class JobCreateRequest(BaseModel):
mode: str
submode: Optional[str] = None
prompt: str
negative_prompt: Optional[str] = None
ground_truth_asset_id: str
motion_asset_id: Optional[str] = None
audio_asset_id: Optional[str] = None
reference_asset_ids: Optional[List[str]] = None
pose_asset_id: Optional[str] = None
settings: Optional[dict] = None
@field_validator("mode")
@classmethod
def validate_mode(cls, v: str) -> str:
if v not in ("animate", "audio"):
raise ValueError("mode must be 'animate' or 'audio'")
return v
@field_validator("submode")
@classmethod
def validate_submode(cls, v: Optional[str]) -> Optional[str]:
if v is not None and v not in ("move", "mix"):
raise ValueError("submode must be 'move' or 'mix'")
return v
class JobOutputResponse(BaseModel):
id: str
output_type: str
file_path: str
poster_path: Optional[str] = None
created_at: datetime
class Config:
from_attributes = True
class JobEventResponse(BaseModel):
id: str
event_type: str
message: Optional[str] = None
created_at: datetime
class Config:
from_attributes = True
class AssetTrashRequest(BaseModel):
asset_ids: List[str]
class JobResponse(BaseModel):
id: str
mode: str
submode: Optional[str] = None
prompt: str
negative_prompt: Optional[str] = None
status: str
comfy_prompt_id: Optional[str] = None
workflow_template_name: Optional[str] = None
error_message: Optional[str] = None
ground_truth_asset_id: Optional[str] = None
motion_asset_id: Optional[str] = None
audio_asset_id: Optional[str] = None
pose_asset_id: Optional[str] = None
outputs: List[JobOutputResponse] = []
events: List[JobEventResponse] = []
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class JobListResponse(BaseModel):
id: str
mode: str
submode: Optional[str] = None
prompt: str
error_message: Optional[str] = None
status: str
ground_truth_asset_id: Optional[str] = None
motion_asset_id: Optional[str] = None
audio_asset_id: Optional[str] = None
pose_asset_id: Optional[str] = None
outputs: List[JobOutputResponse] = []
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True

View File

@@ -0,0 +1,128 @@
import logging
from pathlib import Path
from typing import Any, Dict
import httpx
from app.core.config import settings
logger = logging.getLogger(__name__)
class ComfyClient:
def __init__(self, base_url: str | None = None):
self.base_url = (base_url or settings.COMFYUI_BASE_URL).rstrip("/")
self._client = httpx.AsyncClient(timeout=120.0)
async def close(self) -> None:
await self._client.aclose()
async def health_check(self) -> bool:
for endpoint in ("/system_stats", "/"):
try:
response = await self._client.get(f"{self.base_url}{endpoint}")
if response.status_code == 200:
return True
except Exception as exc:
logger.warning("ComfyUI health check failed at %s: %s", endpoint, exc)
return False
async def upload_image(self, file_path: str, filename: str) -> str:
with open(file_path, "rb") as handle:
files = {"image": (filename, handle, "application/octet-stream")}
response = await self._client.post(f"{self.base_url}/upload/image", files=files)
response.raise_for_status()
data = response.json()
return data.get("name", filename)
async def upload_media(self, file_path: str, filename: str, media_type: str) -> str:
endpoint = {
"image": "/upload/image",
"pose_sheet": "/upload/image",
"video": "/upload/video",
"audio": "/upload/audio",
}.get(media_type)
field_name = {
"image": "image",
"pose_sheet": "image",
"video": "video",
"audio": "audio",
}.get(media_type)
if not endpoint or not field_name:
raise ValueError(f"Unsupported ComfyUI upload media type: {media_type}")
mime_type = "application/octet-stream"
suffix = Path(filename).suffix.lower()
if media_type in ("image", "pose_sheet"):
mime_type = {
".jpg": "image/jpeg",
".jpeg": "image/jpeg",
".png": "image/png",
".webp": "image/webp",
}.get(suffix, mime_type)
elif media_type == "video":
mime_type = {
".mp4": "video/mp4",
".webm": "video/webm",
".mov": "video/quicktime",
}.get(suffix, mime_type)
elif media_type == "audio":
mime_type = {
".mp3": "audio/mpeg",
".mp4": "audio/mp4",
".wav": "audio/wav",
".ogg": "audio/ogg",
}.get(suffix, mime_type)
with open(file_path, "rb") as handle:
files = {field_name: (filename, handle, mime_type)}
response = await self._client.post(f"{self.base_url}{endpoint}", files=files)
response.raise_for_status()
data = response.json()
return data.get("name", filename)
async def submit_prompt(self, workflow: Dict[str, Any], client_id: str | None = None) -> str:
payload: Dict[str, Any] = {"prompt": workflow}
if client_id:
payload["client_id"] = client_id
response = await self._client.post(f"{self.base_url}/prompt", json=payload)
if response.is_error:
detail = response.text
raise RuntimeError(f"ComfyUI prompt submission failed ({response.status_code}): {detail}")
data = response.json()
prompt_id = data.get("prompt_id")
if not prompt_id:
raise RuntimeError(f"No prompt_id returned by ComfyUI: {data}")
return prompt_id
async def get_history(self, prompt_id: str) -> Dict[str, Any]:
response = await self._client.get(f"{self.base_url}/history/{prompt_id}")
response.raise_for_status()
data = response.json()
return data.get(prompt_id, {})
async def get_history_all(self) -> Dict[str, Any]:
response = await self._client.get(f"{self.base_url}/history")
response.raise_for_status()
return response.json()
async def get_queue(self) -> Dict[str, Any]:
response = await self._client.get(f"{self.base_url}/queue")
response.raise_for_status()
return response.json()
async def get_object_info(self, node_name: str) -> Dict[str, Any]:
response = await self._client.get(f"{self.base_url}/object_info/{node_name}")
response.raise_for_status()
return response.json().get(node_name, {})
async def download_output(self, filename: str, subfolder: str = "", folder_type: str = "output") -> bytes:
params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
response = await self._client.get(f"{self.base_url}/view", params=params)
response.raise_for_status()
return response.content
comfy_client = ComfyClient()

View File

@@ -0,0 +1,325 @@
import asyncio
import json
import logging
import subprocess
import uuid
from pathlib import Path
from typing import Iterable, Optional
from sqlalchemy.orm import Session
from app.db.session import SessionLocal
from app.models import Asset, Job, JobEvent, JobOutput
from app.services.comfy_client import comfy_client
from app.services.storage import asset_storage, output_storage
from app.services.workflow_binder import WorkflowBinder, select_template_name
logger = logging.getLogger(__name__)
VIDEO_EXTENSIONS = {".mp4", ".mov", ".webm", ".mkv", ".avi"}
MODEL_LOADER_INPUTS = {
"CLIPLoader": "clip_name",
"VAELoader": "vae_name",
"UNETLoader": "unet_name",
"LoraLoaderModelOnly": "lora_name",
}
def _add_event(db: Session, job_id: str, event_type: str, message: str, payload: dict | None = None) -> None:
event = JobEvent(
job_id=job_id,
event_type=event_type,
message=message,
payload_json=json.dumps(payload) if payload else None,
)
db.add(event)
db.commit()
def _set_status(db: Session, job: Job, status: str, error: str | None = None) -> None:
job.status = status
if error:
job.error_message = error
db.commit()
_add_event(db, job.id, "status_change", f"Job status -> {status}")
def _extract_history_error(history: dict) -> str | None:
status = history.get("status", {}) or {}
if status.get("status_str") == "error":
for message in status.get("messages", []) or []:
if not isinstance(message, (list, tuple)) or len(message) < 2:
continue
payload = message[1] or {}
if message[0] == "execution_error":
exception_message = payload.get("exception_message")
node_id = payload.get("node_id")
node_type = payload.get("node_type")
if exception_message and node_id and node_type:
return f"ComfyUI execution error on node {node_id} ({node_type}): {exception_message}"
if exception_message:
return f"ComfyUI execution error: {exception_message}"
return f"ComfyUI execution error: {payload}"
return "ComfyUI execution failed without a detailed error message."
if history.get("node_errors"):
return f"ComfyUI node validation failed: {history['node_errors']}"
return None
def _output_type_for_filename(filename: str) -> str:
return "video" if Path(filename).suffix.lower() in VIDEO_EXTENSIONS else "image"
def _required_model_values(workflow: dict) -> dict[str, set[str]]:
required: dict[str, set[str]] = {loader: set() for loader in MODEL_LOADER_INPUTS}
for node in workflow.values():
if not isinstance(node, dict):
continue
class_type = node.get("class_type")
input_name = MODEL_LOADER_INPUTS.get(class_type)
if not input_name:
continue
value = (node.get("inputs") or {}).get(input_name)
if isinstance(value, str) and value:
required[class_type].add(value)
return {loader: values for loader, values in required.items() if values}
async def _validate_runtime_models(workflow: dict) -> None:
required = _required_model_values(workflow)
if not required:
return
missing_by_loader: list[str] = []
for loader, expected_values in required.items():
object_info = await comfy_client.get_object_info(loader)
loader_input = MODEL_LOADER_INPUTS[loader]
available_raw = (((object_info.get("input") or {}).get("required") or {}).get(loader_input) or [[]])[0]
available = set(available_raw or [])
missing = sorted(value for value in expected_values if value not in available)
if missing:
missing_by_loader.append(f"{loader} missing {missing}; available={sorted(available)}")
if missing_by_loader:
raise RuntimeError(
"ComfyUI runtime is missing required Wan model files. "
+ " | ".join(missing_by_loader)
)
async def _get_history_with_fallback(prompt_id: str) -> dict:
history = await comfy_client.get_history(prompt_id)
if history:
return history
all_history = await comfy_client.get_history_all()
return all_history.get(prompt_id, {})
def _iter_history_files(node_output: dict) -> Iterable[dict]:
for video in node_output.get("videos", []) or []:
yield {
"filename": video["filename"],
"subfolder": video.get("subfolder", ""),
"folder_type": video.get("type", "output"),
"output_type": "video",
}
for image in node_output.get("images", []) or []:
filename = image["filename"]
yield {
"filename": filename,
"subfolder": image.get("subfolder", ""),
"folder_type": image.get("type", "output"),
"output_type": _output_type_for_filename(filename),
}
async def _collect_outputs_from_history(db: Session, job: Job, history: dict) -> int:
existing_paths = {output.file_path for output in job.outputs}
created = 0
for node_id, node_output in (history.get("outputs", {}) or {}).items():
for file_info in _iter_history_files(node_output):
fname = file_info["filename"]
data = await comfy_client.download_output(
fname,
file_info["subfolder"],
file_info["folder_type"],
)
rel_path = output_storage.save_bytes(data, job.id, fname)
if rel_path in existing_paths:
continue
poster_path = None
if file_info["output_type"] == "video":
try:
poster_fname = f"poster_{Path(fname).stem}.jpg"
poster_abs = str(output_storage.absolute_path(f"{job.id}/{poster_fname}"))
subprocess.run(
["ffmpeg", "-y", "-i", str(output_storage.absolute_path(rel_path)), "-vframes", "1", poster_abs],
capture_output=True,
timeout=30,
check=False,
)
if Path(poster_abs).exists():
poster_path = f"{job.id}/{poster_fname}"
except Exception:
pass
db.add(
JobOutput(
job_id=job.id,
output_type=file_info["output_type"],
file_path=rel_path,
poster_path=poster_path,
metadata_json=json.dumps({"node_id": node_id, "filename": fname}),
)
)
existing_paths.add(rel_path)
created += 1
if created:
db.commit()
return created
async def reconcile_job_outputs_if_missing(job_id: str) -> bool:
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
return False
if job.status != "completed" or not job.comfy_prompt_id or job.outputs:
return False
history = await _get_history_with_fallback(job.comfy_prompt_id)
if not history or _extract_history_error(history):
return False
created = await _collect_outputs_from_history(db, job, history)
if created:
_add_event(db, job.id, "outputs_reconciled", f"Recovered {created} output file(s) from ComfyUI history.")
return True
return False
finally:
db.close()
async def _upload_asset_to_comfy(db: Session, asset_id: Optional[str]) -> Optional[str]:
if not asset_id:
return None
asset = db.query(Asset).filter(Asset.id == asset_id).first()
if not asset:
raise ValueError(f"Asset {asset_id} not found")
if asset.is_trashed:
raise ValueError(f"Asset {asset.original_filename} is in trash")
return await comfy_client.upload_media(
str(asset_storage.absolute_path(asset.storage_path)),
asset.original_filename,
asset.asset_type,
)
def _validate_job(job: Job) -> list[str]:
errors = []
if not job.prompt or not job.prompt.strip():
errors.append("Prompt is required")
if not job.ground_truth_asset_id:
errors.append("Ground truth image is required")
if job.mode == "animate":
if job.submode not in ("move", "mix"):
errors.append("Animate mode requires submode 'move' or 'mix'")
elif job.mode == "audio":
if not job.audio_asset_id:
errors.append("Audio mode requires an audio file")
else:
errors.append("Unknown mode")
return errors
async def run_job(job_id: str) -> None:
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
return
_set_status(db, job, "validating")
errors = _validate_job(job)
if errors:
_set_status(db, job, "failed", "; ".join(errors))
return
_set_status(db, job, "uploading_assets")
gt_name = await _upload_asset_to_comfy(db, job.ground_truth_asset_id)
motion_name = await _upload_asset_to_comfy(db, job.motion_asset_id)
audio_name = await _upload_asset_to_comfy(db, job.audio_asset_id)
pose_name = await _upload_asset_to_comfy(db, job.pose_asset_id)
ref_names = []
if job.reference_asset_ids_json:
for ref_id in json.loads(job.reference_asset_ids_json):
uploaded = await _upload_asset_to_comfy(db, ref_id)
if uploaded:
ref_names.append(uploaded)
settings_dict = json.loads(job.settings_json) if job.settings_json else {}
binder = WorkflowBinder(select_template_name(job.mode, job.submode))
if "PLACEHOLDER" in binder.status.upper():
raise RuntimeError(
f"Workflow template '{select_template_name(job.mode, job.submode)}' is still a placeholder. "
"Replace it with the production ComfyUI export before running real generations."
)
raw_seed = settings_dict.get("seed", 0)
seed = raw_seed if isinstance(raw_seed, int) and raw_seed >= 0 else 0
params = {
"positive_prompt": job.prompt,
"negative_prompt": job.negative_prompt or "",
"ground_truth": gt_name,
"motion_video": motion_name,
"audio": audio_name,
"pose_sheet": pose_name,
"reference_image": ref_names[0] if ref_names else None,
"seed": seed,
"steps": settings_dict.get("steps", 20),
"cfg": settings_dict.get("cfg", 7.0),
}
workflow = binder.bind(params)
await _validate_runtime_models(workflow)
job.workflow_template_name = select_template_name(job.mode, job.submode)
job.workflow_template_version = binder.version
db.commit()
_set_status(db, job, "queued")
prompt_id = await comfy_client.submit_prompt(workflow, client_id=str(uuid.uuid4()))
job.comfy_prompt_id = prompt_id
db.commit()
_add_event(db, job.id, "submitted", f"ComfyUI prompt_id: {prompt_id}")
_set_status(db, job, "executing")
history = {}
for _ in range(360):
await asyncio.sleep(5)
history = await _get_history_with_fallback(prompt_id)
history_error = _extract_history_error(history)
if history_error:
_set_status(db, job, "failed", history_error)
return
if history.get("status", {}).get("completed"):
break
else:
_set_status(db, job, "failed", "Timed out waiting for ComfyUI")
return
_set_status(db, job, "collecting_outputs")
await _collect_outputs_from_history(db, job, history)
_set_status(db, job, "completed")
except Exception as exc:
logger.exception("Job %s failed: %s", job_id, exc)
job = db.query(Job).filter(Job.id == job_id).first()
if job:
_set_status(db, job, "failed", str(exc))
finally:
db.close()

View File

@@ -0,0 +1,118 @@
import subprocess
import uuid
from pathlib import Path
from typing import Optional
import aiofiles
from fastapi import UploadFile
from PIL import Image
from app.core.config import settings
class LocalStorageService:
def __init__(self, root: str):
self.root = Path(root)
self.root.mkdir(parents=True, exist_ok=True)
async def save_upload(self, upload: UploadFile, subfolder: str) -> tuple[str, int]:
dest_dir = self.root / subfolder
dest_dir.mkdir(parents=True, exist_ok=True)
ext = Path(upload.filename or "file").suffix
filename = f"{uuid.uuid4()}{ext}"
dest_path = dest_dir / filename
content = await upload.read()
async with aiofiles.open(dest_path, "wb") as handle:
await handle.write(content)
return str(dest_path.relative_to(self.root)).replace("\\", "/"), len(content)
def save_bytes(self, data: bytes, subfolder: str, filename: str) -> str:
dest_dir = self.root / subfolder
dest_dir.mkdir(parents=True, exist_ok=True)
dest_path = dest_dir / filename
with open(dest_path, "wb") as handle:
handle.write(data)
return str(dest_path.relative_to(self.root)).replace("\\", "/")
def absolute_path(self, relative_path: str) -> Path:
return self.root / relative_path
def delete_relative_path(self, relative_path: Optional[str]) -> None:
if not relative_path:
return
abs_path = self.absolute_path(relative_path)
try:
if abs_path.exists():
abs_path.unlink()
except Exception:
pass
def generate_thumbnail(self, image_path: str, thumb_subfolder: str) -> Optional[str]:
try:
abs_path = self.absolute_path(image_path)
with Image.open(abs_path) as img:
img.thumbnail((400, 400))
thumb_dir = self.root / thumb_subfolder
thumb_dir.mkdir(parents=True, exist_ok=True)
thumb_name = f"thumb_{Path(image_path).stem}.jpg"
thumb_path = thumb_dir / thumb_name
img.convert("RGB").save(thumb_path, "JPEG", quality=80)
return str(thumb_path.relative_to(self.root)).replace("\\", "/")
except Exception:
return None
def generate_video_thumbnail(self, video_path: str, thumb_subfolder: str) -> Optional[str]:
abs_path = self.absolute_path(video_path)
thumb_dir = self.root / thumb_subfolder
thumb_dir.mkdir(parents=True, exist_ok=True)
thumb_name = f"thumb_{Path(video_path).stem}.jpg"
thumb_path = thumb_dir / thumb_name
try:
subprocess.run(
[
"ffmpeg",
"-y",
"-i",
str(abs_path),
"-ss",
"00:00:00.500",
"-vframes",
"1",
str(thumb_path),
],
capture_output=True,
timeout=30,
check=False,
)
if thumb_path.exists():
return str(thumb_path.relative_to(self.root)).replace("\\", "/")
except Exception:
return None
return None
def detect_duration_seconds(self, relative_path: str) -> Optional[float]:
abs_path = self.absolute_path(relative_path)
try:
result = subprocess.run(
[
"ffprobe",
"-v",
"error",
"-show_entries",
"format=duration",
"-of",
"default=noprint_wrappers=1:nokey=1",
str(abs_path),
],
capture_output=True,
text=True,
timeout=20,
check=True,
)
return round(float(result.stdout.strip()), 3)
except Exception:
return None
asset_storage = LocalStorageService(settings.ASSET_STORAGE_ROOT)
output_storage = LocalStorageService(settings.OUTPUT_STORAGE_ROOT)

View File

@@ -0,0 +1,64 @@
import copy
import json
import logging
from pathlib import Path
from typing import Any, Dict, Optional
logger = logging.getLogger(__name__)
WORKFLOWS_ROOT = Path(__file__).parents[3] / "workflows"
_REGISTRY: Dict[str, Path] = {}
def _discover() -> None:
_REGISTRY.clear()
for path in WORKFLOWS_ROOT.rglob("*.json"):
try:
with open(path, encoding="utf-8") as handle:
data = json.load(handle)
meta = data.get("__animatrix_meta__", {})
_REGISTRY[meta.get("name") or path.stem] = path
except Exception as exc:
logger.warning("Could not load workflow %s: %s", path, exc)
_discover()
def select_template_name(mode: str, submode: Optional[str]) -> str:
if mode == "animate":
return f"wan22_animate_{submode or 'move'}"
if mode == "audio":
return "wan22_s2v"
raise ValueError(f"Unknown mode: {mode}")
class WorkflowBinder:
def __init__(self, template_name: str):
if template_name not in _REGISTRY:
_discover()
if template_name not in _REGISTRY:
raise FileNotFoundError(
f"Workflow template '{template_name}' not found in {WORKFLOWS_ROOT}. Available: {list(_REGISTRY.keys())}"
)
with open(_REGISTRY[template_name], encoding="utf-8") as handle:
self._raw = json.load(handle)
self._meta = self._raw.get("__animatrix_meta__", {})
self._param_nodes = self._meta.get("param_nodes", {})
self.version = self._meta.get("version", "unknown")
self.status = self._meta.get("status", "")
def bind(self, params: Dict[str, Any]) -> Dict[str, Any]:
workflow = copy.deepcopy(self._raw)
workflow.pop("__animatrix_meta__", None)
for param_key, value in params.items():
if value is None:
continue
node_spec = self._param_nodes.get(param_key)
if not node_spec:
continue
node_id = str(node_spec["node_id"])
input_name = node_spec["input"]
if node_id in workflow:
workflow[node_id]["inputs"][input_name] = value
return workflow

13
backend/requirements.txt Normal file
View File

@@ -0,0 +1,13 @@
fastapi==0.111.0
uvicorn[standard]==0.29.0
sqlalchemy==2.0.30
alembic==1.13.1
pydantic==2.7.1
pydantic-settings==2.2.1
passlib==1.7.4
python-jose[cryptography]==3.3.0
python-multipart==0.0.9
httpx==0.27.0
aiofiles==23.2.1
Pillow==11.2.1
python-dotenv==1.0.1

5
backend/run.py Normal file
View File

@@ -0,0 +1,5 @@
import uvicorn
if __name__ == "__main__":
uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)

0
backend/storage/.gitkeep Normal file
View File