forked from sagnik/Project_Velocity
154 lines
5.3 KiB
Python
154 lines
5.3 KiB
Python
"""
|
|
backend/auth/dependencies.py — FastAPI RBAC Dependency Injection
|
|
|
|
Provides:
|
|
- get_current_user: decodes JWT and returns UserPrincipal
|
|
- require_role(min_role): raises HTTP 403 if user role is insufficient
|
|
|
|
Role hierarchy (ascending):
|
|
JUNIOR_BROKER < SENIOR_BROKER < SALES_DIRECTOR < ADMIN
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional
|
|
from dataclasses import dataclass
|
|
|
|
from fastapi import Depends, Header, HTTPException, status
|
|
from jose import JWTError, jwt
|
|
from passlib.context import CryptContext
|
|
|
|
# ── Role hierarchy ────────────────────────────────────────────────────────────
|
|
|
|
ROLE_HIERARCHY = {
|
|
"JUNIOR_BROKER": 0,
|
|
"SENIOR_BROKER": 1,
|
|
"SALES_DIRECTOR": 2,
|
|
"ADMIN": 3,
|
|
}
|
|
|
|
|
|
def default_tenant_id() -> str:
|
|
return os.getenv("VELOCITY_DEFAULT_TENANT_ID", "tenant_velocity").strip() or "tenant_velocity"
|
|
|
|
# ── Password hashing ──────────────────────────────────────────────────────────
|
|
|
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
|
|
|
|
|
def _truncate_bcrypt_input(value: str) -> str:
|
|
raw = value.encode("utf-8")
|
|
if len(raw) <= 72:
|
|
return value
|
|
return raw[:72].decode("utf-8", errors="ignore")
|
|
|
|
|
|
def hash_password(plain: str) -> str:
|
|
return pwd_context.hash(_truncate_bcrypt_input(plain))
|
|
|
|
|
|
def verify_password(plain: str, hashed: str) -> bool:
|
|
return pwd_context.verify(_truncate_bcrypt_input(plain), hashed)
|
|
|
|
|
|
# ── JWT helpers ───────────────────────────────────────────────────────────────
|
|
|
|
# Secret and algorithm retrieved from environment — never hardcoded.
|
|
JWT_SECRET = os.environ["VELOCITY_JWT_SECRET"]
|
|
JWT_ALGORITHM = "HS256"
|
|
JWT_EXPIRE_HOURS = 8
|
|
|
|
|
|
def create_access_token(user_id: str, role: str, tenant_id: Optional[str] = None) -> str:
|
|
expire = datetime.now(timezone.utc) + timedelta(hours=JWT_EXPIRE_HOURS)
|
|
normalized_role = role.strip().upper()
|
|
normalized_tenant = (tenant_id or default_tenant_id()).strip() or default_tenant_id()
|
|
payload = {
|
|
"sub": user_id,
|
|
"role": normalized_role,
|
|
"tenant_id": normalized_tenant,
|
|
"exp": expire,
|
|
"iat": datetime.now(timezone.utc),
|
|
}
|
|
return jwt.encode(payload, JWT_SECRET, algorithm=JWT_ALGORITHM)
|
|
|
|
|
|
# ── UserPrincipal dataclass ───────────────────────────────────────────────────
|
|
|
|
@dataclass
|
|
class UserPrincipal:
|
|
user_id: str
|
|
role: str
|
|
tenant_id: str = default_tenant_id()
|
|
|
|
@property
|
|
def role_level(self) -> int:
|
|
return ROLE_HIERARCHY.get(self.role.upper(), -1)
|
|
|
|
|
|
# ── Dependency: parse bearer token ────────────────────────────────────────────
|
|
|
|
def get_current_user(
|
|
authorization: Optional[str] = Header(default=None),
|
|
) -> UserPrincipal:
|
|
"""
|
|
Extracts and validates a JWT from the Authorization: Bearer <token> header.
|
|
Raises HTTP 401 on missing/invalid token.
|
|
"""
|
|
if not authorization or not authorization.startswith("Bearer "):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Missing or malformed Authorization header.",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
token = authorization.split(" ", 1)[1]
|
|
try:
|
|
payload = jwt.decode(
|
|
token,
|
|
JWT_SECRET,
|
|
algorithms=[JWT_ALGORITHM],
|
|
options={"require": ["sub", "role", "exp"]},
|
|
)
|
|
except JWTError as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=f"Invalid token: {exc}",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
) from exc
|
|
|
|
return UserPrincipal(
|
|
user_id=payload["sub"],
|
|
role=str(payload["role"]).strip().upper(),
|
|
tenant_id=str(payload.get("tenant_id") or default_tenant_id()).strip() or default_tenant_id(),
|
|
)
|
|
|
|
|
|
# ── Dependency factory: role gate ─────────────────────────────────────────────
|
|
|
|
def require_role(minimum_role: str):
|
|
"""
|
|
Returns a FastAPI dependency that raises HTTP 403 if the authenticated
|
|
user's role is below `minimum_role` in the hierarchy.
|
|
|
|
Usage:
|
|
@router.get("/protected")
|
|
async def protected(user: UserPrincipal = Depends(require_role("SENIOR_BROKER"))):
|
|
...
|
|
"""
|
|
min_level = ROLE_HIERARCHY.get(minimum_role)
|
|
if min_level is None:
|
|
raise ValueError(f"Unknown role: {minimum_role}")
|
|
|
|
def _check(user: UserPrincipal = Depends(get_current_user)) -> UserPrincipal:
|
|
if user.role_level < min_level:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=f"Insufficient role. Required: {minimum_role}, current: {user.role}.",
|
|
)
|
|
return user
|
|
|
|
return _check
|