Skip to content
Snippets Groups Projects
Select Git revision
  • d44bdd57676a4de7e8872bc254290d26f2a27493
  • main default
  • full_migration
  • v1.0.9 protected
  • v1.0.8 protected
  • v1.0.7 protected
  • v1.0.6 protected
  • v1.0.5 protected
  • v1.0.4 protected
  • v1.0.3 protected
  • v1.0.2 protected
  • v1.0.1 protected
  • v1.0 protected
13 results

mysql_connector.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    mysql_connector.py 10.00 KiB
    from mysql.connector.cursor import MySQLCursor
    from mysql.connector.errors import ProgrammingError, DataError, NotSupportedError, InterfaceError
    from mysql.connector.version import VERSION_TEXT
    from mysql.connector import MySQLConnection, NUMBER, STRING, BINARY, DATETIME
    from time import time_ns
    from datetime import datetime, date
    
    from api.database import DbValueType, DbResultSet, FilledStatement
    from api.database.database import PreparedStatement, DbConnectionFactory, DatabaseWarning, DatabaseError, DbAffectedRows, \
        DatabaseResultRow
    from api.database.abstract_py_connector import PythonDbConnection, Cursor
    
    print(f"Using MySQL Connector Version: {VERSION_TEXT}")
    
    _GRACE_TIME_NO_PING_NS: int = 10 * 1000 * 1000 * 1000
    
    
    class _MySqlDbConnection(PythonDbConnection[MySQLConnection, MySQLCursor]):
        
        def __init__(self, use_multi_query: bool, fetch_warnings: bool, py_connection: MySQLConnection):
            super().__init__(
                "%s",
                (ProgrammingError, DataError, NotSupportedError),
                py_connection)
            self._use_multi_query = use_multi_query
            self._get_warnings = fetch_warnings
            self._disconnected = False
        
        def _can_use_multi_query(self, statements: [FilledStatement]):
            return self._use_multi_query
        
        def is_disconnected(self, force_ping: bool = False) -> bool:
            if self._closed:
                raise RuntimeError("Already closed")
            if self._disconnected:
                return True
            if not force_ping and self._last_successful_request + _GRACE_TIME_NO_PING_NS > time_ns():
                return False
            if self._py_connection.is_connected():
                return False
            self._disconnected = True
            return True
        
        def try_reconnect(self) -> bool:
            # noinspection PyBroadException
            try:
                self._py_connection.reconnect(attempts=1)
                self._py_connection.cmd_query("SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ")
                self._py_connection.cmd_query("SET SESSION sql_mode = 'ANSI_QUOTES'")
                self._py_connection.get_warnings = self._get_warnings
            except Exception:
                return False
            self._clear_cache_on_new_connection()
            self._disconnected = False
            return True
    
        def close(self):
            self._closed = True
            self._py_connection.close()
        
        def get_transaction_begin_statement(self, writable: bool) -> PreparedStatement or str:
            if writable:
                return "START TRANSACTION READ WRITE"
            else:
                return "START TRANSACTION READ ONLY"
        
        def get_transaction_end_statement(self, commit: bool) -> PreparedStatement or str:
            if commit:
                return "COMMIT"
            else:
                return "ROLLBACK"
        
        def _create_cursor(self, prepared: bool):
            return self._py_connection.cursor(prepared=prepared)
        
        def _db_execute_single_statement(self,
                                         cursor: MySQLCursor,
                                         statement: str,
                                         values: list[DbValueType]) -> tuple[list[DatabaseWarning], DbAffectedRows, DbResultSet]:
            cursor.execute(statement, params=values)
            return self._get_result(cursor)
        
        def _db_execute_multiple_statements(self,
                                            cursor: MySQLCursor,
                                            statements: str,
                                            values: list[DbValueType]) -> list[tuple[list[DatabaseWarning], DbAffectedRows, DbResultSet]]:
            results = []
            for res_cursor in cursor.execute(statements, params=values, multi=True):
                results.append(self._get_result(res_cursor))
            return results
        
        def _get_result(self, cursor) -> tuple[list[DatabaseWarning], DbAffectedRows, DbResultSet]:
            row_count = cursor.rowcount
            result_set = self._get_result_set(cursor)
            if row_count == -1:
                # Do not get rowcount after fetchall if it was valid before. If there is no result, mysql sometimes
                # just sets this to 0 in fetchall
                row_count = cursor.rowcount
            return self._fetch_warnings(cursor), row_count, result_set
        
        def _fetch_warnings(self, cursor: MySQLCursor) -> list[DatabaseWarning]:
            if not self._get_warnings:
                return []
            try:
                mysql_warnings = cursor.fetchwarnings()
                if mysql_warnings is None:
                    warnings = []
                else:
                    # mysql warning is tuple (type, id, message)
                    warnings = list(map(lambda warn: DatabaseWarning(f"{warn[0]} ({warn[1]}): {warn[2]}"), mysql_warnings))
                return warnings
            except InterfaceError as e:
                if e.msg == "Failed getting warnings; No result set to fetch from.":  # Why ever this throws an exception
                    return []
                raise e
        
        @staticmethod
        def _get_result_set(cursor: MySQLCursor) -> DbResultSet:
            try:
                cursor_rows = cursor.fetchall()
            except InterfaceError as e:
                if e.msg == "No result set to fetch from.":  # Why ever this throws an exception
                    return []
                raise e
            if len(cursor_rows) == 0:
                return []
            
            column_mapping = PythonDbConnection._create_column_mapping(cursor)
            
            result_set = []
            for cursor_row in cursor_rows:
                result_row = []
                for type_code, cursor_value in zip(
                        map(lambda desc: desc[1],  # As in the Python DB Api v2
                            cursor.description),
                        cursor_row):
                    
                    if cursor_value is None:
                        result_value = None
                    
                    elif type_code == NUMBER:
                        if isinstance(cursor_value, int):
                            result_value = cursor_value
                        else:
                            raise TypeError(f"Unknown type {type(cursor_value).__name__} for NUMBER")  # pragma: no cover
                    
                    # Unfortunately there is no way to know if it's a BLOB or a TEXT column in MySql. As we currently don't
                    # use any BLOBs we just interpret all binary columns as strings
                    # (Note that the flags in the cursor.description (8th entry) do not contain this information either)
                    elif type_code == STRING or type_code == BINARY:
                        if isinstance(cursor_value, str):
                            result_value = cursor_value
                        elif isinstance(cursor_value, bytes):
                            result_value = cursor_value.decode()
                        elif isinstance(cursor_value, bytearray):
                            result_value = cursor_value.decode()
                        else:
                            raise TypeError(
                                f"Unknown type {type(cursor_value).__name__} for STRING OR BINARY")  # pragma: no cover
                    
                    elif type_code == DATETIME:
                        if isinstance(cursor_value, datetime):
                            result_value = cursor_value
                        elif isinstance(cursor_value, date):
                            result_value = datetime(
                                year=cursor_value.year,
                                month=cursor_value.month,
                                day=cursor_value.day
                            )
                        else:
                            raise TypeError(f"Unknown type {type(cursor_value).__name__} for DATETIME")  # pragma: no cover
                    
                    else:
                        raise ValueError(f"Unknown type code {type_code}")  # pragma: no cover
                    
                    result_row.append(result_value)
                result_set.append(DatabaseResultRow(column_mapping, tuple(result_row)))
            
            return result_set
        
        def _db_execute_script(self, cursor: MySQLCursor, script: str):
            for res_cursor in cursor.execute(script, multi=True):
                warnings = self._fetch_warnings(res_cursor)
                if len(warnings) > 0:
                    raise DatabaseError(f"Got warnings during execution of script: {warnings}")
                try:
                    res_cursor.fetchall()  # We need to always fetch these
                except InterfaceError as e:
                    if e.msg == "No result set to fetch from.":  # Why ever this throws an exception
                        continue
                    raise e
    
    
    class MySqlDbConnectionFactory(DbConnectionFactory):
        
        def __init__(self, use_multi_query: bool, fetch_warnings: bool,
                     host: str, port: int, unix_socket: str,
                     username: str, password: str, database: str):
            super().__init__()
            if use_multi_query and fetch_warnings:
                # In this case MySQL Connector seems buggy/throws exceptions
                raise ValueError("Warnings cannot be fetched with multi queries")
            self._use_multi_query = use_multi_query
            self._fetch_warnings = fetch_warnings
            self._host = host
            self._port = port
            self._unix_socket = unix_socket
            self._username = username
            self._password = password
            self._database = database
        
        def supports_per_transaction_writeable_flag(self) -> bool:
            return True
        
        def new_connection(self, writable: bool = True) -> _MySqlDbConnection:
            try:
                mysql_connection = MySQLConnection(
                    host=self._host,
                    port=self._port,
                    unix_socket=self._unix_socket,
                    user=self._username,
                    password=self._password,
                    database=self._database
                )
                mysql_connection.cmd_query("SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ")
                mysql_connection.cmd_query("SET SESSION sql_mode = 'ANSI_QUOTES'")
                mysql_connection.get_warnings = self._fetch_warnings
            except Exception as e:
                raise DatabaseError("An exception occurred while connecting to database") from e
            return _MySqlDbConnection(self._use_multi_query, self._fetch_warnings, mysql_connection)