#7 Task completed. Co-authored-by: Sayan Datta <sayan@Sayans-MacBook-Air.local> Reviewed-on: #8
389 lines
13 KiB
Python
389 lines
13 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Dream Weaver Mask Preprocessor
|
|
==============================
|
|
Utility script for preprocessing and caching segmentation masks.
|
|
Enables offline mask generation to speed up production workflows.
|
|
|
|
Target Hardware: Dual NVIDIA RTX PRO 6000 Blackwell
|
|
Author: Project Velocity Team
|
|
Version: 1.0.0
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import json
|
|
import hashlib
|
|
import argparse
|
|
import logging
|
|
from pathlib import Path
|
|
from typing import List, Optional, Tuple, Dict
|
|
from dataclasses import dataclass
|
|
import numpy as np
|
|
from PIL import Image
|
|
import cv2
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
)
|
|
logger = logging.getLogger('MaskPreprocessor')
|
|
|
|
|
|
@dataclass
|
|
class MaskConfig:
|
|
"""Configuration for mask generation."""
|
|
grow_pixels: int = 3
|
|
feather_pixels: int = 5
|
|
threshold: float = 0.3
|
|
target_classes: List[str] = None
|
|
|
|
def __post_init__(self):
|
|
if self.target_classes is None:
|
|
self.target_classes = ["wall", "floor", "ceiling"]
|
|
|
|
|
|
class MaskPreprocessor:
|
|
"""Preprocesses and caches segmentation masks for Dream Weaver."""
|
|
|
|
def __init__(self, cache_dir: str = "Project_Velocity/comfy_engine/cache/masks/"):
|
|
self.cache_dir = Path(cache_dir)
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
self.config = MaskConfig()
|
|
logger.info(f"MaskPreprocessor initialized. Cache directory: {self.cache_dir}")
|
|
|
|
def _get_image_hash(self, image_path: str) -> str:
|
|
"""Generate MD5 hash of image for caching."""
|
|
hasher = hashlib.md5()
|
|
with open(image_path, 'rb') as f:
|
|
hasher.update(f.read())
|
|
return hasher.hexdigest()
|
|
|
|
def _get_cache_path(self, image_path: str, suffix: str = "") -> Path:
|
|
"""Generate cache file path for an image."""
|
|
image_hash = self._get_image_hash(image_path)
|
|
filename = f"{image_hash}{suffix}.png"
|
|
return self.cache_dir / filename
|
|
|
|
def is_cached(self, image_path: str, suffix: str = "") -> bool:
|
|
"""Check if a mask is already cached."""
|
|
cache_path = self._get_cache_path(image_path, suffix)
|
|
return cache_path.exists()
|
|
|
|
def load_from_cache(self, image_path: str, suffix: str = "") -> Optional[np.ndarray]:
|
|
"""Load mask from cache if available."""
|
|
cache_path = self._get_cache_path(image_path, suffix)
|
|
if cache_path.exists():
|
|
logger.info(f"Loading cached mask from {cache_path}")
|
|
mask = cv2.imread(str(cache_path), cv2.IMREAD_GRAYSCALE)
|
|
return mask
|
|
return None
|
|
|
|
def save_to_cache(self, image_path: str, mask: np.ndarray, suffix: str = "") -> str:
|
|
"""Save mask to cache."""
|
|
cache_path = self._get_cache_path(image_path, suffix)
|
|
cv2.imwrite(str(cache_path), mask)
|
|
logger.info(f"Saved mask to cache: {cache_path}")
|
|
return str(cache_path)
|
|
|
|
def create_structural_mask(self, image_path: str, mask_data: np.ndarray) -> np.ndarray:
|
|
"""
|
|
Create a structural preservation mask from segmentation data.
|
|
This mask identifies walls, floors, ceilings that must be preserved.
|
|
"""
|
|
# Ensure binary mask
|
|
if len(mask_data.shape) == 3:
|
|
mask_data = cv2.cvtColor(mask_data, cv2.COLOR_BGR2GRAY)
|
|
|
|
_, binary_mask = cv2.threshold(
|
|
mask_data,
|
|
int(255 * self.config.threshold),
|
|
255,
|
|
cv2.THRESH_BINARY
|
|
)
|
|
|
|
return binary_mask.astype(np.uint8)
|
|
|
|
def grow_mask(self, mask: np.ndarray, pixels: int = None) -> np.ndarray:
|
|
"""
|
|
Grow (dilate) the mask by specified pixels.
|
|
This prevents edge bleeding by expanding the mask slightly.
|
|
"""
|
|
if pixels is None:
|
|
pixels = self.config.grow_pixels
|
|
|
|
kernel = np.ones((pixels * 2 + 1, pixels * 2 + 1), np.uint8)
|
|
grown_mask = cv2.dilate(mask, kernel, iterations=1)
|
|
return grown_mask
|
|
|
|
def feather_mask(self, mask: np.ndarray, pixels: int = None) -> np.ndarray:
|
|
"""
|
|
Apply Gaussian blur to feather mask edges.
|
|
Creates smooth transitions at boundaries.
|
|
"""
|
|
if pixels is None:
|
|
pixels = self.config.feather_pixels
|
|
|
|
# Ensure odd kernel size
|
|
kernel_size = pixels * 2 + 1
|
|
feathered = cv2.GaussianBlur(mask, (kernel_size, kernel_size), 0)
|
|
return feathered
|
|
|
|
def invert_mask(self, mask: np.ndarray) -> np.ndarray:
|
|
"""Invert mask (structural -> stylable or vice versa)."""
|
|
return cv2.bitwise_not(mask)
|
|
|
|
def combine_masks(self, masks: List[np.ndarray], operation: str = "union") -> np.ndarray:
|
|
"""
|
|
Combine multiple masks.
|
|
operation: 'union' (OR), 'intersection' (AND), 'difference'
|
|
"""
|
|
if not masks:
|
|
return None
|
|
|
|
result = masks[0].copy()
|
|
|
|
for mask in masks[1:]:
|
|
if operation == "union":
|
|
result = cv2.bitwise_or(result, mask)
|
|
elif operation == "intersection":
|
|
result = cv2.bitwise_and(result, mask)
|
|
elif operation == "difference":
|
|
result = cv2.bitwise_and(result, cv2.bitwise_not(mask))
|
|
|
|
return result
|
|
|
|
def create_multi_region_mask(
|
|
self,
|
|
image_path: str,
|
|
regions: Dict[str, np.ndarray]
|
|
) -> Dict[str, np.ndarray]:
|
|
"""
|
|
Create masks for multiple regions (walls, floor, ceiling, etc.)
|
|
Returns dictionary of processed masks.
|
|
"""
|
|
processed_masks = {}
|
|
|
|
for region_name, mask_data in regions.items():
|
|
logger.info(f"Processing mask for region: {region_name}")
|
|
|
|
# Create base mask
|
|
base_mask = self.create_structural_mask(image_path, mask_data)
|
|
|
|
# Grow mask to prevent edge bleeding
|
|
grown_mask = self.grow_mask(base_mask)
|
|
|
|
# Feather edges
|
|
feathered_mask = self.feather_mask(grown_mask)
|
|
|
|
# Cache the processed mask
|
|
cache_path = self.save_to_cache(
|
|
image_path,
|
|
feathered_mask,
|
|
suffix=f"_{region_name}"
|
|
)
|
|
|
|
processed_masks[region_name] = {
|
|
"mask": feathered_mask,
|
|
"cache_path": cache_path
|
|
}
|
|
|
|
# Create combined structural mask
|
|
all_structural = [m["mask"] for m in processed_masks.values()]
|
|
combined_structural = self.combine_masks(all_structural, operation="union")
|
|
|
|
# Create stylable mask (inverse of structural)
|
|
stylable_mask = self.invert_mask(combined_structural)
|
|
|
|
# Save combined masks
|
|
structural_cache = self.save_to_cache(
|
|
image_path,
|
|
combined_structural,
|
|
suffix="_structural"
|
|
)
|
|
stylable_cache = self.save_to_cache(
|
|
image_path,
|
|
stylable_mask,
|
|
suffix="_stylable"
|
|
)
|
|
|
|
processed_masks["combined_structural"] = {
|
|
"mask": combined_structural,
|
|
"cache_path": structural_cache
|
|
}
|
|
processed_masks["stylable"] = {
|
|
"mask": stylable_mask,
|
|
"cache_path": stylable_cache
|
|
}
|
|
|
|
return processed_masks
|
|
|
|
def preprocess_image(self, image_path: str) -> Dict:
|
|
"""
|
|
Complete preprocessing pipeline for a single image.
|
|
Returns metadata about generated masks.
|
|
"""
|
|
logger.info(f"Preprocessing image: {image_path}")
|
|
|
|
# Check if already cached
|
|
if self.is_cached(image_path, "_structural"):
|
|
logger.info(f"Image already preprocessed: {image_path}")
|
|
return {
|
|
"image_path": image_path,
|
|
"cached": True,
|
|
"masks": {
|
|
"structural": str(self._get_cache_path(image_path, "_structural")),
|
|
"stylable": str(self._get_cache_path(image_path, "_stylable"))
|
|
}
|
|
}
|
|
|
|
# Load image for reference
|
|
img = cv2.imread(image_path)
|
|
if img is None:
|
|
raise ValueError(f"Could not load image: {image_path}")
|
|
|
|
height, width = img.shape[:2]
|
|
|
|
# Create placeholder masks (in production, these would come from SAM)
|
|
# This simulates wall, floor, ceiling segmentation
|
|
regions = {}
|
|
|
|
# Wall mask (upper portion)
|
|
wall_mask = np.zeros((height, width), dtype=np.uint8)
|
|
wall_mask[0:int(height*0.6), :] = 255
|
|
regions["wall"] = wall_mask
|
|
|
|
# Floor mask (lower portion)
|
|
floor_mask = np.zeros((height, width), dtype=np.uint8)
|
|
floor_mask[int(height*0.6):, :] = 255
|
|
regions["floor"] = floor_mask
|
|
|
|
# Ceiling mask (top portion)
|
|
ceiling_mask = np.zeros((height, width), dtype=np.uint8)
|
|
ceiling_mask[0:int(height*0.15), :] = 255
|
|
regions["ceiling"] = ceiling_mask
|
|
|
|
# Process all regions
|
|
processed = self.create_multi_region_mask(image_path, regions)
|
|
|
|
return {
|
|
"image_path": image_path,
|
|
"cached": False,
|
|
"dimensions": (width, height),
|
|
"masks": {
|
|
name: data["cache_path"]
|
|
for name, data in processed.items()
|
|
}
|
|
}
|
|
|
|
def batch_preprocess(self, directory: str, pattern: str = "*.jpg") -> List[Dict]:
|
|
"""Preprocess all images in a directory."""
|
|
input_dir = Path(directory)
|
|
image_files = list(input_dir.glob(pattern))
|
|
image_files.extend(list(input_dir.glob("*.png")))
|
|
|
|
results = []
|
|
for img_file in image_files:
|
|
try:
|
|
result = self.preprocess_image(str(img_file))
|
|
results.append(result)
|
|
except Exception as e:
|
|
logger.error(f"Failed to preprocess {img_file}: {e}")
|
|
|
|
return results
|
|
|
|
def clear_cache(self):
|
|
"""Clear all cached masks."""
|
|
for cache_file in self.cache_dir.glob("*.png"):
|
|
cache_file.unlink()
|
|
logger.info("Cache cleared")
|
|
|
|
def get_cache_stats(self) -> Dict:
|
|
"""Get cache statistics."""
|
|
cache_files = list(self.cache_dir.glob("*.png"))
|
|
total_size = sum(f.stat().st_size for f in cache_files)
|
|
|
|
return {
|
|
"cached_files": len(cache_files),
|
|
"total_size_mb": total_size / (1024 * 1024),
|
|
"cache_directory": str(self.cache_dir)
|
|
}
|
|
|
|
|
|
def main():
|
|
"""Main entry point for command-line usage."""
|
|
parser = argparse.ArgumentParser(
|
|
description="Dream Weaver Mask Preprocessor"
|
|
)
|
|
parser.add_argument(
|
|
"--image",
|
|
type=str,
|
|
help="Single image to preprocess"
|
|
)
|
|
parser.add_argument(
|
|
"--directory",
|
|
type=str,
|
|
help="Directory of images to preprocess"
|
|
)
|
|
parser.add_argument(
|
|
"--cache-dir",
|
|
type=str,
|
|
default="Project_Velocity/comfy_engine/cache/masks/",
|
|
help="Cache directory for masks"
|
|
)
|
|
parser.add_argument(
|
|
"--grow",
|
|
type=int,
|
|
default=3,
|
|
help="Pixels to grow mask (dilation)"
|
|
)
|
|
parser.add_argument(
|
|
"--feather",
|
|
type=int,
|
|
default=5,
|
|
help="Pixels to feather mask edges"
|
|
)
|
|
parser.add_argument(
|
|
"--clear-cache",
|
|
action="store_true",
|
|
help="Clear all cached masks"
|
|
)
|
|
parser.add_argument(
|
|
"--stats",
|
|
action="store_true",
|
|
help="Show cache statistics"
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Initialize preprocessor
|
|
preprocessor = MaskPreprocessor(cache_dir=args.cache_dir)
|
|
preprocessor.config.grow_pixels = args.grow
|
|
preprocessor.config.feather_pixels = args.feather
|
|
|
|
if args.clear_cache:
|
|
preprocessor.clear_cache()
|
|
return
|
|
|
|
if args.stats:
|
|
stats = preprocessor.get_cache_stats()
|
|
print(json.dumps(stats, indent=2))
|
|
return
|
|
|
|
if args.image:
|
|
result = preprocessor.preprocess_image(args.image)
|
|
print(json.dumps(result, indent=2))
|
|
|
|
elif args.directory:
|
|
results = preprocessor.batch_preprocess(args.directory)
|
|
print(json.dumps(results, indent=2))
|
|
print(f"\nProcessed {len(results)} images")
|
|
|
|
else:
|
|
print("No action specified. Use --help for usage information.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|