Files
template-fastapi/{{project_slug}}/middlewares/observability.py.jinja
Aleksei Sokol 53f14a8624
All checks were successful
Run linters on applied template / Python 3.13 lint and build (push) Successful in 1m40s
Version 0.4.0
Changes:
- put ObservabilityMiddleware before ExceptionHandlerMiddleware to avoid repetative code
- add application startup and last metrics update metrics along with CPU usage metric and threads count
- move host and port to new uvicorn section at config along with new reload and forwarded_allow_ips
- add request_id and remove trace_id/span_id generation if tracing is disabled
- move logging logic from utils to observability
- pass trace_id/span_id in HEX form
2026-01-03 16:29:58 +03:00

109 lines
4.2 KiB
Django/Jinja

"""Observability middleware is defined here."""
import time
import uuid
from fastapi import FastAPI, Request
from opentelemetry import context as tracing_context
from opentelemetry import trace
from opentelemetry.semconv.attributes import http_attributes, url_attributes
from opentelemetry.trace import NonRecordingSpan, SpanContext, TraceFlags
from starlette.middleware.base import BaseHTTPMiddleware
from {{project_slug}}.dependencies import logger_dep
from {{project_slug}}.observability.metrics import Metrics
from {{project_slug}}.observability.utils import URLsMapper, get_tracing_headers
_tracer = trace.get_tracer_provider().get_tracer(__name__)
class ObservabilityMiddleware(BaseHTTPMiddleware): # pylint: disable=too-few-public-methods
"""Middleware for global observability requests.
- Generate tracing span and adds response headers
'X-Trace-Id', 'X-Span-Id' (if tracing is configured) and 'X-Request-Id'
- Binds trace_id it to logger passing it in request state (`request.state.logger`)
- Collects metrics for Prometheus
"""
def __init__(self, app: FastAPI, metrics: Metrics, urls_mapper: URLsMapper):
super().__init__(app)
self._http_metrics = metrics.http
self._urls_mapper = urls_mapper
async def dispatch(self, request: Request, call_next):
logger = logger_dep.from_request(request)
_try_get_parent_span_id(request)
with _tracer.start_as_current_span("http request") as span:
request_id = str(uuid.uuid4())
logger = logger.bind(request_id=request_id)
logger_dep.attach_to_request(request, logger)
span.set_attributes(
{
http_attributes.HTTP_REQUEST_METHOD: request.method,
url_attributes.URL_PATH: request.url.path,
url_attributes.URL_QUERY: str(request.query_params),
"request_client": request.client.host,
"request_id": request_id,
}
)
await logger.ainfo(
"http begin",
client=request.client.host,
path_params=request.path_params,
method=request.method,
url=str(request.url),
)
path_for_metric = self._urls_mapper.map(request.method, request.url.path)
self._http_metrics.requests_started.add(1, {"method": request.method, "path": path_for_metric})
self._http_metrics.inflight_requests.add(1)
time_begin = time.monotonic()
result = await call_next(request)
duration_seconds = time.monotonic() - time_begin
result.headers.update({"X-Request-Id": request_id} | get_tracing_headers())
await logger.ainfo("http end", time_consumed=round(duration_seconds, 3), status_code=result.status_code)
self._http_metrics.requests_finished.add(
1,
{
http_attributes.HTTP_REQUEST_METHOD: request.method,
url_attributes.URL_PATH: path_for_metric,
http_attributes.HTTP_RESPONSE_STATUS_CODE: result.status_code,
},
)
self._http_metrics.inflight_requests.add(-1)
if result.status_code // 100 == 2:
span.set_status(trace.StatusCode.OK)
span.set_attribute(http_attributes.HTTP_RESPONSE_STATUS_CODE, result.status_code)
self._http_metrics.request_processing_duration.record(
duration_seconds, {"method": request.method, "path": path_for_metric}
)
return result
def _try_get_parent_span_id(request: Request) -> None:
trace_id_str = request.headers.get("X-Trace-Id")
span_id_str = request.headers.get("X-Span-Id")
if trace_id_str is None or span_id_str is None:
return
if not trace_id_str.isalnum() or not span_id_str.isalnum():
return
try:
span_context = SpanContext(
trace_id=int(trace_id_str, 16), span_id=int(span_id_str, 16), is_remote=True, trace_flags=TraceFlags(0x01)
)
except Exception: # pylint: disable=broad-except
return
ctx = trace.set_span_in_context(NonRecordingSpan(span_context))
tracing_context.attach(ctx)