forked from sagnik/Velocity-OS
103 lines
3.0 KiB
Python
103 lines
3.0 KiB
Python
from __future__ import annotations
|
|
|
|
import hashlib
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Iterable
|
|
|
|
|
|
MIGRATIONS_DIR = Path(__file__).resolve().parent / "versions"
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class Migration:
|
|
version: str
|
|
name: str
|
|
path: Path
|
|
checksum: str
|
|
sql: str
|
|
|
|
|
|
def _checksum(sql: str) -> str:
|
|
return hashlib.sha256(sql.encode("utf-8")).hexdigest()
|
|
|
|
|
|
def discover_migrations(directory: Path = MIGRATIONS_DIR) -> list[Migration]:
|
|
if not directory.exists():
|
|
return []
|
|
|
|
migrations: list[Migration] = []
|
|
for path in sorted(directory.glob("*.sql")):
|
|
version, _, name = path.stem.partition("_")
|
|
if not version or not name:
|
|
raise ValueError(f"Invalid migration filename: {path.name}")
|
|
sql = path.read_text(encoding="utf-8")
|
|
migrations.append(
|
|
Migration(
|
|
version=version,
|
|
name=name,
|
|
path=path,
|
|
checksum=_checksum(sql),
|
|
sql=sql,
|
|
)
|
|
)
|
|
|
|
seen: set[str] = set()
|
|
for migration in migrations:
|
|
if migration.version in seen:
|
|
raise ValueError(f"Duplicate migration version: {migration.version}")
|
|
seen.add(migration.version)
|
|
return migrations
|
|
|
|
|
|
async def ensure_migration_table(conn) -> None:
|
|
await conn.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
version TEXT PRIMARY KEY,
|
|
name TEXT NOT NULL,
|
|
checksum TEXT NOT NULL,
|
|
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
|
)
|
|
"""
|
|
)
|
|
|
|
|
|
async def applied_versions(conn) -> dict[str, str]:
|
|
await ensure_migration_table(conn)
|
|
rows = await conn.fetch("SELECT version, checksum FROM schema_migrations")
|
|
return {row["version"]: row["checksum"] for row in rows}
|
|
|
|
|
|
async def apply_migrations(conn, migrations: Iterable[Migration] | None = None) -> list[str]:
|
|
pending = list(migrations if migrations is not None else discover_migrations())
|
|
applied = await applied_versions(conn)
|
|
applied_now: list[str] = []
|
|
|
|
for migration in pending:
|
|
existing_checksum = applied.get(migration.version)
|
|
if existing_checksum == migration.checksum:
|
|
continue
|
|
if existing_checksum and existing_checksum != migration.checksum:
|
|
raise RuntimeError(
|
|
f"Migration checksum mismatch for {migration.version}; "
|
|
"create a new migration instead of editing an applied one."
|
|
)
|
|
|
|
transaction = conn.transaction()
|
|
async with transaction:
|
|
await conn.execute(migration.sql)
|
|
await conn.execute(
|
|
"""
|
|
INSERT INTO schema_migrations (version, name, checksum)
|
|
VALUES ($1, $2, $3)
|
|
""",
|
|
migration.version,
|
|
migration.name,
|
|
migration.checksum,
|
|
)
|
|
applied_now.append(migration.version)
|
|
|
|
return applied_now
|
|
|