forked from sagnik/Project_Velocity
Merge Conflicts (#41)
Co-authored-by: Sayan Datta <sayan@Sayans-MacBook-Air.local> Reviewed-on: sagnik/Project_Velocity#41
This commit is contained in:
102
backend/migrations/runner.py
Normal file
102
backend/migrations/runner.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
MIGRATIONS_DIR = Path(__file__).resolve().parent / "versions"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Migration:
|
||||
version: str
|
||||
name: str
|
||||
path: Path
|
||||
checksum: str
|
||||
sql: str
|
||||
|
||||
|
||||
def _checksum(sql: str) -> str:
|
||||
return hashlib.sha256(sql.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def discover_migrations(directory: Path = MIGRATIONS_DIR) -> list[Migration]:
|
||||
if not directory.exists():
|
||||
return []
|
||||
|
||||
migrations: list[Migration] = []
|
||||
for path in sorted(directory.glob("*.sql")):
|
||||
version, _, name = path.stem.partition("_")
|
||||
if not version or not name:
|
||||
raise ValueError(f"Invalid migration filename: {path.name}")
|
||||
sql = path.read_text(encoding="utf-8")
|
||||
migrations.append(
|
||||
Migration(
|
||||
version=version,
|
||||
name=name,
|
||||
path=path,
|
||||
checksum=_checksum(sql),
|
||||
sql=sql,
|
||||
)
|
||||
)
|
||||
|
||||
seen: set[str] = set()
|
||||
for migration in migrations:
|
||||
if migration.version in seen:
|
||||
raise ValueError(f"Duplicate migration version: {migration.version}")
|
||||
seen.add(migration.version)
|
||||
return migrations
|
||||
|
||||
|
||||
async def ensure_migration_table(conn) -> None:
|
||||
await conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS schema_migrations (
|
||||
version TEXT PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
checksum TEXT NOT NULL,
|
||||
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
async def applied_versions(conn) -> dict[str, str]:
|
||||
await ensure_migration_table(conn)
|
||||
rows = await conn.fetch("SELECT version, checksum FROM schema_migrations")
|
||||
return {row["version"]: row["checksum"] for row in rows}
|
||||
|
||||
|
||||
async def apply_migrations(conn, migrations: Iterable[Migration] | None = None) -> list[str]:
|
||||
pending = list(migrations if migrations is not None else discover_migrations())
|
||||
applied = await applied_versions(conn)
|
||||
applied_now: list[str] = []
|
||||
|
||||
for migration in pending:
|
||||
existing_checksum = applied.get(migration.version)
|
||||
if existing_checksum == migration.checksum:
|
||||
continue
|
||||
if existing_checksum and existing_checksum != migration.checksum:
|
||||
raise RuntimeError(
|
||||
f"Migration checksum mismatch for {migration.version}; "
|
||||
"create a new migration instead of editing an applied one."
|
||||
)
|
||||
|
||||
transaction = conn.transaction()
|
||||
async with transaction:
|
||||
await conn.execute(migration.sql)
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO schema_migrations (version, name, checksum)
|
||||
VALUES ($1, $2, $3)
|
||||
""",
|
||||
migration.version,
|
||||
migration.name,
|
||||
migration.checksum,
|
||||
)
|
||||
applied_now.append(migration.version)
|
||||
|
||||
return applied_now
|
||||
|
||||
Reference in New Issue
Block a user