Skip to content
Snippets Groups Projects
Commit d0c87174 authored by Simon Künzel's avatar Simon Künzel
Browse files

py: Misc Improvements

parent 7e57d259
No related branches found
No related tags found
No related merge requests found
......@@ -272,7 +272,7 @@ def authenticate_fsmpi(username: str, password: str) -> {}:
if len(user_list_db) < 1:
result = trans.execute_statement_full_result_and_commit(
_SQL_INSERT_USER, user_id, user_info["givenName"], user_id)
user_db_id = result[3]
user_db_id = result[DB_RESULT_AUTO_INCREMENT]
else:
trans.commit()
user_db = user_list_db[0]
......
......@@ -4,7 +4,9 @@ from api.database.database import (PreparedStatement, FilledStatement,
AbstractTransaction, ReadTransaction, WriteTransaction,
DbConnection,
DbConnectionPool,
DbResult, DbResultSet, DbResultRow, DbValueType, DbWarning)
DbResult, DbResultSet, DbResultRow, DbValueType, DbWarning,
DB_RESULT_EXCEPTION, DB_RESULT_WARNINGS, DB_RESULT_SET, DB_RESULT_AFFECTED_ROWS,
DB_RESULT_AUTO_INCREMENT)
db_pool: DbConnectionPool
......
......@@ -65,15 +65,16 @@ class PythonDbConnection(DbConnection, Generic[Connection, Cursor], ABC):
else:
assert isinstance(filled_stat.statement, str)
cursor = self._execute_unprepared_statement(filled_stat.statement, filled_stat.values)
warnings, auto_increment, result_set = self._get_result(cursor)
warnings, auto_increment, affected_rows, result_set = self._get_result(cursor)
except self._caller_exception as e:
exception = e
warnings = []
auto_increment = None
affected_rows = None
result_set = []
except Exception as e:
raise DatabaseError("An exception occurred while executing statement") from e
results.append((exception, warnings, result_set, auto_increment))
results.append((exception, warnings, result_set, affected_rows, auto_increment))
if _DEBUG_PRINT_STATEMENT_EXECUTION:
# noinspection PyUnboundLocalVariable
......@@ -112,11 +113,11 @@ class PythonDbConnection(DbConnection, Generic[Connection, Cursor], ABC):
try:
cursor_iterator = self._unprepared_cursor.execute(full_query, params=tuple(all_values), multi=True)
for i in range(0, len(statements)):
warnings, auto_increment, result_set = self._get_result(next(cursor_iterator))
results.append((None, warnings, result_set, auto_increment))
warnings, auto_increment, affected_rows, result_set = self._get_result(next(cursor_iterator))
results.append((None, warnings, result_set, affected_rows, auto_increment))
except self._caller_exception as e:
for i in range(0, len(statements)):
results.append((e, [], None, []))
results.append((e, [], [], None, None))
except Exception as e:
raise DatabaseError("An exception occurred while executing statements") from e
......@@ -164,16 +165,21 @@ class PythonDbConnection(DbConnection, Generic[Connection, Cursor], ABC):
def _fetch_auto_increment_value(self, cursor: Cursor) -> int or None:
pass
@abstractmethod
def _fetch_affected_rows_count(self, cursor: Cursor) -> int or None:
pass
@abstractmethod
def _fetch_all_rows(self, cursor: Cursor) -> list[tuple]:
pass
def _get_result(self, cursor: Cursor) -> tuple[list[DbWarning], int, DbResultSet]:
def _get_result(self, cursor: Cursor) -> tuple[list[DbWarning], int or None, int or None, DbResultSet]:
warnings = self._fetch_warnings(cursor)
auto_increment_value = self._fetch_auto_increment_value(cursor)
affected_rows_count = self._fetch_affected_rows_count(cursor)
cursor_rows = self._fetch_all_rows(cursor)
if len(cursor_rows) == 0:
return warnings, auto_increment_value, []
return warnings, auto_increment_value, affected_rows_count, []
column_names = []
column_type_codes = []
......@@ -231,4 +237,4 @@ class PythonDbConnection(DbConnection, Generic[Connection, Cursor], ABC):
result_row[column_name] = result_value
result_set.append(result_row)
return warnings, auto_increment_value, result_set
return warnings, auto_increment_value, affected_rows_count, result_set
......@@ -10,8 +10,13 @@ DbResultSet = list[DbResultRow]
# First string is warning type, int is warning id, second string is message
DbWarning = tuple[str, int, str]
DbAutoIncrementValue = int or None
DbResult = tuple[Exception or None, list[DbWarning], DbResultSet, DbAutoIncrementValue]
DbResult = tuple[Exception or None, list[DbWarning], DbResultSet, int or None, int or None]
DB_RESULT_EXCEPTION = 0
DB_RESULT_WARNINGS = 1
DB_RESULT_SET = 2
DB_RESULT_AFFECTED_ROWS = 3
DB_RESULT_AUTO_INCREMENT = 4
SQL_PARAMETER_INDICATOR = "?"
......@@ -175,7 +180,7 @@ class AbstractTransaction(ABC):
"""
Does the same as :func:`get_full_result` with default raise_exception but only returns the result set
"""
return self.get_full_result(stat, ignore_warnings=ignore_warnings)[2]
return self.get_full_result(stat, ignore_warnings=ignore_warnings)[DB_RESULT_SET]
def get_full_result(self,
stat: FilledStatement,
......@@ -203,11 +208,11 @@ class AbstractTransaction(ABC):
result = self._statement_results[stat]
del self._statement_results[stat]
if raise_exception and result[0] is not None:
if raise_exception and result[DB_RESULT_EXCEPTION] is not None:
raise RuntimeError("An exception had occurred while executing statement "
"(Or other statement in multi-query)") from result[0]
if not ignore_warnings and len(result[1]) > 0:
raise WarningError(result[1])
"(Or other statement in multi-query)") from result[DB_RESULT_EXCEPTION]
if not ignore_warnings and len(result[DB_RESULT_WARNINGS]) > 0:
raise WarningError(result[DB_RESULT_WARNINGS])
return result
def execute_statement(self,
......
......@@ -5,7 +5,7 @@ from mysql.connector import MySQLConnection, NUMBER, STRING, BINARY, DATETIME
from time import time_ns
from api.database.database import PreparedStatement, DbWarning, DbConnectionFactory, DatabaseError
from api.database.abstract_py_connector import PythonDbConnection
from api.database.abstract_py_connector import PythonDbConnection, Cursor
print(f"Using MySQL Connector Version: {VERSION_TEXT}")
......@@ -86,6 +86,9 @@ class _MySqlDbConnection(PythonDbConnection[MySQLConnection, MySQLCursor]):
def _fetch_auto_increment_value(self, cursor: MySQLCursor) -> int or None:
return cursor.getlastrowid()
def _fetch_affected_rows_count(self, cursor: Cursor) -> int or None:
return cursor.rowcount
def _fetch_all_rows(self, cursor: MySQLCursor) -> list[tuple]:
try:
return cursor.fetchall()
......
......@@ -79,6 +79,9 @@ class SqLiteDbConnection(PythonDbConnection[Connection, Cursor]):
def _fetch_auto_increment_value(self, cursor: Cursor) -> int or None:
return cursor.lastrowid
def _fetch_affected_rows_count(self, cursor: Cursor) -> int or None:
return cursor.rowcount
def _fetch_all_rows(self, cursor: Cursor) -> list[tuple]:
return cursor.fetchall()
......
from api.miscellaneous.util import DEBUG_ENABLED, API_DATETIME_FORMAT, truncate_string, flat_map
from api.miscellaneous.util import (DEBUG_ENABLED, API_DATETIME_FORMAT, ID_STRING_REGEX_NO_LENGTH,
ID_STRING_PATTERN_NO_LENGTH, truncate_string, flat_map)
from api.miscellaneous.db_util import (DB_ID_PATTERN, db_ensure_valid_id,
db_collect_id_sorted_data, db_collect_unsorted_data,
db_group_data_by_id)
......
from typing import Callable, TypeVar
import re
import server
from api.miscellaneous.constants import *
DEBUG_ENABLED: bool = server.config.get("DEBUG", False)
API_DATETIME_FORMAT = "%Y-%m-%dT%H:%M:%S"
ID_STRING_REGEX_NO_LENGTH = "(?=.*[a-z_-])[a-z0-9_-]*"
ID_STRING_PATTERN_NO_LENGTH = re.compile(ID_STRING_REGEX_NO_LENGTH)
print(f"Debug is enabled: {DEBUG_ENABLED}")
......
......@@ -378,7 +378,7 @@ class ObjectClass:
object_id = transaction.execute_statement_full_result(f"""
INSERT INTO `{self._db_table}` ({",".join(map(lambda t: f"`{t[0]}`", column_updates))})
VALUES ({"" if len(column_updates) == 0 else ("?," * (len(column_updates) - 1) + "?")})
""", *[new_value_db for _, new_value_db in column_updates])[3]
""", *[new_value_db for _, new_value_db in column_updates])[DB_RESULT_AUTO_INCREMENT]
changelog_update.add_entry(CreationEntry(
self._creation_field,
object_id,
......
import re
from api.miscellaneous import ID_STRING_REGEX_NO_LENGTH
from api.objects.type import (NumberType, IntType, LongType,
SimpleType, ObjectIdType,
StringType, SimpleStringType, UniqueStringType, MappedStringType,
DatetimeType)
ID_STRING_REGEX_NO_LENGTH = "(?=.*[a-z_-])[a-z0-9_-]*"
ID_STRING_PATTERN_NO_LENGTH = re.compile(ID_STRING_REGEX_NO_LENGTH)
SEMESTER_STRING_REGEX = "([0-9]{4}(ws|ss)|none)"
SEMESTER_STRING_PATTERN = re.compile(SEMESTER_STRING_REGEX)
......
from api.miscellaneous import *
from api.database import *
from api.routes.route import api_route, api_function, api_add_route, ApiResponse, check_client_int, api_request_get_query_int
from api.routes.route import (api_route, api_function, api_add_route, ApiResponse, check_client_int,
api_request_get_query_int, api_request_get_query_string)
from api.authentication import api_moderator_route, is_moderator, get_user_id, check_csrf_token
import api.routes.authentication
......
from functools import wraps
import traceback
from flask import Response, request
from re import Pattern
import server
from api.authentication import is_moderator
......@@ -28,13 +29,16 @@ def check_client_int(value: int, name: str, min_value: int = MIN_VALUE_UINT32, m
name, f"Value must not be less than {max_value}"))
def api_request_get_query_string(id: str, max_length: int, default: str):
def api_request_get_query_string(id: str, max_length: int, pattern: Pattern or None = None, default: str or None = None) -> str or None:
if id not in request.args:
return default
value = request.args[id]
if len(value) > max_length:
raise ApiClientException(ERROR_REQUEST_INVALID_PARAMETER(
f"URL.{id}", f"Must not be longer than {max_length} characters"))
if pattern is not None and pattern.fullmatch(value) is None:
raise ApiClientException(ERROR_REQUEST_INVALID_PARAMETER(
f"URL.{id}", f"Does not match pattern"))
return value
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment