"""Connection manager class and get_connection function are defined here.""" from asyncio import Lock from contextlib import asynccontextmanager from itertools import cycle from typing import Any, AsyncIterator import structlog from sqlalchemy import select, text from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine from {{project_slug}}.db.config import DBConfig class PostgresConnectionManager: # pylint: disable=too-many-instance-attributes """Connection manager for PostgreSQL database""" def __init__( # pylint: disable=too-many-arguments self, master: DBConfig, replicas: list[DBConfig] | None, logger: structlog.stdlib.BoundLogger, *, engine_options: dict[str, Any] | None = None, application_name: str | None = None, ) -> None: """Initialize connection manager entity.""" self._master_engine: AsyncEngine | None = None self._replica_engines: list[AsyncEngine] = [] self._master = master self._replicas = replicas or [] self._lock = Lock() self._logger = logger self._engine_options = engine_options or {} self._application_name = application_name # Iterator for round-robin through replicas self._replica_cycle = None async def update( # pylint: disable=too-many-arguments self, *, master: DBConfig | None = None, replicas: list[DBConfig] | None = None, logger: structlog.stdlib.BoundLogger | None = None, application_name: str | None = None, engine_options: dict[str, Any] | None = None, ) -> None: """Update connection manager parameters and refresh connection.""" self._master = master or self._master self._replicas = replicas or self._replicas self._logger = logger or self._logger self._application_name = application_name or self._application_name self._engine_options = engine_options or self._engine_options if self.initialized: await self.refresh() @property def initialized(self) -> bool: return self._master_engine is not None async def refresh(self, no_force_refresh: bool = False) -> None: """Initialize or reinitialize connection engine. Params: no_force_refresh (bool): if set to True and ConnectionManager is already initialized, no refresh is performed """ async with self._lock: if no_force_refresh and self.initialized: return await self.shutdown(use_lock=False) await self._logger.ainfo( "creating postgres master connection pool", max_size=self._master.pool_size, user=self._master.user, host=self._master.host, port=self._master.port, database=self._master.database, ) self._master_engine = create_async_engine( f"postgresql+asyncpg://{self._master.user}:{self._master.password.get_secret_value()}" f"@{self._master.host}:{self._master.port}/{self._master.database}", future=True, pool_size=max(1, self._master.pool_size - 5), max_overflow=min(self._master.pool_size - 1, 5), **self._engine_options, ) try: async with self._master_engine.connect() as conn: cur = await conn.execute(select(text("1"))) assert cur.fetchone()[0] == 1 except Exception as exc: self._master_engine = None raise RuntimeError("something wrong with database connection, aborting") from exc if len(self._replicas) > 0: for replica in self._replicas: await self._logger.ainfo( "creating postgres readonly connection pool", max_size=replica.pool_size, user=replica.user, host=replica.host, port=replica.port, database=replica.database, ) replica_engine = create_async_engine( f"postgresql+asyncpg://{replica.user}:{replica.password.get_secret_value()}@" f"{replica.host}:{replica.port}/{replica.database}", future=True, pool_size=max(1, self._master.pool_size - 5), max_overflow=min(self._master.pool_size - 1, 5), **self._engine_options, ) try: async with replica_engine.connect() as conn: cur = await conn.execute(select(1)) assert cur.fetchone()[0] == 1 self._replica_engines.append(replica_engine) except Exception as exc: # pylint: disable=broad-except await replica_engine.dispose() await self._logger.aexception("error connecting to replica", host=replica.host, error=repr(exc)) if self._replica_engines: self._replica_cycle = cycle(self._replica_engines) else: self._replica_cycle = None await self._logger.awarning("no available replicas, read queries will go to the master") async def shutdown(self, use_lock: bool = True) -> None: """Dispose connection pool and deinitialize. Can be called multiple times.""" if use_lock: async with self._lock: await self.shutdown(use_lock=False) return if self.initialized: self._logger.info("shutting down postgres connection engine") await self._master_engine.dispose() self._master_engine = None for engine in self._replica_engines: await engine.dispose() self._replica_engines.clear() @asynccontextmanager async def get_connection(self) -> AsyncIterator[AsyncConnection]: """Get an async connection to the database with read-write ability.""" if not self.initialized: await self.refresh(no_force_refresh=True) async with self._master_engine.connect() as conn: if self._application_name is not None: await conn.execute(text(f'SET application_name TO "{self._application_name}"')) await conn.commit() yield conn @asynccontextmanager async def get_ro_connection(self) -> AsyncIterator[AsyncConnection]: """Get an async connection to the database which can be read-only and will attempt to use replica instances of the database.""" if not self.initialized: await self.refresh(no_force_refresh=True) # If there are no replicas, use master if self._replica_cycle is None: async with self.get_connection() as conn: yield conn return # Select the next replica (round-robin), `self._replica_cycle` is guaranteed to have values here engine = next(self._replica_cycle) # pylint: disable=stop-iteration-return conn = None try: conn = await engine.connect() if self._application_name is not None: await conn.execute(text(f'SET application_name TO "{self._application_name}"')) await conn.commit() except Exception as exc: # pylint: disable=broad-except if conn is not None: try: conn.close() except Exception: # pylint: disable=broad-except pass await self._logger.awarning( "error connecting to replica, falling back to master", error=repr(exc), error_type=type(exc).__name__ ) # On exception from replica fallback to master connection async with self.get_connection() as conn: yield conn return try: yield conn finally: await conn.close()