From 99b75708aeaa1c9f6bc988d6fa4aa957d3187062 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20K=C3=BCnzel?= <simonk@fsmpi.rwth-aachen.de> Date: Sun, 2 Jun 2024 20:18:40 +0200 Subject: [PATCH] Retry transaction if there are conflicts, Closes #28 --- config/api_example_config.py | 6 +- src/api/announcement.py | 5 +- src/api/authentication.py | 40 +-- src/api/course.py | 13 +- src/api/database/__init__.py | 4 +- src/api/database/abstract_py_connector.py | 8 +- src/api/database/database.py | 158 ++++++++---- src/api/database/mysql_connector.py | 7 +- src/api/database/postgres_connector.py | 5 +- src/api/database/sqlite_connector.py | 5 +- src/api/job.py | 72 +++--- src/api/objects/current_changelog.py | 27 +- src/api/objects/object.py | 290 ++++++++++++---------- src/api/routes/courses.py | 257 +++++++++++-------- src/api/routes/object_modifications.py | 73 ++++-- src/api/routes/site.py | 60 +++-- src/api/routes/sorter.py | 18 +- src/api/routes/user.py | 3 +- src/api/user.py | 18 +- 19 files changed, 630 insertions(+), 439 deletions(-) diff --git a/config/api_example_config.py b/config/api_example_config.py index 9a21cd9..78446d8 100644 --- a/config/api_example_config.py +++ b/config/api_example_config.py @@ -24,7 +24,11 @@ DB_CONNECTIONS = { # Maximum time to wait for a free connection (An API request will probably fail if this times out) "max_wait_time_sec": 10, # Maximum amount of transaction requests which may wait concurrently. More incoming requests will fail immediately - "max_waiting_count": 25 + "max_waiting_count": 25, + # Maximum amount of attempts for a read transaction if there are conflicts between the transactions + "max_read_attempts": 2, + # Maximum amount of attempts for a write transaction if there are conflicts between the transactions + "max_write_attempts": 2 } # DB_ENGINE = "mysql" diff --git a/src/api/announcement.py b/src/api/announcement.py index 3964771..0b0ed8d 100644 --- a/src/api/announcement.py +++ b/src/api/announcement.py @@ -18,9 +18,8 @@ WHERE NOT "deleted" def query_announcements(): is_mod = is_moderator() current_time: datetime = datetime.now() - with db_pool.start_read_transaction() as trans: - announcements_db = trans.execute_statement_and_close( - _SQL_GET_ANNOUNCEMENTS, is_mod, current_time, current_time) + announcements_db = db_pool.execute_read_statement_in_transaction( + _SQL_GET_ANNOUNCEMENTS, is_mod, current_time, current_time) announcements_json = [] for announcement_db in announcements_db: diff --git a/src/api/authentication.py b/src/api/authentication.py index f4eb559..e166511 100644 --- a/src/api/authentication.py +++ b/src/api/authentication.py @@ -411,6 +411,27 @@ RETURNING "id" """) +def _db_get_user_id(trans: ReadTransaction, user_id: str) -> int: + user_list_db = trans.execute_statement_and_close(_SQL_GET_USER_BY_NAME, user_id) + if len(user_list_db) < 1: + raise ApiClientException(ERROR_AUTHENTICATION_NOT_AVAILABLE( + "Site is read-only and we can not create a new account for you in the database")) + user_db = user_list_db[0] + return user_db["id"] + + +def _db_get_or_create_user_id(trans: WriteTransaction, user_id: str, given_name: str) -> int: + user_list_db = trans.execute_statement(_SQL_GET_USER_BY_NAME, user_id) + if len(user_list_db) < 1: + result = trans.execute_statement_and_commit( + _SQL_INSERT_USER, user_id, given_name, user_id) + return result[0]["id"] + else: + trans.commit() + user_db = user_list_db[0] + return user_db["id"] + + def authenticate_fsmpi(username: str, password: str) -> {}: """ May throw APIClientException. @@ -425,24 +446,9 @@ def authenticate_fsmpi(username: str, password: str) -> {}: raise ApiClientException(ERROR_AUTHENTICATION_FAILED) if api.live_config.is_readonly(): - with db_pool.start_read_transaction() as trans: - user_list_db = trans.execute_statement_and_close(_SQL_GET_USER_BY_NAME, user_id) - if len(user_list_db) < 1: - raise ApiClientException(ERROR_AUTHENTICATION_NOT_AVAILABLE( - "Site is read-only and we can not create a new account for you in the database")) - user_db = user_list_db[0] - user_db_id = user_db["id"] + user_db_id = db_pool.execute_read_transaction(_db_get_user_id, user_id) else: - with db_pool.start_write_transaction() as trans: - user_list_db = trans.execute_statement(_SQL_GET_USER_BY_NAME, user_id) - if len(user_list_db) < 1: - result = trans.execute_statement_and_commit( - _SQL_INSERT_USER, user_id, given_name, user_id) - user_db_id = result[0]["id"] - else: - trans.commit() - user_db = user_list_db[0] - user_db_id = user_db["id"] + user_db_id = db_pool.execute_write_transaction(_db_get_or_create_user_id, user_id, given_name) session["user"] = { "uid": user_id, diff --git a/src/api/course.py b/src/api/course.py index 56b0a22..e1ebd2b 100644 --- a/src/api/course.py +++ b/src/api/course.py @@ -112,8 +112,7 @@ def course_query_auth(course_id: int, transaction: AbstractTransaction or None = Returns a result even if course is not visible """ if transaction is None: - with db_pool.start_read_transaction() as trans: - return trans.execute_statement(_SQL_GET_COURSE_AUTH, course_id) + return db_pool.execute_read_statement_in_transaction(_SQL_GET_COURSE_AUTH, course_id) else: return transaction.execute_statement(_SQL_GET_COURSE_AUTH, course_id) @@ -172,16 +171,14 @@ def lecture_query_auth(lecture_id: int, transaction: AbstractTransaction = None) -> DbResultSet: if course_id: if transaction is None: - with db_pool.start_read_transaction() as trans: - return trans.execute_statement_and_close( - _SQL_GET_LECTURE_AUTH_WITH_COURSE_ID, lecture_id, course_id) + return db_pool.execute_read_statement_in_transaction( + _SQL_GET_LECTURE_AUTH_WITH_COURSE_ID, lecture_id, course_id) else: return transaction.execute_statement(_SQL_GET_LECTURE_AUTH_WITH_COURSE_ID, lecture_id, course_id) else: if transaction is None: - with db_pool.start_read_transaction() as trans: - return trans.execute_statement_and_close( - _SQL_GET_LECTURE_AUTH_NO_COURSE_ID, lecture_id) + return db_pool.execute_read_statement_in_transaction( + _SQL_GET_LECTURE_AUTH_NO_COURSE_ID, lecture_id) else: return transaction.execute_statement(_SQL_GET_LECTURE_AUTH_NO_COURSE_ID, lecture_id) diff --git a/src/api/database/__init__.py b/src/api/database/__init__.py index 79e4c07..73ad0b0 100644 --- a/src/api/database/__init__.py +++ b/src/api/database/__init__.py @@ -74,7 +74,9 @@ def __create_pool() -> DbConnectionPool: connection_config["max_count"], connection_config["readonly_percent"], connection_config["max_wait_time_sec"], - connection_config["max_waiting_count"] + connection_config["max_waiting_count"], + connection_config["max_read_attempts"], + connection_config["max_write_attempts"] ) diff --git a/src/api/database/abstract_py_connector.py b/src/api/database/abstract_py_connector.py index 1ed5d4f..c8cbdcc 100644 --- a/src/api/database/abstract_py_connector.py +++ b/src/api/database/abstract_py_connector.py @@ -7,7 +7,7 @@ from functools import reduce from api.database.database import (PreparedStatement, FilledStatement, DbConnection, DbResult, DbResultSet, DbValueType, DbAffectedRows, DatabaseWarning, DatabaseError, - SQL_PARAMETER_INDICATOR) + SQL_PARAMETER_INDICATOR, TransactionConflictError) _DEBUG_PRINT_STATEMENT_EXECUTION: bool = False @@ -65,6 +65,8 @@ class PythonDbConnection(DbConnection, Generic[Connection, Cursor], ABC): affected_rows = None result_set = [] except Exception as e: + if self._is_transaction_conflict_exception(e): + raise TransactionConflictError("Conflict in transactions") from e if filled_stat.queue_traceback is not None: raise DatabaseError(f"An exception occurred while executing statement, queued at:\n\n" f"{reduce(lambda res, s: res + s, traceback.format_list(filled_stat.queue_traceback), '')}") from e @@ -164,6 +166,10 @@ class PythonDbConnection(DbConnection, Generic[Connection, Cursor], ABC): def _create_cursor(self, prepared: bool): pass # pragma: no cover + @abstractmethod + def _is_transaction_conflict_exception(self, exception: Exception) -> bool: + pass # pragma: no cover + @abstractmethod def _db_execute_single_statement(self, cursor: Cursor, diff --git a/src/api/database/database.py b/src/api/database/database.py index 9139102..db4e8f4 100644 --- a/src/api/database/database.py +++ b/src/api/database/database.py @@ -1,9 +1,10 @@ -from threading import Condition -from datetime import datetime -from typing import Callable, Literal +import time from abc import ABC, ABCMeta, abstractmethod from contextlib import contextmanager +from datetime import datetime +from threading import Condition from traceback import StackSummary, extract_stack +from typing import Callable, Literal, TypeVar, TypeVarTuple from api.miscellaneous import DEBUG_ENABLED, DIAGNOSTICS_TRACKER, DiagnosticsCounter @@ -11,6 +12,9 @@ DbValueType = int | str | bytes | datetime _TransactionType = Literal["read", "write"] +_T = TypeVar("_T") +_A = TypeVarTuple("_A") + def _create_diagnostics_counters(id: str) -> dict[_TransactionType, DiagnosticsCounter]: return { @@ -25,6 +29,8 @@ _COUNTERS_TRANSACTION_ATTEMPTED_RECONNECTS = _create_diagnostics_counters("attem _COUNTERS_TRANSACTION_SUCCESSFUL_RECONNECTS = _create_diagnostics_counters("successful_reconnects") _COUNTERS_TRANSACTION_ERRORS_AFTER_RECONNECT = _create_diagnostics_counters("errors_after_reconnect") _COUNTERS_TRANSACTION_ABORTED_BY_USER = _create_diagnostics_counters("aborted_by_user") +_COUNTERS_TRANSACTION_CONFLICTS = _create_diagnostics_counters("conflicts") +_COUNTERS_TRANSACTION_ABORTS_AFTER_REPEATED_CONFLICTS = _create_diagnostics_counters("aborts_after_repeated_conflicts") class DatabaseResultRow: @@ -78,6 +84,10 @@ class DatabaseError(Exception): """Error for any problems of the database not directly caused by the caller, e.g. connection problems, etc.""" +class TransactionConflictError(DatabaseError): + """Error if a transaction cannot commit/close due to conflicts with other transactions.""" + + class WarningError(Exception): def __init__(self, warnings: list[DatabaseWarning]): @@ -160,7 +170,7 @@ class DbConnection(ABC): May raise :class:`DatabaseError` """ pass # pragma: no cover - + @abstractmethod def execute_script(self, script: str): """ @@ -230,11 +240,12 @@ class AbstractTransaction(ABC): if self._closed: raise RuntimeError("Transaction already closed") # pragma: no cover - parameter_count = statement.parameter_count\ - if isinstance(statement, PreparedStatement)\ + parameter_count = statement.parameter_count \ + if isinstance(statement, PreparedStatement) \ else statement.count(SQL_PARAMETER_INDICATOR) if parameter_count != len(values): - raise ValueError(f"Parameter ({parameter_count}) and argument count ({len(values)}) don't match") # pragma: no cover + raise ValueError( + f"Parameter ({parameter_count}) and argument count ({len(values)}) don't match") # pragma: no cover queued = FilledStatement(statement, list(values)) if DEBUG_ENABLED: @@ -309,7 +320,7 @@ class AbstractTransaction(ABC): filled_list = [] for stat in statements: filled_list.append(self.queue_statement(stat[0], *stat[1])) - + if after_queuing is not None: after_queuing() @@ -368,7 +379,7 @@ class AbstractTransaction(ABC): def _remove_all_statements(self): self._queued_statements = [] self._statement_results = {} - + def _has_queued_statements(self): return len(self._queued_statements) > 0 @@ -442,7 +453,7 @@ class ReadTransaction(AbstractTransaction): return if ignore_unused_statements: self._remove_all_statements() - self.queue_statement(self._connection.get_transaction_end_statement(False)) + self.queue_statement(self._connection.get_transaction_end_statement(True)) self._execute_queued() self._close() @@ -581,18 +592,22 @@ class _ConnectionCache: def get_connection(self) -> DbConnection: with self._available_condition: connection = self._try_get_connection() - if connection is not None: - return connection - if self._currently_waiting_count >= self._max_waiting_count: - raise NoAvailableConnectionError("Too many are already waiting") - self._currently_waiting_count += 1 - if not self._available_condition.wait(timeout=self._max_wait_time_sec): + wait_time_sec = self._max_wait_time_sec + while connection is None: + if self._currently_waiting_count >= self._max_waiting_count: + raise NoAvailableConnectionError("Too many are already waiting") + if wait_time_sec <= 0: + raise NoAvailableConnectionError("Timed out") + self._currently_waiting_count += 1 + start_time = time.time() + if not self._available_condition.wait(timeout=wait_time_sec): + self._currently_waiting_count -= 1 + raise NoAvailableConnectionError("Timed out") self._currently_waiting_count -= 1 - raise NoAvailableConnectionError("Timed out") - self._currently_waiting_count -= 1 - connection = self._try_get_connection() - # We were notified, so there is at least one slot to be reused or filled with a new connection - assert connection is not None + wait_time_sec -= time.time() - start_time + connection = self._try_get_connection() + # We were notified, but the connection may still be None since notify() sometimes wakes up more than one + # thread, or allows a different thread to acquire the lock before the notified thread wakes... return connection def return_connection(self, connection: DbConnection): @@ -605,6 +620,9 @@ class _ConnectionCache: self._available_condition.notify() +_Trans = TypeVar("_Trans", ReadTransaction, WriteTransaction) + + class DbConnectionPool: def __init__(self, @@ -612,9 +630,17 @@ class DbConnectionPool: max_connections: int, readonly_percent: float, max_wait_time_sec: float, - max_waiting_count: int): + max_waiting_count: int, + max_read_attempts: int, + max_write_attempts: int): super().__init__() self._factory = factory + if max_read_attempts < 1: + raise ValueError("max_read_attempts must be >= 1") + if max_write_attempts < 1: + raise ValueError("max_write_attempts must be >= 1") + self._max_read_attempts = max_read_attempts + self._max_write_attempts = max_write_attempts if self._factory.supports_per_transaction_writeable_flag(): if max_connections < 0: @@ -714,27 +740,75 @@ class DbConnectionPool: transaction.on_error_close_transaction() raise - @contextmanager - def start_write_transaction(self) -> WriteTransaction: + def execute_read_statement_in_transaction(self, + statement: PreparedStatement or str, + *values: DbValueType) -> DbResultSet: + return self.execute_read_transaction(lambda trans: trans.execute_statement_and_close(statement, *values)) + + def execute_read_statements_in_transaction(self, + *statements: tuple[PreparedStatement | str, list[DbValueType]]) -> tuple[DbResultSet, ...]: + return self.execute_read_transaction(lambda trans: trans.execute_statements_and_close(*statements)) + + def execute_read_transaction(self, function: Callable[[ReadTransaction, *_A], _T], *args: *_A) -> _T: + """ + Executes a read transaction with the given function. The function may be called multiple times if the read + transaction fails due to conflicts with other transaction. + + May raise :class:`NoAvailableConnectionError` + """ + return self._execute_transaction(False, function, *args) + + def execute_write_transaction(self, function: Callable[[WriteTransaction, *_A], _T], *args: *_A) -> _T: """ - Starts a read-write transaction. + Executes a read-write transaction with the given function. The function may be called multiple times if the write + transaction fails due to conflicts with other transaction. The function is never called again if the transaction + committed successfully May raise :class:`NoAvailableConnectionError` """ - transaction = None - try: - if self._factory.supports_per_transaction_writeable_flag(): - connection = self._cache.get_connection() - else: - connection = self._write_cache.get_connection() - transaction = WriteTransaction(lambda: self._on_connection_released(True, connection), connection) - yield transaction - if not transaction.is_closed(): - print("Warning: Rolling open transaction back after execution but there was no exception (Always " - "commit/rollback the transaction explicitly)") - transaction.rollback() - except Exception: - if transaction is not None and not transaction.is_closed(): - _COUNTERS_TRANSACTION_ABORTED_BY_USER["write"].trigger() - transaction.on_error_close_transaction() - raise + return self._execute_transaction(True, function, *args) + + def _execute_transaction(self, writeable: bool, function: Callable[[_Trans, *_A], _T], *args: *_A) -> _T: + attempts = 0 + while True: + attempts += 1 + transaction = None + try: + if self._factory.supports_per_transaction_writeable_flag(): + connection = self._cache.get_connection() + else: + connection = (self._write_cache if writeable else self._read_cache).get_connection() + if writeable: + transaction = WriteTransaction(lambda: self._on_connection_released(True, connection), connection) + else: + transaction = ReadTransaction(lambda: self._on_connection_released(False, connection), connection) + result = function(transaction, *args) + if not transaction.is_closed(): + # noinspection PyBroadException + try: + if writeable: + print("Warning: Rolling open transaction back after execution but there was no exception (Always " + "commit/rollback the transaction explicitly)") + transaction.rollback() + else: + print("Warning: Rolling open transaction back after execution but there was no exception (Always " + "close the transaction explicitly)") + transaction.close() + except Exception: + pass + return result + except TransactionConflictError: + if transaction is not None and not transaction.is_closed(): + # Should never happen, but just in case + transaction.on_error_close_transaction() + _COUNTERS_TRANSACTION_CONFLICTS["write" if writeable else "read"].trigger() + if attempts >= (self._max_write_attempts if writeable else self._max_read_attempts): + _COUNTERS_TRANSACTION_ABORTS_AFTER_REPEATED_CONFLICTS["write" if writeable else "read"].trigger() + raise + continue + except Exception: + if transaction is not None and not transaction.is_closed(): + _COUNTERS_TRANSACTION_ABORTED_BY_USER["write" if writeable else "read"].trigger() + transaction.on_error_close_transaction() + raise + raise AssertionError("This should be unreachable code") diff --git a/src/api/database/mysql_connector.py b/src/api/database/mysql_connector.py index 31f5fe7..c3a3be0 100644 --- a/src/api/database/mysql_connector.py +++ b/src/api/database/mysql_connector.py @@ -1,14 +1,14 @@ from mysql.connector.cursor import MySQLCursor from mysql.connector.errors import ProgrammingError, DataError, NotSupportedError, InterfaceError from mysql.connector.version import VERSION_TEXT -from mysql.connector import MySQLConnection, NUMBER, STRING, BINARY, DATETIME +from mysql.connector import MySQLConnection, NUMBER, STRING, BINARY, DATETIME, InternalError from time import time_ns from datetime import datetime, date from api.database import DbValueType, DbResultSet, FilledStatement from api.database.database import PreparedStatement, DbConnectionFactory, DatabaseWarning, DatabaseError, DbAffectedRows, \ DatabaseResultRow -from api.database.abstract_py_connector import PythonDbConnection, Cursor +from api.database.abstract_py_connector import PythonDbConnection print(f"Using MySQL Connector Version: {VERSION_TEXT}") @@ -73,6 +73,9 @@ class _MySqlDbConnection(PythonDbConnection[MySQLConnection, MySQLCursor]): def _create_cursor(self, prepared: bool): return self._py_connection.cursor(prepared=prepared) + def _is_transaction_conflict_exception(self, exception: Exception) -> bool: + return isinstance(exception, InternalError) and "try restarting transaction" in exception.msg + def _db_execute_single_statement(self, cursor: MySQLCursor, statement: str, diff --git a/src/api/database/postgres_connector.py b/src/api/database/postgres_connector.py index 25c2fde..ae7ea82 100644 --- a/src/api/database/postgres_connector.py +++ b/src/api/database/postgres_connector.py @@ -1,6 +1,6 @@ from typing import Callable -from psycopg.errors import Diagnostic +from psycopg.errors import Diagnostic, SerializationFailure from psycopg import Connection, Cursor from psycopg.cursor import TUPLES_OK import psycopg @@ -104,6 +104,9 @@ source_function: {diagnostic.source_function} def _create_cursor(self, prepared: bool): return self._py_connection.cursor() + def _is_transaction_conflict_exception(self, exception: Exception) -> bool: + return isinstance(exception, SerializationFailure) + def _db_execute_single_statement(self, cursor: Cursor, statement: str, diff --git a/src/api/database/sqlite_connector.py b/src/api/database/sqlite_connector.py index c24bc06..d22945d 100644 --- a/src/api/database/sqlite_connector.py +++ b/src/api/database/sqlite_connector.py @@ -1,5 +1,5 @@ import sqlite3 -from sqlite3 import Connection, Cursor, DatabaseError, DataError, ProgrammingError, NotSupportedError +from sqlite3 import Connection, Cursor, DatabaseError, DataError, ProgrammingError, NotSupportedError, OperationalError from datetime import datetime from os import PathLike @@ -73,6 +73,9 @@ class SqLiteDbConnection(PythonDbConnection[Connection, Cursor]): def _create_cursor(self, prepared: bool): return self._py_connection.cursor() # SQLite does not have prepared cursors + def _is_transaction_conflict_exception(self, exception: Exception) -> bool: + return isinstance(exception, OperationalError) and "database is locked" in str(exception) + def _db_execute_single_statement(self, cursor: Cursor, statement: str, diff --git a/src/api/job.py b/src/api/job.py index b9975e1..7109f48 100644 --- a/src/api/job.py +++ b/src/api/job.py @@ -49,28 +49,29 @@ def get_jobs( if state is None: state = "%" - with db_pool.start_read_transaction() as trans: - if worker_id == "%": - worker_condition = "" - worker_condition_values = [] - else: - worker_condition = "AND \"worker\" LIKE ? OR (\"worker\" IS NULL AND ? = 'none')" - worker_condition_values = [worker_id, worker_id] - row_count_set, entries_db = trans.execute_statements_and_close( - (f""" - SELECT COUNT(*) AS "count" FROM "jobs" - WHERE "type" LIKE ? - AND "state" LIKE ? - {worker_condition} - """, [type, state, *worker_condition_values]), - (f""" - SELECT * FROM "jobs" - WHERE "type" LIKE ? - AND "state" LIKE ? - {worker_condition} - ORDER BY "time_created" DESC, "id" DESC - LIMIT {entries_per_page} OFFSET {entries_per_page * page} - """, [type, state, *worker_condition_values])) + if worker_id == "%": + worker_condition = "" + worker_condition_values = [] + else: + worker_condition = "AND \"worker\" LIKE ? OR (\"worker\" IS NULL AND ? = 'none')" + worker_condition_values = [worker_id, worker_id] + + row_count_set, entries_db = db_pool.execute_read_statements_in_transaction( + (f""" + SELECT COUNT(*) AS "count" FROM "jobs" + WHERE "type" LIKE ? + AND "state" LIKE ? + {worker_condition} + """, [type, state, *worker_condition_values]), + (f""" + SELECT * FROM "jobs" + WHERE "type" LIKE ? + AND "state" LIKE ? + {worker_condition} + ORDER BY "time_created" DESC, "id" DESC + LIMIT {entries_per_page} OFFSET {entries_per_page * page} + """, [type, state, *worker_condition_values]) + ) return (math.ceil(row_count_set[0]["count"] / entries_per_page), list(map(_job_db_to_json, entries_db))) @@ -84,8 +85,7 @@ ORDER BY "last_ping" DESC def get_workers() -> list[dict]: - with db_pool.start_read_transaction() as trans: - workers_db = trans.execute_statement_and_close(_SQL_GET_WORKERS) + workers_db = db_pool.execute_read_statement_in_transaction(_SQL_GET_WORKERS) workers_json = [] for worker_db in workers_db: workers_json.append({ @@ -125,28 +125,28 @@ INSERT INTO "jobs" ("type", "priority", "queue", "state", "time_created", "data" def delete_ready_failed_job(id: int) -> bool: - with db_pool.start_write_transaction() as trans: - affected_count = trans.execute_statement_full_result_and_commit( - _SQL_DELETE_READY_FAILED_JOB, id)[DB_RESULT_AFFECTED_ROWS] + affected_count = db_pool.execute_write_transaction( + lambda trans: trans.execute_statement_full_result_and_commit( + _SQL_DELETE_READY_FAILED_JOB, id)[DB_RESULT_AFFECTED_ROWS]) return affected_count > 0 def restart_failed_job(id: int) -> bool: - with db_pool.start_write_transaction() as trans: - affected_count = trans.execute_statement_full_result_and_commit( - _SQL_RESTART_FAILED_JOB, id)[DB_RESULT_AFFECTED_ROWS] + affected_count = db_pool.execute_write_transaction( + lambda trans: trans.execute_statement_full_result_and_commit( + _SQL_RESTART_FAILED_JOB, id)[DB_RESULT_AFFECTED_ROWS]) return affected_count > 0 def cancel_running_job(id: int) -> bool: - with db_pool.start_write_transaction() as trans: - affected_count = trans.execute_statement_full_result_and_commit( - _SQL_CANCEL_RUNNING_JOB, id)[DB_RESULT_AFFECTED_ROWS] + affected_count = db_pool.execute_write_transaction( + lambda trans: trans.execute_statement_full_result_and_commit( + _SQL_CANCEL_RUNNING_JOB, id)[DB_RESULT_AFFECTED_ROWS]) return affected_count > 0 def copy_finished_deleted_job(id: int) -> bool: - with db_pool.start_write_transaction() as trans: - affected_count = trans.execute_statement_full_result_and_commit( - _SQL_COPY_FINISHED_DELETED_JOB, datetime.now(), id)[DB_RESULT_AFFECTED_ROWS] + affected_count = db_pool.execute_write_transaction( + lambda trans: trans.execute_statement_full_result_and_commit( + _SQL_COPY_FINISHED_DELETED_JOB, datetime.now(), id)[DB_RESULT_AFFECTED_ROWS]) return affected_count > 0 diff --git a/src/api/objects/current_changelog.py b/src/api/objects/current_changelog.py index efb094d..3636ba7 100644 --- a/src/api/objects/current_changelog.py +++ b/src/api/objects/current_changelog.py @@ -93,20 +93,19 @@ def get_changelog(entries_per_page: int, if len(clauses) > 0: where_clause = f"WHERE ({') AND ('.join(clauses)})" - with db_pool.start_read_transaction() as trans: - set_row_count, rows = trans.execute_statements_and_close( - (f""" - SELECT COUNT(*) AS "count" FROM "changelog" - {where_clause} - """, all_values), - (f""" - SELECT * - FROM "changelog" - {where_clause} - ORDER BY "when" DESC, "id" DESC - LIMIT {entries_per_page} OFFSET {entries_per_page * page} - """, all_values) - ) + set_row_count, rows = db_pool.execute_read_statements_in_transaction( + (f""" + SELECT COUNT(*) AS "count" FROM "changelog" + {where_clause} + """, all_values), + (f""" + SELECT * + FROM "changelog" + {where_clause} + ORDER BY "when" DESC, "id" DESC + LIMIT {entries_per_page} OFFSET {entries_per_page * page} + """, all_values) + ) return (math.ceil(set_row_count[0]["count"]/entries_per_page), list(map(_changelog_entry_db_to_json, rows))) diff --git a/src/api/objects/object.py b/src/api/objects/object.py index 2138d97..39e5449 100644 --- a/src/api/objects/object.py +++ b/src/api/objects/object.py @@ -171,21 +171,18 @@ class ObjectClass: def is_get_config_allowed(self) -> bool: return self._allow_get_config - def get_current_config(self, object_id: int) -> JsonTypes or None: - if not self._allow_get_config: - return None - - with db_pool.start_read_transaction() as transaction: - indirect_statements: dict[ObjectIndirectField, tuple[FilledStatement, ...]] = {} - config_extra_statements: list[tuple[ConfigurationExtra, tuple[FilledStatement, ...]]] = [] - for field in self._indirect_fields: - if field.include_in_config: - indirect_statements[field] = field.queue_get_current_statements(transaction, object_id) - for extra in self._config_extras: - config_extra_statements.append((extra, extra.queue_get_current_statements(transaction, object_id))) - object_db_list = transaction.execute_statement_and_close( - self._object_query_request, object_id - ) + def _db_execute_get_config(self, transaction: ReadTransaction, object_id: int) -> dict: + indirect_statements: dict[ObjectIndirectField, tuple[FilledStatement, ...]] = {} + config_extra_statements: list[tuple[ConfigurationExtra, tuple[FilledStatement, ...]]] = [] + for field in self._indirect_fields: + if field.include_in_config: + indirect_statements[field] = field.queue_get_current_statements(transaction, object_id) + for extra in self._config_extras: + config_extra_statements.append((extra, extra.queue_get_current_statements(transaction, object_id))) + object_db_list = transaction.execute_statement_and_close( + self._object_query_request, object_id + ) + if len(object_db_list) < 1: raise ApiClientException(ERROR_UNKNOWN_OBJECT) if len(object_db_list) > 1: @@ -224,6 +221,12 @@ class ObjectClass: return config + def get_current_config(self, object_id: int) -> JsonTypes or None: + if not self._allow_get_config: + return None + + return db_pool.execute_read_transaction(self._db_execute_get_config, object_id) + def modify_current_config(self, transaction: WriteTransaction, modifying_user_id: int, @@ -322,129 +325,150 @@ class ObjectClass: if self._variant_id_type is not None and object_variant_client is None: raise ApiClientException(ERROR_OBJECT_ERROR("Missing object variant")) - with db_pool.start_write_transaction() as transaction: - from api.objects.current_changelog import CreationChangelogUpdate - changelog_update: CreationChangelogUpdate = CreationChangelogUpdate() - - object_variant_db = None - object_variant_json: str or None = None - if self._variant_id_type is not None: - object_variant_json = object_variant_client.as_string(self._variant_id_type.get_max_string_length()) - if object_variant_json not in self._variants_dict: + return db_pool.execute_write_transaction( + self._db_execute_create, + modifying_user_id, + parent_id, + values, + object_variant_client, + parent_class, + parent_id_column + ) + + def _db_execute_create(self, + transaction: WriteTransaction, + modifying_user_id: int, + parent_id: int or None, + values: CJsonObject, + object_variant_client: CJsonValue or None, + parent_class: "ObjectClass" or None, + parent_id_column: int or None) -> int: + from api.objects.current_changelog import CreationChangelogUpdate + changelog_update: CreationChangelogUpdate = CreationChangelogUpdate() + + object_variant_db = None + object_variant_json: str or None = None + if self._variant_id_type is not None: + object_variant_json = object_variant_client.as_string(self._variant_id_type.get_max_string_length()) + if object_variant_json not in self._variants_dict: + raise ApiClientException(ERROR_OBJECT_ERROR( + f"Unknown variant {truncate_string(object_variant_json)}")) + try: + object_variant_db = self._variant_id_type.validate_and_convert_client_value( + transaction, None, object_variant_client) + except Exception as e: # pragma: no cover + raise RuntimeError(f"Variant {truncate_string(object_variant_json)} is declared as possible " + f"variant but variant type does not accept it") from e + + column_updates: list[tuple[str, int]] = [] + indirect_updates: list[tuple[ObjectIndirectField, object]] = [] + for variant_id, field in self._all_fields.values(): + if not values.has(field.api_id): + if not field.include_in_config: + continue + if variant_id is not None and variant_id != object_variant_json: + continue + if field.default_value_json is None: + raise ApiClientException(ERROR_OBJECT_ERROR(f"Missing field {truncate_string(field.api_id)}")) + field_value = CJsonValue(field.default_value_json if field.default_value_json != (None,) else None) + else: + if not field.include_in_config: raise ApiClientException(ERROR_OBJECT_ERROR( - f"Unknown variant {truncate_string(object_variant_json)}")) - try: - object_variant_db = self._variant_id_type.validate_and_convert_client_value( - transaction, None, object_variant_client) - except Exception as e: # pragma: no cover - raise RuntimeError(f"Variant {truncate_string(object_variant_json)} is declared as possible " - f"variant but variant type does not accept it") from e - - column_updates: list[tuple[str, int]] = [] - indirect_updates: list[tuple[ObjectIndirectField, object]] = [] - for variant_id, field in self._all_fields.values(): - if not values.has(field.api_id): - if not field.include_in_config: - continue - if variant_id is not None and variant_id != object_variant_json: - continue - if field.default_value_json is None: - raise ApiClientException(ERROR_OBJECT_ERROR(f"Missing field {truncate_string(field.api_id)}")) - field_value = CJsonValue(field.default_value_json if field.default_value_json != (None, ) else None) - else: - if not field.include_in_config: - raise ApiClientException(ERROR_OBJECT_ERROR( - f"Field {truncate_string(field.api_id)} may not be included in creation")) - if variant_id is not None and variant_id != object_variant_json: - raise ApiClientException(ERROR_OBJECT_ERROR(f"Field {truncate_string(field.api_id)} is for other " - f"variant")) - field_value = values.get(field.api_id) - - if isinstance(field, ObjectColumnField): - update = field.do_value_creation(transaction, changelog_update, field_value) - if update is not None: - column_updates.append(update) - else: - assert isinstance(field, ObjectIndirectField) - indirect_updates.append(( - field, - field.validate_and_convert_client_value(transaction, None, field_value) - )) + f"Field {truncate_string(field.api_id)} may not be included in creation")) + if variant_id is not None and variant_id != object_variant_json: + raise ApiClientException(ERROR_OBJECT_ERROR(f"Field {truncate_string(field.api_id)} is for other " + f"variant")) + field_value = values.get(field.api_id) - if parent_class is not None: - assert isinstance(parent_class, ObjectClass) - parent_list_db = transaction.execute_statement(parent_class._object_query_request, parent_id) - if len(parent_list_db) == 0: - raise ApiClientException(ERROR_OBJECT_ERROR("Parent does not exist")) - column_updates.append((parent_id_column, parent_id)) - - if object_variant_db is not None: - column_updates.append((self._variant_db_column, object_variant_db)) - - object_id = transaction.execute_statement(f""" - INSERT INTO "{self._db_table}" ({",".join(map(lambda t: f'"{t[0]}"', column_updates))}) - VALUES ({"" if len(column_updates) == 0 else ("?," * (len(column_updates) - 1) + "?")}) - RETURNING "id" - """, *[new_value_db for _, new_value_db in column_updates])[0]["id"] - changelog_update.add_entry(CreationEntry( - self._creation_subject, + if isinstance(field, ObjectColumnField): + update = field.do_value_creation(transaction, changelog_update, field_value) + if update is not None: + column_updates.append(update) + else: + assert isinstance(field, ObjectIndirectField) + indirect_updates.append(( + field, + field.validate_and_convert_client_value(transaction, None, field_value) + )) + + if parent_class is not None: + assert isinstance(parent_class, ObjectClass) + parent_list_db = transaction.execute_statement(parent_class._object_query_request, parent_id) + if len(parent_list_db) == 0: + raise ApiClientException(ERROR_OBJECT_ERROR("Parent does not exist")) + column_updates.append((parent_id_column, parent_id)) + + if object_variant_db is not None: + column_updates.append((self._variant_db_column, object_variant_db)) + + object_id = transaction.execute_statement(f""" + INSERT INTO "{self._db_table}" ({",".join(map(lambda t: f'"{t[0]}"', column_updates))}) + VALUES ({"" if len(column_updates) == 0 else ("?," * (len(column_updates) - 1) + "?")}) + RETURNING "id" + """, *[new_value_db for _, new_value_db in column_updates])[0]["id"] + changelog_update.add_entry(CreationEntry( + self._creation_subject, + object_id, + object_variant_json, + parent_class, + parent_id + )) + for field, value in indirect_updates: + field.queue_creation( + transaction, + changelog_update, object_id, - object_variant_json, - parent_class, - parent_id - )) - for field, value in indirect_updates: - field.queue_creation( - transaction, - changelog_update, - object_id, - value - ) - changelog_update.queue_log_statements( - transaction, modifying_user_id, object_id) - transaction.commit() + value + ) + changelog_update.queue_log_statements( + transaction, modifying_user_id, object_id) + transaction.commit() return object_id def delete(self, modifying_user_id: int, object_id: int): - with db_pool.start_write_transaction() as transaction: - if len(transaction.execute_statement(self._object_query_request, object_id)) == 0: - raise ApiClientException(ERROR_UNKNOWN_OBJECT) - from api.objects.current_changelog import UpdateChangelogUpdate - changelog = UpdateChangelogUpdate() - - changelog.add_entry(DeletionEntry( - self._deletion_subject, - object_id, - True - )) - transaction.queue_statement(f""" - UPDATE "{self._db_table}" - SET "deleted" = true - WHERE "id" = ? - """, object_id) - changelog.queue_log_statements(transaction, modifying_user_id) - transaction.commit() + db_pool.execute_write_transaction(self._db_execute_delete, modifying_user_id, object_id) + + def _db_execute_delete(self, transaction: WriteTransaction, modifying_user_id: int, object_id: int): + if len(transaction.execute_statement(self._object_query_request, object_id)) == 0: + raise ApiClientException(ERROR_UNKNOWN_OBJECT) + from api.objects.current_changelog import UpdateChangelogUpdate + changelog = UpdateChangelogUpdate() + + changelog.add_entry(DeletionEntry( + self._deletion_subject, + object_id, + True + )) + transaction.queue_statement(f""" + UPDATE "{self._db_table}" + SET "deleted" = true + WHERE "id" = ? + """, object_id) + changelog.queue_log_statements(transaction, modifying_user_id) + transaction.commit() def undo_deletion(self, modifying_user_id: int, object_id: int): - with db_pool.start_write_transaction() as transaction: - object_list_db = transaction.execute_statement(self._object_query_request_no_deletion_check, object_id) - if len(object_list_db) == 0: - raise ApiClientException(ERROR_UNKNOWN_OBJECT) - if not bool(object_list_db[0]["deleted"]): - raise ApiClientException(ERROR_OBJECT_ERROR("Object not deleted")) - - from api.objects.current_changelog import UpdateChangelogUpdate - changelog = UpdateChangelogUpdate() - - changelog.add_entry(DeletionEntry( - self._deletion_subject, - object_id, - False - )) - transaction.queue_statement(f""" - UPDATE "{self._db_table}" - SET "deleted" = false - WHERE "id" = ? - """, object_id) - changelog.queue_log_statements(transaction, modifying_user_id) - transaction.commit() + db_pool.execute_write_transaction(self._db_execute_undo_deletion, modifying_user_id, object_id) + + def _db_execute_undo_deletion(self, transaction: WriteTransaction, modifying_user_id: int, object_id: int): + object_list_db = transaction.execute_statement(self._object_query_request_no_deletion_check, object_id) + if len(object_list_db) == 0: + raise ApiClientException(ERROR_UNKNOWN_OBJECT) + if not bool(object_list_db[0]["deleted"]): + raise ApiClientException(ERROR_OBJECT_ERROR("Object not deleted")) + + from api.objects.current_changelog import UpdateChangelogUpdate + changelog = UpdateChangelogUpdate() + + changelog.add_entry(DeletionEntry( + self._deletion_subject, + object_id, + False + )) + transaction.queue_statement(f""" + UPDATE "{self._db_table}" + SET "deleted" = false + WHERE "id" = ? + """, object_id) + changelog.queue_log_statements(transaction, modifying_user_id) + transaction.commit() diff --git a/src/api/routes/courses.py b/src/api/routes/courses.py index acafc9a..0d46b34 100644 --- a/src/api/routes/courses.py +++ b/src/api/routes/courses.py @@ -33,11 +33,10 @@ ORDER BY "courses"."id" ASC @api_route("/courses", ["GET"], allow_while_readonly=True) def api_route_courses(): is_mod: bool = is_moderator() - with db_pool.start_read_transaction() as trans: - courses_db, auth_db = trans.execute_statements_and_close( - (_SQL_GET_ALL_COURSES, [is_mod]), - (_SQL_GET_ALL_COURSES_AUTH, [is_mod]) - ) + courses_db, auth_db = db_pool.execute_read_statements_in_transaction( + (_SQL_GET_ALL_COURSES, [is_mod]), + (_SQL_GET_ALL_COURSES_AUTH, [is_mod]) + ) courses_list_json = course_list_db_to_json_no_lectures(courses_db, auth_db, is_mod) courses_map_json = {} for course_json in courses_list_json: @@ -100,42 +99,57 @@ ORDER BY "lectures"."id" ASC """) +def _db_execute_get_courses(trans: ReadTransaction, + is_mod: bool, + course_id: int = None, + course_id_string: str = None) -> tuple[ + DbResult, DbResultSet, DbResultSet or None, DbResultSet or None, DbResultSet or None, DbResultSet or None +]: + if course_id: + courses_array = trans.execute_statement(_SQL_GET_COURSE_BY_ID_NO_CHECK, course_id) + else: + courses_array = trans.execute_statement(_SQL_GET_COURSE_BY_ID_STRING_NO_CHECK, course_id_string) + + if len(courses_array) == 0: + trans.close() + raise ApiClientException(ERROR_UNKNOWN_OBJECT) + if len(courses_array) > 1: # pragma: no cover + raise Exception("Got more than one course with id " + str(course_id if course_id else course_id_string)) + course_db = courses_array[0] + if not is_mod and not bool(course_db["visible"]): + raise ApiClientException(ERROR_UNAUTHORIZED) + + course_id = course_db["id"] + + auth_stat = course_queue_query_auth(trans, course_id) + if "include_lectures" not in request.args or request.args["include_lectures"] != "true": + trans.close() + return course_db, trans.get_result(auth_stat), None, None, None, None + + lectures_db, lectures_auth_db, chapters_db, media_db = trans.execute_statements_and_close( + (_SQL_GET_COURSE_LECTURES, [course_id, is_mod]), + (_SQL_GET_COURSE_LECTURES_AUTH, [course_id, is_mod]), + (_SQL_GET_COURSE_LECTURES_CHAPTERS, [course_id, is_mod]), + (_SQL_GET_COURSE_LECTURES_MEDIA, [course_id, is_mod]), + ) + return course_db, trans.get_result(auth_stat), lectures_db, lectures_auth_db, chapters_db, media_db + + @api_add_route("/course/<int:course_id>", ["GET"]) @api_add_route("/course/<string:course_id_string>", ["GET"]) @api_function(allow_while_readonly=True) def api_route_course(course_id: int = None, course_id_string: str = None): is_mod: bool = is_moderator() - with db_pool.start_read_transaction() as trans: - if course_id: - courses_array = trans.execute_statement(_SQL_GET_COURSE_BY_ID_NO_CHECK, course_id) - else: - courses_array = trans.execute_statement(_SQL_GET_COURSE_BY_ID_STRING_NO_CHECK, course_id_string) - - if len(courses_array) == 0: - trans.close() - raise ApiClientException(ERROR_UNKNOWN_OBJECT) - if len(courses_array) > 1: # pragma: no cover - raise Exception("Got more than one course with id " + str(course_id if course_id else course_id_string)) - course_db = courses_array[0] - if not is_mod and not bool(course_db["visible"]): - raise ApiClientException(ERROR_UNAUTHORIZED) - - course_id = course_db["id"] - - auth_stat = course_queue_query_auth(trans, course_id) - if "include_lectures" not in request.args or request.args["include_lectures"] != "true": - trans.close() - course_json = course_db_to_json_no_lectures(course_db, trans.get_result(auth_stat), is_mod) - return course_json - - lectures_db, lectures_auth_db, chapters_db, media_db = trans.execute_statements_and_close( - (_SQL_GET_COURSE_LECTURES, [course_id, is_mod]), - (_SQL_GET_COURSE_LECTURES_AUTH, [course_id, is_mod]), - (_SQL_GET_COURSE_LECTURES_CHAPTERS, [course_id, is_mod]), - (_SQL_GET_COURSE_LECTURES_MEDIA, [course_id, is_mod]), - ) - course_auth_db = trans.get_result(auth_stat) + + course_db, course_auth_db, lectures_db, lectures_auth_db, chapters_db, media_db = db_pool.execute_read_transaction( + _db_execute_get_courses, + is_mod, + course_id, + course_id_string + ) course_json = course_db_to_json_no_lectures(course_db, course_auth_db, is_mod) + if lectures_db is None: + return course_json lectures_auth_i = 0 chapters_i = 0 @@ -194,17 +208,26 @@ WHERE %s _SQL_GET_LECTURE_MEDIA = PreparedStatement(_SQL_GET_LECTURE_MEDIA_NO_ID_COND % '"videos"."lecture_id" = ?') +def _db_execute_get_lecture(trans: ReadTransaction, is_mod: bool, lecture_id: int) -> tuple[ + DbResultSet, DbResultSet, DbResultSet, DbResultSet +]: + auth_stat = lecture_queue_query_auth(trans, lecture_id) + lecture_list_db, chapters_db, media_db = trans.execute_statements_and_close( + (_SQL_GET_LECTURE_WITH_COURSE_NO_CHECK, [lecture_id]), + (_SQL_GET_LECTURE_CHAPTERS, [lecture_id, is_mod]), + (_SQL_GET_LECTURE_MEDIA, [lecture_id, is_mod]), + ) + return lecture_list_db, trans.get_result(auth_stat), chapters_db, media_db + + @api_route("/lecture/<int:lecture_id>", ["GET"], allow_while_readonly=True) def api_route_lecture(lecture_id: int): is_mod: bool = is_moderator() - with db_pool.start_read_transaction() as trans: - auth_stat = lecture_queue_query_auth(trans, lecture_id) - lecture_list_db, chapters_db, media_db = trans.execute_statements_and_close( - (_SQL_GET_LECTURE_WITH_COURSE_NO_CHECK, [lecture_id]), - (_SQL_GET_LECTURE_CHAPTERS, [lecture_id, is_mod]), - (_SQL_GET_LECTURE_MEDIA, [lecture_id, is_mod]), - ) - auth_db = trans.get_result(auth_stat) + lecture_list_db, auth_db, chapters_db, media_db = db_pool.execute_read_transaction( + _db_execute_get_lecture, + is_mod, + lecture_id + ) if len(lecture_list_db) == 0: raise ApiClientException(ERROR_UNKNOWN_OBJECT) @@ -258,28 +281,82 @@ _CHAPTER_SUGGESTIONS_RATE_LIMITER = IntervalRateLimiter( ) +def _db_execute_chapter_suggestion(trans: WriteTransaction, is_mod: bool, lecture_id: int, start_time: int, name: str): + lecture_count_db, chapter_count_db = trans.execute_statements( + (_SQL_EXISTS_LECTURE, [lecture_id, is_mod]), + (_SQL_GET_CHAPTER_COUNT, [lecture_id]) + ) + if lecture_count_db[0]["count"] == 0: + raise ApiClientException(ERROR_UNKNOWN_OBJECT) + if chapter_count_db[0]["count"] >= _CHAPTER_LIMIT_FOR_SUGGESTIONS: + raise ApiClientException(ERROR_TOO_MANY_SUGGESTIONS) + trans.execute_statement_and_commit(_SQL_PUT_CHAPTER_SUGGESTION, lecture_id, start_time, name, False) + + @api_route("/lecture/<int:lecture_id>/chapter_suggestion", ["PUT"]) def api_route_lecture_chapter_suggestion(lecture_id: int): is_mod = is_moderator() json_request = get_client_json(request) start_time = json_request.get_int("start_time", 0, MAX_VALUE_SINT32) name = json_request.get_string("name", 128) - - with db_pool.start_write_transaction() as trans: - lecture_count_db, chapter_count_db = trans.execute_statements( - (_SQL_EXISTS_LECTURE, [lecture_id, is_mod]), - (_SQL_GET_CHAPTER_COUNT, [lecture_id]) - ) - if lecture_count_db[0]["count"] == 0: - raise ApiClientException(ERROR_UNKNOWN_OBJECT) - if chapter_count_db[0]["count"] >= _CHAPTER_LIMIT_FOR_SUGGESTIONS: - raise ApiClientException(ERROR_TOO_MANY_SUGGESTIONS) - if not _CHAPTER_SUGGESTIONS_RATE_LIMITER.check_new_request(): - raise ApiClientException(ERROR_TOO_MANY_SUGGESTIONS) - trans.execute_statement_and_commit(_SQL_PUT_CHAPTER_SUGGESTION, lecture_id, start_time, name, False) + + if not _CHAPTER_SUGGESTIONS_RATE_LIMITER.check_new_request(): + raise ApiClientException(ERROR_TOO_MANY_SUGGESTIONS) + db_pool.execute_write_transaction(_db_execute_chapter_suggestion, is_mod, lecture_id, start_time, name) return {}, HTTP_201_CREATED +def _db_execute_search(trans: ReadTransaction, is_mod: bool, search_term: str): + course_list_statement = course_queue_search(trans, search_term, is_mod) + lecture_list_statement = lecture_queue_search(trans, search_term, is_mod) + + course_list_db = trans.get_result(course_list_statement) + lecture_list_db = trans.get_result(lecture_list_statement) + + if len(course_list_db) == 0 and len(lecture_list_db) == 0: + trans.close() + return course_list_db, lecture_list_db, None, None, None, None + + course_ids: dict[int, None] = {} + for course_db in course_list_db: + course_ids[course_db["id"]] = None + + lecture_ids: dict[int, None] = {} + for lecture_db in lecture_list_db: + lecture_ids[lecture_db["id"]] = None + course_ids[lecture_db["course_id"]] = None + + assert len(course_ids) > 0 + course_auth_statement = trans.queue_statement(f""" + SELECT * FROM "perm" + WHERE "course_id" IN ({",".join(map(lambda _: "?", course_ids))}) + ORDER BY "course_id" ASC + """, *course_ids) + if len(lecture_ids) == 0: + trans.close() + return course_list_db, lecture_list_db, trans.get_result(course_auth_statement), None, None, None + + lecture_where_clause = f""""lecture_id" IN ({",".join(map(lambda _: "?", lecture_ids))})""" + lecture_auth_statement = trans.queue_statement(f""" + SELECT * FROM "perm" + WHERE {lecture_where_clause} + ORDER BY "lecture_id" ASC + """, *lecture_ids) + lecture_chapters_statement = trans.queue_statement( + _SQL_GET_LECTURE_CHAPTERS_NO_ID_COND % lecture_where_clause, *lecture_ids, is_mod) + lecture_media_sources_statement = trans.queue_statement( + _SQL_GET_LECTURE_MEDIA_NO_ID_COND % lecture_where_clause, *lecture_ids, is_mod) + trans.close() + return ( + course_list_db, + lecture_list_db, + trans.get_result(course_auth_statement), + trans.get_result(lecture_auth_statement), + trans.get_result(lecture_chapters_statement), + trans.get_result(lecture_media_sources_statement), + ) + + @api_route("/search", ["GET"], allow_while_readonly=True) def api_route_search(): is_mod = is_moderator() @@ -292,53 +369,21 @@ def api_route_search(): if search_term.isspace(): raise ApiClientException(ERROR_REQUEST_INVALID_PARAMETER("URL.q", "Only whitespace is not allowed")) - with db_pool.start_read_transaction() as trans: - course_list_statement = course_queue_search(trans, search_term, is_mod) - lecture_list_statement = lecture_queue_search(trans, search_term, is_mod) - - course_list_db = trans.get_result(course_list_statement) - lecture_list_db = trans.get_result(lecture_list_statement) - - if len(course_list_db) == 0 and len(lecture_list_db) == 0: - trans.close() - return { - "courses": [], - "lectures": [], - "courses_context": {} - } - - course_ids: dict[int, None] = {} - for course_db in course_list_db: - course_ids[course_db["id"]] = None - - lecture_ids: dict[int, None] = {} - for lecture_db in lecture_list_db: - lecture_ids[lecture_db["id"]] = None - course_ids[lecture_db["course_id"]] = None - - assert len(course_ids) > 0 - course_auth_statement = trans.queue_statement(f""" - SELECT * FROM "perm" - WHERE "course_id" IN ({",".join(map(lambda _: "?", course_ids))}) - ORDER BY "course_id" ASC - """, *course_ids) - lecture_auth_statement = None - lecture_chapters_statement = None - lecture_media_sources_statement = None - if len(lecture_ids) > 0: - lecture_where_clause = f""""lecture_id" IN ({",".join(map(lambda _: "?", lecture_ids))})""" - lecture_auth_statement = trans.queue_statement(f""" - SELECT * FROM "perm" - WHERE {lecture_where_clause} - ORDER BY "lecture_id" ASC - """, *lecture_ids) - lecture_chapters_statement = trans.queue_statement( - _SQL_GET_LECTURE_CHAPTERS_NO_ID_COND % lecture_where_clause, *lecture_ids, is_mod) - lecture_media_sources_statement = trans.queue_statement( - _SQL_GET_LECTURE_MEDIA_NO_ID_COND % lecture_where_clause, *lecture_ids, is_mod) - trans.close() + course_list_db, lecture_list_db, course_auth_db, lecture_auth_db, lecture_chapters_db, lecture_media_sources_db = \ + db_pool.execute_read_transaction( + _db_execute_search, + is_mod, + search_term + ) + + if len(course_list_db) == 0 and len(lecture_list_db) == 0: + return { + "courses": [], + "lectures": [], + "courses_context": {} + } - course_auth_by_id = db_group_data_by_id(trans.get_result(course_auth_statement), "course_id") + course_auth_by_id = db_group_data_by_id(course_auth_db, "course_id") courses_context_json: dict[str, JsonTypes] = {} course_id_list_json: list[int] = [] @@ -348,16 +393,16 @@ def api_route_search(): course_auth = course_auth_by_id.get(id, []) courses_context_json[str(id)] = course_db_to_json_no_lectures(course_db, course_auth, is_mod) - if lecture_auth_statement is None: + if len(lecture_list_db) == 0: return { "courses": course_id_list_json, "lectures": [], "courses_context": courses_context_json } - lecture_auth_by_id = db_group_data_by_id(trans.get_result(lecture_auth_statement), "lecture_id") - chapters_by_id = db_group_data_by_id(trans.get_result(lecture_chapters_statement), "lecture_id") - media_sources_by_id = db_group_data_by_id(trans.get_result(lecture_media_sources_statement), "lecture_id") + lecture_auth_by_id = db_group_data_by_id(lecture_auth_db, "lecture_id") + chapters_by_id = db_group_data_by_id(lecture_chapters_db, "lecture_id") + media_sources_by_id = db_group_data_by_id(lecture_media_sources_db, "lecture_id") lecture_list_json = [] for lecture_db in lecture_list_db: diff --git a/src/api/routes/object_modifications.py b/src/api/routes/object_modifications.py index b3fa54f..11537af 100644 --- a/src/api/routes/object_modifications.py +++ b/src/api/routes/object_modifications.py @@ -1,10 +1,21 @@ from flask import request +from api.objects.object import ObjectClass from api.routes import * from api.objects import * from api.authentication import check_csrf_token +def _db_execute_path_configuration(trans: WriteTransaction, + object_class: ObjectClass, + modifying_user_id: int, + object_id: int, + expected_current_values_json: CJsonObject, + updates_json: CJsonObject): + object_class.modify_current_config(trans, modifying_user_id, object_id, expected_current_values_json, updates_json) + trans.commit() + + @api_route(f"""\ /object_management\ /<any({','.join(key for key, obj in object_classes_by_id.items())}):object_type>\ @@ -18,13 +29,12 @@ def api_route_object_management_configuration(object_type: str, object_id: int): check_csrf_token() check_client_int(object_id, "URL.object_id") json_request = get_client_json(request) - updates_json = json_request.get_object("updates") - expected_current_values_json = json_request.get_object("expected_current_values") - modifying_user_id: int = get_user_id() - with db_pool.start_write_transaction() as trans: - object_classes_by_id[object_type].modify_current_config(trans, modifying_user_id, object_id, - expected_current_values_json, updates_json) - trans.commit() + db_pool.execute_write_transaction(_db_execute_path_configuration, + object_classes_by_id[object_type], + get_user_id(), + object_id, + json_request.get_object("expected_current_values"), + json_request.get_object("updates")) return {} else: object_class = object_classes_by_id[object_type] @@ -67,6 +77,15 @@ def api_route_object_management_new(object_type: str): _MAX_OBJECT_TYPE_LENGTH = max(100, max(map(lambda key: len(key), object_classes_by_id.keys()))) +def _db_execute_modify_many(trans: WriteTransaction, + modifying_user_id: int, + updates: list[tuple[ObjectClass, int, CJsonObject, CJsonObject]]): + for object_class, object_id, expected_current_values_json, updates_json in updates: + object_class.modify_current_config( + trans, modifying_user_id, object_id, expected_current_values_json, updates_json) + trans.commit() + + @api_route("/object_management/modify_many", ["POST"]) @api_moderator_route(require_csrf_token=True) def api_route_object_management_modify_many(): @@ -77,25 +96,27 @@ def api_route_object_management_modify_many(): if objects_list_json.length() > 32: raise ApiClientException(ERROR_OBJECT_ERROR("Too many updates. Only 32 updates in one query are allowed")) - updated_objects: dict[tuple[str, int], None] = {} - with db_pool.start_write_transaction() as trans: - for object_json_raw in objects_list_json: - object_json = object_json_raw.as_object() - object_type = object_json.get_string("type", _MAX_OBJECT_TYPE_LENGTH) - object_id = object_json.get_int("id", MIN_VALUE_SINT32, MAX_VALUE_SINT32) - updates_json = object_json.get_object("updates") - expected_current_values_json = object_json.get_object("expected_current_values") - - if object_type not in object_classes_by_id: - raise ApiClientException(ERROR_OBJECT_ERROR(f"Unknown type {truncate_string(object_type)}")) - if (object_type, object_id) in updated_objects: - raise ApiClientException(ERROR_OBJECT_ERROR( - f"Duplicated object of type {truncate_string(object_type)} with id {object_id}")) - updated_objects[(object_type, object_id)] = None - - object_classes_by_id[object_type].modify_current_config( - trans, modifying_user_id, object_id, expected_current_values_json, updates_json) - trans.commit() + seen_objects: dict[tuple[str, int], None] = {} + updates: list[tuple[ObjectClass, int, CJsonObject, CJsonObject]] = [] + for object_json_raw in objects_list_json: + object_json = object_json_raw.as_object() + object_type = object_json.get_string("type", _MAX_OBJECT_TYPE_LENGTH) + object_id = object_json.get_int("id", MIN_VALUE_SINT32, MAX_VALUE_SINT32) + + if object_type not in object_classes_by_id: + raise ApiClientException(ERROR_OBJECT_ERROR(f"Unknown type {truncate_string(object_type)}")) + if (object_type, object_id) in seen_objects: + raise ApiClientException(ERROR_OBJECT_ERROR( + f"Duplicated object of type {truncate_string(object_type)} with id {object_id}")) + seen_objects[(object_type, object_id)] = None + + updates.append(( + object_classes_by_id[object_type], + object_id, + object_json.get_object("expected_current_values"), + object_json.get_object("updates") + )) + db_pool.execute_write_transaction(_db_execute_modify_many, modifying_user_id, updates) return {} diff --git a/src/api/routes/site.py b/src/api/routes/site.py index b447722..194a725 100644 --- a/src/api/routes/site.py +++ b/src/api/routes/site.py @@ -163,32 +163,26 @@ ORDER BY CASE WHEN "order" IS NULL THEN 0 ELSE "order" END ASC, "id" ASC """) -@api_route("/homepage", ["GET"], allow_while_readonly=True) -def api_route_homepage(): - is_mod: bool = is_moderator() - upcoming_start: date = date.today() - upcoming_end: date = upcoming_start + timedelta(days=7) - - with db_pool.start_read_transaction() as transaction: - upcoming_db, \ - upcoming_auth_db, \ - latest_media_db, \ - latest_media_auth_db, \ - active_livestreams_db, \ - active_livestreams_auth_db, \ - featured_list_db \ - = transaction.execute_statements( - (_SQL_GET_UPCOMING_LECTURES, [upcoming_start, upcoming_end, is_mod]), - (_SQL_GET_UPCOMING_LECTURES_AUTH, [upcoming_start, upcoming_end, is_mod]), - (_SQL_GET_LATEST_MEDIA, [is_mod, is_mod]), - (_SQL_GET_LATEST_MEDIA_AUTH, [is_mod, is_mod]), - (_SQL_GET_ACTIVE_LIVESTREAMS, [is_mod]), - (_SQL_GET_ACTIVE_LIVESTREAMS_AUTH, [is_mod]), - (_SQL_GET_FEATURED, [is_mod]) - ) - courses_context = {} - featured = get_homepage_featured(transaction, courses_context, is_mod, featured_list_db) - # Transaction is closed by get_homepage_featured +def _db_execute_get_homepage(transaction: ReadTransaction, is_mod: bool, upcoming_start: date, upcoming_end: date): + upcoming_db, \ + upcoming_auth_db, \ + latest_media_db, \ + latest_media_auth_db, \ + active_livestreams_db, \ + active_livestreams_auth_db, \ + featured_list_db \ + = transaction.execute_statements( + (_SQL_GET_UPCOMING_LECTURES, [upcoming_start, upcoming_end, is_mod]), + (_SQL_GET_UPCOMING_LECTURES_AUTH, [upcoming_start, upcoming_end, is_mod]), + (_SQL_GET_LATEST_MEDIA, [is_mod, is_mod]), + (_SQL_GET_LATEST_MEDIA_AUTH, [is_mod, is_mod]), + (_SQL_GET_ACTIVE_LIVESTREAMS, [is_mod]), + (_SQL_GET_ACTIVE_LIVESTREAMS_AUTH, [is_mod]), + (_SQL_GET_FEATURED, [is_mod]) + ) + courses_context = {} + featured = get_homepage_featured(transaction, courses_context, is_mod, featured_list_db) + # Transaction is closed by get_homepage_featured return { "featured": featured, @@ -205,6 +199,20 @@ def api_route_homepage(): } +@api_route("/homepage", ["GET"], allow_while_readonly=True) +def api_route_homepage(): + is_mod: bool = is_moderator() + upcoming_start: date = date.today() + upcoming_end: date = upcoming_start + timedelta(days=7) + + return db_pool.execute_read_transaction( + _db_execute_get_homepage, + is_mod, + upcoming_start, + upcoming_end + ) + + def get_homepage_upcoming_lectures( courses_context: {}, is_mod: bool, diff --git a/src/api/routes/sorter.py b/src/api/routes/sorter.py index 462cb2f..a2d9e04 100644 --- a/src/api/routes/sorter.py +++ b/src/api/routes/sorter.py @@ -32,11 +32,10 @@ WHERE "id" = ? @api_route("/sorter/log", ["GET"], allow_while_readonly=True) @api_moderator_route() def api_route_sorter_log(): - with db_pool.start_read_transaction() as trans: - log_db, error_log_db = trans.execute_statements_and_close( - (_SQL_GET_SORTER_LOG, []), - (_SQL_GET_SORTER_ERROR_LOG, []) - ) + log_db, error_log_db = db_pool.execute_read_statements_in_transaction( + (_SQL_GET_SORTER_LOG, []), + (_SQL_GET_SORTER_ERROR_LOG, []) + ) log_json = [] for log_entry_db in log_db: @@ -85,9 +84,8 @@ def api_route_sorter_log(): @api_moderator_route(require_csrf_token=True) def api_route_delete_sorter_error_log(log_id: int): check_client_int(log_id, "URL.log_id") - with db_pool.start_write_transaction() as trans: - trans.execute_statement_and_commit( - _SQL_DELETE_ERROR_LOG_ENTRY, - log_id - ) + db_pool.execute_write_transaction(lambda trans: trans.execute_statement_and_commit( + _SQL_DELETE_ERROR_LOG_ENTRY, + log_id + )) return {} diff --git a/src/api/routes/user.py b/src/api/routes/user.py index a7f1f49..9fe9fe5 100644 --- a/src/api/routes/user.py +++ b/src/api/routes/user.py @@ -11,8 +11,7 @@ _SQL_GET_USERS = PreparedStatement(""" @api_route("/users", ["GET"], allow_while_readonly=True) @api_moderator_route() def api_route_users(): - with db_pool.start_read_transaction() as trans: - user_list_db = trans.execute_statement_and_close(_SQL_GET_USERS) + user_list_db = db_pool.execute_read_statement_in_transaction(_SQL_GET_USERS) return { "users": user_info_list_db_to_json(user_list_db) } diff --git a/src/api/user.py b/src/api/user.py index 3b070a3..47ce8d5 100644 --- a/src/api/user.py +++ b/src/api/user.py @@ -1,6 +1,6 @@ from api.database import * from api.miscellaneous import * -from api.settings import Settings, SettingsSection, SettingsSubSectionEntry, SettingsValueEntry, SettingsBooleanEntry +from api.settings import Settings, SettingsSection, SettingsSubSectionEntry, SettingsBooleanEntry def user_info_list_db_to_json(user_list_db: list[DbResultRow]): @@ -35,8 +35,7 @@ WHERE "id" = ? def get_user_settings(user_id: int) -> dict or None: - with db_pool.start_read_transaction() as trans: - user_list_db = trans.execute_statement_and_close(_SQL_GET_USER, user_id) + user_list_db = db_pool.execute_read_statement_in_transaction(_SQL_GET_USER, user_id) if len(user_list_db) == 0: return None user_db = user_list_db[0] @@ -55,9 +54,10 @@ def update_user_settings(user_id: int, updates_json: CJsonObject): )) if len(updates_db) == 0: return - with db_pool.start_write_transaction() as trans: - trans.execute_statement_and_commit(f""" - UPDATE "users" - SET {",".join(map(lambda tup: f'"{tup[0]}" = ?', updates_db))} - WHERE "id" = ? - """, *map(lambda tup: tup[1], updates_db), user_id) + db_pool.execute_write_transaction(lambda trans: trans.execute_statement_and_commit( + f""" + UPDATE "users" + SET {",".join(map(lambda tup: f'"{tup[0]}" = ?', updates_db))} + WHERE "id" = ? + """, *map(lambda tup: tup[1], updates_db), user_id) + ) -- GitLab