from threading import Condition
from datetime import datetime
from typing import Callable, Literal
from abc import ABC, ABCMeta, abstractmethod
from contextlib import contextmanager
from traceback import StackSummary, extract_stack

from api.miscellaneous import DEBUG_ENABLED, DIAGNOSTICS_TRACKER, DiagnosticsCounter

DbValueType = int | str | bytes | datetime

_TransactionType = Literal["read", "write"]


def _create_diagnostics_counters(id: str) -> dict[_TransactionType, DiagnosticsCounter]:
    return {
        "read": DIAGNOSTICS_TRACKER.register_counter(f"database.transaction.read.{id}"),
        "write": DIAGNOSTICS_TRACKER.register_counter(f"database.transaction.write.{id}")
    }


_COUNTERS_TRANSACTION_COUNT = _create_diagnostics_counters("count")
_COUNTERS_TRANSACTION_ERRORS = _create_diagnostics_counters("errors")
_COUNTERS_TRANSACTION_ATTEMPTED_RECONNECTS = _create_diagnostics_counters("attempted_reconnects")
_COUNTERS_TRANSACTION_SUCCESSFUL_RECONNECTS = _create_diagnostics_counters("successful_reconnects")
_COUNTERS_TRANSACTION_ERRORS_AFTER_RECONNECT = _create_diagnostics_counters("errors_after_reconnect")
_COUNTERS_TRANSACTION_ABORTED_BY_USER = _create_diagnostics_counters("aborted_by_user")


class DatabaseResultRow:
    
    def __init__(self, column_mapping: dict[str, int], values: tuple[DbValueType, ...]):
        super().__init__()
        self._column_mapping = column_mapping
        self._values = values
    
    def __contains__(self, item: str or int):
        if isinstance(item, str):
            return item in self._column_mapping
        assert isinstance(item, int)
        return 0 <= item < len(self._values)
    
    def __getitem__(self, key: str or int) -> DbValueType:
        if isinstance(key, str):
            return self._values[self._column_mapping[key]]
        assert isinstance(key, int)
        return self._values[key]
    
    def __str__(self):
        return self._values.__str__()
    
    def __repr__(self):
        return self.__str__()


class DatabaseWarning:
    
    def __init__(self, message: str):
        super().__init__()
        self._message = message


DbResultRow = DatabaseResultRow
DbResultSet = list[DbResultRow]

DbAffectedRows = int
DbResult = tuple[Exception or None, list[DatabaseWarning], DbResultSet, DbAffectedRows]

DB_RESULT_EXCEPTION = 0
DB_RESULT_WARNINGS = 1
DB_RESULT_SET = 2
DB_RESULT_AFFECTED_ROWS = 3

SQL_PARAMETER_INDICATOR = "?"


class DatabaseError(Exception):
    """Error for any problems of the database not directly caused by the caller, e.g. connection problems, etc."""


class WarningError(Exception):
    
    def __init__(self, warnings: list[DatabaseWarning]):
        super().__init__(str(warnings))
        self.warnings = warnings


class NoAvailableConnectionError(Exception):
    """Indicates that there are currently no connections available"""


class PreparedStatement:
    
    def __init__(self, statement: str):
        super().__init__()
        self.statement = statement
        self.parameter_count = statement.count(SQL_PARAMETER_INDICATOR)
    
    def __str__(self):
        return self.statement


class FilledStatement:
    
    def __init__(self, statement: PreparedStatement or str, values: list[DbValueType]):
        super().__init__()
        self.statement = statement
        self.values = values
        self.queue_traceback: StackSummary or None = None


