mirror of
https://github.com/suitenumerique/docs.git
synced 2026-05-08 16:12:26 +02:00
We want to give to the js client the ability to use some headers to avoid fetching a content it already have. For this, the content endpoint will return an ETag and Last-Modified headers corresponding to the file content ETag and its last modification. For future fetch, the client can use the If-None-Match or If-Modified-Since request headers, if one of these headers are satisfied, the endpoint will return a 304 response. If not it will still return a 200
202 lines
6.8 KiB
Python
202 lines
6.8 KiB
Python
"""Util to generate S3 authorization headers for object storage access control"""
|
|
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
|
|
from django.conf import settings
|
|
from django.core.cache import cache
|
|
from django.core.files.storage import default_storage
|
|
from django.utils.decorators import method_decorator
|
|
|
|
import botocore
|
|
from lasuite.oidc_login.decorators import refresh_oidc_access_token
|
|
from rest_framework.throttling import BaseThrottle
|
|
|
|
|
|
def nest_tree(flat_list, steplen):
|
|
"""
|
|
Convert a flat list of serialized documents into a nested tree making advantage
|
|
of the`path` field and its step length.
|
|
"""
|
|
node_dict = {}
|
|
roots = []
|
|
|
|
# Sort the flat list by path to ensure parent nodes are processed first
|
|
flat_list.sort(key=lambda x: x["path"])
|
|
|
|
for node in flat_list:
|
|
node["children"] = [] # Initialize children list
|
|
node_dict[node["path"]] = node
|
|
|
|
# Determine parent path
|
|
parent_path = node["path"][:-steplen]
|
|
|
|
if parent_path in node_dict:
|
|
node_dict[parent_path]["children"].append(node)
|
|
else:
|
|
roots.append(node) # Collect root nodes
|
|
|
|
if len(roots) > 1:
|
|
raise ValueError("More than one root element detected.")
|
|
|
|
return roots[0] if roots else None
|
|
|
|
|
|
def filter_root_paths(paths, skip_sorting=False):
|
|
"""
|
|
Filters root paths from a list of paths representing a tree structure.
|
|
A root path is defined as a path that is not a prefix of any other path.
|
|
|
|
Args:
|
|
paths (list of str): The list of paths.
|
|
|
|
Returns:
|
|
list of str: The filtered list of root paths.
|
|
"""
|
|
if not skip_sorting:
|
|
paths.sort()
|
|
|
|
root_paths = []
|
|
for path in paths:
|
|
# If the current path is not a prefix of the last added root path, add it
|
|
if not root_paths or not path.startswith(root_paths[-1]):
|
|
root_paths.append(path)
|
|
|
|
return root_paths
|
|
|
|
|
|
def generate_s3_authorization_headers(key):
|
|
"""
|
|
Generate authorization headers for an s3 object.
|
|
These headers can be used as an alternative to signed urls with many benefits:
|
|
- the urls of our files never expire and can be stored in our documents' content
|
|
- we don't leak authorized urls that could be shared (file access can only be done
|
|
with cookies)
|
|
- access control is truly realtime
|
|
- the object storage service does not need to be exposed on internet
|
|
"""
|
|
url = default_storage.unsigned_connection.meta.client.generate_presigned_url(
|
|
"get_object",
|
|
ExpiresIn=0,
|
|
Params={"Bucket": default_storage.bucket_name, "Key": key},
|
|
)
|
|
request = botocore.awsrequest.AWSRequest(method="get", url=url)
|
|
|
|
s3_client = default_storage.connection.meta.client
|
|
# pylint: disable=protected-access
|
|
credentials = s3_client._request_signer._credentials # noqa: SLF001
|
|
frozen_credentials = credentials.get_frozen_credentials()
|
|
region = s3_client.meta.region_name
|
|
auth = botocore.auth.S3SigV4Auth(frozen_credentials, "s3", region)
|
|
auth.add_auth(request)
|
|
|
|
return request
|
|
|
|
|
|
def conditional_refresh_oidc_token(func):
|
|
"""
|
|
Conditionally apply refresh_oidc_access_token decorator.
|
|
|
|
The decorator is only applied if OIDC_STORE_REFRESH_TOKEN is True, meaning
|
|
we can actually refresh something. Broader settings checks are done in settings.py.
|
|
"""
|
|
if settings.OIDC_STORE_REFRESH_TOKEN:
|
|
return method_decorator(refresh_oidc_access_token)(func)
|
|
|
|
return func
|
|
|
|
|
|
class AIBaseRateThrottle(BaseThrottle, ABC):
|
|
"""Base throttle class for AI-related rate limiting with backoff."""
|
|
|
|
def __init__(self, rates):
|
|
"""Initialize instance attributes with configurable rates."""
|
|
super().__init__()
|
|
self.rates = rates
|
|
self.cache_key = None
|
|
self.recent_requests_minute = 0
|
|
self.recent_requests_hour = 0
|
|
self.recent_requests_day = 0
|
|
|
|
@abstractmethod
|
|
def get_cache_key(self, request, view):
|
|
"""Abstract method to generate cache key for throttling."""
|
|
|
|
def allow_request(self, request, view):
|
|
"""Check if the request is allowed based on rate limits."""
|
|
self.cache_key = self.get_cache_key(request, view)
|
|
if not self.cache_key:
|
|
return True # Allow if no cache key is generated
|
|
|
|
now = time.time()
|
|
history = cache.get(self.cache_key, [])
|
|
# Keep requests within the last 24 hours
|
|
history = [req for req in history if req > now - 86400]
|
|
|
|
# Calculate recent requests
|
|
self.recent_requests_minute = len([req for req in history if req > now - 60])
|
|
self.recent_requests_hour = len([req for req in history if req > now - 3600])
|
|
self.recent_requests_day = len(history)
|
|
|
|
# Check rate limits
|
|
if self.recent_requests_minute >= self.rates["minute"]:
|
|
return False
|
|
if self.recent_requests_hour >= self.rates["hour"]:
|
|
return False
|
|
if self.recent_requests_day >= self.rates["day"]:
|
|
return False
|
|
|
|
# Log the request
|
|
history.append(now)
|
|
cache.set(self.cache_key, history, timeout=86400)
|
|
return True
|
|
|
|
def wait(self):
|
|
"""Implement a backoff strategy by increasing wait time based on limits hit."""
|
|
if self.recent_requests_day >= self.rates["day"]:
|
|
return 86400
|
|
if self.recent_requests_hour >= self.rates["hour"]:
|
|
return 3600
|
|
if self.recent_requests_minute >= self.rates["minute"]:
|
|
return 60
|
|
return None
|
|
|
|
|
|
class AIDocumentRateThrottle(AIBaseRateThrottle):
|
|
"""Throttle for limiting AI requests per document with backoff."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(settings.AI_DOCUMENT_RATE_THROTTLE_RATES)
|
|
|
|
def get_cache_key(self, request, view):
|
|
"""Include document ID in the cache key."""
|
|
document_id = view.kwargs["pk"]
|
|
return f"document_{document_id}_throttle_ai"
|
|
|
|
|
|
class AIUserRateThrottle(AIBaseRateThrottle):
|
|
"""Throttle that limits requests per user or IP with backoff and rate limits."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(settings.AI_USER_RATE_THROTTLE_RATES)
|
|
|
|
def get_cache_key(self, request, view=None):
|
|
"""Generate a cache key based on the user ID or IP for anonymous users."""
|
|
if request.user.is_authenticated:
|
|
return f"user_{request.user.id!s}_throttle_ai"
|
|
return f"anonymous_{self.get_ident(request)}_throttle_ai"
|
|
|
|
def get_ident(self, request):
|
|
"""Return the request IP address."""
|
|
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
|
|
return (
|
|
x_forwarded_for.split(",")[0]
|
|
if x_forwarded_for
|
|
else request.META.get("REMOTE_ADDR")
|
|
)
|
|
|
|
|
|
def get_content_metadata_cache_key(document_id):
|
|
"""Return the cache key used to store content metadata."""
|
|
return f"docs:content-metadata:{document_id!s}"
|