from mysql.connector.cursor import MySQLCursor
from mysql.connector.errors import ProgrammingError, DataError, NotSupportedError, InterfaceError
from mysql.connector.version import VERSION_TEXT
from mysql.connector import MySQLConnection, NUMBER, STRING, BINARY, DATETIME, InternalError
from time import time_ns
from datetime import datetime, date

from api.database import DbValueType, DbResultSet, FilledStatement
from api.database.database import PreparedStatement, DbConnectionFactory, DatabaseWarning, DatabaseError, \
    DbAffectedRows, \
    DatabaseResultRow, TransactionIsolationLevel
from api.database.abstract_py_connector import PythonDbConnection

print(f"Using MySQL Connector Version: {VERSION_TEXT}")

_GRACE_TIME_NO_PING_NS: int = 10 * 1000 * 1000 * 1000


class _MySqlDbConnection(PythonDbConnection[MySQLConnection, MySQLCursor]):
    
    def __init__(self, use_multi_query: bool, fetch_warnings: bool, py_connection: MySQLConnection):
        super().__init__(
            "%s",
            (ProgrammingError, DataError, NotSupportedError),
            py_connection)
        self._use_multi_query = use_multi_query
        self._get_warnings = fetch_warnings
        self._disconnected = False
    
    def _can_use_multi_query(self, statements: [FilledStatement]):
        return self._use_multi_query
    
    def is_disconnected(self, force_ping: bool = False) -> bool:
        if self._closed:
            raise RuntimeError("Already closed")
        if self._disconnected:
            return True
        if not force_ping and self._last_successful_request + _GRACE_TIME_NO_PING_NS > time_ns():
            return False
        if self._py_connection.is_connected():
            return False
        self._disconnected = True
        return True
    
    def try_reconnect(self) -> bool:
        # noinspection PyBroadException
        try:
            self._py_connection.reconnect(attempts=1)
            self._py_connection.cmd_query("SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ")
            self._py_connection.cmd_query("SET SESSION sql_mode = 'ANSI_QUOTES'")
            self._py_connection.get_warnings = self._get_warnings
        except Exception:
            return False
        self._clear_cache_on_new_connection()
        self._disconnected = False
        return True

    def close(self):
        self._closed = True
        self._py_connection.close()
    
    def get_transaction_begin_statements(self, writable: bool, isolation_level: TransactionIsolationLevel) -> PreparedStatement or str:
        if writable:
            return [
                f"SET SESSION TRANSACTION ISOLATION LEVEL {isolation_level.value}",
                "START TRANSACTION READ WRITE",
            ]
        else:
            return [
                f"SET SESSION TRANSACTION ISOLATION LEVEL {isolation_level.value}",
                "START TRANSACTION READ ONLY",
            ]
    
    def get_transaction_end_statement(self, commit: bool) -> PreparedStatement or str:
        if commit:
            return "COMMIT"
        else:
            return "ROLLBACK"
    
    def _create_cursor(self, prepared: bool):
        return self._py_connection.cursor(prepared=prepared)
    
    def _is_transaction_conflict_exception(self, exception: Exception) -> bool:
        return isinstance(exception, InternalError) and "try restarting transaction" in exception.msg
    
    def _db_execute_single_statement(self,
                                     cursor: MySQLCursor,
                                     statement: str,
                                     values: list[DbValueType]) -> tuple[list[DatabaseWarning], DbAffectedRows, DbResultSet]:
        # Note: Warnings seem to be broken. When a warning is generated, an error is thrown in the library:
        # TypeError: expected string or bytes-like object, got 'Warning'
        cursor.execute(statement, params=values)
        return self._get_result(cursor)
    
    def _db_execute_multiple_statements(self,
                                        cursor: MySQLCursor,
                                        statements: str,
                                        values: list[DbValueType]) -> list[tuple[list[DatabaseWarning], DbAffectedRows, DbResultSet]]:
        # With multi=True the warnings seem to break. When a warning is generated we can't seem to be able to retrieve
        # it, but in later queries (without multi) the following error is thrown:
        # InterfaceError: Use cmd_query_iter for statements with multiple queries.
        results = []
        for res_cursor in cursor.execute(statements, params=values, multi=True):
            results.append(self._get_result(res_cursor))
        return results
    
    def _get_result(self, cursor) -> tuple[list[DatabaseWarning], DbAffectedRows, DbResultSet]:
        row_count = cursor.rowcount
        result_set = self._get_result_set(cursor)
        if row_count == -1:
            # Do not get rowcount after fetchall if it was valid before. If there is no result, mysql sometimes
            # just sets this to 0 in fetchall
            row_count = cursor.rowcount
        return self._fetch_warnings(cursor), row_count, result_set
    
    def _fetch_warnings(self, cursor: MySQLCursor) -> list[DatabaseWarning]:
        if not self._get_warnings:
            return []
        try:
            mysql_warnings = cursor.fetchwarnings()
            if mysql_warnings is None:
                warnings = []
            else:
                # mysql warning is tuple (type, id, message)
                warnings = list(map(lambda warn: DatabaseWarning(f"{warn[0]} ({warn[1]}): {warn[2]}"), mysql_warnings))
            return warnings
        except InterfaceError as e:
            if e.msg == "Failed getting warnings; No result set to fetch from.":  # Why ever this throws an exception
                return []
            raise e
    
    @staticmethod
    def _get_result_set(cursor: MySQLCursor) -> DbResultSet:
        try:
            cursor_rows = cursor.fetchall()
        except InterfaceError as e:
            if e.msg == "No result set to fetch from.":  # Why ever this throws an exception
                return []
            raise e
        if len(cursor_rows) == 0:
            return []
        
        column_mapping = PythonDbConnection._create_column_mapping(cursor)
        
        result_set = []
        for cursor_row in cursor_rows:
            result_row = []
            for type_code, cursor_value in zip(
                    map(lambda desc: desc[1],  # As in the Python DB Api v2
                        cursor.description),
                    cursor_row):
                
                if cursor_value is None:
                    result_value = None
                
                elif type_code == NUMBER:
                    if isinstance(cursor_value, int):
                        result_value = cursor_value
                    else:
                        raise TypeError(f"Unknown type {type(cursor_value).__name__} for NUMBER")  # pragma: no cover
                
                # Unfortunately there is no way to know if it's a BLOB or a TEXT column in MySql. As we currently don't
                # use any BLOBs we just interpret all binary columns as strings
                # (Note that the flags in the cursor.description (8th entry) do not contain this information either)
                elif type_code == STRING or type_code == BINARY:
                    if isinstance(cursor_value, str):
                        result_value = cursor_value
                    elif isinstance(cursor_value, bytes):
                        result_value = cursor_value.decode()
                    elif isinstance(cursor_value, bytearray):
                        result_value = cursor_value.decode()
                    else:
                        raise TypeError(
                            f"Unknown type {type(cursor_value).__name__} for STRING OR BINARY")  # pragma: no cover
                
                elif type_code == DATETIME:
                    if isinstance(cursor_value, datetime):
                        result_value = cursor_value
                    elif isinstance(cursor_value, date):
                        result_value = datetime(
                            year=cursor_value.year,
                            month=cursor_value.month,
                            day=cursor_value.day
                        )
                    else:
                        raise TypeError(f"Unknown type {type(cursor_value).__name__} for DATETIME")  # pragma: no cover
                
                else:
                    raise ValueError(f"Unknown type code {type_code}")  # pragma: no cover
                
                result_row.append(result_value)
            result_set.append(DatabaseResultRow(column_mapping, tuple(result_row)))
        
        return result_set
    
    def _db_execute_script(self, cursor: MySQLCursor, script: str):
        # See comment on multi=True in _db_execute_multiple_statements
        for res_cursor in cursor.execute(script, multi=True):
            warnings = self._fetch_warnings(res_cursor)
            if len(warnings) > 0:
                raise DatabaseError(f"Got warnings during execution of script: {warnings}")
            try:
                res_cursor.fetchall()  # We need to always fetch these
            except InterfaceError as e:
                if e.msg == "No result set to fetch from.":  # Why ever this throws an exception
                    continue
                raise e