class DbConnection(ABC):
    __metaclass__ = ABCMeta
    
    def __init__(self):
        super().__init__()
    
    @abstractmethod
    def is_disconnected(self, force_ping: bool = False) -> bool:
        """
        Returns True if this connection is definitely closed. False indicates that it is probably still intact.
        This may ping the database.
        
        :param force_ping: If this is true the database is guaranteed to be pinged
        """
        pass  # pragma: no cover
    
    def try_reconnect(self) -> bool:
        """
        Tries to reconnect to the database. Note that any ongoing transaction is rolled back
        :return: True if reconnection was successful
        """
        pass  # pragma: no cover
    
    @abstractmethod
    def close(self):
        """Closes this connection, freeing all resources"""
        pass  # pragma: no cover
    
    @abstractmethod
    def get_transaction_begin_statement(self, writable: bool) -> PreparedStatement or str:
        """
        Returns the statement to begin a transaction
        :param writable: Specifies whether this transaction is writable.
                         If :func:`supports_per_transaction_writeable_flag` of the :class:`DbConnectionFactory` is False
                         this parameter is ignored.
        """
        pass  # pragma: no cover
    
    @abstractmethod
    def get_transaction_end_statement(self, commit: bool) -> PreparedStatement or str:
        """Returns the statement to end a transaction"""
        pass  # pragma: no cover
    
    @abstractmethod
    def execute_statements(self, statements: [FilledStatement]) -> list[DbResult]:
        """
        Executes the given statements and returns their results.
        
        May raise :class:`DatabaseError`
        """
        pass  # pragma: no cover

    @abstractmethod
    def execute_script(self, script: str):
        """
        See execute_script of :class:`DbConnectionPool`
        """
        pass  # pragma: no cover


class DbConnectionFactory(ABC):
    __metaclass__ = ABCMeta
    
    def __init__(self):
        super().__init__()
    
    @abstractmethod
    def supports_per_transaction_writeable_flag(self) -> bool:
        """
        True if a single connection can support read-only and read-write transactions.
        
        If False a connection may either be read-only or read-write for its full lifetime
        """
        pass  # pragma: no cover
    
    @abstractmethod
    def new_connection(self, writable: bool = True) -> DbConnection:
        """
        Creates a new connection
        
        :param writable: Ignored if :func:`supports_per_transaction_writeable_flag` is True.
        Otherwise, this specifies whether the database can be modified with this connection
        :return: A new connection
        
        May raise :class:`DatabaseError`
        """
        pass  # pragma: no cover


