From 8a6ac9d0fb36baba44b28bed41c598e95d4a4213 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Simon=20K=C3=BCnzel?= <simonk@fsmpi.rwth-aachen.de>
Date: Sat, 7 Dec 2024 00:28:44 +0100
Subject: [PATCH] Add Open Telemetry OTLP traces

---
 config/api_example_config.py |   3 +
 requirements.txt             |   5 +
 src/api/routes/route.py      | 206 +++++++++++++++++++++++------------
 3 files changed, 147 insertions(+), 67 deletions(-)

diff --git a/config/api_example_config.py b/config/api_example_config.py
index 259fea6..7db2920 100644
--- a/config/api_example_config.py
+++ b/config/api_example_config.py
@@ -130,3 +130,6 @@ INTERNAL_IP_RANGES = ["127.0.0.0/8", "192.168.155.0/24", "fd78:4d90:6fe4::/48"]
 
 # Only for debugging. In percent, from 0 to 100. With this you need luck to make a request
 #  API_ROULETTE_MODE = 0
+
+# OPEN_TELEMETRY_OLTP_ENDPOINT = "http://localhost:4318"
+OPEN_TELEMETRY_TRACE_FILTER_ONLY_HTTP_HEADER = "X-Trace-Me"
diff --git a/requirements.txt b/requirements.txt
index 4019fe3..8d6474d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -17,6 +17,11 @@ mysql-connector-python==8.4.0
 psycopg[c]==3.1.19
 # sqlite is part of the standard library
 
+# Open Telemetry
+opentelemetry-api==1.28.2
+opentelemetry-sdk==1.28.2
+opentelemetry-exporter-otlp-proto-http==1.28.2
+
 # required for testing
 coverage==7.5.1
 pylint==3.2.0
diff --git a/src/api/routes/route.py b/src/api/routes/route.py
index 74dbc8b..aca1758 100644
--- a/src/api/routes/route.py
+++ b/src/api/routes/route.py
@@ -11,10 +11,38 @@ from api.miscellaneous import *
 from api.version import get_api_path, API_LATEST_VERSION, API_OLDEST_ACTIVE_VERSION
 
 
+_SERVER_NAME = api.config["API_SERVER_NAME"]
+
 _API_GLOBAL_RATE_LIMITERS = create_configured_host_rate_limiters("global", api.config["API_GLOBAL_RATE_LIMIT"])
 _DEFAULT_CACHE_CONTROL_MAX_AGE_SECONDS = api.config["DEFAULT_CACHE_CONTROL_MAX_AGE_SECONDS"]
 
 
+def _init_open_telemetry_tracer(oltp_endpoint: str) -> "Tracer":
+    from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
+    from opentelemetry.sdk.trace import TracerProvider
+    from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExportResult
+    from opentelemetry.sdk.resources import Resource, SERVICE_NAME
+    
+    resource = Resource(attributes={
+        SERVICE_NAME: "videoag-api"
+    })
+    
+    provider = TracerProvider(resource=resource)
+    exporter = OTLPSpanExporter(endpoint=f"{oltp_endpoint}/v1/traces")
+    provider.add_span_processor(BatchSpanProcessor(exporter))
+
+    if exporter.export([]) != SpanExportResult.SUCCESS:
+        print(f"Warning: Test span export to {oltp_endpoint} failed. Tracing might not work")
+    
+    tracer = provider.get_tracer("videoag.api")
+    return tracer
+
+
+open_telemetry_tracer = None
+if "OPEN_TELEMETRY_OLTP_ENDPOINT" in api.config:
+    open_telemetry_tracer = _init_open_telemetry_tracer(api.config["OPEN_TELEMETRY_OLTP_ENDPOINT"])
+
+
 def check_client_int(value: int, name: str, min_value: int = MIN_VALUE_UINT32, max_value: int = MAX_VALUE_UINT32):
     from api.miscellaneous.errors import ApiClientException, ERROR_REQUEST_INVALID_PARAMETER
     if value < min_value:
@@ -72,7 +100,7 @@ def _handle_internal_server_error(e=None):
 
 
 class ApiResponse:
