Files
Project_Velocity/comfy_engine/scripts/dreamweaver_batch_processor.py

499 lines
18 KiB
Python

#!/usr/bin/env python3
"""
Dream Weaver Batch Processor
============================
Automated batch processing script for Dream Weaver interior restyling workflow.
Handles directory monitoring, automatic mask caching, and queue management.
Target Hardware: Dual NVIDIA RTX PRO 6000 Blackwell (96GB GDDR7 each)
Author: Project Velocity Team
Version: 1.0.0
"""
import os
import sys
import json
import time
import hashlib
import asyncio
import argparse
import logging
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, asdict
import requests
import websockets
import aiofiles
from watchdog.observers import Observer
from watchdog.events import FileSystemEventHandler
# Configuration
CONFIG = {
"comfyui_server": "http://localhost:8188",
"comfyui_ws": "ws://localhost:8188/ws",
"input_directory": "Project_Velocity/comfy_engine/test_inputs/",
"output_directory": "Project_Velocity/comfy_engine/test_outputs/",
"cache_directory": "Project_Velocity/comfy_engine/cache/masks/",
"workflow_phase1": "Project_Velocity/comfy_engine/workflows/dreamweaver_phase1_depth.json",
"workflow_phase2": "Project_Velocity/comfy_engine/workflows/dreamweaver_phase2_multicontrol.json",
"workflow_phase3": "Project_Velocity/comfy_engine/workflows/dreamweaver_phase3_batch.json",
"batch_size": 8,
"target_resolution": (1024, 1024),
"enable_mask_cache": True,
"gpu_sharding": True,
"dual_gpu": True,
}
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('dreamweaver_batch.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger('DreamWeaver')
@dataclass
class ProcessingJob:
"""Represents a single image processing job."""
job_id: str
input_path: str
output_path: str
style_template: str
phase: int
status: str = "pending"
created_at: datetime = None
started_at: datetime = None
completed_at: datetime = None
error_message: str = None
mask_cached: bool = False
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.now()
def to_dict(self) -> Dict:
return {
"job_id": self.job_id,
"input_path": self.input_path,
"output_path": self.output_path,
"style_template": self.style_template,
"phase": self.phase,
"status": self.status,
"created_at": self.created_at.isoformat() if self.created_at else None,
"started_at": self.started_at.isoformat() if self.started_at else None,
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
"error_message": self.error_message,
"mask_cached": self.mask_cached
}
class MaskCacheManager:
"""Manages caching of segmentation masks for improved performance."""
def __init__(self, cache_dir: str):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
logger.info(f"Mask cache initialized at: {self.cache_dir}")
def _get_cache_key(self, image_path: str) -> str:
"""Generate cache key from image content hash."""
hasher = hashlib.md5()
with open(image_path, 'rb') as f:
hasher.update(f.read())
return hasher.hexdigest()
def get_cached_mask(self, image_path: str) -> Optional[str]:
"""Retrieve cached mask path if it exists."""
cache_key = self._get_cache_key(image_path)
cached_path = self.cache_dir / f"{cache_key}.png"
if cached_path.exists():
logger.info(f"Cache hit for {image_path}")
return str(cached_path)
return None
def cache_mask(self, image_path: str, mask_path: str) -> str:
"""Cache a mask file for future use."""
cache_key = self._get_cache_key(image_path)
cached_path = self.cache_dir / f"{cache_key}.png"
import shutil
shutil.copy2(mask_path, cached_path)
logger.info(f"Cached mask for {image_path} at {cached_path}")
return str(cached_path)
class ComfyUIClient:
"""Client for communicating with ComfyUI server."""
def __init__(self, server_url: str, ws_url: str):
self.server_url = server_url
self.ws_url = ws_url
self.client_id = self._generate_client_id()
logger.info(f"ComfyUI client initialized with ID: {self.client_id}")
def _generate_client_id(self) -> str:
"""Generate unique client ID."""
return f"dreamweaver_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{os.urandom(4).hex()}"
async def submit_workflow(self, workflow: Dict, input_image: str) -> str:
"""Submit a workflow to ComfyUI queue."""
# Update workflow with input image
for node_id, node in workflow.items():
if node.get("class_type") == "LoadImage":
node["inputs"]["image"] = input_image
if node.get("class_type") == "LoadImageBatch":
node["inputs"]["directory"] = os.path.dirname(input_image)
payload = {
"prompt": workflow,
"client_id": self.client_id
}
response = requests.post(
f"{self.server_url}/prompt",
json=payload
)
response.raise_for_status()
result = response.json()
prompt_id = result.get("prompt_id")
logger.info(f"Submitted workflow with prompt_id: {prompt_id}")
return prompt_id
async def get_queue_status(self) -> Dict:
"""Get current queue status."""
response = requests.get(f"{self.server_url}/queue")
return response.json()
async def wait_for_completion(self, prompt_id: str, timeout: int = 300) -> bool:
"""Wait for workflow completion via WebSocket."""
start_time = time.time()
async with websockets.connect(
f"{self.ws_url}?clientId={self.client_id}"
) as websocket:
while time.time() - start_time < timeout:
try:
message = await asyncio.wait_for(
websocket.recv(),
timeout=5.0
)
data = json.loads(message)
if data.get("type") == "executing":
if data["data"].get("prompt_id") == prompt_id:
node_id = data["data"].get("node")
logger.debug(f"Executing node: {node_id}")
elif data.get("type") == "completed":
if data["data"].get("prompt_id") == prompt_id:
logger.info(f"Workflow {prompt_id} completed")
return True
elif data.get("type") == "error":
logger.error(f"Workflow error: {data}")
return False
except asyncio.TimeoutError:
continue
logger.warning(f"Workflow {prompt_id} timed out")
return False
class BatchProcessor:
"""Main batch processing controller."""
def __init__(self, config: Dict):
self.config = config
self.queue: List[ProcessingJob] = []
self.processing = False
self.cache_manager = MaskCacheManager(config["cache_directory"])
self.comfy_client = ComfyUIClient(
config["comfyui_server"],
config["comfyui_ws"]
)
# Load workflow templates
self.workflows = self._load_workflows()
# Ensure output directory exists
Path(config["output_directory"]).mkdir(parents=True, exist_ok=True)
def _load_workflows(self) -> Dict[int, Dict]:
"""Load workflow JSON files."""
workflows = {}
workflow_paths = {
1: self.config["workflow_phase1"],
2: self.config["workflow_phase2"],
3: self.config["workflow_phase3"]
}
for phase, path in workflow_paths.items():
try:
with open(path, 'r') as f:
workflows[phase] = json.load(f)
logger.info(f"Loaded Phase {phase} workflow")
except Exception as e:
logger.error(f"Failed to load Phase {phase} workflow: {e}")
return workflows
def add_job(self, input_path: str, style_template: str = "scandinavian_minimalist", phase: int = 1) -> str:
"""Add a new processing job to the queue."""
job_id = hashlib.md5(f"{input_path}_{time.time()}".encode()).hexdigest()[:12]
output_filename = f"{Path(input_path).stem}_restyled_{job_id}.png"
output_path = os.path.join(self.config["output_directory"], output_filename)
job = ProcessingJob(
job_id=job_id,
input_path=input_path,
output_path=output_path,
style_template=style_template,
phase=phase
)
# Check if mask is cached
if self.config["enable_mask_cache"]:
cached_mask = self.cache_manager.get_cached_mask(input_path)
job.mask_cached = cached_mask is not None
self.queue.append(job)
logger.info(f"Added job {job_id} to queue. Queue size: {len(self.queue)}")
return job_id
async def process_single(self, job: ProcessingJob) -> bool:
"""Process a single job."""
job.status = "processing"
job.started_at = datetime.now()
try:
logger.info(f"Processing job {job.job_id}: {job.input_path}")
# Get workflow for phase
workflow = self.workflows.get(job.phase)
if not workflow:
raise ValueError(f"Workflow for phase {job.phase} not found")
# Submit to ComfyUI
prompt_id = await self.comfy_client.submit_workflow(
workflow,
job.input_path
)
# Wait for completion
success = await self.comfy_client.wait_for_completion(prompt_id)
if success:
job.status = "completed"
job.completed_at = datetime.now()
logger.info(f"Job {job.job_id} completed successfully")
return True
else:
job.status = "failed"
job.error_message = "Workflow execution failed or timed out"
logger.error(f"Job {job.job_id} failed")
return False
except Exception as e:
job.status = "failed"
job.error_message = str(e)
logger.error(f"Error processing job {job.job_id}: {e}")
return False
async def process_batch(self, jobs: List[ProcessingJob]) -> List[bool]:
"""Process multiple jobs in batch (Phase 3)."""
if not jobs:
return []
logger.info(f"Processing batch of {len(jobs)} jobs")
results = []
# For batch processing, use Phase 3 workflow
workflow = self.workflows.get(3)
if not workflow:
logger.warning("Phase 3 workflow not available, processing sequentially")
for job in jobs:
result = await self.process_single(job)
results.append(result)
return results
# TODO: Implement true batch processing with Phase 3 workflow
# This would require grouping images and processing together
for job in jobs:
result = await self.process_single(job)
results.append(result)
return results
async def run(self):
"""Main processing loop."""
logger.info("Starting batch processor")
self.processing = True
while self.processing:
# Get pending jobs
pending_jobs = [j for j in self.queue if j.status == "pending"]
if not pending_jobs:
await asyncio.sleep(1)
continue
# Check if batch processing is appropriate
if len(pending_jobs) >= self.config["batch_size"] and self.config.get("dual_gpu"):
# Process in batches for Phase 3
batch = pending_jobs[:self.config["batch_size"]]
await self.process_batch(batch)
else:
# Process single job with appropriate phase
job = pending_jobs[0]
await self.process_single(job)
def stop(self):
"""Stop the processing loop."""
logger.info("Stopping batch processor")
self.processing = False
def get_status(self) -> Dict:
"""Get current processing status."""
total = len(self.queue)
pending = len([j for j in self.queue if j.status == "pending"])
processing = len([j for j in self.queue if j.status == "processing"])
completed = len([j for j in self.queue if j.status == "completed"])
failed = len([j for j in self.queue if j.status == "failed"])
return {
"total_jobs": total,
"pending": pending,
"processing": processing,
"completed": completed,
"failed": failed,
"is_running": self.processing
}
class InputDirectoryHandler(FileSystemEventHandler):
"""Handles new file events in input directory."""
def __init__(self, processor: BatchProcessor):
self.processor = processor
def on_created(self, event):
if not event.is_directory:
file_path = event.src_path
if file_path.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')):
logger.info(f"New image detected: {file_path}")
self.processor.add_job(file_path)
def load_style_template(template_name: str) -> str:
"""Load a style template from prompts directory."""
template_path = Path("Project_Velocity/comfy_engine/prompts/") / f"{template_name}.txt"
if template_path.exists():
with open(template_path, 'r') as f:
content = f.read()
# Extract positive prompt
lines = content.split('\n')
positive_lines = []
in_positive = False
for line in lines:
if 'POSITIVE PROMPT:' in line:
in_positive = True
continue
if in_positive and line.startswith('Style Weight:'):
break
if in_positive and line.strip() and not line.startswith('-'):
positive_lines.append(line.strip())
return ' '.join(positive_lines)
return ""
async def main():
"""Main entry point."""
parser = argparse.ArgumentParser(
description="Dream Weaver Batch Processor"
)
parser.add_argument(
"--monitor",
action="store_true",
help="Enable directory monitoring mode"
)
parser.add_argument(
"--input",
type=str,
help="Single input image to process"
)
parser.add_argument(
"--style",
type=str,
default="scandinavian_minimalist",
choices=["scandinavian_minimalist", "art_deco_luxe", "cyberpunk_neon", "biophilic_organic", "japandi_fusion"],
help="Style template to apply"
)
parser.add_argument(
"--phase",
type=int,
default=1,
choices=[1, 2, 3],
help="Processing phase to use"
)
parser.add_argument(
"--batch",
action="store_true",
help="Process all images in input directory"
)
args = parser.parse_args()
# Initialize processor
processor = BatchProcessor(CONFIG)
if args.input:
# Process single image
job_id = processor.add_job(args.input, args.style, args.phase)
await processor.process_single(processor.queue[-1])
print(f"Processed image: {args.input}")
print(f"Job ID: {job_id}")
elif args.batch:
# Process all images in directory
input_dir = Path(CONFIG["input_directory"])
image_files = list(input_dir.glob("*.jpg")) + list(input_dir.glob("*.png"))
for img_file in image_files:
processor.add_job(str(img_file), args.style, args.phase)
await processor.run()
elif args.monitor:
# Start directory monitoring
event_handler = InputDirectoryHandler(processor)
observer = Observer()
observer.schedule(
event_handler,
CONFIG["input_directory"],
recursive=False
)
observer.start()
logger.info(f"Started monitoring: {CONFIG['input_directory']}")
try:
# Run processor
await processor.run()
except KeyboardInterrupt:
processor.stop()
observer.stop()
observer.join()
else:
print("No action specified. Use --help for usage information.")
if __name__ == "__main__":
asyncio.run(main())