class AbstractTransaction(ABC):
    __metaclass__ = ABCMeta
    
    def __init__(self, type: _TransactionType, release_connection: Callable, connection: DbConnection):
        super().__init__()
        self._type = type
        self._release_connection = release_connection
        self._closed = False
        self._connection = connection
        self._has_executed_statements = False
        self._has_encountered_errors = False
        self._queued_statements: list[FilledStatement] = []
        self._statement_results: dict[FilledStatement, DbResult] = {}
        _COUNTERS_TRANSACTION_COUNT[type].trigger()
    
    def is_closed(self) -> bool:
        return self._closed
    
    def queue_statement(self, statement: PreparedStatement or str, *values: DbValueType) -> FilledStatement:
        """
        Queues the specified statement to be executed. The result can later be queried with :func:`get_result`.
        Note that the statements are always executed in the order in which they were queued.
        
        :return: An object with which the result can be retrieved later
        
        Throws :class:`ValueError` if the parameter count (count of ?) does not match the length of values
        
        Throws :class:`RuntimeError` if the transaction was already closed (committed/rolled back)
        """
        if self._closed:
            raise RuntimeError("Transaction already closed")  # pragma: no cover
        
        parameter_count = statement.parameter_count\
            if isinstance(statement, PreparedStatement)\
            else statement.count(SQL_PARAMETER_INDICATOR)
        if parameter_count != len(values):
            raise ValueError(f"Parameter ({parameter_count}) and argument count ({len(values)}) don't match")  # pragma: no cover
        
        queued = FilledStatement(statement, list(values))
        if DEBUG_ENABLED:
            queued.queue_traceback = extract_stack()
        self._queued_statements.append(queued)
        return queued
    
    def get_result(self, stat: FilledStatement, ignore_warnings: bool = False) -> DbResultSet:
        """
        Does the same as :func:`get_full_result` with default raise_exception but only returns the result set
        """
        return self.get_full_result(stat, ignore_warnings=ignore_warnings)[DB_RESULT_SET]
    
    def get_full_result(self,
                        stat: FilledStatement,
                        raise_exception: bool = True,
                        ignore_warnings: bool = False) -> DbResult:
        """
        Returns the result of the specified statement. If the statement has not been executed yet, the statement and
        possibly other queued statements are executed.
        :param stat: The filled statement returned by :func:`queue_statement`
        :param raise_exception: If True any exception generated by the statement execution will be raised by this
                                method. Note that if the statement was executed in a multi-query, the exception may
                                have been caused by another statement
        :param ignore_warnings: If True an exception will be raised if there are any warnings. Note that if the statement
                                was executed in a multi-query, the warnings may have been caused by another statement
        :return: The result. You may use :func:`get_result` if you are only interested in the result set
        
        Note that depending on the underlying connection, warnings may not be fetched.
        
        May raise :class:`DatabaseError` (e.g. connection was permanently lost) (regardless of raise_exception)
        This causes the transaction to be closed (rolled back)
        """
        if stat not in self._statement_results:
            self._execute_queued()
        if stat not in self._statement_results:  # pragma: no cover
            raise ValueError("Statement was not queued or result was already queried")
        result = self._statement_results[stat]
        del self._statement_results[stat]
        
        if raise_exception and result[DB_RESULT_EXCEPTION] is not None:  # pragma: no cover
            raise RuntimeError("An exception had occurred while executing statement "
                               "(Or other statement in multi-query)") from result[DB_RESULT_EXCEPTION]
        if not ignore_warnings and len(result[DB_RESULT_WARNINGS]) > 0:
            raise WarningError(result[DB_RESULT_WARNINGS])
        return result
    
    def execute_statement(self,
                          statement: PreparedStatement or str,
                          *values: DbValueType,
                          ignore_warnings: bool = False) -> DbResultSet:
        """
        Queues the statement and immediately retrieves the result.
        See :func:`queue_statement` and :func:`get_result`
        """
        return self.get_result(self.queue_statement(statement, *values), ignore_warnings)
    
    def execute_statement_full_result(self,
                                      statement: PreparedStatement or str,
                                      *values: DbValueType,
                                      raise_exception: bool = True,
                                      ignore_warnings: bool = False) -> DbResult:
        """
        Queues the statement and immediately retrieves the result.
        See :func:`queue_statement` and :func:`get_full_result`
        """
        return self.get_full_result(self.queue_statement(statement, *values), raise_exception, ignore_warnings)
    
    def _execute_statements(self,
                            *statements: tuple[PreparedStatement | str, list[DbValueType]],
                            after_queuing: Callable or None) -> tuple[DbResultSet, ...]:
        filled_list = []
        for stat in statements:
            filled_list.append(self.queue_statement(stat[0], *stat[1]))
            
        if after_queuing is not None:
            after_queuing()
        
        results: list[DbResultSet] = []
        for filled in filled_list:
            results.append(self.get_result(filled))
        return tuple(results)
    
    def on_error_close_transaction(self):
        # noinspection PyBroadException
        try:
            self._connection.execute_statements(
                [FilledStatement(self._connection.get_transaction_end_statement(commit=False), [])])
        except Exception:
            pass
        self._closed = True
        self._queued_statements = []
        self._release_connection()
    
    def _execute_queued(self):
        if not self._has_queued_statements():
            return
        try:
            results = self._connection.execute_statements(self._queued_statements)
        except DatabaseError as e:
            if not self._has_encountered_errors:
                _COUNTERS_TRANSACTION_ERRORS[self._type].trigger()
            self._has_encountered_errors = True
            
            if self._has_executed_statements or not self._connection.is_disconnected(force_ping=True):
                self.on_error_close_transaction()
                raise e
            # All triggers after here can be triggered at most once for a single transaction because after this code block
            # we have either executed statements or closed the transaction
            _COUNTERS_TRANSACTION_ATTEMPTED_RECONNECTS[self._type].trigger()
            if not self._connection.try_reconnect():
                self.on_error_close_transaction()
                raise e
            _COUNTERS_TRANSACTION_SUCCESSFUL_RECONNECTS[self._type].trigger()
            try:
                results = self._connection.execute_statements(self._queued_statements)
            except Exception:
                _COUNTERS_TRANSACTION_ERRORS_AFTER_RECONNECT[self._type].trigger()
                self.on_error_close_transaction()
                raise e  # Raise original exception
        except Exception as e:
            self.on_error_close_transaction()
            raise e
        
        self._has_executed_statements = True
        
        for result, stat in zip(results, self._queued_statements):
            self._statement_results[stat] = result
        self._queued_statements = []
    
    def _remove_all_statements(self):
        self._queued_statements = []
        self._statement_results = {}
        
    def _has_queued_statements(self):
        return len(self._queued_statements) > 0
    
    def _close(self):
        if self._closed:
            return
        if self._has_queued_statements():
            raise RuntimeError("Subclass must commit/rollback transaction")  # pragma: no cover
        self._closed = True
        self._release_connection()


