Skip to content
Snippets Groups Projects
Commit 46e09122 authored by Simon Künzel's avatar Simon Künzel
Browse files

Improve database, move some mysql dependent code to mysql_connector

parent 036add53
No related branches found
No related tags found
No related merge requests found
......@@ -7,9 +7,8 @@ from api.database.database import (PreparedStatement, FilledStatement,
AbstractTransaction, ReadTransaction, WriteTransaction,
DbConnection,
DbConnectionPool,
DbResult, DbResultSet, DbResultRow, DbValueType, DbWarning,
DB_RESULT_EXCEPTION, DB_RESULT_WARNINGS, DB_RESULT_SET, DB_RESULT_AFFECTED_ROWS,
DB_RESULT_AUTO_INCREMENT)
DbResult, DbResultSet, DbResultRow, DbValueType, DatabaseWarning,
DB_RESULT_EXCEPTION, DB_RESULT_WARNINGS, DB_RESULT_SET, DB_RESULT_AFFECTED_ROWS)
db_pool: DbConnectionPool
......
import time
from datetime import datetime, date
from typing import TypeVar, Generic
from abc import ABC, ABCMeta, abstractmethod
import traceback
from functools import reduce
from api.database.database import (PreparedStatement, FilledStatement, DbConnection,
DbResult, DbResultSet, DbValueType, DbWarning,
DatabaseError,
DbResult, DbResultSet, DbValueType, DbAffectedRows,
DatabaseWarning, DatabaseError,
SQL_PARAMETER_INDICATOR)
_DEBUG_PRINT_STATEMENT_EXECUTION: bool = False
......@@ -23,21 +22,11 @@ class PythonDbConnection(DbConnection, Generic[Connection, Cursor], ABC):
use_multi_query: bool,
parameter_indicator: str,
caller_exception: tuple[type, ...],
no_type_checks: bool,
type_binary,
type_number,
type_string,
type_datetime,
_py_connection: Connection):
super().__init__()
self._use_multi_query = use_multi_query
self._parameter_indicator = parameter_indicator
self._caller_exception = caller_exception
self._no_type_checks = no_type_checks
self._type_binary = type_binary
self._type_number = type_number
self._type_string = type_string
self._type_datetime = type_datetime
self._py_connection = _py_connection
self._closed = False
# str is statement with SQL_PARAMETER_INDICATOR replaced
......@@ -63,15 +52,14 @@ class PythonDbConnection(DbConnection, Generic[Connection, Cursor], ABC):
exception = None
try:
if isinstance(filled_stat.statement, PreparedStatement):
cursor = self._execute_prepared_statement(filled_stat.statement, filled_stat.values)
res = self._execute_prepared_statement(filled_stat.statement, filled_stat.values)
else:
assert isinstance(filled_stat.statement, str)
cursor = self._execute_unprepared_statement(filled_stat.statement, filled_stat.values)
warnings, auto_increment, affected_rows, result_set = self._get_result(cursor)
res = self._execute_unprepared_statement(filled_stat.statement, filled_stat.values)
warnings, affected_rows, result_set = res
except self._caller_exception as e: # pragma: no cover
exception = e
warnings = []
auto_increment = None
affected_rows = None
result_set = []
except Exception as e:
......@@ -107,7 +95,7 @@ class PythonDbConnection(DbConnection, Generic[Connection, Cursor], ABC):
else:
replaced_stat = self._replace_parameter_indicator(filled_stat.statement)
full_query += replaced_stat
all_values.extend(list(filled_stat.values))
all_values.extend(filled_stat.values)
if self._unprepared_cursor is None:
self._unprepared_cursor = self._create_cursor(False)
......@@ -117,10 +105,9 @@ class PythonDbConnection(DbConnection, Generic[Connection, Cursor], ABC):
results: list[DbResult] = []
try:
cursor_iterator = self._unprepared_cursor.execute(full_query, params=tuple(all_values), multi=True)
for i in range(0, len(statements)):
warnings, auto_increment, affected_rows, result_set = self._get_result(next(cursor_iterator))
results.append((None, warnings, result_set, affected_rows, auto_increment))
for warnings, affected_rows, result_set in self._db_execute_multiple_statements(
self._unprepared_cursor, full_query, all_values):
results.append((None, warnings, result_set, affected_rows))
except self._caller_exception as e:
for i in range(0, len(statements)):
results.append((e, [], [], None, None))
......@@ -143,104 +130,49 @@ class PythonDbConnection(DbConnection, Generic[Connection, Cursor], ABC):
self._prepared_statement_cache[statement] = data
return data
@abstractmethod
def _create_cursor(self, prepared: bool):
pass # pragma: no cover
def _execute_prepared_statement(self, statement: PreparedStatement, *values: DbValueType) -> Cursor:
def _execute_prepared_statement(self,
statement: PreparedStatement,
values: list[DbValueType]) -> tuple[list[DatabaseWarning], DbAffectedRows, DbResultSet]:
if statement not in self._prepared_statement_cache:
cursor = self._create_cursor(True)
replaced_stat, cursor = self._get_prepared_statement_data(statement, cursor)
else:
replaced_stat, cursor = self._prepared_statement_cache[statement]
cursor.execute(replaced_stat, *values)
return cursor
return self._db_execute_single_statement(cursor, replaced_stat, values)
def _execute_unprepared_statement(self, statement: str, *values: DbValueType) -> Cursor:
def _execute_unprepared_statement(self,
statement: str,
values: list[DbValueType]) -> tuple[list[DatabaseWarning], DbAffectedRows, DbResultSet]:
if self._unprepared_cursor is None:
self._unprepared_cursor = self._create_cursor(False)
replaced_stat = self._replace_parameter_indicator(statement)
self._unprepared_cursor.execute(replaced_stat, *values)
return self._unprepared_cursor
return self._db_execute_single_statement(self._unprepared_cursor, replaced_stat, values)
@abstractmethod
def _fetch_warnings(self, cursor: Cursor) -> list[DbWarning]:
pass # pragma: no cover
@staticmethod
def _create_column_mapping(cursor: Cursor) -> dict[str, int]:
column_mapping = {}
i = -1
for column_description in cursor.description:
i += 1
name = column_description[0] # As in the Python DB Api v2
column_mapping[name] = i
return column_mapping
@abstractmethod
def _fetch_auto_increment_value(self, cursor: Cursor) -> int or None:
def _create_cursor(self, prepared: bool):
pass # pragma: no cover
@abstractmethod
def _fetch_affected_rows_count(self, cursor: Cursor) -> int or None:
def _db_execute_single_statement(self,
cursor: Cursor,
statement: str,
values: list[DbValueType]) -> tuple[list[DatabaseWarning], DbAffectedRows, DbResultSet]:
pass # pragma: no cover
@abstractmethod
def _fetch_all_rows(self, cursor: Cursor) -> list[tuple]:
def _db_execute_multiple_statements(self,
cursor: Cursor,
statements: str,
values: list[DbValueType]) -> list[tuple[list[DatabaseWarning], DbAffectedRows, DbResultSet]]:
pass # pragma: no cover
def _get_result(self, cursor: Cursor) -> tuple[list[DbWarning], int or None, int or None, DbResultSet]:
warnings = self._fetch_warnings(cursor)
auto_increment_value = self._fetch_auto_increment_value(cursor)
affected_rows_count = self._fetch_affected_rows_count(cursor)
cursor_rows = self._fetch_all_rows(cursor)
if len(cursor_rows) == 0:
return warnings, auto_increment_value, affected_rows_count, []
column_names = []
column_type_codes = []
for column_description in cursor.description:
name = column_description[0] # As in the Python DB Api v2
type_code = column_description[1]
column_names.append(name)
column_type_codes.append(type_code)
result_set = []
for cursor_row in cursor_rows:
result_row = {}
for column_name, type_code, cursor_value in zip(column_names, column_type_codes, cursor_row):
if self._no_type_checks: # Because sqlite doesn't honor the type_code (it's always None)!?
result_value = cursor_value
elif cursor_value is None:
result_value = None
elif type_code == self._type_number:
if isinstance(cursor_value, int):
result_value = cursor_value
else:
raise TypeError(f"Unknown type {type(cursor_value).__name__} for NUMBER") # pragma: no cover
# Unfortunately there is no way to know if it's a BLOB or a TEXT column in MySql. As we currently don't
# use any BLOBs we just interpret all binary columns as strings
# (Note that the flags in the cursor.description (8th entry) do not contain this information either)
elif type_code == self._type_string or type_code == self._type_binary:
if isinstance(cursor_value, str):
result_value = cursor_value
elif isinstance(cursor_value, bytes):
result_value = cursor_value.decode()
elif isinstance(cursor_value, bytearray):
result_value = cursor_value.decode()
else:
raise TypeError(f"Unknown type {type(cursor_value).__name__} for STRING OR BINARY") # pragma: no cover
elif type_code == self._type_datetime:
if isinstance(cursor_value, datetime):
result_value = cursor_value
elif isinstance(cursor_value, date):
result_value = datetime(
year=cursor_value.year,
month=cursor_value.month,
day=cursor_value.day
)
else:
raise TypeError(f"Unknown type {type(cursor_value).__name__} for DATETIME") # pragma: no cover
else:
raise ValueError(f"Unknown type code {type_code}") # pragma: no cover
result_row[column_name] = result_value
result_set.append(result_row)
return warnings, auto_increment_value, affected_rows_count, result_set
......@@ -8,18 +8,51 @@ from traceback import StackSummary, extract_stack
from api.miscellaneous import DEBUG_ENABLED
DbValueType = int | str | bytes | datetime
DbResultRow = dict[str, DbValueType]
class DatabaseResultRow:
def __init__(self, column_mapping: dict[str, int], values: tuple[DbValueType, ...]):
super().__init__()
self._column_mapping = column_mapping
self._values = values
def __contains__(self, item: str or int):
if isinstance(item, str):
return item in self._column_mapping
assert isinstance(item, int)
return 0 <= item < len(self._values)
def __getitem__(self, key: str or int) -> DbValueType:
if isinstance(key, str):
return self._values[self._column_mapping[key]]
assert isinstance(key, int)
return self._values[key]
def __str__(self):
return self._values.__str__()
def __repr__(self):
return self.__str__()
class DatabaseWarning:
def __init__(self, message: str):
super().__init__()
self._message = message
DbResultRow = DatabaseResultRow
DbResultSet = list[DbResultRow]
# First string is warning type, int is warning id, second string is message
DbWarning = tuple[str, int, str]
DbResult = tuple[Exception or None, list[DbWarning], DbResultSet, int or None, int or None]
DbAffectedRows = int
DbResult = tuple[Exception or None, list[DatabaseWarning], DbResultSet, DbAffectedRows]
DB_RESULT_EXCEPTION = 0
DB_RESULT_WARNINGS = 1
DB_RESULT_SET = 2
DB_RESULT_AFFECTED_ROWS = 3
DB_RESULT_AUTO_INCREMENT = 4
SQL_PARAMETER_INDICATOR = "?"
......@@ -30,7 +63,7 @@ class DatabaseError(Exception):
class WarningError(Exception):
def __init__(self, warnings: list[DbWarning]):
def __init__(self, warnings: list[DatabaseWarning]):
super().__init__(str(warnings))
self.warnings = warnings
......@@ -52,7 +85,7 @@ class PreparedStatement:
class FilledStatement:
def __init__(self, statement: PreparedStatement or str, *values: DbValueType):
def __init__(self, statement: PreparedStatement or str, values: list[DbValueType]):
super().__init__()
self.statement = statement
self.values = values
......@@ -176,7 +209,7 @@ class AbstractTransaction(ABC):
if parameter_count != len(values):
raise ValueError(f"Parameter ({parameter_count}) and argument count ({len(values)}) don't match") # pragma: no cover
queued = FilledStatement(statement, *values)
queued = FilledStatement(statement, list(values))
if DEBUG_ENABLED:
queued.queue_traceback = extract_stack()
self._queued_statements.append(queued)
......@@ -262,7 +295,7 @@ class AbstractTransaction(ABC):
# noinspection PyBroadException
try:
self._connection.execute_statements(
[FilledStatement(self._connection.get_transaction_end_statement(commit=False))])
[FilledStatement(self._connection.get_transaction_end_statement(commit=False), [])])
except Exception:
pass
self._closed = True
......
......@@ -3,8 +3,11 @@ from mysql.connector.errors import ProgrammingError, DataError, NotSupportedErro
from mysql.connector.version import VERSION_TEXT
from mysql.connector import MySQLConnection, NUMBER, STRING, BINARY, DATETIME
from time import time_ns
from datetime import datetime, date
from api.database.database import PreparedStatement, DbWarning, DbConnectionFactory, DatabaseError
from api.database import DbValueType, DbResultSet
from api.database.database import PreparedStatement, DbConnectionFactory, DatabaseWarning, DatabaseError, DbAffectedRows, \
DatabaseResultRow
from api.database.abstract_py_connector import PythonDbConnection, Cursor
print(f"Using MySQL Connector Version: {VERSION_TEXT}")
......@@ -19,11 +22,6 @@ class _MySqlDbConnection(PythonDbConnection[MySQLConnection, MySQLCursor]):
use_multi_query,
"%s",
(ProgrammingError, DataError, NotSupportedError),
False,
BINARY,
NUMBER,
STRING,
DATETIME,
_py_connection)
self._get_warnings = fetch_warnings
self._disconnected = False
......@@ -71,32 +69,110 @@ class _MySqlDbConnection(PythonDbConnection[MySQLConnection, MySQLCursor]):
def _create_cursor(self, prepared: bool):
return self._py_connection.cursor(prepared=prepared)
def _fetch_warnings(self, cursor: MySQLCursor) -> list[DbWarning]:
def _db_execute_single_statement(self,
cursor: MySQLCursor,
statement: str,
values: list[DbValueType]) -> tuple[list[DatabaseWarning], DbAffectedRows, DbResultSet]:
cursor.execute(statement, params=values)
return self._get_result(cursor)
def _db_execute_multiple_statements(self,
cursor: MySQLCursor,
statements: str,
values: list[DbValueType]) -> list[tuple[list[DatabaseWarning], DbAffectedRows, DbResultSet]]:
results = []
for res_cursor in cursor.execute(statements, params=values, multi=True):
results.append(self._get_result(res_cursor))
return results
def _get_result(self, cursor) -> tuple[list[DatabaseWarning], DbAffectedRows, DbResultSet]:
row_count = cursor.rowcount
result_set = self._get_result_set(cursor)
if row_count == -1:
# Do not get rowcount after fetchall if it was valid before. If there is no result, mysql sometimes
# just sets this to 0 in fetchall
row_count = cursor.rowcount
return self._fetch_warnings(cursor), row_count, result_set
def _fetch_warnings(self, cursor: MySQLCursor) -> list[DatabaseWarning]:
if not self._get_warnings:
return []
try:
warnings = cursor.fetchwarnings()
if warnings is None:
mysql_warnings = cursor.fetchwarnings()
if mysql_warnings is None:
warnings = []
else:
# mysql warning is tuple (type, id, message)
warnings = list(map(lambda warn: DatabaseWarning(f"{warn[0]} ({warn[1]}): {warn[2]}"), mysql_warnings))
return warnings
except InterfaceError as e:
if e.msg == "Failed getting warnings; No result set to fetch from.": # Why ever this throws an exception
return []
raise e
def _fetch_auto_increment_value(self, cursor: MySQLCursor) -> int or None:
return cursor.getlastrowid()
def _fetch_affected_rows_count(self, cursor: Cursor) -> int or None:
return cursor.rowcount
def _fetch_all_rows(self, cursor: MySQLCursor) -> list[tuple]:
@staticmethod
def _get_result_set(cursor: MySQLCursor) -> DbResultSet:
try:
return cursor.fetchall()
cursor_rows = cursor.fetchall()
except InterfaceError as e:
if e.msg == "No result set to fetch from.": # Why ever this throws an exception
return []
raise e
if len(cursor_rows) == 0:
return []
column_mapping = PythonDbConnection._create_column_mapping(cursor)
result_set = []
for cursor_row in cursor_rows:
result_row = []
for type_code, cursor_value in zip(
map(lambda desc: desc[1], # As in the Python DB Api v2
cursor.description),
cursor_row):
if cursor_value is None:
result_value = None
elif type_code == NUMBER:
if isinstance(cursor_value, int):
result_value = cursor_value
else:
raise TypeError(f"Unknown type {type(cursor_value).__name__} for NUMBER") # pragma: no cover
# Unfortunately there is no way to know if it's a BLOB or a TEXT column in MySql. As we currently don't
# use any BLOBs we just interpret all binary columns as strings
# (Note that the flags in the cursor.description (8th entry) do not contain this information either)
elif type_code == STRING or type_code == BINARY:
if isinstance(cursor_value, str):
result_value = cursor_value
elif isinstance(cursor_value, bytes):
result_value = cursor_value.decode()
elif isinstance(cursor_value, bytearray):
result_value = cursor_value.decode()
else:
raise TypeError(
f"Unknown type {type(cursor_value).__name__} for STRING OR BINARY") # pragma: no cover
elif type_code == DATETIME:
if isinstance(cursor_value, datetime):
result_value = cursor_value
elif isinstance(cursor_value, date):
result_value = datetime(
year=cursor_value.year,
month=cursor_value.month,
day=cursor_value.day
)
else:
raise TypeError(f"Unknown type {type(cursor_value).__name__} for DATETIME") # pragma: no cover
else:
raise ValueError(f"Unknown type code {type_code}") # pragma: no cover
result_row.append(result_value)
result_set.append(DatabaseResultRow(column_mapping, tuple(result_row)))
return result_set
class MySqlDbConnectionFactory(DbConnectionFactory):
......
......@@ -3,7 +3,9 @@ from sqlite3 import Connection, Cursor, DatabaseError, DataError, ProgrammingErr
from datetime import datetime
from os import PathLike
from api.database.database import PreparedStatement, DbWarning, DbConnectionFactory, DatabaseError
from api.database import DbValueType, DbResultSet
from api.database.database import PreparedStatement, DbConnectionFactory, DatabaseWarning, DatabaseError, DbAffectedRows, \
DatabaseResultRow
from api.database.abstract_py_connector import PythonDbConnection
print(f"Using SQLite Connector Version: {sqlite3.sqlite_version}")
......@@ -43,11 +45,6 @@ class SqLiteDbConnection(PythonDbConnection[Connection, Cursor]):
False, # Not supported by sqlite
"?",
(ProgrammingError, DataError, NotSupportedError),
True,
None, # Value does not matter
None, # Value does not matter
None, # Value does not matter
None, # Value does not matter
_py_connection)
def is_disconnected(self, force_ping: bool = False) -> bool:
......@@ -74,17 +71,23 @@ class SqLiteDbConnection(PythonDbConnection[Connection, Cursor]):
def _create_cursor(self, prepared: bool):
return self._py_connection.cursor() # SQLite does not have prepared cursors
def _fetch_warnings(self, cursor: Cursor) -> list[DbWarning]:
return [] # Warnings seem to be thrown as exceptions
def _fetch_auto_increment_value(self, cursor: Cursor) -> int or None:
return cursor.lastrowid
def _fetch_affected_rows_count(self, cursor: Cursor) -> int or None:
return cursor.rowcount
def _fetch_all_rows(self, cursor: Cursor) -> list[tuple]:
return cursor.fetchall()
def _db_execute_single_statement(self,
cursor: Cursor,
statement: str,
values: list[DbValueType]) -> tuple[list[DatabaseWarning], DbAffectedRows, DbResultSet]:
cursor.execute(statement, values)
all_rows = cursor.fetchall()
if len(all_rows) > 0:
column_mapping = PythonDbConnection._create_column_mapping(cursor)
return ([], # Warnings seem to be thrown as exceptions
cursor.rowcount,
list(map(lambda row: DatabaseResultRow(column_mapping, tuple(row)), all_rows)))
def _db_execute_multiple_statements(self,
cursor: Cursor,
statements: str,
values: list[DbValueType]) -> list[tuple[list[DatabaseWarning], DbAffectedRows, DbResultSet]]:
raise RuntimeError("Multi query is not supported. This method should not have be called") # pragma: no cover
class SqLiteDbConnectionFactory(DbConnectionFactory):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment