From 103d37d1d670b83b571b5c1d6086a1b2de9de4ba Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Simon=20K=C3=BCnzel?= <simonk@fsmpi.rwth-aachen.de>
Date: Sat, 22 Jun 2024 18:58:10 +0200
Subject: [PATCH] Allow different transaction isolation levels in db and set
 default for write transactions to SERIALIZABLE

---
 src/api/database/__init__.py           |   1 +
 src/api/database/database.py           | 101 ++++++++++++++++++-------
 src/api/database/mysql_connector.py    |  24 ++++--
 src/api/database/postgres_connector.py |  11 +--
 src/api/database/sqlite_connector.py   |   9 ++-
 5 files changed, 104 insertions(+), 42 deletions(-)

diff --git a/src/api/database/__init__.py b/src/api/database/__init__.py
index 73ad0b0..815690c 100644
--- a/src/api/database/__init__.py
+++ b/src/api/database/__init__.py
@@ -5,6 +5,7 @@ from api.database.database import (PreparedStatement, FilledStatement,
                                    AbstractTransaction, ReadTransaction, WriteTransaction,
                                    DbConnection,
                                    DbConnectionPool,
+                                   TransactionIsolationLevel,
                                    DbResult, DbResultSet, DbResultRow, DbValueType, DatabaseWarning,
                                    DB_RESULT_EXCEPTION, DB_RESULT_WARNINGS, DB_RESULT_SET, DB_RESULT_AFFECTED_ROWS)
 
diff --git a/src/api/database/database.py b/src/api/database/database.py
index 18fbbec..89ae887 100644
--- a/src/api/database/database.py
+++ b/src/api/database/database.py
@@ -1,7 +1,7 @@
 import time
 from abc import ABC, ABCMeta, abstractmethod
-from contextlib import contextmanager
 from datetime import datetime
+from enum import StrEnum
 from threading import Condition
 from traceback import StackSummary, extract_stack
 from typing import Callable, Literal, TypeVar, TypeVarTuple
@@ -16,6 +16,13 @@ _T = TypeVar("_T")
 _A = TypeVarTuple("_A")
 
 
+class TransactionIsolationLevel(StrEnum):
+    READ_UNCOMMITTED = "READ UNCOMMITTED"
+    READ_COMMITTED = "READ COMMITTED"
+    REPEATABLE_READ = "REPEATABLE READ"
+    SERIALIZABLE = "SERIALIZABLE"
+
+
 def _create_diagnostics_counters(id: str) -> dict[_TransactionType, DiagnosticsCounter]:
     return {
         "read": DIAGNOSTICS_TRACKER.register_counter(f"database.transaction.read.{id}"),
@@ -148,12 +155,14 @@ class DbConnection(ABC):
         pass  # pragma: no cover
     
     @abstractmethod
-    def get_transaction_begin_statement(self, writable: bool) -> PreparedStatement or str:
+    def get_transaction_begin_statements(self, writable: bool,
+                                         isolation_level: TransactionIsolationLevel) -> list[PreparedStatement or str]:
         """
-        Returns the statement to begin a transaction
+        Returns the statements to begin a transaction
         :param writable: Specifies whether this transaction is writable.
                          If :func:`supports_per_transaction_writeable_flag` of the :class:`DbConnectionFactory` is False
                          this parameter is ignored.
+        :param isolation_level: Specifies the minimum isolation level
         """
         pass  # pragma: no cover
     
@@ -394,9 +403,11 @@ class AbstractTransaction(ABC):
 
 class ReadTransaction(AbstractTransaction):
     
-    def __init__(self, release_connection: Callable, connection: DbConnection):
+    def __init__(self, isolation_level: TransactionIsolationLevel, release_connection: Callable,
+                 connection: DbConnection):
         super().__init__("read", release_connection, connection)
-        self.queue_statement(connection.get_transaction_begin_statement(False))
+        for stat in connection.get_transaction_begin_statements(False, isolation_level):
+            self.queue_statement(stat)
     
     def execute_statement_and_close(self,
                                     statement: PreparedStatement or str,
@@ -460,10 +471,12 @@ class ReadTransaction(AbstractTransaction):
 
 class WriteTransaction(AbstractTransaction):
     
-    def __init__(self, release_connection: Callable, connection: DbConnection):
+    def __init__(self, isolation_level: TransactionIsolationLevel, release_connection: Callable,
+                 connection: DbConnection):
         super().__init__("write", release_connection, connection)
         self._committed = False
-        self.queue_statement(connection.get_transaction_begin_statement(True))
+        for stat in connection.get_transaction_begin_statements(True, isolation_level):
+            self.queue_statement(stat)
     
     def execute_statement_and_commit(self,
                                      statement: PreparedStatement or str,
@@ -513,9 +526,9 @@ class WriteTransaction(AbstractTransaction):
         
         After this method is called, no new statements may be queued.
         
-        :param ignore_statement_exceptions If False and a statement exception occurs/has occurred in any results which
-                                           have not been retrieved yet, the transaction will be rolled back and the
-                                           exception is raised by this method
+        :param ignore_statement_exceptions: If False and a statement exception occurs/has occurred in any results which
+                                            have not been retrieved yet, the transaction will be rolled back and the
+                                            exception is raised by this method
         
         Raises :class:`RuntimeError` if the transaction was already rolled back
         
@@ -681,8 +694,8 @@ class DbConnectionPool:
         
         By default, NO TRANSACTION MANAGEMENT IS DONE. The script must start and commit a transaction
         itself. After the script, any ongoing transaction is rolled back. If wrap_in_transaction is True, the
-        transaction start and end statements are put at the beginning and end of the script (with simple string
-        concatenation; the last statement must have a ';').
+        transaction start and end statements are put at the beginning and end of the script (serializable isolation;
+        with simple string concatenation; the last statement must have a ';').
         
         Unlike other functions this immediately raises an exception for any error (usually :class:`DatabaseError`)
         and then rolls back.
@@ -695,9 +708,14 @@ class DbConnectionPool:
             connection = self._write_cache.get_connection()
         
         if wrap_in_transaction:
-            script = (connection.get_transaction_begin_statement(writable=True) + ";\n"
-                      + script + "\n"
-                      + connection.get_transaction_end_statement(commit=True) + ";")
+            script = (
+                    ";".join(connection.get_transaction_begin_statements(
+                        writable=True,
+                        isolation_level=TransactionIsolationLevel.SERIALIZABLE)
+                    )
+                    + ";\n"
+                    + script + "\n"
+                    + connection.get_transaction_end_statement(commit=True) + ";")
         
         rollback = not wrap_in_transaction
         try:
@@ -717,33 +735,60 @@ class DbConnectionPool:
     
     def execute_read_statement_in_transaction(self,
                                               statement: PreparedStatement or str,
-                                              *values: DbValueType) -> DbResultSet:
-        return self.execute_read_transaction(lambda trans: trans.execute_statement_and_close(statement, *values))
+                                              *values: DbValueType,
+                                              isolation_level: TransactionIsolationLevel = TransactionIsolationLevel.REPEATABLE_READ
+                                              ) -> DbResultSet:
+        return self.execute_read_transaction(
+            lambda trans: trans.execute_statement_and_close(statement, *values),
+            isolation_level=isolation_level
+        )
     
     def execute_read_statements_in_transaction(self,
-                                               *statements: tuple[PreparedStatement | str, list[DbValueType]]) -> tuple[DbResultSet, ...]:
-        return self.execute_read_transaction(lambda trans: trans.execute_statements_and_close(*statements))
-    
-    def execute_read_transaction(self, function: Callable[[ReadTransaction, *_A], _T], *args: *_A) -> _T:
+                                               *statements: tuple[PreparedStatement | str, list[DbValueType]],
+                                               isolation_level: TransactionIsolationLevel = TransactionIsolationLevel.REPEATABLE_READ) -> \
+            tuple[DbResultSet, ...]:
+        return self.execute_read_transaction(
+            lambda trans: trans.execute_statements_and_close(*statements),
+            isolation_level=isolation_level
+        )
+    
+    def execute_read_transaction(self,
+                                 function: Callable[[ReadTransaction, *_A], _T],
+                                 *args: *_A,
+                                 isolation_level: TransactionIsolationLevel = TransactionIsolationLevel.REPEATABLE_READ
+                                 ) -> _T:
         """
         Executes a read transaction with the given function. The function may be called multiple times if the read
         transaction fails due to conflicts with other transaction.
+        
+        Note that the specified isolation level is only a minimum. A higher isolation level may be used for the transaction.
 
         May raise :class:`NoAvailableConnectionError`
         """
-        return self._execute_transaction(False, function, *args)
+        return self._execute_transaction(False, function, *args, isolation_level=isolation_level)
     
-    def execute_write_transaction(self, function: Callable[[WriteTransaction, *_A], _T], *args: *_A) -> _T:
+    def execute_write_transaction(self,
+                                  function: Callable[[WriteTransaction, *_A], _T],
+                                  *args: *_A,
+                                  isolation_level: TransactionIsolationLevel = TransactionIsolationLevel.SERIALIZABLE
+                                  ) -> _T:
         """
         Executes a read-write transaction with the given function. The function may be called multiple times if the write
         transaction fails due to conflicts with other transaction. The function is never called again if the transaction
         committed successfully
         
+        Note that the specified isolation level is only a minimum. A higher isolation level may be used for the transaction.
+        
         May raise :class:`NoAvailableConnectionError`
         """
-        return self._execute_transaction(True, function, *args)
+        return self._execute_transaction(True, function, *args, isolation_level=isolation_level)
     
-    def _execute_transaction(self, writeable: bool, function: Callable[[_Trans, *_A], _T], *args: *_A) -> _T:
+    def _execute_transaction(self,
+                             writeable: bool,
+                             function: Callable[[_Trans, *_A], _T],
+                             *args: *_A,
+                             isolation_level: TransactionIsolationLevel = TransactionIsolationLevel.SERIALIZABLE
+                             ) -> _T:
         attempts = 0
         while True:
             attempts += 1
@@ -754,9 +799,11 @@ class DbConnectionPool:
                 else:
                     connection = (self._write_cache if writeable else self._read_cache).get_connection()
                 if writeable:
-                    transaction = WriteTransaction(lambda: self._on_connection_released(True, connection), connection)
+                    transaction = WriteTransaction(isolation_level,
+                                                   lambda: self._on_connection_released(True, connection), connection)
                 else:
-                    transaction = ReadTransaction(lambda: self._on_connection_released(False, connection), connection)
+                    transaction = ReadTransaction(isolation_level,
+                                                  lambda: self._on_connection_released(False, connection), connection)
                 result = function(transaction, *args)
                 if not transaction.is_closed():
                     # noinspection PyBroadException
diff --git a/src/api/database/mysql_connector.py b/src/api/database/mysql_connector.py
index c3a3be0..9a5ed04 100644
--- a/src/api/database/mysql_connector.py
+++ b/src/api/database/mysql_connector.py
@@ -6,8 +6,9 @@ from time import time_ns
 from datetime import datetime, date
 
 from api.database import DbValueType, DbResultSet, FilledStatement
-from api.database.database import PreparedStatement, DbConnectionFactory, DatabaseWarning, DatabaseError, DbAffectedRows, \
-    DatabaseResultRow
+from api.database.database import PreparedStatement, DbConnectionFactory, DatabaseWarning, DatabaseError, \
+    DbAffectedRows, \
+    DatabaseResultRow, TransactionIsolationLevel
 from api.database.abstract_py_connector import PythonDbConnection
 
 print(f"Using MySQL Connector Version: {VERSION_TEXT}")
@@ -58,11 +59,17 @@ class _MySqlDbConnection(PythonDbConnection[MySQLConnection, MySQLCursor]):
         self._closed = True
         self._py_connection.close()
     
-    def get_transaction_begin_statement(self, writable: bool) -> PreparedStatement or str:
+    def get_transaction_begin_statements(self, writable: bool, isolation_level: TransactionIsolationLevel) -> PreparedStatement or str:
         if writable:
-            return "START TRANSACTION READ WRITE"
+            return [
+                f"SET SESSION TRANSACTION ISOLATION LEVEL {isolation_level.value}",
+                "START TRANSACTION READ WRITE",
+            ]
         else:
-            return "START TRANSACTION READ ONLY"
+            return [
+                f"SET SESSION TRANSACTION ISOLATION LEVEL {isolation_level.value}",
+                "START TRANSACTION READ ONLY",
+            ]
     
     def get_transaction_end_statement(self, commit: bool) -> PreparedStatement or str:
         if commit:
@@ -80,6 +87,8 @@ class _MySqlDbConnection(PythonDbConnection[MySQLConnection, MySQLCursor]):
                                      cursor: MySQLCursor,
                                      statement: str,
                                      values: list[DbValueType]) -> tuple[list[DatabaseWarning], DbAffectedRows, DbResultSet]:
+        # Note: Warnings seem to be broken. When a warning is generated, an error is thrown in the library:
+        # TypeError: expected string or bytes-like object, got 'Warning'
         cursor.execute(statement, params=values)
         return self._get_result(cursor)
     
@@ -87,6 +96,9 @@ class _MySqlDbConnection(PythonDbConnection[MySQLConnection, MySQLCursor]):
                                         cursor: MySQLCursor,
                                         statements: str,
                                         values: list[DbValueType]) -> list[tuple[list[DatabaseWarning], DbAffectedRows, DbResultSet]]:
+        # With multi=True the warnings seem to break. When a warning is generated we can't seem to be able to retrieve
+        # it, but in later queries (without multi) the following error is thrown:
+        # InterfaceError: Use cmd_query_iter for statements with multiple queries.
         results = []
         for res_cursor in cursor.execute(statements, params=values, multi=True):
             results.append(self._get_result(res_cursor))
@@ -182,6 +194,7 @@ class _MySqlDbConnection(PythonDbConnection[MySQLConnection, MySQLCursor]):
         return result_set
     
     def _db_execute_script(self, cursor: MySQLCursor, script: str):
+        # See comment on multi=True in _db_execute_multiple_statements
         for res_cursor in cursor.execute(script, multi=True):
             warnings = self._fetch_warnings(res_cursor)
             if len(warnings) > 0:
