From 961e4b4cba24ac962adef3daf2a50f7972eb56b5 Mon Sep 17 00:00:00 2001
From: Dorian Koch <doriank@fsmpi.rwth-aachen.de>
Date: Thu, 26 Sep 2024 11:46:54 +0200
Subject: [PATCH] Recover spawning jobs

---
 src/job_controller.py   |  5 +++++
 src/job_database_api.py | 15 +++++++++++----
 src/main.py             | 42 +++++++++++++++++++++++++++++------------
 3 files changed, 46 insertions(+), 16 deletions(-)

diff --git a/src/job_controller.py b/src/job_controller.py
index 294638a..afbeb4e 100644
--- a/src/job_controller.py
+++ b/src/job_controller.py
@@ -4,6 +4,7 @@ from job_database_api import DummyJobDatabaseApi, JobData, JobDatabaseApi
 from jobs.dummy_job import DummyJob
 from kubernetes_api import K8sApi
 import os
+import datetime
 
 
 def load_config():
@@ -32,6 +33,10 @@ class ControllerState():
         db_engine = self.config.get("DB_ENGINE")
         if db_engine == "dummy":
             self.job_api = DummyJobDatabaseApi()
+            # make some dummy jobs
+            start_id = int(datetime.datetime.now().timestamp())
+            for i in range(start_id, start_id + 4):
+                self.job_api.create_job(JobData("job{}".format(i), "dummy"))
         else:
             raise Exception(f"Unknown DB_ENGINE: {db_engine}")
         self.event_queue = EventQueue()
diff --git a/src/job_database_api.py b/src/job_database_api.py
index b4c8562..93d904c 100644
--- a/src/job_database_api.py
+++ b/src/job_database_api.py
@@ -34,19 +34,23 @@ class JobDatabaseApi(metaclass=ABCMeta):
 
     @abstractmethod
     def get_next_jobs_and_set_spawning(self, limit: int) -> list[JobData]:
-        pass
+        pass  # atomically retrieve and set state to SPAWNING
+
+    @abstractmethod
+    def get_all_spawning_jobs(self) -> list[JobData]:
+        pass  # used for recovery
 
     @abstractmethod
     def get_job_by_id(self, job_id: str) -> Optional[JobData]:
-        pass
+        pass  # refresh state from db
 
     @abstractmethod
     def create_job(self, job: JobData):
-        pass
+        pass  # insert into db
 
     @abstractmethod
     def update_job_state(self, job_id: str, new_state: JobState):
-        pass
+        pass  # update state in db
 
 
 class DummyJobDatabaseApi(JobDatabaseApi):
@@ -65,6 +69,9 @@ class DummyJobDatabaseApi(JobDatabaseApi):
             ret.append(next)
         return ret
 
+    def get_all_spawning_jobs(self) -> list[JobData]:
+        return [job for job in self.db_state.values() if job.job_state == JobState.SPAWNING]
+
     def get_job_by_id(self, job_id: str) -> Optional[JobData]:
         return copy.deepcopy(self.db_state.get(job_id, None))
 
diff --git a/src/main.py b/src/main.py
index e0cf58d..47e8da3 100644
--- a/src/main.py
+++ b/src/main.py
@@ -1,7 +1,7 @@
-from actions.spawn_job import WatchJob
+from actions.spawn_job import SpawnJob, WatchJob
 from event_queue import EventResult
 from actions.find_ready_jobs import FindReadyJobs
-from job_database_api import JobData
+from job_database_api import JobData, JobState
 from job_controller import ControllerState
 
 import datetime
@@ -60,12 +60,23 @@ def main():
             cstate.event_queue.put(WatchJob(watch.metadata.labels["job_id"]))
 
     print("Done checking for existing jobs")
-    # TODO: check for existing jobs in spawning state in db that are not in k8s and requeue them (ready state will be picked up by FindReadyJobs)
-
-    # make some dummy jobs
-    start_id = int(datetime.datetime.now().timestamp())
-    for i in range(start_id, start_id + 4):
-        cstate.job_api.create_job(JobData("job{}".format(i), "dummy"))
+    # find spawning jobs that are not in k8s
+    spawning_jobs = cstate.job_api.get_all_spawning_jobs()
+    num_resetted = 0
+    for job in spawning_jobs:
+        exists = False
+        for k8sjob in existing_worker_jobs.items:
+            if k8sjob.metadata.labels["job_id"] == job.job_id:
+                exists = True
+                break
+        if not exists:
+            # reset to ready
+            # TODO: maybe reset them to an error state?
+            job.update_state(cstate, JobState.READY)
+            num_resetted += 1
+    if num_resetted > 0:
+        print(f"Resetted {num_resetted} spawning jobs to ready state")
+    print("Reconcilation done")
 
     cstate.event_queue.put(FindReadyJobs())
 
@@ -107,10 +118,17 @@ def main():
             print(f"Error in event {evt}: {e}")
             print("###")
     print("Event loop stopped")
-    # print all remaining events
-    print("Remaining events in queue:")
-    while not cstate.event_queue.empty():
-        print(cstate.event_queue.get())
+    # set all jobs that were supposed to be spawned back to ready
+    num_readied = 0
+    for evt in cstate.event_queue.queue.queue:
+        if isinstance(evt, SpawnJob):
+            # get up to date job data (it may have been canceled)
+            job = cstate.job_api.get_job_by_id(evt.job.jobData.job_id)
+            if job is not None and job.job_state == JobState.SPAWNING:
+                job.update_state(cstate, JobState.READY)
+                num_readied += 1
+    if num_readied > 0:
+        print(f"Readied {num_readied} jobs that were supposed to be spawned")
     sys.exit(0)
 
 
-- 
GitLab