from __future__ import annotations import logging import time import uuid from collections import deque from dataclasses import asdict, dataclass from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import Response logger = logging.getLogger("velocity.observability") @dataclass(frozen=True) class RequestMetric: request_id: str method: str path: str status_code: int duration_ms: float class RequestObservabilityMiddleware(BaseHTTPMiddleware): def __init__(self, app, *, max_metrics: int = 500) -> None: super().__init__(app) self.max_metrics = max_metrics async def dispatch(self, request: Request, call_next): request_id = request.headers.get("x-request-id") or str(uuid.uuid4()) request.state.request_id = request_id started = time.perf_counter() status_code = 500 try: response = await call_next(request) status_code = response.status_code return self._finalize(request, response, request_id, started, status_code) except Exception: duration_ms = (time.perf_counter() - started) * 1000 self._record_metric(request, request_id, status_code, duration_ms) logger.exception( "request_failed", extra={ "request_id": request_id, "method": request.method, "path": request.url.path, "duration_ms": round(duration_ms, 2), }, ) raise def _finalize( self, request: Request, response: Response, request_id: str, started: float, status_code: int, ) -> Response: duration_ms = (time.perf_counter() - started) * 1000 response.headers["X-Request-ID"] = request_id response.headers["X-Response-Time-Ms"] = f"{duration_ms:.2f}" self._record_metric(request, request_id, status_code, duration_ms) logger.info( "request_completed", extra={ "request_id": request_id, "method": request.method, "path": request.url.path, "status_code": status_code, "duration_ms": round(duration_ms, 2), }, ) return response def _record_metric( self, request: Request, request_id: str, status_code: int, duration_ms: float, ) -> None: metrics = getattr(request.app.state, "request_metrics", None) if metrics is None: metrics = deque(maxlen=self.max_metrics) request.app.state.request_metrics = metrics metrics.append( RequestMetric( request_id=request_id, method=request.method, path=request.url.path, status_code=status_code, duration_ms=round(duration_ms, 2), ) ) def metrics_snapshot(app, *, limit: int = 50) -> list[dict]: metrics = getattr(app.state, "request_metrics", deque()) return [asdict(metric) for metric in list(metrics)[-limit:]][::-1]