class ReadTransaction(AbstractTransaction):
    
    def __init__(self, release_connection: Callable, connection: DbConnection):
        super().__init__("read", release_connection, connection)
        self.queue_statement(connection.get_transaction_begin_statement(False))
    
    def execute_statement_and_close(self,
                                    statement: PreparedStatement or str,
                                    *values: DbValueType,
                                    ignore_warnings: bool = False) -> DbResultSet:
        """
        Queues the statement, closes the transaction and returns the result.
        See :func:`queue_statement`, :func:`close` and :func:`get_result`
        """
        filled = self.queue_statement(statement, *values)
        self.close()
        return self.get_result(filled, ignore_warnings)
    
    def execute_statement_full_result_and_close(self,
                                                statement: PreparedStatement or str,
                                                *values: DbValueType,
                                                raise_exception: bool = True,
                                                ignore_warnings: bool = False) -> DbResult:
        """
        Queues the statement, closes the transaction and returns the result.
        See :func:`queue_statement`, :func:`close` and :func:`get_full_result`
        """
        filled = self.queue_statement(statement, *values)
        self.close()
        return self.get_full_result(filled, raise_exception, ignore_warnings)
    
    def execute_statements(self,
                           *statements: tuple[PreparedStatement | str, list[DbValueType]]) -> tuple[DbResultSet, ...]:
        """
        Queues all the statements and immediately retrieves all the results.
        See :func:`queue_statement` and :func:`get_result`
        """
        return self._execute_statements(*statements, after_queuing=None)
    
    def execute_statements_and_close(
            self,
            *statements: tuple[PreparedStatement | str, list[DbValueType]]) -> tuple[DbResultSet, ...]:
        """
        Queues all the statements, closes the transaction and returns all the results.
        See :func:`queue_statement`, :func:`close` and :func:`get_result`
        """
        return self._execute_statements(*statements, after_queuing=self.close)
    
    def close(self, ignore_unused_statements=False):
        """
        Executes all queued statements (if ignore_unused_statements is False) and then closes this transaction.
        
        After this method is called, no new statements may be queued.
        
        May raise :class:`DatabaseError` (e.g. connection was permanently lost)
        This causes the transaction to be closed (:func:`close`)
        """
        if self._closed:
            return
        if ignore_unused_statements:
            self._remove_all_statements()
        self.queue_statement(self._connection.get_transaction_end_statement(False))
        self._execute_queued()
        self._close()


