Initial Animatrix import
This commit is contained in:
118
backend/app/api/routes/jobs.py
Normal file
118
backend/app/api/routes/jobs.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import json
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
from app.core.deps import get_current_user
|
||||
from app.db.session import get_db
|
||||
from app.models import Asset, Job, JobOutput, User
|
||||
from app.schemas import JobCreateRequest, JobListResponse, JobResponse
|
||||
from app.services.orchestrator import reconcile_job_outputs_if_missing, run_job
|
||||
from app.services.storage import output_storage
|
||||
|
||||
router = APIRouter(prefix="/api/jobs", tags=["jobs"])
|
||||
|
||||
|
||||
@router.post("/", response_model=JobResponse, status_code=201)
|
||||
async def create_job(
|
||||
payload: JobCreateRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
def assert_owns(asset_id: Optional[str], label: str):
|
||||
if asset_id:
|
||||
asset = (
|
||||
db.query(Asset)
|
||||
.filter(Asset.id == asset_id, Asset.owner_id == current_user.id, Asset.is_trashed.is_(False))
|
||||
.first()
|
||||
)
|
||||
if not asset:
|
||||
raise HTTPException(400, f"{label} asset not found or not owned by user")
|
||||
|
||||
assert_owns(payload.ground_truth_asset_id, "ground_truth")
|
||||
assert_owns(payload.motion_asset_id, "motion")
|
||||
assert_owns(payload.audio_asset_id, "audio")
|
||||
assert_owns(payload.pose_asset_id, "pose_sheet")
|
||||
for ref_id in payload.reference_asset_ids or []:
|
||||
assert_owns(ref_id, f"reference {ref_id}")
|
||||
|
||||
job = Job(
|
||||
owner_id=current_user.id,
|
||||
mode=payload.mode,
|
||||
submode=payload.submode,
|
||||
prompt=payload.prompt,
|
||||
negative_prompt=payload.negative_prompt,
|
||||
status="created",
|
||||
ground_truth_asset_id=payload.ground_truth_asset_id,
|
||||
motion_asset_id=payload.motion_asset_id,
|
||||
audio_asset_id=payload.audio_asset_id,
|
||||
pose_asset_id=payload.pose_asset_id,
|
||||
reference_asset_ids_json=json.dumps(payload.reference_asset_ids) if payload.reference_asset_ids else None,
|
||||
settings_json=json.dumps(payload.settings) if payload.settings else None,
|
||||
)
|
||||
db.add(job)
|
||||
db.commit()
|
||||
db.refresh(job)
|
||||
|
||||
background_tasks.add_task(run_job, job.id)
|
||||
return job
|
||||
|
||||
|
||||
@router.get("/", response_model=List[JobListResponse])
|
||||
def list_jobs(
|
||||
mode: Optional[str] = None,
|
||||
status: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
query = db.query(Job).filter(Job.owner_id == current_user.id)
|
||||
query = query.options(selectinload(Job.outputs))
|
||||
if mode:
|
||||
query = query.filter(Job.mode == mode)
|
||||
if status:
|
||||
query = query.filter(Job.status == status)
|
||||
return query.order_by(Job.created_at.desc()).limit(100).all()
|
||||
|
||||
|
||||
@router.get("/{job_id}", response_model=JobResponse)
|
||||
async def get_job(job_id: str, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)):
|
||||
job = (
|
||||
db.query(Job)
|
||||
.options(selectinload(Job.outputs), selectinload(Job.events))
|
||||
.filter(Job.id == job_id, Job.owner_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
if not job:
|
||||
raise HTTPException(404, "Job not found")
|
||||
if job.status == "completed" and not job.outputs and job.comfy_prompt_id:
|
||||
await reconcile_job_outputs_if_missing(job.id)
|
||||
db.expire_all()
|
||||
job = (
|
||||
db.query(Job)
|
||||
.options(selectinload(Job.outputs), selectinload(Job.events))
|
||||
.filter(Job.id == job_id, Job.owner_id == current_user.id)
|
||||
.first()
|
||||
)
|
||||
return job
|
||||
|
||||
|
||||
@router.get("/{job_id}/outputs/{output_id}/download")
|
||||
def download_output(
|
||||
job_id: str,
|
||||
output_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
job = db.query(Job).filter(Job.id == job_id, Job.owner_id == current_user.id).first()
|
||||
if not job:
|
||||
raise HTTPException(404, "Job not found")
|
||||
output = db.query(JobOutput).filter(JobOutput.id == output_id, JobOutput.job_id == job_id).first()
|
||||
if not output:
|
||||
raise HTTPException(404, "Output not found")
|
||||
abs_path = output_storage.absolute_path(output.file_path)
|
||||
if not abs_path.exists():
|
||||
raise HTTPException(404, "Output file not found")
|
||||
return FileResponse(str(abs_path))
|
||||
Reference in New Issue
Block a user