Skip to content
Snippets Groups Projects
Select Git revision
  • 49a6c008b0468bd05cef552514a241143456a19e
  • main default
  • full_migration
  • v1.0.9 protected
  • v1.0.8 protected
  • v1.0.7 protected
  • v1.0.6 protected
  • v1.0.5 protected
  • v1.0.4 protected
  • v1.0.3 protected
  • v1.0.2 protected
  • v1.0.1 protected
  • v1.0 protected
13 results

route.py

Blame
  • Code owners
    Assign users and groups as approvers for specific file changes. Learn more.
    route.py 9.48 KiB
    from functools import wraps
    import traceback
    from flask import Response, request
    from re import Pattern
    import time
    
    import api
    from api.authentication import is_moderator
    from api.database.database import TransactionConflictError, NoAvailableConnectionError
    from api.miscellaneous import *
    from api.version import get_api_path, API_LATEST_VERSION, API_OLDEST_ACTIVE_VERSION
    
    
    _API_GLOBAL_RATE_LIMITERS = create_configured_host_rate_limiters("global", api.config["API_GLOBAL_RATE_LIMIT"])
    _DEFAULT_CACHE_CONTROL_MAX_AGE_SECONDS = api.config["DEFAULT_CACHE_CONTROL_MAX_AGE_SECONDS"]
    
    
    def check_client_int(value: int, name: str, min_value: int = MIN_VALUE_UINT32, max_value: int = MAX_VALUE_UINT32):
        from api.miscellaneous.errors import ApiClientException, ERROR_REQUEST_INVALID_PARAMETER
        if value < min_value:
            raise ApiClientException(ERROR_REQUEST_INVALID_PARAMETER(
                name, f"Value must not be less than {min_value}"))
        if value > max_value:
            raise ApiClientException(ERROR_REQUEST_INVALID_PARAMETER(
                name, f"Value must not be greater than {max_value}"))
    
    
    def api_request_get_query_string(id: str, max_length: int, pattern: Pattern or None = None, default: str or None = None) -> str or None:
        if id not in request.args:
            return default
        value = request.args[id]
        if len(value) > max_length:
            raise ApiClientException(ERROR_REQUEST_INVALID_PARAMETER(
                f"URL.{id}", f"Must not be longer than {max_length} characters"))
        if pattern is not None and pattern.fullmatch(value) is None:
            raise ApiClientException(ERROR_REQUEST_INVALID_PARAMETER(
                f"URL.{id}", f"Does not match pattern"))
        return value
    
    
    def api_request_get_query_int(id: str, default: int or None, min_value: int = MIN_VALUE_UINT32, max_value: int = MAX_VALUE_UINT32) -> int or None:
        if id not in request.args:
            return default
        try:
            value = int(request.args[id])
        except ValueError:
            raise ApiClientException(ERROR_REQUEST_INVALID_PARAMETER(
                f"URL.{id}", "Cannot parse integer"))
        check_client_int(value, f"URL.{id}", min_value, max_value)
        return value
    
    
    @api.app.errorhandler(400)
    def _handle_bad_request(e=None):
        return api_on_error(ERROR_BAD_REQUEST("Bad request"))
    
    
    @api.app.errorhandler(404)
    def _handle_not_found(e=None):
        return api_on_error(ERROR_UNKNOWN_REQUEST_PATH)
    
    
    @api.app.errorhandler(405)
    def _handle_method_not_allowed(e=None):
        return api_on_error(ERROR_METHOD_NOT_ALLOWED)
    
    
    @api.app.errorhandler(500)
    @api.app.errorhandler(Exception)
    def _handle_internal_server_error(e=None):
        return api_on_error(ERROR_INTERNAL_SERVER_ERROR)
    
    
    class ApiResponse:
        
        def __init__(self,
                     data,
                     status: int = HTTP_200_OK,
                     mime_type: str or None = "application/json",
                     headers: dict[str, str] or None = None,
                     default_cache_control_max_age_sec: int or None = _DEFAULT_CACHE_CONTROL_MAX_AGE_SECONDS):
            super().__init__()
            if mime_type == "application/json" and (isinstance(data, dict) or isinstance(data, list)):
                data = json.dumps(data)
            self._default_cache_control_max_age_sec = default_cache_control_max_age_sec
            
            self.response = Response(
                data,
                status=status,
                headers=headers,
                mimetype=mime_type
            )
        
        def build_response(self):
            if "Cache-Control" not in self.response.headers:
                cache_params: list[str] = []
                if (self.response.status_code < 400
                        and request.method == "GET"
                        and "Set-Cookie" not in self.response.headers
                        and not is_moderator()
                        and self._default_cache_control_max_age_sec is not None):
                    cache_params.append(f"max-age={self._default_cache_control_max_age_sec}")
                    cache_params.append("public")
                else:
                    cache_params.append("no-store")
                self.response.headers["Cache-Control"] = ",".join(cache_params)
            
            return self.response
    
    
    def api_route(path: str, methods: list[str],
                  allow_while_readonly: bool = False,
                  allow_while_disabled: bool = False,
                  min_api_version: int = API_OLDEST_ACTIVE_VERSION,
                  max_api_version: int = API_LATEST_VERSION):
        def decorator(func):
            func = api_function(allow_while_readonly=allow_while_readonly, allow_while_disabled=allow_while_disabled)(func)
            func = api_add_route(path, methods, min_api_version, max_api_version)(func)
            return func
        
        return decorator
    
    
    def api_add_route(path: str, methods: list[str],
                      min_api_version: int = API_OLDEST_ACTIVE_VERSION,
                      max_api_version: int = API_LATEST_VERSION):
        def decorator(func):
            if not hasattr(func, "is_api_route") or not func.is_api_route:
                raise Exception("@api_add_route() seems to be applied before @api_function()")
            
            for version in range(min_api_version, max_api_version + 1):
                full_path = get_api_path(version, path)
                if DEBUG_ENABLED:
                    print(f"Registering api route: {full_path}")
                api.app.add_url_rule(full_path, methods=methods, view_func=func)
            return func
        
        return decorator
    
    
    def api_function(track_in_diagnostics: bool = True,
                     allow_while_readonly: bool = False,
                     allow_while_disabled: bool = False,
                     rate_limiters: tuple[HostBasedCounterRateLimiter, ...] = _API_GLOBAL_RATE_LIMITERS):
        def decorator(func):
            if hasattr(func, "is_api_route") and func.is_api_route:
                raise Exception("An @api_function() decorator has already been applied. Are you using multiple @api_route()? "
                                "Use @api_add_route(...)@api_add_route(..)@api_function() instead")
            
            call_counter = None
            call_time_counter = None
            if track_in_diagnostics:
                func_name: str = func.__name__
                if not func_name.startswith("api_route_"):  # pragma: no cover
                    raise RuntimeError("Api route function names must start with 'api_route_' "
                                       "(These names are used in diagnostics data)")
                func_name = func_name[len("api_route_"):]
                if len(func_name) == 0:
                    raise RuntimeError("Api route function has no name (just 'api_route_')")  # pragma: no cover
                
                call_counter = DIAGNOSTICS_TRACKER.register_counter(f"route.{func_name}")
                call_time_counter = DIAGNOSTICS_TRACKER.register_counter(f"route.{func_name}.time")
            
            @wraps(func)
            def wrapper(*args, **kwargs):
                try:
                    if api.live_config.is_disabled() and not allow_while_disabled:
                        raise ApiClientException(ERROR_SITE_IS_DISABLED)
                    if api.live_config.is_readonly() and not allow_while_readonly:
                        raise ApiClientException(ERROR_SITE_IS_READONLY)
                    
                    if call_counter is not None:
                        call_counter.trigger()
                    
                    if "X_REAL_IP" in request.headers:
                        ip_string = request.headers["X_REAL_IP"]
                        for rate_limiter in rate_limiters:
                            if not rate_limiter.check_new_request(ip_string):
                                raise ApiClientException(ERROR_RATE_LIMITED)
                    
                    if DEBUG_ENABLED and "API_ROULETTE_MODE" in api.config:
                        import random
                        from api.miscellaneous.errors import ALL_ERRORS_RANDOM
                        if random.random() * 100 < int(api.config["API_ROULETTE_MODE"]):
                            raise ApiClientException(random.choice(ALL_ERRORS_RANDOM))
                    
                    start_time = time.time()
                    result = func(*args, **kwargs)
                    if call_time_counter is not None:
                        call_time_counter.trigger(int((time.time() - start_time) * 1000))
                    
                    if isinstance(result, Response):
                        return result
                    elif isinstance(result, ApiResponse):
                        return result.build_response()
                    elif result is None:
                        return ApiResponse(None, HTTP_200_OK, None).build_response()
                    elif isinstance(result, dict):
                        return ApiResponse(result).build_response()
                    elif isinstance(result, tuple) and len(result) == 2:
                        return ApiResponse(result[0], result[1]).build_response()
                    else:  # pragma: no cover
                        raise Exception(f"Api route {truncate_string(request.path)} returned result of unknown type: {str(result)}")
                except ApiClientException as e:
                    return api_on_error(e.error)
                except (TransactionConflictError, NoAvailableConnectionError) as e:
                    print(f"An transaction conflict occurred while handling api request '{truncate_string(request.path, 200)}':")
                    traceback.print_exception(e)
                    return api_on_error(ERROR_SITE_IS_OVERLOADED)
                except Exception as e:
                    print(f"An exception occurred while handling api request '{truncate_string(request.path, 200)}':")
                    traceback.print_exception(e)
                    return api_on_error(ERROR_INTERNAL_SERVER_ERROR)
            
            wrapper.is_api_route = True
            return wrapper
        
        return decorator