class WriteTransaction(AbstractTransaction):
    
    def __init__(self, release_connection: Callable, connection: DbConnection):
        super().__init__("write", release_connection, connection)
        self._committed = False
        self.queue_statement(connection.get_transaction_begin_statement(True))
    
    def execute_statement_and_commit(self,
                                     statement: PreparedStatement or str,
                                     *values: DbValueType,
                                     ignore_warnings: bool = False) -> DbResultSet:
        """
        Queues the statement, commits and returns the result.
        See :func:`queue_statement`, :func:`commit` and :func:`get_result`
        """
        filled = self.queue_statement(statement, *values)
        self.commit()
        return self.get_result(filled, ignore_warnings)
    
    def execute_statement_full_result_and_commit(self,
                                                 statement: PreparedStatement or str,
                                                 *values: DbValueType,
                                                 raise_exception: bool = True,
                                                 ignore_warnings: bool = False) -> DbResult:
        """
        Queues the statement, commits and returns the result.
        See :func:`queue_statement`, :func:`commit` and :func:`get_full_result`
        """
        filled = self.queue_statement(statement, *values)
        self.commit()
        return self.get_full_result(filled, raise_exception, ignore_warnings)
    
    def execute_statements(self,
                           *statements: tuple[PreparedStatement | str, list[DbValueType]]) -> tuple[DbResultSet, ...]:
        """
        Queues all the statements and immediately retrieves all the results.
        See :func:`queue_statement` and :func:`get_result`
        """
        return self._execute_statements(*statements, after_queuing=None)
    
    def execute_statements_and_commit(
            self,
            *statements: tuple[PreparedStatement | str, list[DbValueType]]) -> tuple[DbResultSet, ...]:
        """
        Queues all the statements, commits and returns all the results.
        See :func:`queue_statement`, :func:`close` and :func:`get_result`
        """
        return self._execute_statements(*statements, after_queuing=self.commit)
    
    def commit(self, ignore_statement_exceptions: bool = False):
        """
        Executes all queued statements and then commits this transaction.
        
        After this method is called, no new statements may be queued.
        
        :param ignore_statement_exceptions If False and a statement exception occurs/has occurred in any results which
                                           have not been retrieved yet, the transaction will be rolled back and the
                                           exception is raised by this method
        
        Raises :class:`RuntimeError` if the transaction was already rolled back
        
        May raise :class:`DatabaseError` (e.g. connection was permanently lost).
        This causes the transaction to be rolled back (:func:`rollback`)
        """
        if self._closed:
            if not self._committed:
                raise RuntimeError("Transaction was already rolled back")
            return
        self._committed = True
        if not ignore_statement_exceptions:
            self._execute_queued()
            for result in self._statement_results.values():
                if result[0] is not None:
                    self.on_error_close_transaction()
                    print("Not committing due to exception during statement execution")
                    raise result[0]
        self.queue_statement(self._connection.get_transaction_end_statement(True))
        self._execute_queued()
        self._close()
    
    def rollback(self):
        """
        Rolls back this transaction. All queued statements are discarded
        
        After this method is called, no new statements may be queued.
        
        Raises :class:`RuntimeError` if the transaction was already committed
        
        May raise :class:`DatabaseError` (e.g. connection was permanently lost).
        """
        if self._closed:
            if self._committed:
                raise RuntimeError("Transaction was already committed")
            return
        self._remove_all_statements()
        self.queue_statement(self._connection.get_transaction_end_statement(False))
        self._execute_queued()
        self._close()


class _ConnectionCache:
    
    def __init__(self,
                 max_wait_time_sec: float,
                 max_waiting_count: int,
                 max_count: int,
                 factory: Callable[[], DbConnection]):
        super().__init__()
        self._max_wait_time_sec = max_wait_time_sec
        self._max_waiting_count = max_waiting_count
        self._max_count = max_count
        self._factory = factory
        self._all_connections: list[DbConnection] = []
        self._available_connections: list[DbConnection] = []
        self._available_condition: Condition = Condition()
        self._currently_waiting_count: int = 0
    
    def _try_get_connection(self) -> DbConnection or None:
        while len(self._available_connections) > 0:
            connection = self._available_connections.pop()
            if not connection.is_disconnected():
                return connection
            self._all_connections.remove(connection)
            connection.close()
        
        if len(self._all_connections) >= self._max_count:
            return None
        connection = self._factory()
        self._all_connections.append(connection)
        return connection
    
    def get_connection(self) -> DbConnection:
        with self._available_condition:
            connection = self._try_get_connection()
            if connection is not None:
                return connection
            if self._currently_waiting_count >= self._max_waiting_count:
                raise NoAvailableConnectionError("Too many are already waiting")
            self._currently_waiting_count += 1
            if not self._available_condition.wait(timeout=self._max_wait_time_sec):
                self._currently_waiting_count -= 1
                raise NoAvailableConnectionError("Timed out")
            self._currently_waiting_count -= 1
            connection = self._try_get_connection()
            # We were notified, so there is at least one slot to be reused or filled with a new connection
            assert connection is not None
            return connection
    
    def return_connection(self, connection: DbConnection):
        with self._available_condition:
            if connection.is_disconnected():
                self._all_connections.remove(connection)
                connection.close()
            else:
                self._available_connections.append(connection)
            self._available_condition.notify()


