Initial Animatrix import
This commit is contained in:
11
backend/.env.example
Normal file
11
backend/.env.example
Normal file
@@ -0,0 +1,11 @@
|
||||
SECRET_KEY=change-me
|
||||
DATABASE_URL=sqlite:///./animatrix.db
|
||||
ASSET_STORAGE_ROOT=./storage/assets
|
||||
OUTPUT_STORAGE_ROOT=./storage/outputs
|
||||
COMFYUI_BASE_URL=https://comfy.desineuron.in
|
||||
# Set to the public HTTPS origin in production so generated backend URLs and cookie policy stay correct.
|
||||
# Example production value:
|
||||
# BACKEND_BASE_URL=https://animatrix.desineuron.in
|
||||
BACKEND_BASE_URL=http://localhost:8000
|
||||
CORS_ORIGINS=http://localhost:3000,https://animatrix.desineuron.in
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=10080
|
||||
1
backend/app/__init__.py
Normal file
1
backend/app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
1
backend/app/api/routes/__init__.py
Normal file
1
backend/app/api/routes/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
28
backend/app/api/routes/admin.py
Normal file
28
backend/app/api/routes/admin.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.deps import get_current_user
|
||||
from app.db.session import get_db
|
||||
from app.models import Job, User
|
||||
from app.services.comfy_client import comfy_client
|
||||
|
||||
router = APIRouter(prefix="/api/admin", tags=["admin"])
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health(_: User = Depends(get_current_user)):
|
||||
return {"api": "ok", "comfyui": await comfy_client.health_check()}
|
||||
|
||||
|
||||
@router.get("/queue")
|
||||
async def queue(_: User = Depends(get_current_user)):
|
||||
return await comfy_client.get_queue()
|
||||
|
||||
|
||||
@router.get("/jobs-summary")
|
||||
def jobs_summary(db: Session = Depends(get_db), _: User = Depends(get_current_user)):
|
||||
total = db.query(Job).count()
|
||||
active = db.query(Job).filter(Job.status.in_(["validating", "uploading_assets", "queued", "executing", "collecting_outputs"])).count()
|
||||
completed = db.query(Job).filter(Job.status == "completed").count()
|
||||
failed = db.query(Job).filter(Job.status == "failed").count()
|
||||
return {"total": total, "active": active, "completed": completed, "failed": failed}
|
||||
120
backend/app/api/routes/assets.py
Normal file
120
backend/app/api/routes/assets.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.deps import get_current_user
|
||||
from app.db.session import get_db
|
||||
from app.models import Asset, User
|
||||
from app.schemas import AssetResponse, AssetTrashRequest
|
||||
from app.services.storage import asset_storage
|
||||
|
||||
router = APIRouter(prefix="/api/assets", tags=["assets"])
|
||||
|
||||
ALLOWED_TYPES = {
|
||||
"image": ["image/jpeg", "image/png", "image/webp"],
|
||||
"video": ["video/mp4", "video/webm", "video/quicktime"],
|
||||
"audio": ["audio/mpeg", "audio/mp4", "audio/wav", "audio/ogg", "audio/x-wav"],
|
||||
"pose_sheet": ["image/jpeg", "image/png", "image/webp"],
|
||||
}
|
||||
MAX_SIZE_BYTES = 500 * 1024 * 1024
|
||||
|
||||
|
||||
@router.post("/upload", response_model=AssetResponse, status_code=201)
|
||||
async def upload_asset(
|
||||
file: UploadFile = File(...),
|
||||
asset_type: str = Form(...),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if asset_type not in ALLOWED_TYPES:
|
||||
raise HTTPException(400, f"asset_type must be one of {list(ALLOWED_TYPES.keys())}")
|
||||
|
||||
mime = file.content_type or ""
|
||||
if mime not in ALLOWED_TYPES[asset_type]:
|
||||
raise HTTPException(400, f"Unsupported mime type {mime} for {asset_type}")
|
||||
|
||||
subfolder = f"{current_user.id}/{asset_type}"
|
||||
storage_path, size_bytes = await asset_storage.save_upload(file, subfolder)
|
||||
if size_bytes > MAX_SIZE_BYTES:
|
||||
raise HTTPException(413, "File too large (max 500MB)")
|
||||
|
||||
thumbnail_path = None
|
||||
width = height = None
|
||||
duration_seconds = None
|
||||
if asset_type in ("image", "pose_sheet"):
|
||||
thumbnail_path = asset_storage.generate_thumbnail(storage_path, f"{current_user.id}/thumbs")
|
||||
try:
|
||||
from PIL import Image
|
||||
|
||||
abs_path = asset_storage.absolute_path(storage_path)
|
||||
with Image.open(abs_path) as image:
|
||||
width, height = image.size
|
||||
except Exception:
|
||||
pass
|
||||
elif asset_type == "video":
|
||||
thumbnail_path = asset_storage.generate_video_thumbnail(storage_path, f"{current_user.id}/thumbs")
|
||||
duration_seconds = asset_storage.detect_duration_seconds(storage_path)
|
||||
else:
|
||||
duration_seconds = asset_storage.detect_duration_seconds(storage_path)
|
||||
|
||||
asset = Asset(
|
||||
owner_id=current_user.id,
|
||||
asset_type=asset_type,
|
||||
mime_type=mime,
|
||||
original_filename=file.filename or "upload",
|
||||
storage_path=storage_path,
|
||||
thumbnail_path=thumbnail_path,
|
||||
size_bytes=size_bytes,
|
||||
width=width,
|
||||
height=height,
|
||||
duration_seconds=duration_seconds,
|
||||
)
|
||||
db.add(asset)
|
||||
db.commit()
|
||||
db.refresh(asset)
|
||||
return asset
|
||||
|
||||
|
||||
@router.get("/", response_model=List[AssetResponse])
|
||||
def list_assets(
|
||||
asset_type: Optional[str] = None,
|
||||
include_trashed: bool = False,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
query = db.query(Asset).filter(Asset.owner_id == current_user.id)
|
||||
if not include_trashed:
|
||||
query = query.filter(Asset.is_trashed.is_(False))
|
||||
if asset_type:
|
||||
query = query.filter(Asset.asset_type == asset_type)
|
||||
return query.order_by(Asset.created_at.desc()).all()
|
||||
|
||||
|
||||
@router.post("/trash", status_code=200)
|
||||
def move_assets_to_trash(
|
||||
payload: AssetTrashRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if not payload.asset_ids:
|
||||
raise HTTPException(400, "No asset ids provided")
|
||||
|
||||
assets = (
|
||||
db.query(Asset)
|
||||
.filter(Asset.owner_id == current_user.id, Asset.id.in_(payload.asset_ids))
|
||||
.all()
|
||||
)
|
||||
if not assets:
|
||||
raise HTTPException(404, "No matching assets found")
|
||||
|
||||
delete_after_at = datetime.now(timezone.utc) + timedelta(days=30)
|
||||
for asset in assets:
|
||||
asset.is_trashed = True
|
||||
asset.delete_after_at = delete_after_at
|
||||
db.commit()
|
||||
return {
|
||||
"moved_to_trash": len(assets),
|
||||
"delete_after_at": delete_after_at.isoformat(),
|
||||
}
|
||||
61
backend/app/api/routes/auth.py
Normal file
61
backend/app/api/routes/auth.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.deps import get_current_user
|
||||
from app.core.security import create_access_token, hash_password, verify_password
|
||||
from app.db.session import get_db
|
||||
from app.models import User
|
||||
from app.schemas import LoginRequest, RegisterRequest, UserResponse
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
def register(payload: RegisterRequest, db: Session = Depends(get_db)):
|
||||
existing = db.query(User).filter(User.email == payload.email).first()
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="Email already registered")
|
||||
user = User(email=payload.email, password_hash=hash_password(payload.password))
|
||||
db.add(user)
|
||||
db.commit()
|
||||
db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
def _is_secure_request(request: Request) -> bool:
|
||||
forwarded_proto = request.headers.get("x-forwarded-proto", "")
|
||||
if "https" in forwarded_proto.lower():
|
||||
return True
|
||||
if request.url.scheme == "https":
|
||||
return True
|
||||
return settings.BACKEND_BASE_URL.startswith("https://")
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
def login(payload: LoginRequest, request: Request, response: Response, db: Session = Depends(get_db)):
|
||||
user = db.query(User).filter(User.email == payload.email).first()
|
||||
if not user or not verify_password(payload.password, user.password_hash):
|
||||
raise HTTPException(status_code=401, detail="Invalid credentials")
|
||||
|
||||
token = create_access_token(subject=user.id)
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=token,
|
||||
httponly=True,
|
||||
samesite="lax",
|
||||
secure=_is_secure_request(request),
|
||||
max_age=60 * 60 * 24 * 7,
|
||||
)
|
||||
return {"message": "Logged in", "user": UserResponse.model_validate(user)}
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
def logout(response: Response):
|
||||
response.delete_cookie("access_token")
|
||||
return {"message": "Logged out"}
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
def me(current_user: User = Depends(get_current_user)):
|
||||
return current_user
|
||||
118
backend/app/api/routes/jobs.py
Normal file
118
backend/app/api/routes/jobs.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from app.core.deps import get_current_user
|
||||
from app.db.session import get_db
|
||||
from app.models import Asset, Job, JobOutput, User
|
||||
from app.schemas import JobCreateRequest, JobListResponse, JobResponse
|
||||
from app.services.orchestrator import reconcile_job_outputs_if_missing, run_job
|
||||
from app.services.storage import output_storage
|
||||
|
||||
router = APIRouter(prefix="/api/jobs", tags=["jobs"])
|
||||
|
||||
|
||||
@router.post("/", response_model=JobResponse, status_code=201)
|
||||
async def create_job(
|
||||
payload: JobCreateRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
def assert_owns(asset_id: Optional[str], label: str):
|
||||
if asset_id:
|
||||
asset = (
|
||||
db.query(Asset)
|
||||
.filter(Asset.id == asset_id, Asset.owner_id == current_user.id, Asset.is_trashed.is_(False))
|
||||
.first()
|
||||
)
|
||||
if not asset:
|
||||
raise HTTPException(400, f"{label} asset not found or not owned by user")
|
||||
|
||||
assert_owns(payload.ground_truth_asset_id, "ground_truth")
|
||||
assert_owns(payload.motion_asset_id, "motion")
|
||||
assert_owns(payload.audio_asset_id, "audio")
|
||||
assert_owns(payload.pose_asset_id, "pose_sheet")
|
||||
for ref_id in payload.reference_asset_ids or []:
|
||||
assert_owns(ref_id, f"reference {ref_id}")
|
||||
|
||||
job = Job(
|
||||
owner_id=current_user.id,
|
||||
mode=payload.mode,
|
||||
submode=payload.submode,
|
||||
prompt=payload.prompt,
|
||||
negative_prompt=payload.negative_prompt,
|
||||
status="created",
|
||||
ground_truth_asset_id=payload.ground_truth_asset_id,
|
||||
motion_asset_id=payload.motion_asset_id,
|
||||
audio_asset_id=payload.audio_asset_id,
|
||||
pose_asset_id=payload.pose_asset_id,
|
||||
reference_asset_ids_json=json.dumps(payload.reference_asset_ids) if payload.reference_asset_ids else None,
|
||||
settings_json=json.dumps(payload.settings) if payload.settings else None,
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
background_tasks.add_task(run_job, job.id)
|
||||
return job
|
||||
|
||||
|
||||
@router.get("/", response_model=List[JobListResponse])
|
||||
def list_jobs(
|
||||
mode: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
query = db.query(Job).filter(Job.owner_id == current_user.id)
|
||||
query = query.options(selectinload(Job.outputs))
|
||||
if mode:
|
||||
query = query.filter(Job.mode == mode)
|
||||
if status:
|
||||
query = query.filter(Job.status == status)
|
||||
return query.order_by(Job.created_at.desc()).limit(100).all()
|
||||
|
||||
|
||||
@router.get("/{job_id}", response_model=JobResponse)
|
||||
async def get_job(job_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
|
||||
job = (
|
||||
db.query(Job)
|
||||
.options(selectinload(Job.outputs), selectinload(Job.events))
|
||||
.filter(Job.id == job_id, Job.owner_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
if not job:
|
||||
raise HTTPException(404, "Job not found")
|
||||
if job.status == "completed" and not job.outputs and job.comfy_prompt_id:
|
||||
await reconcile_job_outputs_if_missing(job.id)
|
||||
db.expire_all()
|
||||
job = (
|
||||
db.query(Job)
|
||||
.options(selectinload(Job.outputs), selectinload(Job.events))
|
||||
.filter(Job.id == job_id, Job.owner_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
return job
|
||||
|
||||
|
||||
@router.get("/{job_id}/outputs/{output_id}/download")
|
||||
def download_output(
|
||||
job_id: str,
|
||||
output_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
job = db.query(Job).filter(Job.id == job_id, Job.owner_id == current_user.id).first()
|
||||
if not job:
|
||||
raise HTTPException(404, "Job not found")
|
||||
output = db.query(JobOutput).filter(JobOutput.id == output_id, JobOutput.job_id == job_id).first()
|
||||
if not output:
|
||||
raise HTTPException(404, "Output not found")
|
||||
abs_path = output_storage.absolute_path(output.file_path)
|
||||
if not abs_path.exists():
|
||||
raise HTTPException(404, "Output file not found")
|
||||
return FileResponse(str(abs_path))
|
||||
25
backend/app/core/config.py
Normal file
25
backend/app/core/config.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
SECRET_KEY: str = "dev-secret-change-in-production"
|
||||
DATABASE_URL: str = "sqlite:///./animatrix.db"
|
||||
ASSET_STORAGE_ROOT: str = "./storage/assets"
|
||||
OUTPUT_STORAGE_ROOT: str = "./storage/outputs"
|
||||
COMFYUI_BASE_URL: str = "https://comfy.desineuron.in"
|
||||
BACKEND_BASE_URL: str = "http://localhost:8000"
|
||||
CORS_ORIGINS: str = "http://localhost:3000"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES: int = 10080
|
||||
|
||||
@property
|
||||
def cors_origins_list(self) -> List[str]:
|
||||
return [o.strip() for o in self.CORS_ORIGINS.split(",") if o.strip()]
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
30
backend/app/core/deps.py
Normal file
30
backend/app/core/deps.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Cookie, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.security import decode_access_token
|
||||
from app.db.session import get_db
|
||||
from app.models import User
|
||||
|
||||
|
||||
def get_current_user(
|
||||
access_token: Optional[str] = Cookie(default=None),
|
||||
db: Session = Depends(get_db),
|
||||
) -> User:
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
)
|
||||
if not access_token:
|
||||
raise credentials_exception
|
||||
|
||||
user_id = decode_access_token(access_token)
|
||||
if not user_id:
|
||||
raise credentials_exception
|
||||
|
||||
user = db.query(User).filter(User.id == user_id, User.is_active.is_(True)).first()
|
||||
if not user:
|
||||
raise credentials_exception
|
||||
|
||||
return user
|
||||
34
backend/app/core/security.py
Normal file
34
backend/app/core/security.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
pwd_context = CryptContext(schemes=["pbkdf2_sha256"], deprecated="auto")
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def create_access_token(subject: str, expires_delta: Optional[timedelta] = None) -> str:
|
||||
expire = datetime.now(timezone.utc) + (
|
||||
expires_delta or timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
)
|
||||
payload = {"sub": subject, "exp": expire}
|
||||
return jwt.encode(payload, settings.SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> Optional[str]:
|
||||
try:
|
||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
|
||||
return payload.get("sub")
|
||||
except JWTError:
|
||||
return None
|
||||
63
backend/app/db/init_db.py
Normal file
63
backend/app/db/init_db.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.core.config import settings
|
||||
from app.db.session import Base, engine
|
||||
from app.models import Asset, Job, JobEvent, JobOutput, User # noqa: F401
|
||||
from app.services.storage import asset_storage
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
Path(settings.ASSET_STORAGE_ROOT).mkdir(parents=True, exist_ok=True)
|
||||
Path(settings.OUTPUT_STORAGE_ROOT).mkdir(parents=True, exist_ok=True)
|
||||
Base.metadata.create_all(bind=engine)
|
||||
_migrate_assets_table()
|
||||
_cleanup_expired_trashed_assets()
|
||||
|
||||
|
||||
def _migrate_assets_table() -> None:
|
||||
with engine.begin() as conn:
|
||||
columns = {
|
||||
row[1]
|
||||
for row in conn.execute(text("PRAGMA table_info(assets)")).fetchall()
|
||||
}
|
||||
if "is_trashed" not in columns:
|
||||
conn.execute(text("ALTER TABLE assets ADD COLUMN is_trashed BOOLEAN NOT NULL DEFAULT 0"))
|
||||
if "delete_after_at" not in columns:
|
||||
conn.execute(text("ALTER TABLE assets ADD COLUMN delete_after_at DATETIME NULL"))
|
||||
|
||||
|
||||
def _cleanup_expired_trashed_assets() -> None:
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with engine.begin() as conn:
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT id, storage_path, thumbnail_path
|
||||
FROM assets
|
||||
WHERE is_trashed = 1
|
||||
AND delete_after_at IS NOT NULL
|
||||
AND delete_after_at <= :now
|
||||
"""
|
||||
),
|
||||
{"now": now},
|
||||
).fetchall()
|
||||
|
||||
for _, storage_path, thumbnail_path in rows:
|
||||
asset_storage.delete_relative_path(storage_path)
|
||||
asset_storage.delete_relative_path(thumbnail_path)
|
||||
|
||||
if rows:
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
DELETE FROM assets
|
||||
WHERE is_trashed = 1
|
||||
AND delete_after_at IS NOT NULL
|
||||
AND delete_after_at <= :now
|
||||
"""
|
||||
),
|
||||
{"now": now},
|
||||
)
|
||||
23
backend/app/db/session.py
Normal file
23
backend/app/db/session.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import DeclarativeBase, sessionmaker
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
connect_args = {}
|
||||
if settings.DATABASE_URL.startswith("sqlite"):
|
||||
connect_args = {"check_same_thread": False}
|
||||
|
||||
engine = create_engine(settings.DATABASE_URL, connect_args=connect_args)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
32
backend/app/main.py
Normal file
32
backend/app/main.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from app.api.routes import admin, assets, auth, jobs
|
||||
from app.core.config import settings
|
||||
from app.db.init_db import init_db
|
||||
|
||||
app = FastAPI(title="Animatrix API", version="0.1.0")
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins_list,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
init_db()
|
||||
|
||||
app.include_router(auth.router)
|
||||
app.include_router(assets.router)
|
||||
app.include_router(jobs.router)
|
||||
app.include_router(admin.router)
|
||||
|
||||
app.mount("/storage/assets", StaticFiles(directory=settings.ASSET_STORAGE_ROOT), name="assets")
|
||||
app.mount("/storage/outputs", StaticFiles(directory=settings.OUTPUT_STORAGE_ROOT), name="outputs")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health() -> dict[str, str]:
|
||||
return {"status": "ok"}
|
||||
3
backend/app/models/__init__.py
Normal file
3
backend/app/models/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from app.models.models import Asset, Job, JobEvent, JobOutput, User
|
||||
|
||||
__all__ = ["User", "Asset", "Job", "JobOutput", "JobEvent"]
|
||||
106
backend/app/models/models.py
Normal file
106
backend/app/models/models.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Float, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from app.db.session import Base
|
||||
|
||||
|
||||
def utcnow():
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def new_uuid():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=new_uuid)
|
||||
email: Mapped[str] = mapped_column(String, unique=True, nullable=False, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String, nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow, onupdate=utcnow)
|
||||
|
||||
assets: Mapped[list["Asset"]] = relationship("Asset", back_populates="owner", cascade="all, delete-orphan")
|
||||
jobs: Mapped[list["Job"]] = relationship("Job", back_populates="owner", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class Asset(Base):
|
||||
__tablename__ = "assets"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=new_uuid)
|
||||
owner_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"), nullable=False, index=True)
|
||||
asset_type: Mapped[str] = mapped_column(String, nullable=False)
|
||||
mime_type: Mapped[str] = mapped_column(String, nullable=False)
|
||||
original_filename: Mapped[str] = mapped_column(String, nullable=False)
|
||||
storage_path: Mapped[str] = mapped_column(String, nullable=False)
|
||||
thumbnail_path: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
||||
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||
width: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
height: Mapped[Optional[int]] = mapped_column(Integer, nullable=True)
|
||||
duration_seconds: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
is_trashed: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False, index=True)
|
||||
delete_after_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow)
|
||||
|
||||
owner: Mapped["User"] = relationship("User", back_populates="assets")
|
||||
|
||||
|
||||
class Job(Base):
|
||||
__tablename__ = "jobs"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=new_uuid)
|
||||
owner_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"), nullable=False, index=True)
|
||||
mode: Mapped[str] = mapped_column(String, nullable=False)
|
||||
submode: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
||||
prompt: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
negative_prompt: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String, nullable=False, default="created", index=True)
|
||||
comfy_prompt_id: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
||||
workflow_template_name: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
||||
workflow_template_version: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
||||
settings_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
ground_truth_asset_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("assets.id"), nullable=True)
|
||||
motion_asset_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("assets.id"), nullable=True)
|
||||
audio_asset_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("assets.id"), nullable=True)
|
||||
reference_asset_ids_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
pose_asset_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("assets.id"), nullable=True)
|
||||
error_message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow, onupdate=utcnow)
|
||||
|
||||
owner: Mapped["User"] = relationship("User", back_populates="jobs")
|
||||
outputs: Mapped[list["JobOutput"]] = relationship("JobOutput", back_populates="job", cascade="all, delete-orphan")
|
||||
events: Mapped[list["JobEvent"]] = relationship("JobEvent", back_populates="job", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
class JobOutput(Base):
|
||||
__tablename__ = "job_outputs"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=new_uuid)
|
||||
job_id: Mapped[str] = mapped_column(String, ForeignKey("jobs.id"), nullable=False, index=True)
|
||||
output_type: Mapped[str] = mapped_column(String, nullable=False)
|
||||
file_path: Mapped[str] = mapped_column(String, nullable=False)
|
||||
poster_path: Mapped[Optional[str]] = mapped_column(String, nullable=True)
|
||||
metadata_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow)
|
||||
|
||||
job: Mapped["Job"] = relationship("Job", back_populates="outputs")
|
||||
|
||||
|
||||
class JobEvent(Base):
|
||||
__tablename__ = "job_events"
|
||||
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=new_uuid)
|
||||
job_id: Mapped[str] = mapped_column(String, ForeignKey("jobs.id"), nullable=False, index=True)
|
||||
event_type: Mapped[str] = mapped_column(String, nullable=False)
|
||||
message: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
payload_json: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=utcnow)
|
||||
|
||||
job: Mapped["Job"] = relationship("Job", back_populates="events")
|
||||
25
backend/app/schemas/__init__.py
Normal file
25
backend/app/schemas/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from app.schemas.schemas import (
|
||||
AssetResponse,
|
||||
AssetTrashRequest,
|
||||
JobCreateRequest,
|
||||
JobEventResponse,
|
||||
JobListResponse,
|
||||
JobOutputResponse,
|
||||
JobResponse,
|
||||
LoginRequest,
|
||||
RegisterRequest,
|
||||
UserResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"RegisterRequest",
|
||||
"LoginRequest",
|
||||
"UserResponse",
|
||||
"AssetResponse",
|
||||
"AssetTrashRequest",
|
||||
"JobCreateRequest",
|
||||
"JobOutputResponse",
|
||||
"JobEventResponse",
|
||||
"JobResponse",
|
||||
"JobListResponse",
|
||||
]
|
||||
143
backend/app/schemas/schemas.py
Normal file
143
backend/app/schemas/schemas.py
Normal file
@@ -0,0 +1,143 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, EmailStr, field_validator
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
@field_validator("password")
|
||||
@classmethod
|
||||
def password_min_length(cls, v: str) -> str:
|
||||
if len(v) < 8:
|
||||
raise ValueError("Password must be at least 8 characters")
|
||||
return v
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class AssetResponse(BaseModel):
|
||||
id: str
|
||||
asset_type: str
|
||||
mime_type: str
|
||||
original_filename: str
|
||||
storage_path: str
|
||||
size_bytes: int
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
duration_seconds: Optional[float] = None
|
||||
thumbnail_path: Optional[str] = None
|
||||
is_trashed: bool = False
|
||||
delete_after_at: Optional[datetime] = None
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class JobCreateRequest(BaseModel):
|
||||
mode: str
|
||||
submode: Optional[str] = None
|
||||
prompt: str
|
||||
negative_prompt: Optional[str] = None
|
||||
ground_truth_asset_id: str
|
||||
motion_asset_id: Optional[str] = None
|
||||
audio_asset_id: Optional[str] = None
|
||||
reference_asset_ids: Optional[List[str]] = None
|
||||
pose_asset_id: Optional[str] = None
|
||||
settings: Optional[dict] = None
|
||||
|
||||
@field_validator("mode")
|
||||
@classmethod
|
||||
def validate_mode(cls, v: str) -> str:
|
||||
if v not in ("animate", "audio"):
|
||||
raise ValueError("mode must be 'animate' or 'audio'")
|
||||
return v
|
||||
|
||||
@field_validator("submode")
|
||||
@classmethod
|
||||
def validate_submode(cls, v: Optional[str]) -> Optional[str]:
|
||||
if v is not None and v not in ("move", "mix"):
|
||||
raise ValueError("submode must be 'move' or 'mix'")
|
||||
return v
|
||||
|
||||
|
||||
class JobOutputResponse(BaseModel):
|
||||
id: str
|
||||
output_type: str
|
||||
file_path: str
|
||||
poster_path: Optional[str] = None
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class JobEventResponse(BaseModel):
|
||||
id: str
|
||||
event_type: str
|
||||
message: Optional[str] = None
|
||||
created_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class AssetTrashRequest(BaseModel):
|
||||
asset_ids: List[str]
|
||||
|
||||
|
||||
class JobResponse(BaseModel):
|
||||
id: str
|
||||
mode: str
|
||||
submode: Optional[str] = None
|
||||
prompt: str
|
||||
negative_prompt: Optional[str] = None
|
||||
status: str
|
||||
comfy_prompt_id: Optional[str] = None
|
||||
workflow_template_name: Optional[str] = None
|
||||
error_message: Optional[str] = None
|
||||
ground_truth_asset_id: Optional[str] = None
|
||||
motion_asset_id: Optional[str] = None
|
||||
audio_asset_id: Optional[str] = None
|
||||
pose_asset_id: Optional[str] = None
|
||||
outputs: List[JobOutputResponse] = []
|
||||
events: List[JobEventResponse] = []
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class JobListResponse(BaseModel):
|
||||
id: str
|
||||
mode: str
|
||||
submode: Optional[str] = None
|
||||
prompt: str
|
||||
error_message: Optional[str] = None
|
||||
status: str
|
||||
ground_truth_asset_id: Optional[str] = None
|
||||
motion_asset_id: Optional[str] = None
|
||||
audio_asset_id: Optional[str] = None
|
||||
pose_asset_id: Optional[str] = None
|
||||
outputs: List[JobOutputResponse] = []
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
128
backend/app/services/comfy_client.py
Normal file
128
backend/app/services/comfy_client.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
import httpx
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ComfyClient:
|
||||
def __init__(self, base_url: str | None = None):
|
||||
self.base_url = (base_url or settings.COMFYUI_BASE_URL).rstrip("/")
|
||||
self._client = httpx.AsyncClient(timeout=120.0)
|
||||
|
||||
async def close(self) -> None:
|
||||
await self._client.aclose()
|
||||
|
||||
async def health_check(self) -> bool:
|
||||
for endpoint in ("/system_stats", "/"):
|
||||
try:
|
||||
response = await self._client.get(f"{self.base_url}{endpoint}")
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.warning("ComfyUI health check failed at %s: %s", endpoint, exc)
|
||||
return False
|
||||
|
||||
async def upload_image(self, file_path: str, filename: str) -> str:
|
||||
with open(file_path, "rb") as handle:
|
||||
files = {"image": (filename, handle, "application/octet-stream")}
|
||||
response = await self._client.post(f"{self.base_url}/upload/image", files=files)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get("name", filename)
|
||||
|
||||
async def upload_media(self, file_path: str, filename: str, media_type: str) -> str:
|
||||
endpoint = {
|
||||
"image": "/upload/image",
|
||||
"pose_sheet": "/upload/image",
|
||||
"video": "/upload/video",
|
||||
"audio": "/upload/audio",
|
||||
}.get(media_type)
|
||||
field_name = {
|
||||
"image": "image",
|
||||
"pose_sheet": "image",
|
||||
"video": "video",
|
||||
"audio": "audio",
|
||||
}.get(media_type)
|
||||
|
||||
if not endpoint or not field_name:
|
||||
raise ValueError(f"Unsupported ComfyUI upload media type: {media_type}")
|
||||
|
||||
mime_type = "application/octet-stream"
|
||||
suffix = Path(filename).suffix.lower()
|
||||
if media_type in ("image", "pose_sheet"):
|
||||
mime_type = {
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".png": "image/png",
|
||||
".webp": "image/webp",
|
||||
}.get(suffix, mime_type)
|
||||
elif media_type == "video":
|
||||
mime_type = {
|
||||
".mp4": "video/mp4",
|
||||
".webm": "video/webm",
|
||||
".mov": "video/quicktime",
|
||||
}.get(suffix, mime_type)
|
||||
elif media_type == "audio":
|
||||
mime_type = {
|
||||
".mp3": "audio/mpeg",
|
||||
".mp4": "audio/mp4",
|
||||
".wav": "audio/wav",
|
||||
".ogg": "audio/ogg",
|
||||
}.get(suffix, mime_type)
|
||||
|
||||
with open(file_path, "rb") as handle:
|
||||
files = {field_name: (filename, handle, mime_type)}
|
||||
response = await self._client.post(f"{self.base_url}{endpoint}", files=files)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
return data.get("name", filename)
|
||||
|
||||
async def submit_prompt(self, workflow: Dict[str, Any], client_id: str | None = None) -> str:
|
||||
payload: Dict[str, Any] = {"prompt": workflow}
|
||||
if client_id:
|
||||
payload["client_id"] = client_id
|
||||
response = await self._client.post(f"{self.base_url}/prompt", json=payload)
|
||||
if response.is_error:
|
||||
detail = response.text
|
||||
raise RuntimeError(f"ComfyUI prompt submission failed ({response.status_code}): {detail}")
|
||||
data = response.json()
|
||||
prompt_id = data.get("prompt_id")
|
||||
if not prompt_id:
|
||||
raise RuntimeError(f"No prompt_id returned by ComfyUI: {data}")
|
||||
return prompt_id
|
||||
|
||||
async def get_history(self, prompt_id: str) -> Dict[str, Any]:
|
||||
response = await self._client.get(f"{self.base_url}/history/{prompt_id}")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return data.get(prompt_id, {})
|
||||
|
||||
async def get_history_all(self) -> Dict[str, Any]:
|
||||
response = await self._client.get(f"{self.base_url}/history")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def get_queue(self) -> Dict[str, Any]:
|
||||
response = await self._client.get(f"{self.base_url}/queue")
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def get_object_info(self, node_name: str) -> Dict[str, Any]:
|
||||
response = await self._client.get(f"{self.base_url}/object_info/{node_name}")
|
||||
response.raise_for_status()
|
||||
return response.json().get(node_name, {})
|
||||
|
||||
async def download_output(self, filename: str, subfolder: str = "", folder_type: str = "output") -> bytes:
|
||||
params = {"filename": filename, "subfolder": subfolder, "type": folder_type}
|
||||
response = await self._client.get(f"{self.base_url}/view", params=params)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
|
||||
comfy_client = ComfyClient()
|
||||
325
backend/app/services/orchestrator.py
Normal file
325
backend/app/services/orchestrator.py
Normal file
@@ -0,0 +1,325 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db.session import SessionLocal
|
||||
from app.models import Asset, Job, JobEvent, JobOutput
|
||||
from app.services.comfy_client import comfy_client
|
||||
from app.services.storage import asset_storage, output_storage
|
||||
from app.services.workflow_binder import WorkflowBinder, select_template_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
VIDEO_EXTENSIONS = {".mp4", ".mov", ".webm", ".mkv", ".avi"}
|
||||
MODEL_LOADER_INPUTS = {
|
||||
"CLIPLoader": "clip_name",
|
||||
"VAELoader": "vae_name",
|
||||
"UNETLoader": "unet_name",
|
||||
"LoraLoaderModelOnly": "lora_name",
|
||||
}
|
||||
|
||||
|
||||
def _add_event(db: Session, job_id: str, event_type: str, message: str, payload: dict | None = None) -> None:
|
||||
event = JobEvent(
|
||||
job_id=job_id,
|
||||
event_type=event_type,
|
||||
message=message,
|
||||
payload_json=json.dumps(payload) if payload else None,
|
||||
)
|
||||
db.add(event)
|
||||
db.commit()
|
||||
|
||||
|
||||
def _set_status(db: Session, job: Job, status: str, error: str | None = None) -> None:
|
||||
job.status = status
|
||||
if error:
|
||||
job.error_message = error
|
||||
db.commit()
|
||||
_add_event(db, job.id, "status_change", f"Job status -> {status}")
|
||||
|
||||
|
||||
def _extract_history_error(history: dict) -> str | None:
|
||||
status = history.get("status", {}) or {}
|
||||
if status.get("status_str") == "error":
|
||||
for message in status.get("messages", []) or []:
|
||||
if not isinstance(message, (list, tuple)) or len(message) < 2:
|
||||
continue
|
||||
payload = message[1] or {}
|
||||
if message[0] == "execution_error":
|
||||
exception_message = payload.get("exception_message")
|
||||
node_id = payload.get("node_id")
|
||||
node_type = payload.get("node_type")
|
||||
if exception_message and node_id and node_type:
|
||||
return f"ComfyUI execution error on node {node_id} ({node_type}): {exception_message}"
|
||||
if exception_message:
|
||||
return f"ComfyUI execution error: {exception_message}"
|
||||
return f"ComfyUI execution error: {payload}"
|
||||
return "ComfyUI execution failed without a detailed error message."
|
||||
|
||||
if history.get("node_errors"):
|
||||
return f"ComfyUI node validation failed: {history['node_errors']}"
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _output_type_for_filename(filename: str) -> str:
|
||||
return "video" if Path(filename).suffix.lower() in VIDEO_EXTENSIONS else "image"
|
||||
|
||||
|
||||
def _required_model_values(workflow: dict) -> dict[str, set[str]]:
|
||||
required: dict[str, set[str]] = {loader: set() for loader in MODEL_LOADER_INPUTS}
|
||||
for node in workflow.values():
|
||||
if not isinstance(node, dict):
|
||||
continue
|
||||
class_type = node.get("class_type")
|
||||
input_name = MODEL_LOADER_INPUTS.get(class_type)
|
||||
if not input_name:
|
||||
continue
|
||||
value = (node.get("inputs") or {}).get(input_name)
|
||||
if isinstance(value, str) and value:
|
||||
required[class_type].add(value)
|
||||
return {loader: values for loader, values in required.items() if values}
|
||||
|
||||
|
||||
async def _validate_runtime_models(workflow: dict) -> None:
|
||||
required = _required_model_values(workflow)
|
||||
if not required:
|
||||
return
|
||||
|
||||
missing_by_loader: list[str] = []
|
||||
for loader, expected_values in required.items():
|
||||
object_info = await comfy_client.get_object_info(loader)
|
||||
loader_input = MODEL_LOADER_INPUTS[loader]
|
||||
available_raw = (((object_info.get("input") or {}).get("required") or {}).get(loader_input) or [[]])[0]
|
||||
available = set(available_raw or [])
|
||||
missing = sorted(value for value in expected_values if value not in available)
|
||||
if missing:
|
||||
missing_by_loader.append(f"{loader} missing {missing}; available={sorted(available)}")
|
||||
|
||||
if missing_by_loader:
|
||||
raise RuntimeError(
|
||||
"ComfyUI runtime is missing required Wan model files. "
|
||||
+ " | ".join(missing_by_loader)
|
||||
)
|
||||
|
||||
|
||||
async def _get_history_with_fallback(prompt_id: str) -> dict:
|
||||
history = await comfy_client.get_history(prompt_id)
|
||||
if history:
|
||||
return history
|
||||
all_history = await comfy_client.get_history_all()
|
||||
return all_history.get(prompt_id, {})
|
||||
|
||||
|
||||
def _iter_history_files(node_output: dict) -> Iterable[dict]:
|
||||
for video in node_output.get("videos", []) or []:
|
||||
yield {
|
||||
"filename": video["filename"],
|
||||
"subfolder": video.get("subfolder", ""),
|
||||
"folder_type": video.get("type", "output"),
|
||||
"output_type": "video",
|
||||
}
|
||||
|
||||
for image in node_output.get("images", []) or []:
|
||||
filename = image["filename"]
|
||||
yield {
|
||||
"filename": filename,
|
||||
"subfolder": image.get("subfolder", ""),
|
||||
"folder_type": image.get("type", "output"),
|
||||
"output_type": _output_type_for_filename(filename),
|
||||
}
|
||||
|
||||
|
||||
async def _collect_outputs_from_history(db: Session, job: Job, history: dict) -> int:
|
||||
existing_paths = {output.file_path for output in job.outputs}
|
||||
created = 0
|
||||
|
||||
for node_id, node_output in (history.get("outputs", {}) or {}).items():
|
||||
for file_info in _iter_history_files(node_output):
|
||||
fname = file_info["filename"]
|
||||
data = await comfy_client.download_output(
|
||||
fname,
|
||||
file_info["subfolder"],
|
||||
file_info["folder_type"],
|
||||
)
|
||||
rel_path = output_storage.save_bytes(data, job.id, fname)
|
||||
if rel_path in existing_paths:
|
||||
continue
|
||||
|
||||
poster_path = None
|
||||
if file_info["output_type"] == "video":
|
||||
try:
|
||||
poster_fname = f"poster_{Path(fname).stem}.jpg"
|
||||
poster_abs = str(output_storage.absolute_path(f"{job.id}/{poster_fname}"))
|
||||
subprocess.run(
|
||||
["ffmpeg", "-y", "-i", str(output_storage.absolute_path(rel_path)), "-vframes", "1", poster_abs],
|
||||
capture_output=True,
|
||||
timeout=30,
|
||||
check=False,
|
||||
)
|
||||
if Path(poster_abs).exists():
|
||||
poster_path = f"{job.id}/{poster_fname}"
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
db.add(
|
||||
JobOutput(
|
||||
job_id=job.id,
|
||||
output_type=file_info["output_type"],
|
||||
file_path=rel_path,
|
||||
poster_path=poster_path,
|
||||
metadata_json=json.dumps({"node_id": node_id, "filename": fname}),
|
||||
)
|
||||
)
|
||||
existing_paths.add(rel_path)
|
||||
created += 1
|
||||
|
||||
if created:
|
||||
db.commit()
|
||||
|
||||
return created
|
||||
|
||||
|
||||
async def reconcile_job_outputs_if_missing(job_id: str) -> bool:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if not job:
|
||||
return False
|
||||
if job.status != "completed" or not job.comfy_prompt_id or job.outputs:
|
||||
return False
|
||||
|
||||
history = await _get_history_with_fallback(job.comfy_prompt_id)
|
||||
if not history or _extract_history_error(history):
|
||||
return False
|
||||
|
||||
created = await _collect_outputs_from_history(db, job, history)
|
||||
if created:
|
||||
_add_event(db, job.id, "outputs_reconciled", f"Recovered {created} output file(s) from ComfyUI history.")
|
||||
return True
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
async def _upload_asset_to_comfy(db: Session, asset_id: Optional[str]) -> Optional[str]:
|
||||
if not asset_id:
|
||||
return None
|
||||
asset = db.query(Asset).filter(Asset.id == asset_id).first()
|
||||
if not asset:
|
||||
raise ValueError(f"Asset {asset_id} not found")
|
||||
if asset.is_trashed:
|
||||
raise ValueError(f"Asset {asset.original_filename} is in trash")
|
||||
return await comfy_client.upload_media(
|
||||
str(asset_storage.absolute_path(asset.storage_path)),
|
||||
asset.original_filename,
|
||||
asset.asset_type,
|
||||
)
|
||||
|
||||
|
||||
def _validate_job(job: Job) -> list[str]:
|
||||
errors = []
|
||||
if not job.prompt or not job.prompt.strip():
|
||||
errors.append("Prompt is required")
|
||||
if not job.ground_truth_asset_id:
|
||||
errors.append("Ground truth image is required")
|
||||
if job.mode == "animate":
|
||||
if job.submode not in ("move", "mix"):
|
||||
errors.append("Animate mode requires submode 'move' or 'mix'")
|
||||
elif job.mode == "audio":
|
||||
if not job.audio_asset_id:
|
||||
errors.append("Audio mode requires an audio file")
|
||||
else:
|
||||
errors.append("Unknown mode")
|
||||
return errors
|
||||
|
||||
|
||||
async def run_job(job_id: str) -> None:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if not job:
|
||||
return
|
||||
|
||||
_set_status(db, job, "validating")
|
||||
errors = _validate_job(job)
|
||||
if errors:
|
||||
_set_status(db, job, "failed", "; ".join(errors))
|
||||
return
|
||||
|
||||
_set_status(db, job, "uploading_assets")
|
||||
gt_name = await _upload_asset_to_comfy(db, job.ground_truth_asset_id)
|
||||
motion_name = await _upload_asset_to_comfy(db, job.motion_asset_id)
|
||||
audio_name = await _upload_asset_to_comfy(db, job.audio_asset_id)
|
||||
pose_name = await _upload_asset_to_comfy(db, job.pose_asset_id)
|
||||
ref_names = []
|
||||
if job.reference_asset_ids_json:
|
||||
for ref_id in json.loads(job.reference_asset_ids_json):
|
||||
uploaded = await _upload_asset_to_comfy(db, ref_id)
|
||||
if uploaded:
|
||||
ref_names.append(uploaded)
|
||||
|
||||
settings_dict = json.loads(job.settings_json) if job.settings_json else {}
|
||||
binder = WorkflowBinder(select_template_name(job.mode, job.submode))
|
||||
if "PLACEHOLDER" in binder.status.upper():
|
||||
raise RuntimeError(
|
||||
f"Workflow template '{select_template_name(job.mode, job.submode)}' is still a placeholder. "
|
||||
"Replace it with the production ComfyUI export before running real generations."
|
||||
)
|
||||
raw_seed = settings_dict.get("seed", 0)
|
||||
seed = raw_seed if isinstance(raw_seed, int) and raw_seed >= 0 else 0
|
||||
|
||||
params = {
|
||||
"positive_prompt": job.prompt,
|
||||
"negative_prompt": job.negative_prompt or "",
|
||||
"ground_truth": gt_name,
|
||||
"motion_video": motion_name,
|
||||
"audio": audio_name,
|
||||
"pose_sheet": pose_name,
|
||||
"reference_image": ref_names[0] if ref_names else None,
|
||||
"seed": seed,
|
||||
"steps": settings_dict.get("steps", 20),
|
||||
"cfg": settings_dict.get("cfg", 7.0),
|
||||
}
|
||||
workflow = binder.bind(params)
|
||||
await _validate_runtime_models(workflow)
|
||||
job.workflow_template_name = select_template_name(job.mode, job.submode)
|
||||
job.workflow_template_version = binder.version
|
||||
db.commit()
|
||||
|
||||
_set_status(db, job, "queued")
|
||||
prompt_id = await comfy_client.submit_prompt(workflow, client_id=str(uuid.uuid4()))
|
||||
job.comfy_prompt_id = prompt_id
|
||||
db.commit()
|
||||
_add_event(db, job.id, "submitted", f"ComfyUI prompt_id: {prompt_id}")
|
||||
|
||||
_set_status(db, job, "executing")
|
||||
history = {}
|
||||
for _ in range(360):
|
||||
await asyncio.sleep(5)
|
||||
history = await _get_history_with_fallback(prompt_id)
|
||||
history_error = _extract_history_error(history)
|
||||
if history_error:
|
||||
_set_status(db, job, "failed", history_error)
|
||||
return
|
||||
if history.get("status", {}).get("completed"):
|
||||
break
|
||||
else:
|
||||
_set_status(db, job, "failed", "Timed out waiting for ComfyUI")
|
||||
return
|
||||
|
||||
_set_status(db, job, "collecting_outputs")
|
||||
await _collect_outputs_from_history(db, job, history)
|
||||
_set_status(db, job, "completed")
|
||||
except Exception as exc:
|
||||
logger.exception("Job %s failed: %s", job_id, exc)
|
||||
job = db.query(Job).filter(Job.id == job_id).first()
|
||||
if job:
|
||||
_set_status(db, job, "failed", str(exc))
|
||||
finally:
|
||||
db.close()
|
||||
118
backend/app/services/storage.py
Normal file
118
backend/app/services/storage.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import subprocess
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import aiofiles
|
||||
from fastapi import UploadFile
|
||||
from PIL import Image
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class LocalStorageService:
|
||||
def __init__(self, root: str):
|
||||
self.root = Path(root)
|
||||
self.root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async def save_upload(self, upload: UploadFile, subfolder: str) -> tuple[str, int]:
|
||||
dest_dir = self.root / subfolder
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
ext = Path(upload.filename or "file").suffix
|
||||
filename = f"{uuid.uuid4()}{ext}"
|
||||
dest_path = dest_dir / filename
|
||||
content = await upload.read()
|
||||
async with aiofiles.open(dest_path, "wb") as handle:
|
||||
await handle.write(content)
|
||||
return str(dest_path.relative_to(self.root)).replace("\\", "/"), len(content)
|
||||
|
||||
def save_bytes(self, data: bytes, subfolder: str, filename: str) -> str:
|
||||
dest_dir = self.root / subfolder
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
dest_path = dest_dir / filename
|
||||
with open(dest_path, "wb") as handle:
|
||||
handle.write(data)
|
||||
return str(dest_path.relative_to(self.root)).replace("\\", "/")
|
||||
|
||||
def absolute_path(self, relative_path: str) -> Path:
|
||||
return self.root / relative_path
|
||||
|
||||
def delete_relative_path(self, relative_path: Optional[str]) -> None:
|
||||
if not relative_path:
|
||||
return
|
||||
abs_path = self.absolute_path(relative_path)
|
||||
try:
|
||||
if abs_path.exists():
|
||||
abs_path.unlink()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def generate_thumbnail(self, image_path: str, thumb_subfolder: str) -> Optional[str]:
|
||||
try:
|
||||
abs_path = self.absolute_path(image_path)
|
||||
with Image.open(abs_path) as img:
|
||||
img.thumbnail((400, 400))
|
||||
thumb_dir = self.root / thumb_subfolder
|
||||
thumb_dir.mkdir(parents=True, exist_ok=True)
|
||||
thumb_name = f"thumb_{Path(image_path).stem}.jpg"
|
||||
thumb_path = thumb_dir / thumb_name
|
||||
img.convert("RGB").save(thumb_path, "JPEG", quality=80)
|
||||
return str(thumb_path.relative_to(self.root)).replace("\\", "/")
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def generate_video_thumbnail(self, video_path: str, thumb_subfolder: str) -> Optional[str]:
|
||||
abs_path = self.absolute_path(video_path)
|
||||
thumb_dir = self.root / thumb_subfolder
|
||||
thumb_dir.mkdir(parents=True, exist_ok=True)
|
||||
thumb_name = f"thumb_{Path(video_path).stem}.jpg"
|
||||
thumb_path = thumb_dir / thumb_name
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(abs_path),
|
||||
"-ss",
|
||||
"00:00:00.500",
|
||||
"-vframes",
|
||||
"1",
|
||||
str(thumb_path),
|
||||
],
|
||||
capture_output=True,
|
||||
timeout=30,
|
||||
check=False,
|
||||
)
|
||||
if thumb_path.exists():
|
||||
return str(thumb_path.relative_to(self.root)).replace("\\", "/")
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
def detect_duration_seconds(self, relative_path: str) -> Optional[float]:
|
||||
abs_path = self.absolute_path(relative_path)
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-show_entries",
|
||||
"format=duration",
|
||||
"-of",
|
||||
"default=noprint_wrappers=1:nokey=1",
|
||||
str(abs_path),
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
timeout=20,
|
||||
check=True,
|
||||
)
|
||||
return round(float(result.stdout.strip()), 3)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
asset_storage = LocalStorageService(settings.ASSET_STORAGE_ROOT)
|
||||
output_storage = LocalStorageService(settings.OUTPUT_STORAGE_ROOT)
|
||||
64
backend/app/services/workflow_binder.py
Normal file
64
backend/app/services/workflow_binder.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WORKFLOWS_ROOT = Path(__file__).parents[3] / "workflows"
|
||||
_REGISTRY: Dict[str, Path] = {}
|
||||
|
||||
|
||||
def _discover() -> None:
|
||||
_REGISTRY.clear()
|
||||
for path in WORKFLOWS_ROOT.rglob("*.json"):
|
||||
try:
|
||||
with open(path, encoding="utf-8") as handle:
|
||||
data = json.load(handle)
|
||||
meta = data.get("__animatrix_meta__", {})
|
||||
_REGISTRY[meta.get("name") or path.stem] = path
|
||||
except Exception as exc:
|
||||
logger.warning("Could not load workflow %s: %s", path, exc)
|
||||
|
||||
|
||||
_discover()
|
||||
|
||||
|
||||
def select_template_name(mode: str, submode: Optional[str]) -> str:
|
||||
if mode == "animate":
|
||||
return f"wan22_animate_{submode or 'move'}"
|
||||
if mode == "audio":
|
||||
return "wan22_s2v"
|
||||
raise ValueError(f"Unknown mode: {mode}")
|
||||
|
||||
|
||||
class WorkflowBinder:
|
||||
def __init__(self, template_name: str):
|
||||
if template_name not in _REGISTRY:
|
||||
_discover()
|
||||
if template_name not in _REGISTRY:
|
||||
raise FileNotFoundError(
|
||||
f"Workflow template '{template_name}' not found in {WORKFLOWS_ROOT}. Available: {list(_REGISTRY.keys())}"
|
||||
)
|
||||
with open(_REGISTRY[template_name], encoding="utf-8") as handle:
|
||||
self._raw = json.load(handle)
|
||||
self._meta = self._raw.get("__animatrix_meta__", {})
|
||||
self._param_nodes = self._meta.get("param_nodes", {})
|
||||
self.version = self._meta.get("version", "unknown")
|
||||
self.status = self._meta.get("status", "")
|
||||
|
||||
def bind(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
workflow = copy.deepcopy(self._raw)
|
||||
workflow.pop("__animatrix_meta__", None)
|
||||
for param_key, value in params.items():
|
||||
if value is None:
|
||||
continue
|
||||
node_spec = self._param_nodes.get(param_key)
|
||||
if not node_spec:
|
||||
continue
|
||||
node_id = str(node_spec["node_id"])
|
||||
input_name = node_spec["input"]
|
||||
if node_id in workflow:
|
||||
workflow[node_id]["inputs"][input_name] = value
|
||||
return workflow
|
||||
13
backend/requirements.txt
Normal file
13
backend/requirements.txt
Normal file
@@ -0,0 +1,13 @@
|
||||
fastapi==0.111.0
|
||||
uvicorn[standard]==0.29.0
|
||||
sqlalchemy==2.0.30
|
||||
alembic==1.13.1
|
||||
pydantic==2.7.1
|
||||
pydantic-settings==2.2.1
|
||||
passlib==1.7.4
|
||||
python-jose[cryptography]==3.3.0
|
||||
python-multipart==0.0.9
|
||||
httpx==0.27.0
|
||||
aiofiles==23.2.1
|
||||
Pillow==11.2.1
|
||||
python-dotenv==1.0.1
|
||||
5
backend/run.py
Normal file
5
backend/run.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import uvicorn
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)
|
||||
0
backend/storage/.gitkeep
Normal file
0
backend/storage/.gitkeep
Normal file
Reference in New Issue
Block a user