class MySqlDbConnectionFactory(DbConnectionFactory):
    
    def __init__(self, use_multi_query: bool, fetch_warnings: bool,
                 host: str, port: int, unix_socket: str,
                 username: str, password: str, database: str):
        super().__init__()
        if use_multi_query and fetch_warnings:
            # In this case MySQL Connector seems buggy/throws exceptions
            raise ValueError("Warnings cannot be fetched with multi queries")
        self._use_multi_query = use_multi_query
        self._fetch_warnings = fetch_warnings
        self._host = host
        self._port = port
        self._unix_socket = unix_socket
        self._username = username
        self._password = password
        self._database = database
    
    def supports_per_transaction_writeable_flag(self) -> bool:
        return True
    
    def new_connection(self, writable: bool = True) -> _MySqlDbConnection:
        try:
            mysql_connection = MySQLConnection(
                host=self._host,
                port=self._port,
                unix_socket=self._unix_socket,
                user=self._username,
                password=self._password,
                database=self._database
            )
            mysql_connection.cmd_query("SET SESSION sql_mode = 'ANSI_QUOTES'")
            mysql_connection.get_warnings = self._fetch_warnings
        except Exception as e:
            raise DatabaseError("An exception occurred while connecting to database") from e
        return _MySqlDbConnection(self._use_multi_query, self._fetch_warnings, mysql_connection)