From 98203e12f9c44772d37e958d1c2dd56b81dea5ac Mon Sep 17 00:00:00 2001
From: Dorian Koch <doriank@fsmpi.rwth-aachen.de>
Date: Tue, 24 Sep 2024 15:08:14 +0200
Subject: [PATCH] add signal handler

---
 src/event_queue.py    |  5 ++++-
 src/kubernetes_api.py |  2 ++
 src/main.py           | 32 +++++++++++++++++++++++++++++---
 3 files changed, 35 insertions(+), 4 deletions(-)

diff --git a/src/event_queue.py b/src/event_queue.py
index 4f4a88f..6206a6e 100644
--- a/src/event_queue.py
+++ b/src/event_queue.py
@@ -74,7 +74,10 @@ class EventQueue():
 
     def put(self, event: Event):
         self.queue.put((event.due_at, self.counter, event))
-        self.counter += 1  # this enforces stability in the queue
+        self.counter += 1  # this enforces FIFO in the queue (after comparing due_at)
 
     def get(self):
         return self.queue.get()[2]
+
+    def empty(self):
+        return self.queue.empty()
diff --git a/src/kubernetes_api.py b/src/kubernetes_api.py
index 216727c..5abbfcc 100644
--- a/src/kubernetes_api.py
+++ b/src/kubernetes_api.py
@@ -15,9 +15,11 @@ class K8sApi():
         if os.path.isfile(config.incluster_config.SERVICE_TOKEN_FILENAME):
             print("Using incluster config")
             config.load_incluster_config()
+            self.config_used = "incluster"
         else:
             print("Using local config")
             config.load_kube_config()
+            self.config_used = "local"
         self.api = client.ApiClient()
         self.v1 = client.CoreV1Api(self.api)
         self.batch_v1 = client.BatchV1Api(self.api)
diff --git a/src/main.py b/src/main.py
index ffadddc..de2b741 100644
--- a/src/main.py
+++ b/src/main.py
@@ -8,14 +8,20 @@ import datetime
 import time
 import argparse
 import sys
+import signal
 
 
 def main():
     parser = argparse.ArgumentParser(description='Run the job controller')
     parser.add_argument('--purge_existing_jobs', action='store_true', help='Delete existing jobs in k8s (dangerous!)')
+    parser.add_argument('--incluster', action='store_true', help='Require running in k8s cluster')
     args = parser.parse_args()
 
     cstate = ControllerState()
+    if args.incluster and cstate.k8s.config_used != "incluster":
+        print("Incluster required, but not running in k8s cluster")
+        sys.exit(1)
+        return
 
     # check existing jobs
     print("Checking for existing jobs...")
@@ -51,6 +57,7 @@ def main():
             cstate.event_queue.put(WatchJob(watch.metadata.labels["job_id"]))
 
     print("Done checking for existing jobs")
+    # TODO: check for existing jobs in spawning state in db that are not in k8s and requeue them (ready state will be picked up by FindReadyJobs)
 
     # make some dummy jobs
     start_id = int(datetime.datetime.now().timestamp())
@@ -59,13 +66,26 @@ def main():
 
     cstate.event_queue.put(FindReadyJobs())
 
-    while True:
+    run_event_loop = True
+
+    def signal_handler(sig, frame):
+        nonlocal run_event_loop
+        print("Stopping event loop...")
+        run_event_loop = False
+    signal.signal(signal.SIGINT, signal_handler)
+
+    while run_event_loop:
         evt = cstate.event_queue.get()
-        while not evt.canExecute():
+        while not evt.canExecute() and run_event_loop:
             # because the queue is sorted by due_at, we can wait until the this event is due
             tts = (evt.due_at - datetime.datetime.now()).total_seconds()
             print(f" >> Sleeping for {tts} seconds until next event is due: {evt}")
-            time.sleep(tts)
+            while run_event_loop and tts > 0:
+                time.sleep(1)
+                tts = (evt.due_at - datetime.datetime.now()).total_seconds()
+        if not run_event_loop:
+            cstate.event_queue.put(evt)  # put back before quitting
+            break
         try:
             start = datetime.datetime.now()
             ret = evt(cstate)
@@ -80,6 +100,12 @@ def main():
             print("###")
             print(f"Error in event {evt}: {e}")
             print("###")
+    print("Event loop stopped")
+    # print all remaining events
+    print("Remaining events in queue:")
+    while not cstate.event_queue.empty():
+        print(cstate.event_queue.get())
+    sys.exit(0)
 
 
 if __name__ == "__main__":
-- 
GitLab