From 9ec0bf3ad0ae5a7abf6ababef82f6e0ea230ee39 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Simon=20K=C3=BCnzel?= <simonk@fsmpi.rwth-aachen.de>
Date: Mon, 10 Feb 2025 00:24:46 +0100
Subject: [PATCH] Rework when drift check is skipped

---
 src/videoag_common/database/database.py | 37 +++++++++++++++----------
 1 file changed, 22 insertions(+), 15 deletions(-)

diff --git a/src/videoag_common/database/database.py b/src/videoag_common/database/database.py
index 2cc536a..64fed7d 100644
--- a/src/videoag_common/database/database.py
+++ b/src/videoag_common/database/database.py
@@ -44,9 +44,9 @@ class Database:
         
         check_drift = dict_get_check_type(engine_config, "check_drift", bool, True)
         auto_migration = dict_get_check_type(engine_config, "auto_migration", bool, False)
-        ignore_no_connection = dict_get_check_type(engine_config, "ignore_no_connection", bool, False)
-        if os.environ.get("API_IGNORE_NO_DB_CONNECTION", "false").lower() == "true":
-            ignore_no_connection = True
+        skip_drift_check_when_no_connection = dict_get_check_type(engine_config, "skip_drift_check_when_no_connection", bool, False)
+        if os.environ.get("API_SKIP_DRIFT_CHECK_WHEN_NO_CONNECTION", "false").lower() == "true":
+            skip_drift_check_when_no_connection = True
         
         self._max_read_attempts = dict_get_check_type(engine_config, "max_read_attempts", int, 2)
         self._max_write_attempts = dict_get_check_type(engine_config, "max_write_attempts", int, 2)
@@ -81,18 +81,7 @@ class Database:
             )
         
         if check_drift:
-            # Ensure objects are loaded
-            import videoag_common.objects
-            drifted = False
-            try:
-                drifted = not check_for_drift_and_migrate(Base.metadata, self._engine, auto_migration)
-            except Exception as e:
-                if not ignore_no_connection:
-                    raise e
-                print(f"Exception while checking for drift. ignore_no_connection is set. Exception: {e}")
-            
-            if drifted:
-                raise Exception("Database schema has drifted!")
+            _startup_check_drift(auto_migration, skip_drift_check_when_no_connection)
         
         if config.get("log_all_statements", False):
             import logging
@@ -101,6 +90,24 @@ class Database:
             logger.addHandler(handler)
             logger.setLevel(logging.INFO)
     
+    def _startup_check_drift(self, auto_migration: bool, skip_drift_check_when_no_connection: bool):
+        if skip_drift_check_when_no_connection:
+            session = SessionDb(bind=self._write_engines_by_level[TransactionIsolationLevel.REPEATABLE_READ])
+            try:
+                with session.begin():
+                    session.execute(sqlachemy.text("SELECT 1"))
+            except Exception as e:
+                print(f"Unable to connect with database. Skipping drift check because "
+                      f"skip_drift_check_when_no_connection is set. Exception: {e}")
+                return
+        
+        # Ensure objects are loaded
+        import videoag_common.objects
+        
+        drifted = not check_for_drift_and_migrate(Base.metadata, self._engine, auto_migration)
+        if drifted:
+            raise Exception("Database schema has drifted!")
+    
     def query_all_and_expunge(self, stmt: sqlalchemy.Select[_T]) -> Sequence[_T]:
         def _trans(session: SessionDb):
             res = session.scalars(stmt).all()
-- 
GitLab