mirror of
https://github.com/suitenumerique/docs.git
synced 2026-05-10 09:02:35 +02:00
When a PATCH and a GET on the content endpoint are made at the same time for different users a race condition can happen and the metadata returned by the S3 head_object can be outdated when the object is fetched leading to an error raised because the Content-Length header does not match the size of the response body. To avoid this, we no longer used head_object followed bu get_object, we have to manage everything in one call with the get_object. The get_object also accepts as parameters an etag or last-modified header and will return a 304 if the content has not changed, so we can use this to not return the entire body if this one has not changed.
231 lines
7.9 KiB
Python
231 lines
7.9 KiB
Python
"""Util to generate S3 authorization headers for object storage access control"""
|
|
|
|
import datetime as dt
|
|
import time
|
|
from abc import ABC, abstractmethod
|
|
|
|
from django.conf import settings
|
|
from django.core.cache import cache
|
|
from django.core.files.storage import default_storage
|
|
from django.utils.decorators import method_decorator
|
|
|
|
import botocore
|
|
from lasuite.oidc_login.decorators import refresh_oidc_access_token
|
|
from rest_framework.throttling import BaseThrottle
|
|
|
|
|
|
def nest_tree(flat_list, steplen):
|
|
"""
|
|
Convert a flat list of serialized documents into a nested tree making advantage
|
|
of the`path` field and its step length.
|
|
"""
|
|
node_dict = {}
|
|
roots = []
|
|
|
|
# Sort the flat list by path to ensure parent nodes are processed first
|
|
flat_list.sort(key=lambda x: x["path"])
|
|
|
|
for node in flat_list:
|
|
node["children"] = [] # Initialize children list
|
|
node_dict[node["path"]] = node
|
|
|
|
# Determine parent path
|
|
parent_path = node["path"][:-steplen]
|
|
|
|
if parent_path in node_dict:
|
|
node_dict[parent_path]["children"].append(node)
|
|
else:
|
|
roots.append(node) # Collect root nodes
|
|
|
|
if len(roots) > 1:
|
|
raise ValueError("More than one root element detected.")
|
|
|
|
return roots[0] if roots else None
|
|
|
|
|
|
def filter_root_paths(paths, skip_sorting=False):
|
|
"""
|
|
Filters root paths from a list of paths representing a tree structure.
|
|
A root path is defined as a path that is not a prefix of any other path.
|
|
|
|
Args:
|
|
paths (list of str): The list of paths.
|
|
|
|
Returns:
|
|
list of str: The filtered list of root paths.
|
|
"""
|
|
if not skip_sorting:
|
|
paths.sort()
|
|
|
|
root_paths = []
|
|
for path in paths:
|
|
# If the current path is not a prefix of the last added root path, add it
|
|
if not root_paths or not path.startswith(root_paths[-1]):
|
|
root_paths.append(path)
|
|
|
|
return root_paths
|
|
|
|
|
|
def generate_s3_authorization_headers(key):
|
|
"""
|
|
Generate authorization headers for an s3 object.
|
|
These headers can be used as an alternative to signed urls with many benefits:
|
|
- the urls of our files never expire and can be stored in our documents' content
|
|
- we don't leak authorized urls that could be shared (file access can only be done
|
|
with cookies)
|
|
- access control is truly realtime
|
|
- the object storage service does not need to be exposed on internet
|
|
"""
|
|
url = default_storage.unsigned_connection.meta.client.generate_presigned_url(
|
|
"get_object",
|
|
ExpiresIn=0,
|
|
Params={"Bucket": default_storage.bucket_name, "Key": key},
|
|
)
|
|
request = botocore.awsrequest.AWSRequest(method="get", url=url)
|
|
|
|
s3_client = default_storage.connection.meta.client
|
|
# pylint: disable=protected-access
|
|
credentials = s3_client._request_signer._credentials # noqa: SLF001
|
|
frozen_credentials = credentials.get_frozen_credentials()
|
|
region = s3_client.meta.region_name
|
|
auth = botocore.auth.S3SigV4Auth(frozen_credentials, "s3", region)
|
|
auth.add_auth(request)
|
|
|
|
return request
|
|
|
|
|
|
def conditional_refresh_oidc_token(func):
|
|
"""
|
|
Conditionally apply refresh_oidc_access_token decorator.
|
|
|
|
The decorator is only applied if OIDC_STORE_REFRESH_TOKEN is True, meaning
|
|
we can actually refresh something. Broader settings checks are done in settings.py.
|
|
"""
|
|
if settings.OIDC_STORE_REFRESH_TOKEN:
|
|
return method_decorator(refresh_oidc_access_token)(func)
|
|
|
|
return func
|
|
|
|
|
|
class AIBaseRateThrottle(BaseThrottle, ABC):
|
|
"""Base throttle class for AI-related rate limiting with backoff."""
|
|
|
|
def __init__(self, rates):
|
|
"""Initialize instance attributes with configurable rates."""
|
|
super().__init__()
|
|
self.rates = rates
|
|
self.cache_key = None
|
|
self.recent_requests_minute = 0
|
|
self.recent_requests_hour = 0
|
|
self.recent_requests_day = 0
|
|
|
|
@abstractmethod
|
|
def get_cache_key(self, request, view):
|
|
"""Abstract method to generate cache key for throttling."""
|
|
|
|
def allow_request(self, request, view):
|
|
"""Check if the request is allowed based on rate limits."""
|
|
self.cache_key = self.get_cache_key(request, view)
|
|
if not self.cache_key:
|
|
return True # Allow if no cache key is generated
|
|
|
|
now = time.time()
|
|
history = cache.get(self.cache_key, [])
|
|
# Keep requests within the last 24 hours
|
|
history = [req for req in history if req > now - 86400]
|
|
|
|
# Calculate recent requests
|
|
self.recent_requests_minute = len([req for req in history if req > now - 60])
|
|
self.recent_requests_hour = len([req for req in history if req > now - 3600])
|
|
self.recent_requests_day = len(history)
|
|
|
|
# Check rate limits
|
|
if self.recent_requests_minute >= self.rates["minute"]:
|
|
return False
|
|
if self.recent_requests_hour >= self.rates["hour"]:
|
|
return False
|
|
if self.recent_requests_day >= self.rates["day"]:
|
|
return False
|
|
|
|
# Log the request
|
|
history.append(now)
|
|
cache.set(self.cache_key, history, timeout=86400)
|
|
return True
|
|
|
|
def wait(self):
|
|
"""Implement a backoff strategy by increasing wait time based on limits hit."""
|
|
if self.recent_requests_day >= self.rates["day"]:
|
|
return 86400
|
|
if self.recent_requests_hour >= self.rates["hour"]:
|
|
return 3600
|
|
if self.recent_requests_minute >= self.rates["minute"]:
|
|
return 60
|
|
return None
|
|
|
|
|
|
class AIDocumentRateThrottle(AIBaseRateThrottle):
|
|
"""Throttle for limiting AI requests per document with backoff."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(settings.AI_DOCUMENT_RATE_THROTTLE_RATES)
|
|
|
|
def get_cache_key(self, request, view):
|
|
"""Include document ID in the cache key."""
|
|
document_id = view.kwargs["pk"]
|
|
return f"document_{document_id}_throttle_ai"
|
|
|
|
|
|
class AIUserRateThrottle(AIBaseRateThrottle):
|
|
"""Throttle that limits requests per user or IP with backoff and rate limits."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(settings.AI_USER_RATE_THROTTLE_RATES)
|
|
|
|
def get_cache_key(self, request, view=None):
|
|
"""Generate a cache key based on the user ID or IP for anonymous users."""
|
|
if request.user.is_authenticated:
|
|
return f"user_{request.user.id!s}_throttle_ai"
|
|
return f"anonymous_{self.get_ident(request)}_throttle_ai"
|
|
|
|
def get_ident(self, request):
|
|
"""Return the request IP address."""
|
|
x_forwarded_for = request.META.get("HTTP_X_FORWARDED_FOR")
|
|
return (
|
|
x_forwarded_for.split(",")[0]
|
|
if x_forwarded_for
|
|
else request.META.get("REMOTE_ADDR")
|
|
)
|
|
|
|
|
|
def get_content_metadata_cache_key(document_id):
|
|
"""Return the cache key used to store content metadata."""
|
|
return f"docs:content-metadata:{document_id!s}"
|
|
|
|
|
|
def parse_http_conditional_headers(request):
|
|
"""Extract and normalize `If-None-Match` and `If-Modified-Since`.
|
|
|
|
The `W/` weak prefix is stripped from the ETag because reverse proxies
|
|
(e.g. nginx with gzip) rewrite strong ETags into weak ones, which would
|
|
otherwise break a strict equality check in production.
|
|
"""
|
|
if_none_match = request.META.get("HTTP_IF_NONE_MATCH")
|
|
if if_none_match and if_none_match.startswith("W/"):
|
|
if_none_match = if_none_match.removeprefix("W/")
|
|
|
|
if_modified_since_dt = None
|
|
if not (if_modified_since := request.META.get("HTTP_IF_MODIFIED_SINCE")):
|
|
return if_none_match, if_modified_since_dt
|
|
|
|
try:
|
|
if_modified_since_dt = dt.datetime.strptime(
|
|
if_modified_since, "%a, %d %b %Y %H:%M:%S %Z"
|
|
)
|
|
except ValueError:
|
|
if_modified_since_dt = None
|
|
else:
|
|
if not if_modified_since_dt.tzinfo:
|
|
if_modified_since_dt = if_modified_since_dt.replace(tzinfo=dt.timezone.utc)
|
|
|
|
return if_none_match, if_modified_since_dt
|