-
+    
     def __init__(self,
                  data,
                  status: int = HTTP_200_OK,
@@ -83,14 +111,14 @@ class ApiResponse:
         if mime_type == "application/json" and (isinstance(data, dict) or isinstance(data, list)):
             data = json.dumps(data)
         self._default_cache_control_max_age_sec = default_cache_control_max_age_sec
-
+        
         self.response = Response(
             data,
             status=status,
             headers=headers,
             mimetype=mime_type
         )
-
+    
     def build_response(self):
         if "Cache-Control" not in self.response.headers:
             cache_params: list[str] = []
@@ -104,7 +132,7 @@ class ApiResponse:
             else:
                 cache_params.append("no-store")
             self.response.headers["Cache-Control"] = ",".join(cache_params)
-
+        
         return self.response
 
 
@@ -117,7 +145,7 @@ def api_route(path: str, methods: list[str],
         func = api_function(allow_while_readonly=allow_while_readonly, allow_while_disabled=allow_while_disabled)(func)
         func = api_add_route(path, methods, min_api_version, max_api_version)(func)
         return func
-
+    
     return decorator
 
 
@@ -127,17 +155,107 @@ def api_add_route(path: str, methods: list[str],
     def decorator(func):
         if not hasattr(func, "is_api_route") or not func.is_api_route:
             raise Exception("@api_add_route() seems to be applied before @api_function()")
-
+        
         for version in range(min_api_version, max_api_version + 1):
             full_path = get_api_path(version, path)
             if DEBUG_ENABLED:
                 print(f"Registering api route: {full_path}")
             api.app.add_url_rule(full_path, methods=methods, view_func=func)
         return func
-
+    
     return decorator
 
 
+def _execute_api_route(
+        call_counter: DiagnosticsCounter or None,
+        allow_while_readonly: bool,
+        allow_while_disabled: bool,
+        rate_limiters: tuple[HostBasedCounterRateLimiter, ...],
+        func,
+        args,
+        kwargs
+):
+    try:
+        if api.live_config.is_disabled() and not allow_while_disabled:
+            raise ApiClientException(ERROR_SITE_IS_DISABLED)
+        if api.live_config.is_readonly() and not allow_while_readonly:
+            raise ApiClientException(ERROR_SITE_IS_READONLY)
+        
+        if call_counter is not None:
+            call_counter.trigger()
+        
+        if "X_REAL_IP" in request.headers:
+            ip_string = request.headers["X_REAL_IP"]
+            for rate_limiter in rate_limiters:
+                if not rate_limiter.check_new_request(ip_string):
+                    raise ApiClientException(ERROR_RATE_LIMITED)
+        
+        if DEBUG_ENABLED and "API_ROULETTE_MODE" in api.config:
+            import random
+            from api.miscellaneous.errors import ALL_ERRORS_RANDOM
+            if random.random() * 100 < int(api.config["API_ROULETTE_MODE"]):
+                raise ApiClientException(random.choice(ALL_ERRORS_RANDOM))
+        
+        result = func(*args, **kwargs)
+        
+        if isinstance(result, Response):
+            return result
+        elif isinstance(result, ApiResponse):
+            return result.build_response()
+        elif result is None:
+            return ApiResponse(None, HTTP_200_OK, None).build_response()
+        elif isinstance(result, dict):
+            return ApiResponse(result).build_response()
+        elif isinstance(result, tuple) and len(result) == 2:
+            return ApiResponse(result[0], result[1]).build_response()
+        else:  # pragma: no cover
+            raise Exception(f"Api route {truncate_string(request.path)} returned result of unknown type: {str(result)}")
+    except ApiClientException as e:
+        return api_on_error(e.error)
+    except (TransactionConflictError, NoAvailableConnectionError) as e:
+        print(f"An transaction conflict occurred while handling api request '{truncate_string(request.path, 200)}':")
+        traceback.print_exception(e)
+        return api_on_error(ERROR_SITE_IS_OVERLOADED)
+    except Exception as e:
+        print(f"An exception occurred while handling api request '{truncate_string(request.path, 200)}':")
+        traceback.print_exception(e)
+        return api_on_error(ERROR_INTERNAL_SERVER_ERROR)
+
+
+def _handle_api_request(
+        call_counter: DiagnosticsCounter or None,
+        allow_while_readonly: bool,
+        allow_while_disabled: bool,
+        rate_limiters: tuple[HostBasedCounterRateLimiter, ...],
+        func,
+        args,
+        kwargs
+):
+    do_tracing = open_telemetry_tracer is not None
+    if do_tracing:
+        header_filter = api.config.get("OPEN_TELEMETRY_TRACE_FILTER_ONLY_HTTP_HEADER", "")
+        if header_filter and header_filter not in request.headers:
+            do_tracing = False
+    
+    start_time = time.time_ns()
+    
+    if do_tracing:
+        from opentelemetry.sdk.trace import Tracer
+        assert isinstance(open_telemetry_tracer, Tracer)
+        with open_telemetry_tracer.start_as_current_span("api-request", attributes={
+            "url": request.url,
+            "server": _SERVER_NAME,
+        }):
+            resp = _execute_api_route(call_counter, allow_while_readonly, allow_while_disabled, rate_limiters, func, args, kwargs)
+    else:
+        resp = _execute_api_route(call_counter, allow_while_readonly, allow_while_disabled, rate_limiters, func, args, kwargs)
+    
+    exec_time_ms = (time.time_ns() - start_time) // 1_000_000
+    resp.headers["Server-Timing"] = f"api;dur={exec_time_ms}"
+    resp.headers["X-Server-Name"] = _SERVER_NAME
+    return resp
+
+
 def api_function(track_in_diagnostics: bool = True,
                  allow_while_readonly: bool = False,
                  allow_while_disabled: bool = False,
@@ -146,9 +264,8 @@ def api_function(track_in_diagnostics: bool = True,
         if hasattr(func, "is_api_route") and func.is_api_route:
             raise Exception("An @api_function() decorator has already been applied. Are you using multiple @api_route()? "
                             "Use @api_add_route(...)@api_add_route(..)@api_function() instead")
-
+        
         call_counter = None
-        call_time_counter = None
         if track_in_diagnostics:
             func_name: str = func.__name__
             if not func_name.startswith("api_route_"):  # pragma: no cover
@@ -157,67 +274,22 @@ def api_function(track_in_diagnostics: bool = True,
             func_name = func_name[len("api_route_"):]
             if len(func_name) == 0:
                 raise RuntimeError("Api route function has no name (just 'api_route_')")  # pragma: no cover
-
+            
             call_counter = DIAGNOSTICS_TRACKER.register_counter(f"route.{func_name}")
-            call_time_counter = DIAGNOSTICS_TRACKER.register_counter(f"route.{func_name}.time")
-
+        
         @wraps(func)
         def wrapper(*args, **kwargs):
-            try:
-                if api.live_config.is_disabled() and not allow_while_disabled:
-                    raise ApiClientException(ERROR_SITE_IS_DISABLED)
-                if api.live_config.is_readonly() and not allow_while_readonly:
-                    raise ApiClientException(ERROR_SITE_IS_READONLY)
-
-                if call_counter is not None:
-                    call_counter.trigger()
-
-                if "X_REAL_IP" in request.headers:
-                    ip_string = request.headers["X_REAL_IP"]
-                    for rate_limiter in rate_limiters:
-                        if not rate_limiter.check_new_request(ip_string):
-                            raise ApiClientException(ERROR_RATE_LIMITED)
-
-                if DEBUG_ENABLED and "API_ROULETTE_MODE" in api.config:
-                    import random
-                    from api.miscellaneous.errors import ALL_ERRORS_RANDOM
-                    if random.random() * 100 < int(api.config["API_ROULETTE_MODE"]):
-                        raise ApiClientException(random.choice(ALL_ERRORS_RANDOM))
-
-                start_time = time.time_ns()
-                result = func(*args, **kwargs)
-                stop_time = time.time_ns()
-                if call_time_counter is not None:
-                    call_time_counter.trigger(int((stop_time - start_time) / 1000000))
-
-                if isinstance(result, Response):
-                    resp = result
-                elif isinstance(result, ApiResponse):
-                    resp = result.build_response()
-                elif result is None:
-                    resp = ApiResponse(None, HTTP_200_OK, None).build_response()
-                elif isinstance(result, dict):
-                    resp = ApiResponse(result).build_response()
-                elif isinstance(result, tuple) and len(result) == 2:
-                    resp = ApiResponse(result[0], result[1]).build_response()
-                else:  # pragma: no cover
-                    raise Exception(f"Api route {truncate_string(request.path)} returned result of unknown type: {str(result)}")
-
-                if call_time_counter is not None:  # i.e. diagnostics are enabled
-                    resp.headers["Server-Timing"] = f"api;dur={(stop_time - start_time) / 1000000}"
-                return resp
-            except ApiClientException as e:
-                return api_on_error(e.error)
-            except (TransactionConflictError, NoAvailableConnectionError) as e:
-                print(f"An transaction conflict occurred while handling api request '{truncate_string(request.path, 200)}':")
-                traceback.print_exception(e)
-                return api_on_error(ERROR_SITE_IS_OVERLOADED)
-            except Exception as e:
-                print(f"An exception occurred while handling api request '{truncate_string(request.path, 200)}':")
-                traceback.print_exception(e)
-                return api_on_error(ERROR_INTERNAL_SERVER_ERROR)
-
+            return _handle_api_request(
+                call_counter,
+                allow_while_readonly,
+                allow_while_disabled,
+                rate_limiters,
+                func,
+                args,
+                kwargs
+            )
+        
         wrapper.is_api_route = True
         return wrapper
-
+    
     return decorator
-- 
GitLab