#!/usr/bin/env python3 from __future__ import annotations import json import os import subprocess import sys from pathlib import Path import boto3 DEFAULT_CHECKPOINTS = { "realvisxlV50_v50LightningBakedvae.safetensors": ( "s3://project-velocity/models/realvisxlV50_v50LightningBakedvae.safetensors" ), } def load_env_file(path: Path) -> dict[str, str]: data: dict[str, str] = {} if not path.exists(): return data for line in path.read_text(encoding="utf-8").splitlines(): line = line.strip() if not line or line.startswith("#") or "=" not in line: continue key, value = line.split("=", 1) data[key.strip()] = value.strip() return data def env(name: str, default: str = "") -> str: return os.environ.get(name, default) def resolve_target_instance(ec2) -> dict | None: explicit_instance_id = env("COMFY_INSTANCE_ID") if explicit_instance_id: reservations = ec2.describe_instances(InstanceIds=[explicit_instance_id])["Reservations"] else: tag_key = env("COMFY_INSTANCE_TAG_KEY", "DesineuronRole") tag_value = env("COMFY_INSTANCE_TAG_VALUE", "comfyui") reservations = ec2.describe_instances( Filters=[ {"Name": "instance-state-name", "Values": ["running"]}, {"Name": f"tag:{tag_key}", "Values": [tag_value]}, ] )["Reservations"] instances = [ instance for reservation in reservations for instance in reservation["Instances"] if instance["State"]["Name"] == "running" ] if not instances: return None instances.sort(key=lambda row: row["LaunchTime"], reverse=True) return instances[0] def parse_checkpoints() -> dict[str, str]: raw = env("COMFY_CHECKPOINTS_JSON") if not raw: return dict(DEFAULT_CHECKPOINTS) parsed = json.loads(raw) if not isinstance(parsed, dict): raise ValueError("COMFY_CHECKPOINTS_JSON must be a JSON object of filename to source URI") return {str(name): str(source) for name, source in parsed.items()} def remote_hydration_script(checkpoints: dict[str, str]) -> str: payload = json.dumps(checkpoints) return f"""#!/usr/bin/env bash set -euo pipefail CHECKPOINT_DIR="${{COMFY_CHECKPOINT_DIR:-/opt/dlami/nvme/ComfyUI/models/checkpoints}}" mkdir -p "$CHECKPOINT_DIR" if ! mountpoint -q /opt/dlami/nvme; then echo "GPU NVMe mount /opt/dlami/nvme is not mounted" >&2 exit 2 fi changed=0 python3 - <<'PY' > /tmp/desineuron-comfy-checkpoints.tsv import json for name, source in json.loads({payload!r}).items(): print(f"{{name}}\\t{{source}}") PY while IFS=$'\\t' read -r filename source; do target="$CHECKPOINT_DIR/$filename" if [ ! -s "$target" ]; then tmp="$target.part" rm -f "$tmp" aws s3 cp "$source" "$tmp" --no-progress mv "$tmp" "$target" chmod 0644 "$target" changed=1 fi done < /tmp/desineuron-comfy-checkpoints.tsv rm -f /tmp/desineuron-comfy-checkpoints.tsv if [ "$changed" = "1" ]; then sudo systemctl restart comfyui fi sleep 3 curl -fsS http://127.0.0.1:8188/models/checkpoints """ def main() -> int: ops_env = load_env_file(Path(env("OPS_ENV_FILE", "/opt/desineuron-ops-control-plane/.env"))) for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_DEFAULT_REGION"]: if key not in os.environ and key in ops_env: os.environ[key] = ops_env[key] os.environ.setdefault("AWS_DEFAULT_REGION", ops_env.get("OPS_DEFAULT_REGION", "us-east-1")) key_path = env( "GPU_SSH_KEY_PATH", ops_env.get("OPS_SSH_KEY_PATH", "/opt/desineuron-ops-control-plane/state/desineuron-l4-node.pem"), ) if key_path.startswith("/app/state/"): key_path = key_path.replace("/app/state/", "/opt/desineuron-ops-control-plane/state/") ssh_user = env("GPU_SSH_USER", "ubuntu") ec2 = boto3.client("ec2", region_name=os.environ["AWS_DEFAULT_REGION"]) instance = resolve_target_instance(ec2) if not instance: print("No running ComfyUI GPU instance found", file=sys.stderr) return 1 target_host = instance.get("PublicIpAddress") or instance.get("PrivateIpAddress") if not target_host: print("Target GPU instance has no reachable IP", file=sys.stderr) return 1 checkpoints = parse_checkpoints() command = [ "sudo", "ssh", "-o", "StrictHostKeyChecking=no", "-o", "ConnectTimeout=15", "-i", key_path, f"{ssh_user}@{target_host}", "bash -s", ] result = subprocess.run( command, input=remote_hydration_script(checkpoints), text=True, capture_output=True, check=False, ) if result.stdout: print(result.stdout.strip()) if result.returncode != 0: if result.stderr: print(result.stderr.strip(), file=sys.stderr) return result.returncode return 0 if __name__ == "__main__": raise SystemExit(main())