@@ -225,7 +238,6 @@ class MySqlDbConnectionFactory(DbConnectionFactory):
                 password=self._password,
                 database=self._database
             )
-            mysql_connection.cmd_query("SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ")
             mysql_connection.cmd_query("SET SESSION sql_mode = 'ANSI_QUOTES'")
             mysql_connection.get_warnings = self._fetch_warnings
         except Exception as e:
diff --git a/src/api/database/postgres_connector.py b/src/api/database/postgres_connector.py
index ae7ea82..7975e82 100644
--- a/src/api/database/postgres_connector.py
+++ b/src/api/database/postgres_connector.py
@@ -6,8 +6,9 @@ from psycopg.cursor import TUPLES_OK
 import psycopg
 
 from api.database import DbValueType, DbResultSet, FilledStatement
-from api.database.database import PreparedStatement, DbConnectionFactory, DatabaseWarning, DatabaseError, DbAffectedRows, \
-    DatabaseResultRow
+from api.database.database import PreparedStatement, DbConnectionFactory, DatabaseWarning, DatabaseError, \
+    DbAffectedRows, \
+    DatabaseResultRow, TransactionIsolationLevel
 from api.database.abstract_py_connector import PythonDbConnection
 
 print(f"Using psycopg (Postgres) Connector Version: {psycopg.version.__version__}")
@@ -89,11 +90,11 @@ source_function: {diagnostic.source_function}
         self._closed = True
         self._py_connection.close()
     
-    def get_transaction_begin_statement(self, writable: bool) -> PreparedStatement or str:
+    def get_transaction_begin_statements(self, writable: bool, isolation_level: TransactionIsolationLevel) -> PreparedStatement or str:
         if writable:
-            return "START TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ WRITE"
+            return [f"START TRANSACTION ISOLATION LEVEL {isolation_level.value}, READ WRITE"]
         else:
-            return "START TRANSACTION ISOLATION LEVEL REPEATABLE READ, READ ONLY"
+            return [f"START TRANSACTION ISOLATION LEVEL {isolation_level.value}, READ ONLY"]
     
     def get_transaction_end_statement(self, commit: bool) -> PreparedStatement or str:
         if commit:
diff --git a/src/api/database/sqlite_connector.py b/src/api/database/sqlite_connector.py
index d22945d..3f5282e 100644
--- a/src/api/database/sqlite_connector.py
+++ b/src/api/database/sqlite_connector.py
@@ -4,8 +4,9 @@ from datetime import datetime
 from os import PathLike
 
 from api.database import DbValueType, DbResultSet, FilledStatement
-from api.database.database import PreparedStatement, DbConnectionFactory, DatabaseWarning, DatabaseError, DbAffectedRows, \
-    DatabaseResultRow
+from api.database.database import PreparedStatement, DbConnectionFactory, DatabaseWarning, DatabaseError, \
+    DbAffectedRows, \
+    DatabaseResultRow, TransactionIsolationLevel
 from api.database.abstract_py_connector import PythonDbConnection
 
 print(f"Using SQLite Connector Version: {sqlite3.sqlite_version}")
@@ -61,8 +62,8 @@ class SqLiteDbConnection(PythonDbConnection[Connection, Cursor]):
         self._closed = True
         self._py_connection.close()
     
-    def get_transaction_begin_statement(self, writable: bool) -> PreparedStatement or str:
-        return "BEGIN DEFERRED TRANSACTION"
+    def get_transaction_begin_statements(self, writable: bool, isolation_level: TransactionIsolationLevel) -> PreparedStatement or str:
+        return ["BEGIN DEFERRED TRANSACTION"]  # Transactions in SQLite are serializable by default
     
     def get_transaction_end_statement(self, commit: bool) -> PreparedStatement or str:
         if commit:
-- 
GitLab