Files
Project_Velocity/comfy_engine/scripts/mask_preprocessor.py
2026-03-21 17:01:06 +05:30

389 lines
13 KiB
Python

#!/usr/bin/env python3
"""
Dream Weaver Mask Preprocessor
==============================
Utility script for preprocessing and caching segmentation masks.
Enables offline mask generation to speed up production workflows.
Target Hardware: Dual NVIDIA RTX PRO 6000 Blackwell
Author: Project Velocity Team
Version: 1.0.0
"""
import os
import sys
import json
import hashlib
import argparse
import logging
from pathlib import Path
from typing import List, Optional, Tuple, Dict
from dataclasses import dataclass
import numpy as np
from PIL import Image
import cv2
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('MaskPreprocessor')
@dataclass
class MaskConfig:
"""Configuration for mask generation."""
grow_pixels: int = 3
feather_pixels: int = 5
threshold: float = 0.3
target_classes: List[str] = None
def __post_init__(self):
if self.target_classes is None:
self.target_classes = ["wall", "floor", "ceiling"]
class MaskPreprocessor:
"""Preprocesses and caches segmentation masks for Dream Weaver."""
def __init__(self, cache_dir: str = "Project_Velocity/comfy_engine/cache/masks/"):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.config = MaskConfig()
logger.info(f"MaskPreprocessor initialized. Cache directory: {self.cache_dir}")
def _get_image_hash(self, image_path: str) -> str:
"""Generate MD5 hash of image for caching."""
hasher = hashlib.md5()
with open(image_path, 'rb') as f:
hasher.update(f.read())
return hasher.hexdigest()
def _get_cache_path(self, image_path: str, suffix: str = "") -> Path:
"""Generate cache file path for an image."""
image_hash = self._get_image_hash(image_path)
filename = f"{image_hash}{suffix}.png"
return self.cache_dir / filename
def is_cached(self, image_path: str, suffix: str = "") -> bool:
"""Check if a mask is already cached."""
cache_path = self._get_cache_path(image_path, suffix)
return cache_path.exists()
def load_from_cache(self, image_path: str, suffix: str = "") -> Optional[np.ndarray]:
"""Load mask from cache if available."""
cache_path = self._get_cache_path(image_path, suffix)
if cache_path.exists():
logger.info(f"Loading cached mask from {cache_path}")
mask = cv2.imread(str(cache_path), cv2.IMREAD_GRAYSCALE)
return mask
return None
def save_to_cache(self, image_path: str, mask: np.ndarray, suffix: str = "") -> str:
"""Save mask to cache."""
cache_path = self._get_cache_path(image_path, suffix)
cv2.imwrite(str(cache_path), mask)
logger.info(f"Saved mask to cache: {cache_path}")
return str(cache_path)
def create_structural_mask(self, image_path: str, mask_data: np.ndarray) -> np.ndarray:
"""
Create a structural preservation mask from segmentation data.
This mask identifies walls, floors, ceilings that must be preserved.
"""
# Ensure binary mask
if len(mask_data.shape) == 3:
mask_data = cv2.cvtColor(mask_data, cv2.COLOR_BGR2GRAY)
_, binary_mask = cv2.threshold(
mask_data,
int(255 * self.config.threshold),
255,
cv2.THRESH_BINARY
)
return binary_mask.astype(np.uint8)
def grow_mask(self, mask: np.ndarray, pixels: int = None) -> np.ndarray:
"""
Grow (dilate) the mask by specified pixels.
This prevents edge bleeding by expanding the mask slightly.
"""
if pixels is None:
pixels = self.config.grow_pixels
kernel = np.ones((pixels * 2 + 1, pixels * 2 + 1), np.uint8)
grown_mask = cv2.dilate(mask, kernel, iterations=1)
return grown_mask
def feather_mask(self, mask: np.ndarray, pixels: int = None) -> np.ndarray:
"""
Apply Gaussian blur to feather mask edges.
Creates smooth transitions at boundaries.
"""
if pixels is None:
pixels = self.config.feather_pixels
# Ensure odd kernel size
kernel_size = pixels * 2 + 1
feathered = cv2.GaussianBlur(mask, (kernel_size, kernel_size), 0)
return feathered
def invert_mask(self, mask: np.ndarray) -> np.ndarray:
"""Invert mask (structural -> stylable or vice versa)."""
return cv2.bitwise_not(mask)
def combine_masks(self, masks: List[np.ndarray], operation: str = "union") -> np.ndarray:
"""
Combine multiple masks.
operation: 'union' (OR), 'intersection' (AND), 'difference'
"""
if not masks:
return None
result = masks[0].copy()
for mask in masks[1:]:
if operation == "union":
result = cv2.bitwise_or(result, mask)
elif operation == "intersection":
result = cv2.bitwise_and(result, mask)
elif operation == "difference":
result = cv2.bitwise_and(result, cv2.bitwise_not(mask))
return result
def create_multi_region_mask(
self,
image_path: str,
regions: Dict[str, np.ndarray]
) -> Dict[str, np.ndarray]:
"""
Create masks for multiple regions (walls, floor, ceiling, etc.)
Returns dictionary of processed masks.
"""
processed_masks = {}
for region_name, mask_data in regions.items():
logger.info(f"Processing mask for region: {region_name}")
# Create base mask
base_mask = self.create_structural_mask(image_path, mask_data)
# Grow mask to prevent edge bleeding
grown_mask = self.grow_mask(base_mask)
# Feather edges
feathered_mask = self.feather_mask(grown_mask)
# Cache the processed mask
cache_path = self.save_to_cache(
image_path,
feathered_mask,
suffix=f"_{region_name}"
)
processed_masks[region_name] = {
"mask": feathered_mask,
"cache_path": cache_path
}
# Create combined structural mask
all_structural = [m["mask"] for m in processed_masks.values()]
combined_structural = self.combine_masks(all_structural, operation="union")
# Create stylable mask (inverse of structural)
stylable_mask = self.invert_mask(combined_structural)
# Save combined masks
structural_cache = self.save_to_cache(
image_path,
combined_structural,
suffix="_structural"
)
stylable_cache = self.save_to_cache(
image_path,
stylable_mask,
suffix="_stylable"
)
processed_masks["combined_structural"] = {
"mask": combined_structural,
"cache_path": structural_cache
}
processed_masks["stylable"] = {
"mask": stylable_mask,
"cache_path": stylable_cache
}
return processed_masks
def preprocess_image(self, image_path: str) -> Dict:
"""
Complete preprocessing pipeline for a single image.
Returns metadata about generated masks.
"""
logger.info(f"Preprocessing image: {image_path}")
# Check if already cached
if self.is_cached(image_path, "_structural"):
logger.info(f"Image already preprocessed: {image_path}")
return {
"image_path": image_path,
"cached": True,
"masks": {
"structural": str(self._get_cache_path(image_path, "_structural")),
"stylable": str(self._get_cache_path(image_path, "_stylable"))
}
}
# Load image for reference
img = cv2.imread(image_path)
if img is None:
raise ValueError(f"Could not load image: {image_path}")
height, width = img.shape[:2]
# Create placeholder masks (in production, these would come from SAM)
# This simulates wall, floor, ceiling segmentation
regions = {}
# Wall mask (upper portion)
wall_mask = np.zeros((height, width), dtype=np.uint8)
wall_mask[0:int(height*0.6), :] = 255
regions["wall"] = wall_mask
# Floor mask (lower portion)
floor_mask = np.zeros((height, width), dtype=np.uint8)
floor_mask[int(height*0.6):, :] = 255
regions["floor"] = floor_mask
# Ceiling mask (top portion)
ceiling_mask = np.zeros((height, width), dtype=np.uint8)
ceiling_mask[0:int(height*0.15), :] = 255
regions["ceiling"] = ceiling_mask
# Process all regions
processed = self.create_multi_region_mask(image_path, regions)
return {
"image_path": image_path,
"cached": False,
"dimensions": (width, height),
"masks": {
name: data["cache_path"]
for name, data in processed.items()
}
}
def batch_preprocess(self, directory: str, pattern: str = "*.jpg") -> List[Dict]:
"""Preprocess all images in a directory."""
input_dir = Path(directory)
image_files = list(input_dir.glob(pattern))
image_files.extend(list(input_dir.glob("*.png")))
results = []
for img_file in image_files:
try:
result = self.preprocess_image(str(img_file))
results.append(result)
except Exception as e:
logger.error(f"Failed to preprocess {img_file}: {e}")
return results
def clear_cache(self):
"""Clear all cached masks."""
for cache_file in self.cache_dir.glob("*.png"):
cache_file.unlink()
logger.info("Cache cleared")
def get_cache_stats(self) -> Dict:
"""Get cache statistics."""
cache_files = list(self.cache_dir.glob("*.png"))
total_size = sum(f.stat().st_size for f in cache_files)
return {
"cached_files": len(cache_files),
"total_size_mb": total_size / (1024 * 1024),
"cache_directory": str(self.cache_dir)
}
def main():
"""Main entry point for command-line usage."""
parser = argparse.ArgumentParser(
description="Dream Weaver Mask Preprocessor"
)
parser.add_argument(
"--image",
type=str,
help="Single image to preprocess"
)
parser.add_argument(
"--directory",
type=str,
help="Directory of images to preprocess"
)
parser.add_argument(
"--cache-dir",
type=str,
default="Project_Velocity/comfy_engine/cache/masks/",
help="Cache directory for masks"
)
parser.add_argument(
"--grow",
type=int,
default=3,
help="Pixels to grow mask (dilation)"
)
parser.add_argument(
"--feather",
type=int,
default=5,
help="Pixels to feather mask edges"
)
parser.add_argument(
"--clear-cache",
action="store_true",
help="Clear all cached masks"
)
parser.add_argument(
"--stats",
action="store_true",
help="Show cache statistics"
)
args = parser.parse_args()
# Initialize preprocessor
preprocessor = MaskPreprocessor(cache_dir=args.cache_dir)
preprocessor.config.grow_pixels = args.grow
preprocessor.config.feather_pixels = args.feather
if args.clear_cache:
preprocessor.clear_cache()
return
if args.stats:
stats = preprocessor.get_cache_stats()
print(json.dumps(stats, indent=2))
return
if args.image:
result = preprocessor.preprocess_image(args.image)
print(json.dumps(result, indent=2))
elif args.directory:
results = preprocessor.batch_preprocess(args.directory)
print(json.dumps(results, indent=2))
print(f"\nProcessed {len(results)} images")
else:
print("No action specified. Use --help for usage information.")
if __name__ == "__main__":
main()