feat: Overlay the mathematical Sun Path over the live camera feed or 3D model view
This commit is contained in:
@@ -1,498 +1,498 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Dream Weaver Batch Processor
|
||||
============================
|
||||
Automated batch processing script for Dream Weaver interior restyling workflow.
|
||||
Handles directory monitoring, automatic mask caching, and queue management.
|
||||
|
||||
Target Hardware: Dual NVIDIA RTX PRO 6000 Blackwell (96GB GDDR7 each)
|
||||
Author: Project Velocity Team
|
||||
Version: 1.0.0
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import hashlib
|
||||
import asyncio
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass, asdict
|
||||
import requests
|
||||
import websockets
|
||||
import aiofiles
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
|
||||
# Configuration
|
||||
CONFIG = {
|
||||
"comfyui_server": "http://localhost:8188",
|
||||
"comfyui_ws": "ws://localhost:8188/ws",
|
||||
"input_directory": "Project_Velocity/comfy_engine/test_inputs/",
|
||||
"output_directory": "Project_Velocity/comfy_engine/test_outputs/",
|
||||
"cache_directory": "Project_Velocity/comfy_engine/cache/masks/",
|
||||
"workflow_phase1": "Project_Velocity/comfy_engine/workflows/dreamweaver_phase1_depth.json",
|
||||
"workflow_phase2": "Project_Velocity/comfy_engine/workflows/dreamweaver_phase2_multicontrol.json",
|
||||
"workflow_phase3": "Project_Velocity/comfy_engine/workflows/dreamweaver_phase3_batch.json",
|
||||
"batch_size": 8,
|
||||
"target_resolution": (1024, 1024),
|
||||
"enable_mask_cache": True,
|
||||
"gpu_sharding": True,
|
||||
"dual_gpu": True,
|
||||
}
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('dreamweaver_batch.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger('DreamWeaver')
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingJob:
|
||||
"""Represents a single image processing job."""
|
||||
job_id: str
|
||||
input_path: str
|
||||
output_path: str
|
||||
style_template: str
|
||||
phase: int
|
||||
status: str = "pending"
|
||||
created_at: datetime = None
|
||||
started_at: datetime = None
|
||||
completed_at: datetime = None
|
||||
error_message: str = None
|
||||
mask_cached: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.now()
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {
|
||||
"job_id": self.job_id,
|
||||
"input_path": self.input_path,
|
||||
"output_path": self.output_path,
|
||||
"style_template": self.style_template,
|
||||
"phase": self.phase,
|
||||
"status": self.status,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
"error_message": self.error_message,
|
||||
"mask_cached": self.mask_cached
|
||||
}
|
||||
|
||||
|
||||
class MaskCacheManager:
|
||||
"""Manages caching of segmentation masks for improved performance."""
|
||||
|
||||
def __init__(self, cache_dir: str):
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Mask cache initialized at: {self.cache_dir}")
|
||||
|
||||
def _get_cache_key(self, image_path: str) -> str:
|
||||
"""Generate cache key from image content hash."""
|
||||
hasher = hashlib.md5()
|
||||
with open(image_path, 'rb') as f:
|
||||
hasher.update(f.read())
|
||||
return hasher.hexdigest()
|
||||
|
||||
def get_cached_mask(self, image_path: str) -> Optional[str]:
|
||||
"""Retrieve cached mask path if it exists."""
|
||||
cache_key = self._get_cache_key(image_path)
|
||||
cached_path = self.cache_dir / f"{cache_key}.png"
|
||||
|
||||
if cached_path.exists():
|
||||
logger.info(f"Cache hit for {image_path}")
|
||||
return str(cached_path)
|
||||
return None
|
||||
|
||||
def cache_mask(self, image_path: str, mask_path: str) -> str:
|
||||
"""Cache a mask file for future use."""
|
||||
cache_key = self._get_cache_key(image_path)
|
||||
cached_path = self.cache_dir / f"{cache_key}.png"
|
||||
|
||||
import shutil
|
||||
shutil.copy2(mask_path, cached_path)
|
||||
logger.info(f"Cached mask for {image_path} at {cached_path}")
|
||||
return str(cached_path)
|
||||
|
||||
|
||||
class ComfyUIClient:
|
||||
"""Client for communicating with ComfyUI server."""
|
||||
|
||||
def __init__(self, server_url: str, ws_url: str):
|
||||
self.server_url = server_url
|
||||
self.ws_url = ws_url
|
||||
self.client_id = self._generate_client_id()
|
||||
logger.info(f"ComfyUI client initialized with ID: {self.client_id}")
|
||||
|
||||
def _generate_client_id(self) -> str:
|
||||
"""Generate unique client ID."""
|
||||
return f"dreamweaver_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{os.urandom(4).hex()}"
|
||||
|
||||
async def submit_workflow(self, workflow: Dict, input_image: str) -> str:
|
||||
"""Submit a workflow to ComfyUI queue."""
|
||||
# Update workflow with input image
|
||||
for node_id, node in workflow.items():
|
||||
if node.get("class_type") == "LoadImage":
|
||||
node["inputs"]["image"] = input_image
|
||||
if node.get("class_type") == "LoadImageBatch":
|
||||
node["inputs"]["directory"] = os.path.dirname(input_image)
|
||||
|
||||
payload = {
|
||||
"prompt": workflow,
|
||||
"client_id": self.client_id
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{self.server_url}/prompt",
|
||||
json=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
prompt_id = result.get("prompt_id")
|
||||
logger.info(f"Submitted workflow with prompt_id: {prompt_id}")
|
||||
return prompt_id
|
||||
|
||||
async def get_queue_status(self) -> Dict:
|
||||
"""Get current queue status."""
|
||||
response = requests.get(f"{self.server_url}/queue")
|
||||
return response.json()
|
||||
|
||||
async def wait_for_completion(self, prompt_id: str, timeout: int = 300) -> bool:
|
||||
"""Wait for workflow completion via WebSocket."""
|
||||
start_time = time.time()
|
||||
|
||||
async with websockets.connect(
|
||||
f"{self.ws_url}?clientId={self.client_id}"
|
||||
) as websocket:
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
message = await asyncio.wait_for(
|
||||
websocket.recv(),
|
||||
timeout=5.0
|
||||
)
|
||||
data = json.loads(message)
|
||||
|
||||
if data.get("type") == "executing":
|
||||
if data["data"].get("prompt_id") == prompt_id:
|
||||
node_id = data["data"].get("node")
|
||||
logger.debug(f"Executing node: {node_id}")
|
||||
|
||||
elif data.get("type") == "completed":
|
||||
if data["data"].get("prompt_id") == prompt_id:
|
||||
logger.info(f"Workflow {prompt_id} completed")
|
||||
return True
|
||||
|
||||
elif data.get("type") == "error":
|
||||
logger.error(f"Workflow error: {data}")
|
||||
return False
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
logger.warning(f"Workflow {prompt_id} timed out")
|
||||
return False
|
||||
|
||||
|
||||
class BatchProcessor:
|
||||
"""Main batch processing controller."""
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
self.queue: List[ProcessingJob] = []
|
||||
self.processing = False
|
||||
self.cache_manager = MaskCacheManager(config["cache_directory"])
|
||||
self.comfy_client = ComfyUIClient(
|
||||
config["comfyui_server"],
|
||||
config["comfyui_ws"]
|
||||
)
|
||||
|
||||
# Load workflow templates
|
||||
self.workflows = self._load_workflows()
|
||||
|
||||
# Ensure output directory exists
|
||||
Path(config["output_directory"]).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _load_workflows(self) -> Dict[int, Dict]:
|
||||
"""Load workflow JSON files."""
|
||||
workflows = {}
|
||||
workflow_paths = {
|
||||
1: self.config["workflow_phase1"],
|
||||
2: self.config["workflow_phase2"],
|
||||
3: self.config["workflow_phase3"]
|
||||
}
|
||||
|
||||
for phase, path in workflow_paths.items():
|
||||
try:
|
||||
with open(path, 'r') as f:
|
||||
workflows[phase] = json.load(f)
|
||||
logger.info(f"Loaded Phase {phase} workflow")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load Phase {phase} workflow: {e}")
|
||||
|
||||
return workflows
|
||||
|
||||
def add_job(self, input_path: str, style_template: str = "scandinavian_minimalist", phase: int = 1) -> str:
|
||||
"""Add a new processing job to the queue."""
|
||||
job_id = hashlib.md5(f"{input_path}_{time.time()}".encode()).hexdigest()[:12]
|
||||
output_filename = f"{Path(input_path).stem}_restyled_{job_id}.png"
|
||||
output_path = os.path.join(self.config["output_directory"], output_filename)
|
||||
|
||||
job = ProcessingJob(
|
||||
job_id=job_id,
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
style_template=style_template,
|
||||
phase=phase
|
||||
)
|
||||
|
||||
# Check if mask is cached
|
||||
if self.config["enable_mask_cache"]:
|
||||
cached_mask = self.cache_manager.get_cached_mask(input_path)
|
||||
job.mask_cached = cached_mask is not None
|
||||
|
||||
self.queue.append(job)
|
||||
logger.info(f"Added job {job_id} to queue. Queue size: {len(self.queue)}")
|
||||
return job_id
|
||||
|
||||
async def process_single(self, job: ProcessingJob) -> bool:
|
||||
"""Process a single job."""
|
||||
job.status = "processing"
|
||||
job.started_at = datetime.now()
|
||||
|
||||
try:
|
||||
logger.info(f"Processing job {job.job_id}: {job.input_path}")
|
||||
|
||||
# Get workflow for phase
|
||||
workflow = self.workflows.get(job.phase)
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow for phase {job.phase} not found")
|
||||
|
||||
# Submit to ComfyUI
|
||||
prompt_id = await self.comfy_client.submit_workflow(
|
||||
workflow,
|
||||
job.input_path
|
||||
)
|
||||
|
||||
# Wait for completion
|
||||
success = await self.comfy_client.wait_for_completion(prompt_id)
|
||||
|
||||
if success:
|
||||
job.status = "completed"
|
||||
job.completed_at = datetime.now()
|
||||
logger.info(f"Job {job.job_id} completed successfully")
|
||||
return True
|
||||
else:
|
||||
job.status = "failed"
|
||||
job.error_message = "Workflow execution failed or timed out"
|
||||
logger.error(f"Job {job.job_id} failed")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
job.status = "failed"
|
||||
job.error_message = str(e)
|
||||
logger.error(f"Error processing job {job.job_id}: {e}")
|
||||
return False
|
||||
|
||||
async def process_batch(self, jobs: List[ProcessingJob]) -> List[bool]:
|
||||
"""Process multiple jobs in batch (Phase 3)."""
|
||||
if not jobs:
|
||||
return []
|
||||
|
||||
logger.info(f"Processing batch of {len(jobs)} jobs")
|
||||
results = []
|
||||
|
||||
# For batch processing, use Phase 3 workflow
|
||||
workflow = self.workflows.get(3)
|
||||
if not workflow:
|
||||
logger.warning("Phase 3 workflow not available, processing sequentially")
|
||||
for job in jobs:
|
||||
result = await self.process_single(job)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
# TODO: Implement true batch processing with Phase 3 workflow
|
||||
# This would require grouping images and processing together
|
||||
for job in jobs:
|
||||
result = await self.process_single(job)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
async def run(self):
|
||||
"""Main processing loop."""
|
||||
logger.info("Starting batch processor")
|
||||
self.processing = True
|
||||
|
||||
while self.processing:
|
||||
# Get pending jobs
|
||||
pending_jobs = [j for j in self.queue if j.status == "pending"]
|
||||
|
||||
if not pending_jobs:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
# Check if batch processing is appropriate
|
||||
if len(pending_jobs) >= self.config["batch_size"] and self.config.get("dual_gpu"):
|
||||
# Process in batches for Phase 3
|
||||
batch = pending_jobs[:self.config["batch_size"]]
|
||||
await self.process_batch(batch)
|
||||
else:
|
||||
# Process single job with appropriate phase
|
||||
job = pending_jobs[0]
|
||||
await self.process_single(job)
|
||||
|
||||
def stop(self):
|
||||
"""Stop the processing loop."""
|
||||
logger.info("Stopping batch processor")
|
||||
self.processing = False
|
||||
|
||||
def get_status(self) -> Dict:
|
||||
"""Get current processing status."""
|
||||
total = len(self.queue)
|
||||
pending = len([j for j in self.queue if j.status == "pending"])
|
||||
processing = len([j for j in self.queue if j.status == "processing"])
|
||||
completed = len([j for j in self.queue if j.status == "completed"])
|
||||
failed = len([j for j in self.queue if j.status == "failed"])
|
||||
|
||||
return {
|
||||
"total_jobs": total,
|
||||
"pending": pending,
|
||||
"processing": processing,
|
||||
"completed": completed,
|
||||
"failed": failed,
|
||||
"is_running": self.processing
|
||||
}
|
||||
|
||||
|
||||
class InputDirectoryHandler(FileSystemEventHandler):
|
||||
"""Handles new file events in input directory."""
|
||||
|
||||
def __init__(self, processor: BatchProcessor):
|
||||
self.processor = processor
|
||||
|
||||
def on_created(self, event):
|
||||
if not event.is_directory:
|
||||
file_path = event.src_path
|
||||
if file_path.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')):
|
||||
logger.info(f"New image detected: {file_path}")
|
||||
self.processor.add_job(file_path)
|
||||
|
||||
|
||||
def load_style_template(template_name: str) -> str:
|
||||
"""Load a style template from prompts directory."""
|
||||
template_path = Path("Project_Velocity/comfy_engine/prompts/") / f"{template_name}.txt"
|
||||
if template_path.exists():
|
||||
with open(template_path, 'r') as f:
|
||||
content = f.read()
|
||||
# Extract positive prompt
|
||||
lines = content.split('\n')
|
||||
positive_lines = []
|
||||
in_positive = False
|
||||
for line in lines:
|
||||
if 'POSITIVE PROMPT:' in line:
|
||||
in_positive = True
|
||||
continue
|
||||
if in_positive and line.startswith('Style Weight:'):
|
||||
break
|
||||
if in_positive and line.strip() and not line.startswith('-'):
|
||||
positive_lines.append(line.strip())
|
||||
return ' '.join(positive_lines)
|
||||
return ""
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Dream Weaver Batch Processor"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--monitor",
|
||||
action="store_true",
|
||||
help="Enable directory monitoring mode"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
help="Single input image to process"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--style",
|
||||
type=str,
|
||||
default="scandinavian_minimalist",
|
||||
choices=["scandinavian_minimalist", "art_deco_luxe", "cyberpunk_neon", "biophilic_organic", "japandi_fusion"],
|
||||
help="Style template to apply"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--phase",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=[1, 2, 3],
|
||||
help="Processing phase to use"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch",
|
||||
action="store_true",
|
||||
help="Process all images in input directory"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize processor
|
||||
processor = BatchProcessor(CONFIG)
|
||||
|
||||
if args.input:
|
||||
# Process single image
|
||||
job_id = processor.add_job(args.input, args.style, args.phase)
|
||||
await processor.process_single(processor.queue[-1])
|
||||
print(f"Processed image: {args.input}")
|
||||
print(f"Job ID: {job_id}")
|
||||
|
||||
elif args.batch:
|
||||
# Process all images in directory
|
||||
input_dir = Path(CONFIG["input_directory"])
|
||||
image_files = list(input_dir.glob("*.jpg")) + list(input_dir.glob("*.png"))
|
||||
|
||||
for img_file in image_files:
|
||||
processor.add_job(str(img_file), args.style, args.phase)
|
||||
|
||||
await processor.run()
|
||||
|
||||
elif args.monitor:
|
||||
# Start directory monitoring
|
||||
event_handler = InputDirectoryHandler(processor)
|
||||
observer = Observer()
|
||||
observer.schedule(
|
||||
event_handler,
|
||||
CONFIG["input_directory"],
|
||||
recursive=False
|
||||
)
|
||||
observer.start()
|
||||
logger.info(f"Started monitoring: {CONFIG['input_directory']}")
|
||||
|
||||
try:
|
||||
# Run processor
|
||||
await processor.run()
|
||||
except KeyboardInterrupt:
|
||||
processor.stop()
|
||||
observer.stop()
|
||||
|
||||
observer.join()
|
||||
else:
|
||||
print("No action specified. Use --help for usage information.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Dream Weaver Batch Processor
|
||||
============================
|
||||
Automated batch processing script for Dream Weaver interior restyling workflow.
|
||||
Handles directory monitoring, automatic mask caching, and queue management.
|
||||
|
||||
Target Hardware: Dual NVIDIA RTX PRO 6000 Blackwell (96GB GDDR7 each)
|
||||
Author: Project Velocity Team
|
||||
Version: 1.0.0
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import time
|
||||
import hashlib
|
||||
import asyncio
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass, asdict
|
||||
import requests
|
||||
import websockets
|
||||
import aiofiles
|
||||
from watchdog.observers import Observer
|
||||
from watchdog.events import FileSystemEventHandler
|
||||
|
||||
# Configuration
|
||||
CONFIG = {
|
||||
"comfyui_server": "http://localhost:8188",
|
||||
"comfyui_ws": "ws://localhost:8188/ws",
|
||||
"input_directory": "Project_Velocity/comfy_engine/test_inputs/",
|
||||
"output_directory": "Project_Velocity/comfy_engine/test_outputs/",
|
||||
"cache_directory": "Project_Velocity/comfy_engine/cache/masks/",
|
||||
"workflow_phase1": "Project_Velocity/comfy_engine/workflows/dreamweaver_phase1_depth.json",
|
||||
"workflow_phase2": "Project_Velocity/comfy_engine/workflows/dreamweaver_phase2_multicontrol.json",
|
||||
"workflow_phase3": "Project_Velocity/comfy_engine/workflows/dreamweaver_phase3_batch.json",
|
||||
"batch_size": 8,
|
||||
"target_resolution": (1024, 1024),
|
||||
"enable_mask_cache": True,
|
||||
"gpu_sharding": True,
|
||||
"dual_gpu": True,
|
||||
}
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('dreamweaver_batch.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger('DreamWeaver')
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessingJob:
|
||||
"""Represents a single image processing job."""
|
||||
job_id: str
|
||||
input_path: str
|
||||
output_path: str
|
||||
style_template: str
|
||||
phase: int
|
||||
status: str = "pending"
|
||||
created_at: datetime = None
|
||||
started_at: datetime = None
|
||||
completed_at: datetime = None
|
||||
error_message: str = None
|
||||
mask_cached: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.created_at is None:
|
||||
self.created_at = datetime.now()
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return {
|
||||
"job_id": self.job_id,
|
||||
"input_path": self.input_path,
|
||||
"output_path": self.output_path,
|
||||
"style_template": self.style_template,
|
||||
"phase": self.phase,
|
||||
"status": self.status,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"started_at": self.started_at.isoformat() if self.started_at else None,
|
||||
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
||||
"error_message": self.error_message,
|
||||
"mask_cached": self.mask_cached
|
||||
}
|
||||
|
||||
|
||||
class MaskCacheManager:
|
||||
"""Manages caching of segmentation masks for improved performance."""
|
||||
|
||||
def __init__(self, cache_dir: str):
|
||||
self.cache_dir = Path(cache_dir)
|
||||
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Mask cache initialized at: {self.cache_dir}")
|
||||
|
||||
def _get_cache_key(self, image_path: str) -> str:
|
||||
"""Generate cache key from image content hash."""
|
||||
hasher = hashlib.md5()
|
||||
with open(image_path, 'rb') as f:
|
||||
hasher.update(f.read())
|
||||
return hasher.hexdigest()
|
||||
|
||||
def get_cached_mask(self, image_path: str) -> Optional[str]:
|
||||
"""Retrieve cached mask path if it exists."""
|
||||
cache_key = self._get_cache_key(image_path)
|
||||
cached_path = self.cache_dir / f"{cache_key}.png"
|
||||
|
||||
if cached_path.exists():
|
||||
logger.info(f"Cache hit for {image_path}")
|
||||
return str(cached_path)
|
||||
return None
|
||||
|
||||
def cache_mask(self, image_path: str, mask_path: str) -> str:
|
||||
"""Cache a mask file for future use."""
|
||||
cache_key = self._get_cache_key(image_path)
|
||||
cached_path = self.cache_dir / f"{cache_key}.png"
|
||||
|
||||
import shutil
|
||||
shutil.copy2(mask_path, cached_path)
|
||||
logger.info(f"Cached mask for {image_path} at {cached_path}")
|
||||
return str(cached_path)
|
||||
|
||||
|
||||
class ComfyUIClient:
|
||||
"""Client for communicating with ComfyUI server."""
|
||||
|
||||
def __init__(self, server_url: str, ws_url: str):
|
||||
self.server_url = server_url
|
||||
self.ws_url = ws_url
|
||||
self.client_id = self._generate_client_id()
|
||||
logger.info(f"ComfyUI client initialized with ID: {self.client_id}")
|
||||
|
||||
def _generate_client_id(self) -> str:
|
||||
"""Generate unique client ID."""
|
||||
return f"dreamweaver_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{os.urandom(4).hex()}"
|
||||
|
||||
async def submit_workflow(self, workflow: Dict, input_image: str) -> str:
|
||||
"""Submit a workflow to ComfyUI queue."""
|
||||
# Update workflow with input image
|
||||
for node_id, node in workflow.items():
|
||||
if node.get("class_type") == "LoadImage":
|
||||
node["inputs"]["image"] = input_image
|
||||
if node.get("class_type") == "LoadImageBatch":
|
||||
node["inputs"]["directory"] = os.path.dirname(input_image)
|
||||
|
||||
payload = {
|
||||
"prompt": workflow,
|
||||
"client_id": self.client_id
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{self.server_url}/prompt",
|
||||
json=payload
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
prompt_id = result.get("prompt_id")
|
||||
logger.info(f"Submitted workflow with prompt_id: {prompt_id}")
|
||||
return prompt_id
|
||||
|
||||
async def get_queue_status(self) -> Dict:
|
||||
"""Get current queue status."""
|
||||
response = requests.get(f"{self.server_url}/queue")
|
||||
return response.json()
|
||||
|
||||
async def wait_for_completion(self, prompt_id: str, timeout: int = 300) -> bool:
|
||||
"""Wait for workflow completion via WebSocket."""
|
||||
start_time = time.time()
|
||||
|
||||
async with websockets.connect(
|
||||
f"{self.ws_url}?clientId={self.client_id}"
|
||||
) as websocket:
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
message = await asyncio.wait_for(
|
||||
websocket.recv(),
|
||||
timeout=5.0
|
||||
)
|
||||
data = json.loads(message)
|
||||
|
||||
if data.get("type") == "executing":
|
||||
if data["data"].get("prompt_id") == prompt_id:
|
||||
node_id = data["data"].get("node")
|
||||
logger.debug(f"Executing node: {node_id}")
|
||||
|
||||
elif data.get("type") == "completed":
|
||||
if data["data"].get("prompt_id") == prompt_id:
|
||||
logger.info(f"Workflow {prompt_id} completed")
|
||||
return True
|
||||
|
||||
elif data.get("type") == "error":
|
||||
logger.error(f"Workflow error: {data}")
|
||||
return False
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
logger.warning(f"Workflow {prompt_id} timed out")
|
||||
return False
|
||||
|
||||
|
||||
class BatchProcessor:
|
||||
"""Main batch processing controller."""
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
self.queue: List[ProcessingJob] = []
|
||||
self.processing = False
|
||||
self.cache_manager = MaskCacheManager(config["cache_directory"])
|
||||
self.comfy_client = ComfyUIClient(
|
||||
config["comfyui_server"],
|
||||
config["comfyui_ws"]
|
||||
)
|
||||
|
||||
# Load workflow templates
|
||||
self.workflows = self._load_workflows()
|
||||
|
||||
# Ensure output directory exists
|
||||
Path(config["output_directory"]).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _load_workflows(self) -> Dict[int, Dict]:
|
||||
"""Load workflow JSON files."""
|
||||
workflows = {}
|
||||
workflow_paths = {
|
||||
1: self.config["workflow_phase1"],
|
||||
2: self.config["workflow_phase2"],
|
||||
3: self.config["workflow_phase3"]
|
||||
}
|
||||
|
||||
for phase, path in workflow_paths.items():
|
||||
try:
|
||||
with open(path, 'r') as f:
|
||||
workflows[phase] = json.load(f)
|
||||
logger.info(f"Loaded Phase {phase} workflow")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load Phase {phase} workflow: {e}")
|
||||
|
||||
return workflows
|
||||
|
||||
def add_job(self, input_path: str, style_template: str = "scandinavian_minimalist", phase: int = 1) -> str:
|
||||
"""Add a new processing job to the queue."""
|
||||
job_id = hashlib.md5(f"{input_path}_{time.time()}".encode()).hexdigest()[:12]
|
||||
output_filename = f"{Path(input_path).stem}_restyled_{job_id}.png"
|
||||
output_path = os.path.join(self.config["output_directory"], output_filename)
|
||||
|
||||
job = ProcessingJob(
|
||||
job_id=job_id,
|
||||
input_path=input_path,
|
||||
output_path=output_path,
|
||||
style_template=style_template,
|
||||
phase=phase
|
||||
)
|
||||
|
||||
# Check if mask is cached
|
||||
if self.config["enable_mask_cache"]:
|
||||
cached_mask = self.cache_manager.get_cached_mask(input_path)
|
||||
job.mask_cached = cached_mask is not None
|
||||
|
||||
self.queue.append(job)
|
||||
logger.info(f"Added job {job_id} to queue. Queue size: {len(self.queue)}")
|
||||
return job_id
|
||||
|
||||
async def process_single(self, job: ProcessingJob) -> bool:
|
||||
"""Process a single job."""
|
||||
job.status = "processing"
|
||||
job.started_at = datetime.now()
|
||||
|
||||
try:
|
||||
logger.info(f"Processing job {job.job_id}: {job.input_path}")
|
||||
|
||||
# Get workflow for phase
|
||||
workflow = self.workflows.get(job.phase)
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow for phase {job.phase} not found")
|
||||
|
||||
# Submit to ComfyUI
|
||||
prompt_id = await self.comfy_client.submit_workflow(
|
||||
workflow,
|
||||
job.input_path
|
||||
)
|
||||
|
||||
# Wait for completion
|
||||
success = await self.comfy_client.wait_for_completion(prompt_id)
|
||||
|
||||
if success:
|
||||
job.status = "completed"
|
||||
job.completed_at = datetime.now()
|
||||
logger.info(f"Job {job.job_id} completed successfully")
|
||||
return True
|
||||
else:
|
||||
job.status = "failed"
|
||||
job.error_message = "Workflow execution failed or timed out"
|
||||
logger.error(f"Job {job.job_id} failed")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
job.status = "failed"
|
||||
job.error_message = str(e)
|
||||
logger.error(f"Error processing job {job.job_id}: {e}")
|
||||
return False
|
||||
|
||||
async def process_batch(self, jobs: List[ProcessingJob]) -> List[bool]:
|
||||
"""Process multiple jobs in batch (Phase 3)."""
|
||||
if not jobs:
|
||||
return []
|
||||
|
||||
logger.info(f"Processing batch of {len(jobs)} jobs")
|
||||
results = []
|
||||
|
||||
# For batch processing, use Phase 3 workflow
|
||||
workflow = self.workflows.get(3)
|
||||
if not workflow:
|
||||
logger.warning("Phase 3 workflow not available, processing sequentially")
|
||||
for job in jobs:
|
||||
result = await self.process_single(job)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
# TODO: Implement true batch processing with Phase 3 workflow
|
||||
# This would require grouping images and processing together
|
||||
for job in jobs:
|
||||
result = await self.process_single(job)
|
||||
results.append(result)
|
||||
|
||||
return results
|
||||
|
||||
async def run(self):
|
||||
"""Main processing loop."""
|
||||
logger.info("Starting batch processor")
|
||||
self.processing = True
|
||||
|
||||
while self.processing:
|
||||
# Get pending jobs
|
||||
pending_jobs = [j for j in self.queue if j.status == "pending"]
|
||||
|
||||
if not pending_jobs:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
# Check if batch processing is appropriate
|
||||
if len(pending_jobs) >= self.config["batch_size"] and self.config.get("dual_gpu"):
|
||||
# Process in batches for Phase 3
|
||||
batch = pending_jobs[:self.config["batch_size"]]
|
||||
await self.process_batch(batch)
|
||||
else:
|
||||
# Process single job with appropriate phase
|
||||
job = pending_jobs[0]
|
||||
await self.process_single(job)
|
||||
|
||||
def stop(self):
|
||||
"""Stop the processing loop."""
|
||||
logger.info("Stopping batch processor")
|
||||
self.processing = False
|
||||
|
||||
def get_status(self) -> Dict:
|
||||
"""Get current processing status."""
|
||||
total = len(self.queue)
|
||||
pending = len([j for j in self.queue if j.status == "pending"])
|
||||
processing = len([j for j in self.queue if j.status == "processing"])
|
||||
completed = len([j for j in self.queue if j.status == "completed"])
|
||||
failed = len([j for j in self.queue if j.status == "failed"])
|
||||
|
||||
return {
|
||||
"total_jobs": total,
|
||||
"pending": pending,
|
||||
"processing": processing,
|
||||
"completed": completed,
|
||||
"failed": failed,
|
||||
"is_running": self.processing
|
||||
}
|
||||
|
||||
|
||||
class InputDirectoryHandler(FileSystemEventHandler):
|
||||
"""Handles new file events in input directory."""
|
||||
|
||||
def __init__(self, processor: BatchProcessor):
|
||||
self.processor = processor
|
||||
|
||||
def on_created(self, event):
|
||||
if not event.is_directory:
|
||||
file_path = event.src_path
|
||||
if file_path.lower().endswith(('.jpg', '.jpeg', '.png', '.webp')):
|
||||
logger.info(f"New image detected: {file_path}")
|
||||
self.processor.add_job(file_path)
|
||||
|
||||
|
||||
def load_style_template(template_name: str) -> str:
|
||||
"""Load a style template from prompts directory."""
|
||||
template_path = Path("Project_Velocity/comfy_engine/prompts/") / f"{template_name}.txt"
|
||||
if template_path.exists():
|
||||
with open(template_path, 'r') as f:
|
||||
content = f.read()
|
||||
# Extract positive prompt
|
||||
lines = content.split('\n')
|
||||
positive_lines = []
|
||||
in_positive = False
|
||||
for line in lines:
|
||||
if 'POSITIVE PROMPT:' in line:
|
||||
in_positive = True
|
||||
continue
|
||||
if in_positive and line.startswith('Style Weight:'):
|
||||
break
|
||||
if in_positive and line.strip() and not line.startswith('-'):
|
||||
positive_lines.append(line.strip())
|
||||
return ' '.join(positive_lines)
|
||||
return ""
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Dream Weaver Batch Processor"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--monitor",
|
||||
action="store_true",
|
||||
help="Enable directory monitoring mode"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input",
|
||||
type=str,
|
||||
help="Single input image to process"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--style",
|
||||
type=str,
|
||||
default="scandinavian_minimalist",
|
||||
choices=["scandinavian_minimalist", "art_deco_luxe", "cyberpunk_neon", "biophilic_organic", "japandi_fusion"],
|
||||
help="Style template to apply"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--phase",
|
||||
type=int,
|
||||
default=1,
|
||||
choices=[1, 2, 3],
|
||||
help="Processing phase to use"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch",
|
||||
action="store_true",
|
||||
help="Process all images in input directory"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Initialize processor
|
||||
processor = BatchProcessor(CONFIG)
|
||||
|
||||
if args.input:
|
||||
# Process single image
|
||||
job_id = processor.add_job(args.input, args.style, args.phase)
|
||||
await processor.process_single(processor.queue[-1])
|
||||
print(f"Processed image: {args.input}")
|
||||
print(f"Job ID: {job_id}")
|
||||
|
||||
elif args.batch:
|
||||
# Process all images in directory
|
||||
input_dir = Path(CONFIG["input_directory"])
|
||||
image_files = list(input_dir.glob("*.jpg")) + list(input_dir.glob("*.png"))
|
||||
|
||||
for img_file in image_files:
|
||||
processor.add_job(str(img_file), args.style, args.phase)
|
||||
|
||||
await processor.run()
|
||||
|
||||
elif args.monitor:
|
||||
# Start directory monitoring
|
||||
event_handler = InputDirectoryHandler(processor)
|
||||
observer = Observer()
|
||||
observer.schedule(
|
||||
event_handler,
|
||||
CONFIG["input_directory"],
|
||||
recursive=False
|
||||
)
|
||||
observer.start()
|
||||
logger.info(f"Started monitoring: {CONFIG['input_directory']}")
|
||||
|
||||
try:
|
||||
# Run processor
|
||||
await processor.run()
|
||||
except KeyboardInterrupt:
|
||||
processor.stop()
|
||||
observer.stop()
|
||||
|
||||
observer.join()
|
||||
else:
|
||||
print("No action specified. Use --help for usage information.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
||||
Reference in New Issue
Block a user