Select Git revision
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
route.py 9.48 KiB
from functools import wraps
import traceback
from flask import Response, request
from re import Pattern
import time
import api
from api.authentication import is_moderator
from api.database.database import TransactionConflictError, NoAvailableConnectionError
from api.miscellaneous import *
from api.version import get_api_path, API_LATEST_VERSION, API_OLDEST_ACTIVE_VERSION
_API_GLOBAL_RATE_LIMITERS = create_configured_host_rate_limiters("global", api.config["API_GLOBAL_RATE_LIMIT"])
_DEFAULT_CACHE_CONTROL_MAX_AGE_SECONDS = api.config["DEFAULT_CACHE_CONTROL_MAX_AGE_SECONDS"]
def check_client_int(value: int, name: str, min_value: int = MIN_VALUE_UINT32, max_value: int = MAX_VALUE_UINT32):
from api.miscellaneous.errors import ApiClientException, ERROR_REQUEST_INVALID_PARAMETER
if value < min_value:
raise ApiClientException(ERROR_REQUEST_INVALID_PARAMETER(
name, f"Value must not be less than {min_value}"))
if value > max_value:
raise ApiClientException(ERROR_REQUEST_INVALID_PARAMETER(
name, f"Value must not be greater than {max_value}"))
def api_request_get_query_string(id: str, max_length: int, pattern: Pattern or None = None, default: str or None = None) -> str or None:
if id not in request.args:
return default
value = request.args[id]
if len(value) > max_length:
raise ApiClientException(ERROR_REQUEST_INVALID_PARAMETER(
f"URL.{id}", f"Must not be longer than {max_length} characters"))
if pattern is not None and pattern.fullmatch(value) is None:
raise ApiClientException(ERROR_REQUEST_INVALID_PARAMETER(
f"URL.{id}", f"Does not match pattern"))
return value
def api_request_get_query_int(id: str, default: int or None, min_value: int = MIN_VALUE_UINT32, max_value: int = MAX_VALUE_UINT32) -> int or None:
if id not in request.args:
return default
try:
value = int(request.args[id])
except ValueError:
raise ApiClientException(ERROR_REQUEST_INVALID_PARAMETER(
f"URL.{id}", "Cannot parse integer"))
check_client_int(value, f"URL.{id}", min_value, max_value)
return value
@api.app.errorhandler(400)
def _handle_bad_request(e=None):
return api_on_error(ERROR_BAD_REQUEST("Bad request"))
@api.app.errorhandler(404)
def _handle_not_found(e=None):
return api_on_error(ERROR_UNKNOWN_REQUEST_PATH)
@api.app.errorhandler(405)
def _handle_method_not_allowed(e=None):
return api_on_error(ERROR_METHOD_NOT_ALLOWED)
@api.app.errorhandler(500)
@api.app.errorhandler(Exception)
def _handle_internal_server_error(e=None):
return api_on_error(ERROR_INTERNAL_SERVER_ERROR)
class ApiResponse:
def __init__(self,
data,
status: int = HTTP_200_OK,
mime_type: str or None = "application/json",
headers: dict[str, str] or None = None,
default_cache_control_max_age_sec: int or None = _DEFAULT_CACHE_CONTROL_MAX_AGE_SECONDS):
super().__init__()
if mime_type == "application/json" and (isinstance(data, dict) or isinstance(data, list)):
data = json.dumps(data)
self._default_cache_control_max_age_sec = default_cache_control_max_age_sec
self.response = Response(
data,
status=status,
headers=headers,
mimetype=mime_type
)
def build_response(self):
if "Cache-Control" not in self.response.headers:
cache_params: list[str] = []
if (self.response.status_code < 400
and request.method == "GET"
and "Set-Cookie" not in self.response.headers
and not is_moderator()
and self._default_cache_control_max_age_sec is not None):
cache_params.append(f"max-age={self._default_cache_control_max_age_sec}")
cache_params.append("public")
else:
cache_params.append("no-store")
self.response.headers["Cache-Control"] = ",".join(cache_params)
return self.response
def api_route(path: str, methods: list[str],
allow_while_readonly: bool = False,
allow_while_disabled: bool = False,
min_api_version: int = API_OLDEST_ACTIVE_VERSION,
max_api_version: int = API_LATEST_VERSION):
def decorator(func):
func = api_function(allow_while_readonly=allow_while_readonly, allow_while_disabled=allow_while_disabled)(func)
func = api_add_route(path, methods, min_api_version, max_api_version)(func)
return func
return decorator
def api_add_route(path: str, methods: list[str],
min_api_version: int = API_OLDEST_ACTIVE_VERSION,
max_api_version: int = API_LATEST_VERSION):
def decorator(func):
if not hasattr(func, "is_api_route") or not func.is_api_route:
raise Exception("@api_add_route() seems to be applied before @api_function()")
for version in range(min_api_version, max_api_version + 1):
full_path = get_api_path(version, path)
if DEBUG_ENABLED:
print(f"Registering api route: {full_path}")
api.app.add_url_rule(full_path, methods=methods, view_func=func)
return func
return decorator
def api_function(track_in_diagnostics: bool = True,
allow_while_readonly: bool = False,
allow_while_disabled: bool = False,
rate_limiters: tuple[HostBasedCounterRateLimiter, ...] = _API_GLOBAL_RATE_LIMITERS):
def decorator(func):
if hasattr(func, "is_api_route") and func.is_api_route:
raise Exception("An @api_function() decorator has already been applied. Are you using multiple @api_route()? "
"Use @api_add_route(...)@api_add_route(..)@api_function() instead")
call_counter = None
call_time_counter = None
if track_in_diagnostics:
func_name: str = func.__name__
if not func_name.startswith("api_route_"): # pragma: no cover
raise RuntimeError("Api route function names must start with 'api_route_' "
"(These names are used in diagnostics data)")
func_name = func_name[len("api_route_"):]
if len(func_name) == 0:
raise RuntimeError("Api route function has no name (just 'api_route_')") # pragma: no cover
call_counter = DIAGNOSTICS_TRACKER.register_counter(f"route.{func_name}")
call_time_counter = DIAGNOSTICS_TRACKER.register_counter(f"route.{func_name}.time")
@wraps(func)
def wrapper(*args, **kwargs):
try:
if api.live_config.is_disabled() and not allow_while_disabled:
raise ApiClientException(ERROR_SITE_IS_DISABLED)
if api.live_config.is_readonly() and not allow_while_readonly:
raise ApiClientException(ERROR_SITE_IS_READONLY)
if call_counter is not None:
call_counter.trigger()
if "X_REAL_IP" in request.headers:
ip_string = request.headers["X_REAL_IP"]
for rate_limiter in rate_limiters:
if not rate_limiter.check_new_request(ip_string):
raise ApiClientException(ERROR_RATE_LIMITED)
if DEBUG_ENABLED and "API_ROULETTE_MODE" in api.config:
import random
from api.miscellaneous.errors import ALL_ERRORS_RANDOM
if random.random() * 100 < int(api.config["API_ROULETTE_MODE"]):
raise ApiClientException(random.choice(ALL_ERRORS_RANDOM))
start_time = time.time()
result = func(*args, **kwargs)
if call_time_counter is not None:
call_time_counter.trigger(int((time.time() - start_time) * 1000))
if isinstance(result, Response):
return result
elif isinstance(result, ApiResponse):
return result.build_response()
elif result is None:
return ApiResponse(None, HTTP_200_OK, None).build_response()
elif isinstance(result, dict):
return ApiResponse(result).build_response()
elif isinstance(result, tuple) and len(result) == 2:
return ApiResponse(result[0], result[1]).build_response()
else: # pragma: no cover
raise Exception(f"Api route {truncate_string(request.path)} returned result of unknown type: {str(result)}")
except ApiClientException as e:
return api_on_error(e.error)
except (TransactionConflictError, NoAvailableConnectionError) as e:
print(f"An transaction conflict occurred while handling api request '{truncate_string(request.path, 200)}':")
traceback.print_exception(e)
return api_on_error(ERROR_SITE_IS_OVERLOADED)
except Exception as e:
print(f"An exception occurred while handling api request '{truncate_string(request.path, 200)}':")
traceback.print_exception(e)
return api_on_error(ERROR_INTERNAL_SERVER_ERROR)
wrapper.is_api_route = True
return wrapper
return decorator