#3 Self-approved and unit tests passed with flying colors. Co-authored-by: Sagnik <sagnik7896@gmail.com> Reviewed-on: #5
499 lines
17 KiB
Python
499 lines
17 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Dream Weaver Batch Processor
|
|
============================
|
|
Automated batch processing script for Dream Weaver interior restyling workflow.
|
|
Handles directory monitoring, automatic mask caching, and queue management.
|
|
|
|
Target Hardware: Dual NVIDIA RTX PRO 6000 Blackwell (96GB GDDR7 each)
|
|
Author: Project Velocity Team
|
|
Version: 1.0.0
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import time
|
|
import hashlib
|
|
import asyncio
|
|
import argparse
|
|
import logging
|
|
from pathlib import Path
|
|
from datetime import datetime
|
|
from typing import Dict, List, Optional, Tuple
|
|
from dataclasses import dataclass, asdict
|
|
import requests
|
|
import websockets
|
|
import aiofiles
|
|
from watchdog.observers import Observer
|
|
from watchdog.events import FileSystemEventHandler
|
|
|
|
# Configuration
|
|
CONFIG = {
|
|
"comfyui_server": "http://localhost:8188",
|
|
"comfyui_ws": "ws://localhost:8188/ws",
|
|
"input_directory": "Project_Velocity/comfy_engine/test_inputs/",
|
|
"output_directory": "Project_Velocity/comfy_engine/test_outputs/",
|
|
"cache_directory": "Project_Velocity/comfy_engine/cache/masks/",
|
|
"workflow_phase1": "Project_Velocity/comfy_engine/workflows/dreamweaver_phase1_depth.json",
|
|
"workflow_phase2": "Project_Velocity/comfy_engine/workflows/dreamweaver_phase2_multicontrol.json",
|
|
"workflow_phase3": "Project_Velocity/comfy_engine/workflows/dreamweaver_phase3_batch.json",
|
|
"batch_size": 8,
|
|
"target_resolution": (1024, 1024),
|
|
"enable_mask_cache": True,
|
|
"gpu_sharding": True,
|
|
"dual_gpu": True,
|
|
}
|
|
|
|
# Setup logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.FileHandler('dreamweaver_batch.log'),
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logger = logging.getLogger('DreamWeaver')
|
|
|
|
|
|
@dataclass
|
|
class ProcessingJob:
|
|
"""Represents a single image processing job."""
|
|
job_id: str
|
|
input_path: str
|
|
output_path: str
|
|
style_template: str
|
|
phase: int
|
|
status: str = "pending"
|
|
created_at: datetime = None
|
|
started_at: datetime = None
|
|
completed_at: datetime = None
|
|
error_message: str = None
|
|
mask_cached: bool = False
|
|
|
|
def __post_init__(self):
|
|
if self.created_at is None:
|
|
self.created_at = datetime.now()
|
|
|
|
def to_dict(self) -> Dict:
|
|
return {
|
|
"job_id": self.job_id,
|
|
"input_path": self.input_path,
|
|
"output_path": self.output_path,
|
|
"style_template": self.style_template,
|
|
"phase": self.phase,
|
|
"status": self.status,
|
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
|
"started_at": self.started_at.isoformat() if self.started_at else None,
|
|
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
|
"error_message": self.error_message,
|
|
"mask_cached": self.mask_cached
|
|
}
|
|
|
|
|
|
class MaskCacheManager:
|
|
"""Manages caching of segmentation masks for improved performance."""
|
|
|
|
def __init__(self, cache_dir: str):
|
|
self.cache_dir = Path(cache_dir)
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
logger.info(f"Mask cache initialized at: {self.cache_dir}")
|
|
|
|
def _get_cache_key(self, image_path: str) -> str:
|
|
"""Generate cache key from image content hash."""
|
|
hasher = hashlib.md5()
|
|
with open(image_path, 'rb') as f:
|
|
hasher.update(f.read())
|
|
return hasher.hexdigest()
|
|
|
|
def get_cached_mask(self, image_path: str) -> Optional[str]:
|
|
"""Retrieve cached mask path if it exists."""
|
|
cache_key = self._get_cache_key(image_path)
|
|
cached_path = self.cache_dir / f"{cache_key}.png"
|
|
|
|
if cached_path.exists():
|
|
logger.info(f"Cache hit for {image_path}")
|
|
return str(cached_path)
|
|
return None
|
|
|
|
def cache_mask(self, image_path: str, mask_path: str) -> str:
|
|
"""Cache a mask file for future use."""
|
|
cache_key = self._get_cache_key(image_path)
|
|
cached_path = self.cache_dir / f"{cache_key}.png"
|
|
|
|
import shutil
|
|
shutil.copy2(mask_path, cached_path)
|
|
logger.info(f"Cached mask for {image_path} at {cached_path}")
|
|
return str(cached_path)
|
|
|
|
|
|
class ComfyUIClient:
|
|
"""Client for communicating with ComfyUI server."""
|
|
|
|
def __init__(self, server_url: str, ws_url: str):
|
|
self.server_url = server_url
|
|
self.ws_url = ws_url
|
|
self.client_id = self._generate_client_id()
|
|
logger.info(f"ComfyUI client initialized with ID: {self.client_id}")
|
|
|
|
def _generate_client_id(self) -> str:
|
|
"""Generate unique client ID."""
|
|
return f"dreamweaver_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{os.urandom(4).hex()}"
|
|
|
|
async def submit_workflow(self, workflow: Dict, input_image: str) -> str:
|
|
"""Submit a workflow to ComfyUI queue."""
|
|
# Update workflow with input image
|
|
for node_id, node in workflow.items():
|
|
if node.get("class_type") == "LoadImage":
|
|
node["inputs"]["image"] = input_image
|
|
if node.get("class_type") == "LoadImageBatch":
|
|
node["inputs"]["directory"] = os.path.dirname(input_image)
|
|
|
|
payload = {
|
|
"prompt": workflow,
|
|
"client_id": self.client_id
|
|
}
|
|
|
|
response = requests.post(
|
|
f"{self.server_url}/prompt",
|
|
json=payload
|
|
)
|
|
response.raise_for_status()
|
|
result = response.json()
|
|
prompt_id = result.get("prompt_id")
|
|
logger.info(f"Submitted workflow with prompt_id: {prompt_id}")
|
|
return prompt_id
|
|
|
|
async def get_queue_status(self) -> Dict:
|
|
"""Get current queue status."""
|
|
response = requests.get(f"{self.server_url}/queue")
|
|
return response.json()
|
|
|
|
async def wait_for_completion(self, prompt_id: str, timeout: int = 300) -> bool:
|
|
"""Wait for workflow completion via WebSocket."""
|
|
start_time = time.time()
|
|
|
|
async with websockets.connect(
|
|
f"{self.ws_url}?clientId={self.client_id}"
|
|
) as websocket:
|
|
while time.time() - start_time < timeout:
|
|
try:
|
|
message = await asyncio.wait_for(
|
|
websocket.recv(),
|
|
timeout=5.0
|
|
)
|
|
data = json.loads(message)
|
|
|
|
if data.get("type") == "executing":
|
|
if data["data"].get("prompt_id") == prompt_id:
|
|
node_id = data["data"].get("node")
|
|
logger.debug(f"Executing node: {node_id}")
|
|
|
|
elif data.get("type") == "completed":
|
|
if data["data"].get("prompt_id") == prompt_id:
|
|
logger.info(f"Workflow {prompt_id} completed")
|
|
return True
|
|
|
|
elif data.get("type") == "error":
|
|
logger.error(f"Workflow error: {data}")
|
|
return False
|
|
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
|
|
logger.warning(f"Workflow {prompt_id} timed out")
|
|
return False
|
|
|
|
|
|
class BatchProcessor:
|
|
"""Main batch processing controller."""
|
|
|
|
def __init__(self, config: Dict):
|
|
self.config = config
|
|
self.queue: List[ProcessingJob] = []
|
|
self.processing = False
|
|
self.cache_manager = MaskCacheManager(config["cache_directory"])
|
|
self.comfy_client = ComfyUIClient(
|
|
config["comfyui_server"],
|
|
config["comfyui_ws"]
|
|
)
|
|
|
|
# Load workflow templates
|
|
self.workflows = self._load_workflows()
|
|
|
|
# Ensure output directory exists
|
|
Path(config["output_directory"]).mkdir(parents=True, exist_ok=True)
|
|
|
|
def _load_workflows(self) -> Dict[int, Dict]:
|
|
"""Load workflow JSON files."""
|
|
workflows = {}
|
|
workflow_paths = {
|
|
1: self.config["workflow_phase1"],
|
|
2: self.config["workflow_phase2"],
|
|
3: self.config["workflow_phase3"]
|
|
}
|
|
|
|
for phase, path in workflow_paths.items():
|
|
try:
|
|
with open(path, 'r') as f:
|
|
workflows[phase] = json.load(f)
|
|
logger.info(f"Loaded Phase {phase} workflow")
|
|
except Exception as e:
|
|
logger.error(f"Failed to load Phase {phase} workflow: {e}")
|
|
|
|
return workflows
|
|
|
|
def add_job(self, input_path: str, style_template: str = "scandinavian_minimalist", phase: int = 1) -> str:
|
|
"""Add a new processing job to the queue."""
|
|
job_id = hashlib.md5(f"{input_path}_{time.time()}".encode()).hexdigest()[:12]
|
|
output_filename = f"{Path(input_path).stem}_restyled_{job_id}.png"
|
|
output_path = os.path.join(self.config["output_directory"], output_filename)
|
|
|
|
job = ProcessingJob(
|
|
job_id=job_id,
|
|
input_path=input_path,
|
|
output_path=output_path,
|
|
style_template=style_template,
|
|
phase=phase
|
|
)
|
|
|
|
# Check if mask is cached
|
|
if self.config["enable_mask_cache"]:
|
|
cached_mask = self.cache_manager.get_cached_mask(input_path)
|
|
job.mask_cached = cached_mask is not None
|
|
|
|
self.queue.append(job)
|
|
logger.info(f"Added job {job_id} to queue. Queue size: {len(self.queue)}")
|
|
return job_id
|
|
|
|
async def process_single(self, job: ProcessingJob) -> bool:
|
|
"""Process a single job."""
|
|
job.status = "processing"
|
|
job.started_at = datetime.now()
|
|
|
|
try:
|
|
logger.info(f"Processing job {job.job_id}: {job.input_path}")
|
|
|
|
# Get workflow for phase
|
|
workflow = self.workflows.get(job.phase)
|
|
if not workflow:
|
|
raise ValueError(f"Workflow for phase {job.phase} not found")
|
|
|
|
# Submit to ComfyUI
|
|
prompt_id = await self.comfy_client.submit_workflow(
|
|
workflow,
|
|
job.input_path
|
|
)
|
|
|
|
# Wait for completion
|
|
success = await self.comfy_client.wait_for_completion(prompt_id)
|
|
|
|
if success:
|
|
job.status = "completed"
|
|
job.completed_at = datetime.now()
|
|
logger.info(f"Job {job.job_id} completed successfully")
|
|
return True
|
|
else:
|
|
job.status = "failed"
|
|
job.error_message = "Workflow execution failed or timed out"
|
|
logger.error(f"Job {job.job_id} failed")
|
|
return False
|
|
|
|
except Exception as e:
|
|
job.status = "failed"
|
|
job.error_message = str(e)
|
|
logger.error(f"Error processing job {job.job_id}: {e}")
|
|
return False
|
|
|
|
async def process_batch(self, jobs: List[ProcessingJob]) -> List[bool]:
|
|
"""Process multiple jobs in batch (Phase 3)."""
|
|
if not jobs:
|
|
return []
|
|
|
|
logger.info(f"Processing batch of {len(jobs)} jobs")
|
|
results = []
|
|
|
|
# For batch processing, use Phase 3 workflow
|
|
workflow = self.workflows.get(3)
|
|
if not workflow:
|
|
logger.warning("Phase 3 workflow not available, processing sequentially")
|
|
for job in jobs:
|
|
result = await self.process_single(job)
|
|
results.append(result)
|
|
return results
|
|
|
|
# TODO: Implement true batch processing with Phase 3 workflow
|
|
# This would require grouping images and processing together
|
|
for job in jobs:
|
|
result = await self.process_single(job)
|
|
results.append(result)
|
|
|
|
return results
|
|
|
|
async def run(self):
|
|
"""Main processing loop."""
|
|
logger.info("Starting batch processor")
|
|
self.processing = True
|
|
|
|
while self.processing:
|
|
# Get pending jobs
|
|
pending_jobs = [j for j in self.queue if j.status == "pending"]
|
|
|
|
if not pending_jobs:
|
|
await asyncio.sleep(1)
|
|
continue
|
|
|
|
# Check if batch processing is appropriate
|
|
if len(pending_jobs) >= self.config["batch_size"] and self.config.get("dual_gpu"):
|
|
# Process in batches for Phase 3
|
|
batch = pending_jobs[:self.config["batch_size"]]
|
|
await self.process_batch(batch)
|
|
else:
|
|
# Process single job with appropriate phase
|
|
job = pending_jobs[0]
|
|
await self.process_single(job)
|
|
|
|
def stop(self):
|
|
"""Stop the processing loop."""
|
|
logger.info("Stopping batch processor")
|
|
self.processing = False
|
|
|
|
def get_status(self) -> Dict:
|
|
"""Get current processing status."""
|
|
total = len(self.queue)
|
|
pending = len([j for j in self.queue if j.status == "pending"])
|
|
processing = len([j for j in self.queue if j.status == "processing"])
|
|
completed = len([j for j in self.queue if j.status == "completed"])
|
|
failed = len([j for j in self.queue if j.status == "failed"])
|
|
|
|
return {
|
|
"total_jobs": total,
|
|
"pending": pending,
|
|
"processing": processing,
|
|
"completed": completed,
|
|
"failed": failed,
|
|
"is_running": self.processing
|
|
}
|
|
|
|
|
|
class InputDirectoryHandler(FileSystemEventHandler):
|
|
"""Handles new file events in input directory."""
|
|
|
|
def __init__(self, processor: BatchProcessor):
|
|
self.processor = processor
|
|
|
|
def on_created(self, event):
|
|
if not event.is_directory:
|
|
file_path = event.src_path
|
|
if file_path.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')):
|
|
logger.info(f"New image detected: {file_path}")
|
|
self.processor.add_job(file_path)
|
|
|
|
|
|
def load_style_template(template_name: str) -> str:
|
|
"""Load a style template from prompts directory."""
|
|
template_path = Path("Project_Velocity/comfy_engine/prompts/") / f"{template_name}.txt"
|
|
if template_path.exists():
|
|
with open(template_path, 'r') as f:
|
|
content = f.read()
|
|
# Extract positive prompt
|
|
lines = content.split('\n')
|
|
positive_lines = []
|
|
in_positive = False
|
|
for line in lines:
|
|
if 'POSITIVE PROMPT:' in line:
|
|
in_positive = True
|
|
continue
|
|
if in_positive and line.startswith('Style Weight:'):
|
|
break
|
|
if in_positive and line.strip() and not line.startswith('-'):
|
|
positive_lines.append(line.strip())
|
|
return ' '.join(positive_lines)
|
|
return ""
|
|
|
|
|
|
async def main():
|
|
"""Main entry point."""
|
|
parser = argparse.ArgumentParser(
|
|
description="Dream Weaver Batch Processor"
|
|
)
|
|
parser.add_argument(
|
|
"--monitor",
|
|
action="store_true",
|
|
help="Enable directory monitoring mode"
|
|
)
|
|
parser.add_argument(
|
|
"--input",
|
|
type=str,
|
|
help="Single input image to process"
|
|
)
|
|
parser.add_argument(
|
|
"--style",
|
|
type=str,
|
|
default="scandinavian_minimalist",
|
|
choices=["scandinavian_minimalist", "art_deco_luxe", "cyberpunk_neon", "biophilic_organic", "japandi_fusion"],
|
|
help="Style template to apply"
|
|
)
|
|
parser.add_argument(
|
|
"--phase",
|
|
type=int,
|
|
default=1,
|
|
choices=[1, 2, 3],
|
|
help="Processing phase to use"
|
|
)
|
|
parser.add_argument(
|
|
"--batch",
|
|
action="store_true",
|
|
help="Process all images in input directory"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Initialize processor
|
|
processor = BatchProcessor(CONFIG)
|
|
|
|
if args.input:
|
|
# Process single image
|
|
job_id = processor.add_job(args.input, args.style, args.phase)
|
|
await processor.process_single(processor.queue[-1])
|
|
print(f"Processed image: {args.input}")
|
|
print(f"Job ID: {job_id}")
|
|
|
|
elif args.batch:
|
|
# Process all images in directory
|
|
input_dir = Path(CONFIG["input_directory"])
|
|
image_files = list(input_dir.glob("*.jpg")) + list(input_dir.glob("*.png"))
|
|
|
|
for img_file in image_files:
|
|
processor.add_job(str(img_file), args.style, args.phase)
|
|
|
|
await processor.run()
|
|
|
|
elif args.monitor:
|
|
# Start directory monitoring
|
|
event_handler = InputDirectoryHandler(processor)
|
|
observer = Observer()
|
|
observer.schedule(
|
|
event_handler,
|
|
CONFIG["input_directory"],
|
|
recursive=False
|
|
)
|
|
observer.start()
|
|
logger.info(f"Started monitoring: {CONFIG['input_directory']}")
|
|
|
|
try:
|
|
# Run processor
|
|
await processor.run()
|
|
except KeyboardInterrupt:
|
|
processor.stop()
|
|
observer.stop()
|
|
|
|
observer.join()
|
|
else:
|
|
print("No action specified. Use --help for usage information.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(main())
|