All checks were successful
Run linters on applied template / Python 3.13 lint and build (push) Successful in 1m0s
This is a FastAPI backend microservice template used with `copier` utility. Features of applied template are: - Configuration file processing logic - Metrics and tracing (both optional) configuration available - Debug endpoints - Database migration commands, prepared Alembic environment - Database usage example in ping_db endpoint - gitea sanity check pipeline
178 lines
6.7 KiB
Django/Jinja
178 lines
6.7 KiB
Django/Jinja
"""Observability middleware is defined here."""
|
|
|
|
import time
|
|
from random import randint
|
|
|
|
import structlog
|
|
from fastapi import FastAPI, HTTPException, Request, Response
|
|
from opentelemetry import context as tracing_context
|
|
from opentelemetry import trace
|
|
from opentelemetry.semconv.attributes import exception_attributes, http_attributes, url_attributes
|
|
from opentelemetry.trace import NonRecordingSpan, Span, SpanContext, TraceFlags
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from {{project_slug}}.dependencies import logger_dep
|
|
from {{project_slug}}.exceptions.mapper import ExceptionMapper
|
|
from {{project_slug}}.observability.metrics import Metrics
|
|
from {{project_slug}}.utils.observability import URLsMapper, get_handler_from_path
|
|
|
|
_tracer = trace.get_tracer_provider().get_tracer(__name__)
|
|
|
|
|
|
class ObservableException(RuntimeError):
|
|
"""Runtime Error with `trace_id` and `span_id` set. Guranteed to have `.__cause__` as its parent exception."""
|
|
|
|
def __init__(self, trace_id: str, span_id: int):
|
|
super().__init__()
|
|
self.trace_id = trace_id
|
|
self.span_id = span_id
|
|
|
|
|
|
class ObservabilityMiddleware(BaseHTTPMiddleware): # pylint: disable=too-few-public-methods
|
|
"""Middleware for global observability requests.
|
|
|
|
- Generate tracing span and adds response header 'X-Trace-Id' and X-Span-Id'
|
|
- Binds trace_id it to logger passing it in request state (`request.state.logger`)
|
|
- Collects metrics for Prometheus
|
|
|
|
In case when jaeger is not enabled, trace_id and span_id are generated randomly.
|
|
"""
|
|
|
|
def __init__(self, app: FastAPI, exception_mapper: ExceptionMapper, metrics: Metrics, urls_mapper: URLsMapper):
|
|
super().__init__(app)
|
|
self._exception_mapper = exception_mapper
|
|
self._metrics = metrics
|
|
self._urls_mapper = urls_mapper
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
logger = logger_dep.obtain(request)
|
|
_try_get_parent_span_id(request)
|
|
with _tracer.start_as_current_span("http-request") as span:
|
|
trace_id = hex(span.get_span_context().trace_id or randint(1, 1 << 63))[2:]
|
|
span_id = span.get_span_context().span_id or randint(1, 1 << 32)
|
|
span.set_attributes(
|
|
{
|
|
http_attributes.HTTP_REQUEST_METHOD: request.method,
|
|
url_attributes.URL_PATH: request.url.path,
|
|
url_attributes.URL_QUERY: str(request.query_params),
|
|
"request_client": request.client.host,
|
|
}
|
|
)
|
|
logger = logger.bind(trace_id=trace_id, span_id=span_id)
|
|
request.state.logger = logger
|
|
|
|
await logger.ainfo(
|
|
"handling request",
|
|
client=request.client.host,
|
|
path_params=request.path_params,
|
|
method=request.method,
|
|
url=str(request.url),
|
|
)
|
|
|
|
path_for_metric = self._urls_mapper.map(request.url.path)
|
|
self._metrics.requests_started.add(1, {"method": request.method, "path": path_for_metric})
|
|
|
|
time_begin = time.monotonic()
|
|
try:
|
|
result = await call_next(request)
|
|
duration_seconds = time.monotonic() - time_begin
|
|
|
|
result.headers.update({"X-Trace-Id": trace_id, "X-Span-Id": str(span_id)})
|
|
await self._handle_success(
|
|
request=request,
|
|
result=result,
|
|
logger=logger,
|
|
span=span,
|
|
path_for_metric=path_for_metric,
|
|
duration_seconds=duration_seconds,
|
|
)
|
|
return result
|
|
except Exception as exc:
|
|
duration_seconds = time.monotonic() - time_begin
|
|
await self._handle_exception(
|
|
request=request, exc=exc, logger=logger, span=span, duration_seconds=duration_seconds
|
|
)
|
|
raise ObservableException(trace_id=trace_id, span_id=span_id) from exc
|
|
finally:
|
|
self._metrics.request_processing_duration.record(
|
|
duration_seconds, {"method": request.method, "path": path_for_metric}
|
|
)
|
|
|
|
async def _handle_success( # pylint: disable=too-many-arguments
|
|
self,
|
|
*,
|
|
request: Request,
|
|
result: Response,
|
|
logger: structlog.stdlib.BoundLogger,
|
|
span: Span,
|
|
path_for_metric: str,
|
|
duration_seconds: float,
|
|
) -> None:
|
|
await logger.ainfo("request handled successfully", time_consumed=round(duration_seconds, 3))
|
|
self._metrics.requests_finished.add(
|
|
1, {"method": request.method, "path": path_for_metric, "status_code": result.status_code}
|
|
)
|
|
|
|
span.set_attribute(http_attributes.HTTP_RESPONSE_STATUS_CODE, result.status_code)
|
|
|
|
async def _handle_exception( # pylint: disable=too-many-arguments
|
|
self,
|
|
*,
|
|
request: Request,
|
|
exc: Exception,
|
|
logger: structlog.stdlib.BoundLogger,
|
|
span: Span,
|
|
duration_seconds: float,
|
|
) -> None:
|
|
|
|
cause = exc
|
|
status_code = 500
|
|
if isinstance(exc, HTTPException):
|
|
status_code = getattr(exc, "status_code")
|
|
if exc.__cause__ is not None:
|
|
cause = exc.__cause__
|
|
|
|
self._metrics.errors.add(
|
|
1,
|
|
{
|
|
"method": request.method,
|
|
"path": get_handler_from_path(request.url.path),
|
|
"error_type": type(cause).__name__,
|
|
"status_code": status_code,
|
|
},
|
|
)
|
|
|
|
span.record_exception(exc)
|
|
if self._exception_mapper.is_known(exc):
|
|
log_func = logger.aerror
|
|
else:
|
|
log_func = logger.aexception
|
|
await log_func(
|
|
"failed to handle request", time_consumed=round(duration_seconds, 3), error_type=type(exc).__name__
|
|
)
|
|
|
|
span.set_attributes(
|
|
{
|
|
exception_attributes.EXCEPTION_TYPE: type(exc).__name__,
|
|
exception_attributes.EXCEPTION_MESSAGE: repr(exc),
|
|
http_attributes.HTTP_RESPONSE_STATUS_CODE: status_code,
|
|
}
|
|
)
|
|
|
|
|
|
def _try_get_parent_span_id(request: Request) -> None:
|
|
trace_id_str = request.headers.get("X-Trace-Id")
|
|
span_id_str = request.headers.get("X-Span-Id")
|
|
|
|
if trace_id_str is None or span_id_str is None:
|
|
return
|
|
|
|
if not trace_id_str.isnumeric() or not span_id_str.isnumeric():
|
|
return
|
|
|
|
span_context = SpanContext(
|
|
trace_id=int(trace_id_str), span_id=int(span_id_str), is_remote=True, trace_flags=TraceFlags(0x01)
|
|
)
|
|
ctx = trace.set_span_in_context(NonRecordingSpan(span_context))
|
|
tracing_context.attach(ctx)
|