from threading import Condition from datetime import datetime from typing import Callable, Literal from abc import ABC, ABCMeta, abstractmethod from contextlib import contextmanager from traceback import StackSummary, extract_stack from api.miscellaneous import DEBUG_ENABLED, DIAGNOSTICS_TRACKER, DiagnosticsCounter DbValueType = int | str | bytes | datetime _TransactionType = Literal["read", "write"] def _create_diagnostics_counters(id: str) -> dict[_TransactionType, DiagnosticsCounter]: return { "read": DIAGNOSTICS_TRACKER.register_counter(f"database.transaction.read.{id}"), "write": DIAGNOSTICS_TRACKER.register_counter(f"database.transaction.write.{id}") } _COUNTERS_TRANSACTION_COUNT = _create_diagnostics_counters("count") _COUNTERS_TRANSACTION_ERRORS = _create_diagnostics_counters("errors") _COUNTERS_TRANSACTION_ATTEMPTED_RECONNECTS = _create_diagnostics_counters("attempted_reconnects") _COUNTERS_TRANSACTION_SUCCESSFUL_RECONNECTS = _create_diagnostics_counters("successful_reconnects") _COUNTERS_TRANSACTION_ERRORS_AFTER_RECONNECT = _create_diagnostics_counters("errors_after_reconnect") _COUNTERS_TRANSACTION_ABORTED_BY_USER = _create_diagnostics_counters("aborted_by_user") class DatabaseResultRow: def __init__(self, column_mapping: dict[str, int], values: tuple[DbValueType, ...]): super().__init__() self._column_mapping = column_mapping self._values = values def __contains__(self, item: str or int): if isinstance(item, str): return item in self._column_mapping assert isinstance(item, int) return 0 <= item < len(self._values) def __getitem__(self, key: str or int) -> DbValueType: if isinstance(key, str): return self._values[self._column_mapping[key]] assert isinstance(key, int) return self._values[key] def __str__(self): return self._values.__str__() def __repr__(self): return self.__str__() class DatabaseWarning: def __init__(self, message: str): super().__init__() self._message = message DbResultRow = DatabaseResultRow DbResultSet = list[DbResultRow] DbAffectedRows = int DbResult = tuple[Exception or None, list[DatabaseWarning], DbResultSet, DbAffectedRows] DB_RESULT_EXCEPTION = 0 DB_RESULT_WARNINGS = 1 DB_RESULT_SET = 2 DB_RESULT_AFFECTED_ROWS = 3 SQL_PARAMETER_INDICATOR = "?" class DatabaseError(Exception): """Error for any problems of the database not directly caused by the caller, e.g. connection problems, etc.""" class WarningError(Exception): def __init__(self, warnings: list[DatabaseWarning]): super().__init__(str(warnings)) self.warnings = warnings class NoAvailableConnectionError(Exception): """Indicates that there are currently no connections available""" class PreparedStatement: def __init__(self, statement: str): super().__init__() self.statement = statement self.parameter_count = statement.count(SQL_PARAMETER_INDICATOR) def __str__(self): return self.statement class FilledStatement: def __init__(self, statement: PreparedStatement or str, values: list[DbValueType]): super().__init__() self.statement = statement self.values = values self.queue_traceback: StackSummary or None = None class DbConnection(ABC): __metaclass__ = ABCMeta def __init__(self): super().__init__() @abstractmethod def is_disconnected(self, force_ping: bool = False) -> bool: """ Returns True if this connection is definitely closed. False indicates that it is probably still intact. This may ping the database. :param force_ping: If this is true the database is guaranteed to be pinged """ pass # pragma: no cover def try_reconnect(self) -> bool: """ Tries to reconnect to the database. Note that any ongoing transaction is rolled back :return: True if reconnection was successful """ pass # pragma: no cover @abstractmethod def close(self): """Closes this connection, freeing all resources""" pass # pragma: no cover @abstractmethod def get_transaction_begin_statement(self, writable: bool) -> PreparedStatement or str: """ Returns the statement to begin a transaction :param writable: Specifies whether this transaction is writable. If :func:`supports_per_transaction_writeable_flag` of the :class:`DbConnectionFactory` is False this parameter is ignored. """ pass # pragma: no cover @abstractmethod def get_transaction_end_statement(self, commit: bool) -> PreparedStatement or str: """Returns the statement to end a transaction""" pass # pragma: no cover @abstractmethod def execute_statements(self, statements: [FilledStatement]) -> list[DbResult]: """ Executes the given statements and returns their results. May raise :class:`DatabaseError` """ pass # pragma: no cover @abstractmethod def execute_script(self, script: str): """ See execute_script of :class:`DbConnectionPool` """ pass # pragma: no cover class DbConnectionFactory(ABC): __metaclass__ = ABCMeta def __init__(self): super().__init__() @abstractmethod def supports_per_transaction_writeable_flag(self) -> bool: """ True if a single connection can support read-only and read-write transactions. If False a connection may either be read-only or read-write for its full lifetime """ pass # pragma: no cover @abstractmethod def new_connection(self, writable: bool = True) -> DbConnection: """ Creates a new connection :param writable: Ignored if :func:`supports_per_transaction_writeable_flag` is True. Otherwise, this specifies whether the database can be modified with this connection :return: A new connection May raise :class:`DatabaseError` """ pass # pragma: no cover class AbstractTransaction(ABC): __metaclass__ = ABCMeta def __init__(self, type: _TransactionType, release_connection: Callable, connection: DbConnection): super().__init__() self._type = type self._release_connection = release_connection self._closed = False self._connection = connection self._has_executed_statements = False self._has_encountered_errors = False self._queued_statements: list[FilledStatement] = [] self._statement_results: dict[FilledStatement, DbResult] = {} _COUNTERS_TRANSACTION_COUNT[type].trigger() def is_closed(self) -> bool: return self._closed def queue_statement(self, statement: PreparedStatement or str, *values: DbValueType) -> FilledStatement: """ Queues the specified statement to be executed. The result can later be queried with :func:`get_result`. Note that the statements are always executed in the order in which they were queued. :return: An object with which the result can be retrieved later Throws :class:`ValueError` if the parameter count (count of ?) does not match the length of values Throws :class:`RuntimeError` if the transaction was already closed (committed/rolled back) """ if self._closed: raise RuntimeError("Transaction already closed") # pragma: no cover parameter_count = statement.parameter_count\ if isinstance(statement, PreparedStatement)\ else statement.count(SQL_PARAMETER_INDICATOR) if parameter_count != len(values): raise ValueError(f"Parameter ({parameter_count}) and argument count ({len(values)}) don't match") # pragma: no cover queued = FilledStatement(statement, list(values)) if DEBUG_ENABLED: queued.queue_traceback = extract_stack() self._queued_statements.append(queued) return queued def get_result(self, stat: FilledStatement, ignore_warnings: bool = False) -> DbResultSet: """ Does the same as :func:`get_full_result` with default raise_exception but only returns the result set """ return self.get_full_result(stat, ignore_warnings=ignore_warnings)[DB_RESULT_SET] def get_full_result(self, stat: FilledStatement, raise_exception: bool = True, ignore_warnings: bool = False) -> DbResult: """ Returns the result of the specified statement. If the statement has not been executed yet, the statement and possibly other queued statements are executed. :param stat: The filled statement returned by :func:`queue_statement` :param raise_exception: If True any exception generated by the statement execution will be raised by this method. Note that if the statement was executed in a multi-query, the exception may have been caused by another statement :param ignore_warnings: If True an exception will be raised if there are any warnings. Note that if the statement was executed in a multi-query, the warnings may have been caused by another statement :return: The result. You may use :func:`get_result` if you are only interested in the result set Note that depending on the underlying connection, warnings may not be fetched. May raise :class:`DatabaseError` (e.g. connection was permanently lost) (regardless of raise_exception) This causes the transaction to be closed (rolled back) """ if stat not in self._statement_results: self._execute_queued() if stat not in self._statement_results: # pragma: no cover raise ValueError("Statement was not queued or result was already queried") result = self._statement_results[stat] del self._statement_results[stat] if raise_exception and result[DB_RESULT_EXCEPTION] is not None: # pragma: no cover raise RuntimeError("An exception had occurred while executing statement " "(Or other statement in multi-query)") from result[DB_RESULT_EXCEPTION] if not ignore_warnings and len(result[DB_RESULT_WARNINGS]) > 0: raise WarningError(result[DB_RESULT_WARNINGS]) return result def execute_statement(self, statement: PreparedStatement or str, *values: DbValueType, ignore_warnings: bool = False) -> DbResultSet: """ Queues the statement and immediately retrieves the result. See :func:`queue_statement` and :func:`get_result` """ return self.get_result(self.queue_statement(statement, *values), ignore_warnings) def execute_statement_full_result(self, statement: PreparedStatement or str, *values: DbValueType, raise_exception: bool = True, ignore_warnings: bool = False) -> DbResult: """ Queues the statement and immediately retrieves the result. See :func:`queue_statement` and :func:`get_full_result` """ return self.get_full_result(self.queue_statement(statement, *values), raise_exception, ignore_warnings) def _execute_statements(self, *statements: tuple[PreparedStatement | str, list[DbValueType]], after_queuing: Callable or None) -> tuple[DbResultSet, ...]: filled_list = [] for stat in statements: filled_list.append(self.queue_statement(stat[0], *stat[1])) if after_queuing is not None: after_queuing() results: list[DbResultSet] = [] for filled in filled_list: results.append(self.get_result(filled)) return tuple(results) def on_error_close_transaction(self): # noinspection PyBroadException try: self._connection.execute_statements( [FilledStatement(self._connection.get_transaction_end_statement(commit=False), [])]) except Exception: pass self._closed = True self._queued_statements = [] self._release_connection() def _execute_queued(self): if not self._has_queued_statements(): return try: results = self._connection.execute_statements(self._queued_statements) except DatabaseError as e: if not self._has_encountered_errors: _COUNTERS_TRANSACTION_ERRORS[self._type].trigger() self._has_encountered_errors = True if self._has_executed_statements or not self._connection.is_disconnected(force_ping=True): self.on_error_close_transaction() raise e # All triggers after here can be triggered at most once for a single transaction because after this code block # we have either executed statements or closed the transaction _COUNTERS_TRANSACTION_ATTEMPTED_RECONNECTS[self._type].trigger() if not self._connection.try_reconnect(): self.on_error_close_transaction() raise e _COUNTERS_TRANSACTION_SUCCESSFUL_RECONNECTS[self._type].trigger() try: results = self._connection.execute_statements(self._queued_statements) except Exception: _COUNTERS_TRANSACTION_ERRORS_AFTER_RECONNECT[self._type].trigger() self.on_error_close_transaction() raise e # Raise original exception except Exception as e: self.on_error_close_transaction() raise e self._has_executed_statements = True for result, stat in zip(results, self._queued_statements): self._statement_results[stat] = result self._queued_statements = [] def _remove_all_statements(self): self._queued_statements = [] self._statement_results = {} def _has_queued_statements(self): return len(self._queued_statements) > 0 def _close(self): if self._closed: return if self._has_queued_statements(): raise RuntimeError("Subclass must commit/rollback transaction") # pragma: no cover self._closed = True self._release_connection() class ReadTransaction(AbstractTransaction): def __init__(self, release_connection: Callable, connection: DbConnection): super().__init__("read", release_connection, connection) self.queue_statement(connection.get_transaction_begin_statement(False)) def execute_statement_and_close(self, statement: PreparedStatement or str, *values: DbValueType, ignore_warnings: bool = False) -> DbResultSet: """ Queues the statement, closes the transaction and returns the result. See :func:`queue_statement`, :func:`close` and :func:`get_result` """ filled = self.queue_statement(statement, *values) self.close() return self.get_result(filled, ignore_warnings) def execute_statement_full_result_and_close(self, statement: PreparedStatement or str, *values: DbValueType, raise_exception: bool = True, ignore_warnings: bool = False) -> DbResult: """ Queues the statement, closes the transaction and returns the result. See :func:`queue_statement`, :func:`close` and :func:`get_full_result` """ filled = self.queue_statement(statement, *values) self.close() return self.get_full_result(filled, raise_exception, ignore_warnings) def execute_statements(self, *statements: tuple[PreparedStatement | str, list[DbValueType]]) -> tuple[DbResultSet, ...]: """ Queues all the statements and immediately retrieves all the results. See :func:`queue_statement` and :func:`get_result` """ return self._execute_statements(*statements, after_queuing=None) def execute_statements_and_close( self, *statements: tuple[PreparedStatement | str, list[DbValueType]]) -> tuple[DbResultSet, ...]: """ Queues all the statements, closes the transaction and returns all the results. See :func:`queue_statement`, :func:`close` and :func:`get_result` """ return self._execute_statements(*statements, after_queuing=self.close) def close(self, ignore_unused_statements=False): """ Executes all queued statements (if ignore_unused_statements is False) and then closes this transaction. After this method is called, no new statements may be queued. May raise :class:`DatabaseError` (e.g. connection was permanently lost) This causes the transaction to be closed (:func:`close`) """ if self._closed: return if ignore_unused_statements: self._remove_all_statements() self.queue_statement(self._connection.get_transaction_end_statement(False)) self._execute_queued() self._close() class WriteTransaction(AbstractTransaction): def __init__(self, release_connection: Callable, connection: DbConnection): super().__init__("write", release_connection, connection) self._committed = False self.queue_statement(connection.get_transaction_begin_statement(True)) def execute_statement_and_commit(self, statement: PreparedStatement or str, *values: DbValueType, ignore_warnings: bool = False) -> DbResultSet: """ Queues the statement, commits and returns the result. See :func:`queue_statement`, :func:`commit` and :func:`get_result` """ filled = self.queue_statement(statement, *values) self.commit() return self.get_result(filled, ignore_warnings) def execute_statement_full_result_and_commit(self, statement: PreparedStatement or str, *values: DbValueType, raise_exception: bool = True, ignore_warnings: bool = False) -> DbResult: """ Queues the statement, commits and returns the result. See :func:`queue_statement`, :func:`commit` and :func:`get_full_result` """ filled = self.queue_statement(statement, *values) self.commit() return self.get_full_result(filled, raise_exception, ignore_warnings) def execute_statements(self, *statements: tuple[PreparedStatement | str, list[DbValueType]]) -> tuple[DbResultSet, ...]: """ Queues all the statements and immediately retrieves all the results. See :func:`queue_statement` and :func:`get_result` """ return self._execute_statements(*statements, after_queuing=None) def execute_statements_and_commit( self, *statements: tuple[PreparedStatement | str, list[DbValueType]]) -> tuple[DbResultSet, ...]: """ Queues all the statements, commits and returns all the results. See :func:`queue_statement`, :func:`close` and :func:`get_result` """ return self._execute_statements(*statements, after_queuing=self.commit) def commit(self, ignore_statement_exceptions: bool = False): """ Executes all queued statements and then commits this transaction. After this method is called, no new statements may be queued. :param ignore_statement_exceptions If False and a statement exception occurs/has occurred in any results which have not been retrieved yet, the transaction will be rolled back and the exception is raised by this method Raises :class:`RuntimeError` if the transaction was already rolled back May raise :class:`DatabaseError` (e.g. connection was permanently lost). This causes the transaction to be rolled back (:func:`rollback`) """ if self._closed: if not self._committed: raise RuntimeError("Transaction was already rolled back") return self._committed = True if not ignore_statement_exceptions: self._execute_queued() for result in self._statement_results.values(): if result[0] is not None: self.on_error_close_transaction() print("Not committing due to exception during statement execution") raise result[0] self.queue_statement(self._connection.get_transaction_end_statement(True)) self._execute_queued() self._close() def rollback(self): """ Rolls back this transaction. All queued statements are discarded After this method is called, no new statements may be queued. Raises :class:`RuntimeError` if the transaction was already committed May raise :class:`DatabaseError` (e.g. connection was permanently lost). """ if self._closed: if self._committed: raise RuntimeError("Transaction was already committed") return self._remove_all_statements() self.queue_statement(self._connection.get_transaction_end_statement(False)) self._execute_queued() self._close() class _ConnectionCache: def __init__(self, max_wait_time_sec: float, max_waiting_count: int, max_count: int, factory: Callable[[], DbConnection]): super().__init__() self._max_wait_time_sec = max_wait_time_sec self._max_waiting_count = max_waiting_count self._max_count = max_count self._factory = factory self._all_connections: list[DbConnection] = [] self._available_connections: list[DbConnection] = [] self._available_condition: Condition = Condition() self._currently_waiting_count: int = 0 def _try_get_connection(self) -> DbConnection or None: while len(self._available_connections) > 0: connection = self._available_connections.pop() if not connection.is_disconnected(): return connection self._all_connections.remove(connection) connection.close() if len(self._all_connections) >= self._max_count: return None connection = self._factory() self._all_connections.append(connection) return connection def get_connection(self) -> DbConnection: with self._available_condition: connection = self._try_get_connection() if connection is not None: return connection if self._currently_waiting_count >= self._max_waiting_count: raise NoAvailableConnectionError("Too many are already waiting") self._currently_waiting_count += 1 if not self._available_condition.wait(timeout=self._max_wait_time_sec): self._currently_waiting_count -= 1 raise NoAvailableConnectionError("Timed out") self._currently_waiting_count -= 1 connection = self._try_get_connection() # We were notified, so there is at least one slot to be reused or filled with a new connection assert connection is not None return connection def return_connection(self, connection: DbConnection): with self._available_condition: if connection.is_disconnected(): self._all_connections.remove(connection) connection.close() else: self._available_connections.append(connection) self._available_condition.notify() class DbConnectionPool: def __init__(self, factory: DbConnectionFactory, max_connections: int, readonly_percent: float, max_wait_time_sec: float, max_waiting_count: int): super().__init__() self._factory = factory if self._factory.supports_per_transaction_writeable_flag(): if max_connections < 0: raise ValueError("There must be at least one connection allowed") # pragma: no cover self._cache = _ConnectionCache( max_wait_time_sec, max_waiting_count, max_connections, lambda: self._factory.new_connection()) else: readonly_count = round(max_connections * readonly_percent) writeable_count = max_connections - readonly_count if readonly_count <= 0 or writeable_count <= 0: raise ValueError("Count too small. There must be at least one readonly and one writable connection " "allowed") # pragma: no cover self._read_cache = _ConnectionCache( max_wait_time_sec, max_waiting_count, readonly_count, lambda: self._factory.new_connection(False)) self._write_cache = _ConnectionCache( max_wait_time_sec, max_waiting_count, writeable_count, lambda: self._factory.new_connection(True)) def _on_connection_released(self, writable: bool, connection: DbConnection): if self._factory.supports_per_transaction_writeable_flag(): self._cache.return_connection(connection) elif writable: self._write_cache.return_connection(connection) else: self._read_cache.return_connection(connection) def execute_script(self, script: str, wrap_in_transaction: bool = False): """ Executes the given script. By default, NO TRANSACTION MANAGEMENT IS DONE. The script must start and commit a transaction itself. After the script, any ongoing transaction is rolled back. If wrap_in_transaction is True, the transaction start and end statements are put at the beginning and end of the script (with simple string concatenation; the last statement must have a ';'). Unlike other functions this immediately raises an exception for any error (usually :class:`DatabaseError`) and then rolls back. No results are returned """ if self._factory.supports_per_transaction_writeable_flag(): connection = self._cache.get_connection() else: connection = self._write_cache.get_connection() if wrap_in_transaction: script = (connection.get_transaction_begin_statement(writable=True) + ";\n" + script + "\n" + connection.get_transaction_end_statement(commit=True) + ";") rollback = not wrap_in_transaction try: connection.execute_script(script) except Exception as e: rollback = True raise e finally: if rollback: # noinspection PyBroadException try: connection.execute_statements( [FilledStatement(connection.get_transaction_end_statement(commit=False), [])]) except Exception: pass self._on_connection_released(True, connection) @contextmanager def start_read_transaction(self) -> ReadTransaction: """ Starts a read-only transaction. May raise :class:`NoAvailableConnectionError` """ transaction = None try: if self._factory.supports_per_transaction_writeable_flag(): connection = self._cache.get_connection() else: connection = self._read_cache.get_connection() transaction = ReadTransaction(lambda: self._on_connection_released(False, connection), connection) yield transaction if not transaction.is_closed(): print("Warning: Rolling open transaction back after execution but there was no exception (Always " "close the transaction explicitly)") transaction.close() except Exception: if transaction is not None and not transaction.is_closed(): _COUNTERS_TRANSACTION_ABORTED_BY_USER["read"].trigger() transaction.on_error_close_transaction() raise @contextmanager def start_write_transaction(self) -> WriteTransaction: """ Starts a read-write transaction. May raise :class:`NoAvailableConnectionError` """ transaction = None try: if self._factory.supports_per_transaction_writeable_flag(): connection = self._cache.get_connection() else: connection = self._write_cache.get_connection() transaction = WriteTransaction(lambda: self._on_connection_released(True, connection), connection) yield transaction if not transaction.is_closed(): print("Warning: Rolling open transaction back after execution but there was no exception (Always " "commit/rollback the transaction explicitly)") transaction.rollback() except Exception: if transaction is not None and not transaction.is_closed(): _COUNTERS_TRANSACTION_ABORTED_BY_USER["write"].trigger() transaction.on_error_close_transaction() raise