class DbConnectionPool:
    
    def __init__(self,
                 factory: DbConnectionFactory,
                 max_connections: int,
                 readonly_percent: float,
                 max_wait_time_sec: float,
                 max_waiting_count: int):
        super().__init__()
        self._factory = factory
        
        if self._factory.supports_per_transaction_writeable_flag():
            if max_connections < 0:
                raise ValueError("There must be at least one connection allowed")  # pragma: no cover
            self._cache = _ConnectionCache(
                max_wait_time_sec,
                max_waiting_count,
                max_connections,
                lambda: self._factory.new_connection())
        else:
            readonly_count = round(max_connections * readonly_percent)
            writeable_count = max_connections - readonly_count
            if readonly_count <= 0 or writeable_count <= 0:
                raise ValueError("Count too small. There must be at least one readonly and one writable connection "
                                 "allowed")  # pragma: no cover
            self._read_cache = _ConnectionCache(
                max_wait_time_sec,
                max_waiting_count,
                readonly_count,
                lambda: self._factory.new_connection(False))
            self._write_cache = _ConnectionCache(
                max_wait_time_sec,
                max_waiting_count,
                writeable_count,
                lambda: self._factory.new_connection(True))
    
    def _on_connection_released(self, writable: bool, connection: DbConnection):
        if self._factory.supports_per_transaction_writeable_flag():
            self._cache.return_connection(connection)
        elif writable:
            self._write_cache.return_connection(connection)
        else:
            self._read_cache.return_connection(connection)
    
    def execute_script(self, script: str, wrap_in_transaction: bool = False):
        """
        Executes the given script.
        
        By default, NO TRANSACTION MANAGEMENT IS DONE. The script must start and commit a transaction
        itself. After the script, any ongoing transaction is rolled back. If wrap_in_transaction is True, the
        transaction start and end statements are put at the beginning and end of the script (with simple string
        concatenation; the last statement must have a ';').
        
        Unlike other functions this immediately raises an exception for any error (usually :class:`DatabaseError`)
        and then rolls back.
        
        No results are returned
        """
        if self._factory.supports_per_transaction_writeable_flag():
            connection = self._cache.get_connection()
        else:
            connection = self._write_cache.get_connection()
        
        if wrap_in_transaction:
            script = (connection.get_transaction_begin_statement(writable=True) + ";\n"
                      + script + "\n"
                      + connection.get_transaction_end_statement(commit=True) + ";")
        
        rollback = not wrap_in_transaction
        try:
            connection.execute_script(script)
        except Exception as e:
            rollback = True
            raise e
        finally:
            if rollback:
                # noinspection PyBroadException
                try:
                    connection.execute_statements(
                        [FilledStatement(connection.get_transaction_end_statement(commit=False), [])])
                except Exception:
                    pass
            self._on_connection_released(True, connection)
    
    @contextmanager
    def start_read_transaction(self) -> ReadTransaction:
        """
        Starts a read-only transaction.
        
        May raise :class:`NoAvailableConnectionError`
        """
        transaction = None
        try:
            if self._factory.supports_per_transaction_writeable_flag():
                connection = self._cache.get_connection()
            else:
                connection = self._read_cache.get_connection()
            transaction = ReadTransaction(lambda: self._on_connection_released(False, connection), connection)
            yield transaction
            if not transaction.is_closed():
                print("Warning: Rolling open transaction back after execution but there was no exception (Always "
                      "close the transaction explicitly)")
                transaction.close()
        except Exception:
            if transaction is not None and not transaction.is_closed():
                _COUNTERS_TRANSACTION_ABORTED_BY_USER["read"].trigger()
                transaction.on_error_close_transaction()
            raise
    
    @contextmanager
    def start_write_transaction(self) -> WriteTransaction:
        """
        Starts a read-write transaction.
        
        May raise :class:`NoAvailableConnectionError`
        """
        transaction = None
        try:
            if self._factory.supports_per_transaction_writeable_flag():
                connection = self._cache.get_connection()
            else:
                connection = self._write_cache.get_connection()
            transaction = WriteTransaction(lambda: self._on_connection_released(True, connection), connection)
            yield transaction
            if not transaction.is_closed():
                print("Warning: Rolling open transaction back after execution but there was no exception (Always "
                      "commit/rollback the transaction explicitly)")
                transaction.rollback()
        except Exception:
            if transaction is not None and not transaction.is_closed():
                _COUNTERS_TRANSACTION_ABORTED_BY_USER["write"].trigger()
                transaction.on_error_close_transaction()
            raise