Compare commits

...

2 Commits

Author SHA1 Message Date
Manuel Raynaud
294b1a0dc5 🐛(backend) replace document creation table locks with retry strategy
We have situation where the number of locks in the database can increase
dangerously creating deadlock situation. To remove this situation we
decided to change the strategy to manage document creation concurrency.
We decided to use a retry strategy, trying to create the document
multiple times while a usable path is found. To avoid having an
inifinite loop, we use a max_attempts counter configurable using the
setting TREEBEARD_PATH_COMPUTE_RETRY_MAX_ATTEMPTS
2026-05-06 12:53:53 +02:00
Manuel Raynaud
cbd2705c9f ♻️(backend) split core/utils.py module
We need to split the core/utils.py in multiple submodule created in
core/utils/*.py. We need to do this to avoir circular import between
this module and the models module.
2026-05-06 12:42:29 +02:00
21 changed files with 400 additions and 245 deletions

View File

@@ -9,6 +9,7 @@ and this project adheres to
### Changed
- 🐛(frontend) sanitize pasted and dropped content in document title #2210
- 🐛(backend) replace document creation table locks with retry strategy
## [v5.0.0] - 2026-04-08

View File

@@ -134,6 +134,7 @@ These are the environment variables you can set for the `impress-backend` contai
| THEME_CUSTOMIZATION_CACHE_TIMEOUT | Cache duration for the customization settings | 86400 |
| THEME_CUSTOMIZATION_FILE_PATH | Full path to the file customizing the theme. An example is provided in src/backend/impress/configuration/theme/default.json | BASE_DIR/impress/configuration/theme/default.json |
| TRASHBIN_CUTOFF_DAYS | Trashbin cutoff | 30 |
| TREEBEARD_PATH_COMPUTE_RETRY_MAX_ATTEMPTS | Number of attempts to create a document before failing. | 10 |
| USER_OIDC_ESSENTIAL_CLAIMS | Essential claims in OIDC token | [] |
| USER_ONBOARDING_DOCUMENTS | A list of documents IDs for which a read-only access will be created for new s | [] |
| USER_ONBOARDING_SANDBOX_DOCUMENT | ID of a template sandbox document that will be duplicated for new users | |

View File

@@ -7,7 +7,6 @@ from base64 import b64decode
from os.path import splitext
from django.conf import settings
from django.db import connection, transaction
from django.db.models import Q
from django.utils.functional import lazy
from django.utils.text import slugify
@@ -24,6 +23,7 @@ from core.services.converter_services import (
ConversionError,
Converter,
)
from core.utils.treebeard import create_tree_node_with_retry
class UserSerializer(serializers.ModelSerializer):
@@ -467,18 +467,12 @@ class ServerCreateDocumentSerializer(serializers.Serializer):
{"content": ["Could not convert content"]}
) from err
with transaction.atomic():
# locks the table to ensure safe concurrent access
with connection.cursor() as cursor:
cursor.execute(
f'LOCK TABLE "{models.Document._meta.db_table}" ' # noqa: SLF001
"IN SHARE ROW EXCLUSIVE MODE;"
)
document = models.Document.add_root(
document = create_tree_node_with_retry(
lambda: models.Document.add_root(
title=validated_data["title"],
creator=user,
)
)
if user:
# Associate the document with the pre-existing user

View File

@@ -67,11 +67,10 @@ from core.services.search_indexers import (
get_visited_document_ids_of,
)
from core.tasks.mail import send_ask_for_access_mail
from core.utils import (
extract_attachments,
filter_descendants,
users_sharing_documents_with,
)
from core.utils.paths import filter_descendants
from core.utils.treebeard import create_tree_node_with_retry
from core.utils.users import users_sharing_documents_with
from core.utils.yjs import extract_attachments
from ..enums import FeatureFlag, SearchType
from . import permissions, serializers, utils
@@ -708,18 +707,12 @@ class DocumentViewSet(
{"file": ["Could not convert file content"]}
) from err
with transaction.atomic():
# locks the table to ensure safe concurrent access
with connection.cursor() as cursor:
cursor.execute(
f'LOCK TABLE "{models.Document._meta.db_table}" ' # noqa: SLF001
"IN SHARE ROW EXCLUSIVE MODE;"
)
obj = models.Document.add_root(
obj = create_tree_node_with_retry(
lambda: models.Document.add_root(
creator=self.request.user,
**serializer.validated_data,
)
)
serializer.instance = obj
models.DocumentAccess.objects.create(
document=obj,
@@ -1023,16 +1016,12 @@ class DocumentViewSet(
)
serializer.is_valid(raise_exception=True)
with transaction.atomic():
# "select_for_update" locks the table to ensure safe concurrent access
locked_parent = models.Document.objects.select_for_update().get(
pk=document.pk
)
child_document = locked_parent.add_child(
child_document = create_tree_node_with_retry(
lambda: document.add_child(
creator=request.user,
**serializer.validated_data,
)
)
# Set the created instance to the serializer
serializer.instance = child_document

View File

@@ -9,7 +9,7 @@ from django.db import migrations, models
from botocore.exceptions import ClientError
import core.models
from core.utils import extract_attachments
from core.utils.yjs import extract_attachments
def populate_attachments_on_all_documents(apps, schema_editor):

View File

@@ -19,7 +19,7 @@ from django.core.cache import cache
from django.core.files.base import ContentFile
from django.core.files.storage import default_storage
from django.core.mail import send_mail
from django.db import connection, models, transaction
from django.db import models, transaction
from django.db.models.functions import Left, Length
from django.template.loader import render_to_string
from django.utils import timezone
@@ -39,6 +39,7 @@ from core.choices import (
RoleChoices,
get_equivalent_link_definition,
)
from core.utils.treebeard import create_tree_node_with_retry
from core.validators import sub_validator
logger = getLogger(__name__)
@@ -265,8 +266,6 @@ class User(AbstractBaseUser, BaseModel, auth_models.PermissionsMixin):
duplicate the sandbox document for the user
"""
if settings.USER_ONBOARDING_SANDBOX_DOCUMENT:
# transaction.atomic is used in a context manager to avoid a transaction if
# the settings USER_ONBOARDING_SANDBOX_DOCUMENT is unused
sandbox_id = settings.USER_ONBOARDING_SANDBOX_DOCUMENT
try:
template_document = Document.objects.get(id=sandbox_id)
@@ -276,20 +275,15 @@ class User(AbstractBaseUser, BaseModel, auth_models.PermissionsMixin):
sandbox_id,
)
return
with transaction.atomic():
# locks the table to ensure safe concurrent access
with connection.cursor() as cursor:
cursor.execute(
f'LOCK TABLE "{Document._meta.db_table}" ' # noqa: SLF001
"IN SHARE ROW EXCLUSIVE MODE;"
sandbox_document = create_tree_node_with_retry(
lambda: Document.add_root(
title=template_document.title,
content=template_document.content,
attachments=template_document.attachments,
duplicated_from=template_document,
creator=self,
)
sandbox_document = Document.add_root(
title=template_document.title,
content=template_document.content,
attachments=template_document.attachments,
duplicated_from=template_document,
creator=self,
)
DocumentAccess.objects.create(

View File

@@ -12,8 +12,11 @@ from django.utils.module_loading import import_string
import requests
from core import models, utils
from core import models
from core.enums import SearchType
from core.utils.dicts import get_value_by_pattern
from core.utils.paths import get_ancestor_to_descendants_map
from core.utils.yjs import base64_yjs_to_text
logger = logging.getLogger(__name__)
@@ -44,7 +47,7 @@ def get_batch_accesses_by_users_and_teams(paths):
Get accesses related to a list of document paths,
grouped by users and teams, including all ancestor paths.
"""
ancestor_map = utils.get_ancestor_to_descendants_map(
ancestor_map = get_ancestor_to_descendants_map(
paths, steplen=models.Document.steplen
)
ancestor_paths = list(ancestor_map.keys())
@@ -297,7 +300,7 @@ class FindDocumentIndexer(BaseDocumentIndexer):
>>> get_title({"id": 1})
""
"""
titles = utils.get_value_by_pattern(source, r"^title\.")
titles = get_value_by_pattern(source, r"^title\.")
for title in titles:
if title:
return title
@@ -318,7 +321,7 @@ class FindDocumentIndexer(BaseDocumentIndexer):
"""
doc_path = document.path
doc_content = document.content
text_content = utils.base64_yjs_to_text(doc_content) if doc_content else ""
text_content = base64_yjs_to_text(doc_content) if doc_content else ""
return {
"id": str(document.id),

View File

@@ -11,7 +11,7 @@ from django.dispatch import receiver
from core import models
from core.tasks.search import trigger_batch_document_indexer
from core.utils import get_users_sharing_documents_with_cache_key
from core.utils.users import get_users_sharing_documents_with_cache_key
@receiver(signals.post_save, sender=models.Document)

View File

@@ -12,13 +12,14 @@ import pytest
import responses
from requests import HTTPError
from core import factories, models, utils
from core import factories, models
from core.services.search_indexers import (
BaseDocumentIndexer,
FindDocumentIndexer,
get_document_indexer,
get_visited_document_ids_of,
)
from core.utils.yjs import base64_yjs_to_text
pytestmark = pytest.mark.django_db
@@ -199,7 +200,7 @@ def test_services_search_indexers_serialize_document_returns_expected_json():
"depth": 1,
"path": document.path,
"numchild": 1,
"content": utils.base64_yjs_to_text(document.content),
"content": base64_yjs_to_text(document.content),
"created_at": document.created_at.isoformat(),
"updated_at": document.updated_at.isoformat(),
"reach": document.link_reach,

View File

@@ -8,7 +8,18 @@ from django.core.cache import cache
import pycrdt
import pytest
from core import factories, utils
from core import factories
from core.utils.dicts import get_value_by_pattern
from core.utils.paths import get_ancestor_to_descendants_map
from core.utils.users import (
get_users_sharing_documents_with_cache_key,
users_sharing_documents_with,
)
from core.utils.yjs import (
base64_yjs_to_text,
base64_yjs_to_xml,
extract_attachments,
)
pytestmark = pytest.mark.django_db
@@ -34,12 +45,12 @@ TEST_BASE64_STRING = (
def test_utils_base64_yjs_to_text():
"""Test extract text from saved yjs document"""
assert utils.base64_yjs_to_text(TEST_BASE64_STRING) == "Hello w or ld"
assert base64_yjs_to_text(TEST_BASE64_STRING) == "Hello w or ld"
def test_utils_base64_yjs_to_xml():
"""Test extract xml from saved yjs document"""
content = utils.base64_yjs_to_xml(TEST_BASE64_STRING)
content = base64_yjs_to_xml(TEST_BASE64_STRING)
assert (
'<heading textAlignment="left" level="1"><italic>Hello</italic></heading>'
in content
@@ -79,13 +90,13 @@ def test_utils_extract_attachments():
update = ydoc.get_update()
base64_string = base64.b64encode(update).decode("utf-8")
# image_key2 is missing the "/media/" part and shouldn't get extracted
assert utils.extract_attachments(base64_string) == [image_key1, image_key3]
assert extract_attachments(base64_string) == [image_key1, image_key3]
def test_utils_get_ancestor_to_descendants_map_single_path():
"""Test ancestor mapping of a single path."""
paths = ["000100020005"]
result = utils.get_ancestor_to_descendants_map(paths, steplen=4)
result = get_ancestor_to_descendants_map(paths, steplen=4)
assert result == {
"0001": {"000100020005"},
@@ -97,7 +108,7 @@ def test_utils_get_ancestor_to_descendants_map_single_path():
def test_utils_get_ancestor_to_descendants_map_multiple_paths():
"""Test ancestor mapping of multiple paths with shared prefixes."""
paths = ["000100020005", "00010003"]
result = utils.get_ancestor_to_descendants_map(paths, steplen=4)
result = get_ancestor_to_descendants_map(paths, steplen=4)
assert result == {
"0001": {"000100020005", "00010003"},
@@ -119,10 +130,10 @@ def test_utils_users_sharing_documents_with_cache_miss():
factories.UserDocumentAccessFactory(user=user2, document=doc1)
factories.UserDocumentAccessFactory(user=user3, document=doc2)
cache_key = utils.get_users_sharing_documents_with_cache_key(user1)
cache_key = get_users_sharing_documents_with_cache_key(user1)
cache.delete(cache_key)
result = utils.users_sharing_documents_with(user1)
result = users_sharing_documents_with(user1)
assert user2.id in result
@@ -139,12 +150,12 @@ def test_utils_users_sharing_documents_with_cache_hit():
factories.UserDocumentAccessFactory(user=user1, document=doc1)
factories.UserDocumentAccessFactory(user=user2, document=doc1)
cache_key = utils.get_users_sharing_documents_with_cache_key(user1)
cache_key = get_users_sharing_documents_with_cache_key(user1)
test_cached_data = {user2.id: "2025-02-10"}
cache.set(cache_key, test_cached_data, 86400)
result = utils.users_sharing_documents_with(user1)
result = users_sharing_documents_with(user1)
assert result == test_cached_data
@@ -156,7 +167,7 @@ def test_utils_users_sharing_documents_with_cache_invalidation_on_create():
doc1 = factories.DocumentFactory()
# Pre-populate cache
cache_key = utils.get_users_sharing_documents_with_cache_key(user1)
cache_key = get_users_sharing_documents_with_cache_key(user1)
cache.set(cache_key, {}, 86400)
# Verify cache exists
@@ -182,7 +193,7 @@ def test_utils_users_sharing_documents_with_cache_invalidation_on_delete():
doc_access = factories.UserDocumentAccessFactory(user=user1, document=doc1)
cache_key = utils.get_users_sharing_documents_with_cache_key(user1)
cache_key = get_users_sharing_documents_with_cache_key(user1)
cache.set(cache_key, {user2.id: "2025-02-10"}, 86400)
assert cache.get(cache_key) is not None
@@ -196,10 +207,10 @@ def test_utils_users_sharing_documents_with_empty_result():
"""Test when user is not sharing any documents."""
user1 = factories.UserFactory()
cache_key = utils.get_users_sharing_documents_with_cache_key(user1)
cache_key = get_users_sharing_documents_with_cache_key(user1)
cache.delete(cache_key)
result = utils.users_sharing_documents_with(user1)
result = users_sharing_documents_with(user1)
assert result == {}
@@ -210,7 +221,7 @@ def test_utils_users_sharing_documents_with_empty_result():
def test_utils_get_value_by_pattern_matching_key():
"""Test extracting value from a dictionary with a matching key pattern."""
data = {"title.extension": "Bonjour", "id": 1, "content": "test"}
result = utils.get_value_by_pattern(data, r"^title\.")
result = get_value_by_pattern(data, r"^title\.")
assert set(result) == {"Bonjour"}
@@ -218,7 +229,7 @@ def test_utils_get_value_by_pattern_matching_key():
def test_utils_get_value_by_pattern_multiple_matches():
"""Test that all matching keys are returned."""
data = {"title.extension_1": "Bonjour", "title.extension_2": "Hello", "id": 1}
result = utils.get_value_by_pattern(data, r"^title\.")
result = get_value_by_pattern(data, r"^title\.")
assert set(result) == {
"Bonjour",
@@ -229,7 +240,7 @@ def test_utils_get_value_by_pattern_multiple_matches():
def test_utils_get_value_by_pattern_multiple_extensions():
"""Test that all matching keys are returned."""
data = {"title.extension_1.extension_2": "Bonjour", "id": 1}
result = utils.get_value_by_pattern(data, r"^title\.")
result = get_value_by_pattern(data, r"^title\.")
assert set(result) == {"Bonjour"}
@@ -237,6 +248,6 @@ def test_utils_get_value_by_pattern_multiple_extensions():
def test_utils_get_value_by_pattern_no_match():
"""Test that empty list is returned when no key matches the pattern."""
data = {"name": "Test", "id": 1}
result = utils.get_value_by_pattern(data, r"^title\.")
result = get_value_by_pattern(data, r"^title\.")
assert result == []

View File

@@ -0,0 +1,89 @@
"""Tests for the create_tree_node_with_retry utils."""
from unittest import mock
from django.core.exceptions import ValidationError as DjangoValidationError
from django.db import IntegrityError
import pytest
from core.factories import UserFactory
from core.models import Document
from core.utils.treebeard import _is_tree_path_collision, create_tree_node_with_retry
pytestmark = pytest.mark.django_db
@pytest.mark.parametrize(
"exc",
[
DjangoValidationError({"path": "not unique"}),
IntegrityError("impress_document_path_key"),
],
)
def test_utils_create_tree_node_with_retry_exceed_max_attempts(settings, exc):
"""Test exceeding the max attempts should reraise the exception."""
settings.TREEBEARD_PATH_COMPUTE_RETRY_MAX_ATTEMPTS = 2
create_fn = mock.MagicMock()
create_fn.side_effect = exc
with (
pytest.raises(exc.__class__),
mock.patch(
"core.utils.treebeard._is_tree_path_collision"
) as mock__is_tree_path_collision,
):
mock__is_tree_path_collision.side_effect = _is_tree_path_collision
create_tree_node_with_retry(create_fn)
mock__is_tree_path_collision.assert_called()
assert mock__is_tree_path_collision.call_count == 2
assert create_fn.call_count == 2
@pytest.mark.parametrize(
"exc",
[
DjangoValidationError({"foo": "bar"}),
IntegrityError("not handled"),
],
)
def test_utils_create_tree_node_with_retry_exceed_exception_not_handled(settings, exc):
"""Test with an exception not handled should return reraise it immediatly."""
settings.TREEBEARD_PATH_COMPUTE_RETRY_MAX_ATTEMPTS = 2
create_fn = mock.MagicMock()
create_fn.side_effect = exc
with (
pytest.raises(exc.__class__),
mock.patch(
"core.utils.treebeard._is_tree_path_collision"
) as mock__is_tree_path_collision,
):
mock__is_tree_path_collision.side_effect = _is_tree_path_collision
create_tree_node_with_retry(create_fn)
mock__is_tree_path_collision.assert_called()
assert mock__is_tree_path_collision.call_count == 1
assert create_fn.call_count == 1
def test_utils_create_tree_node_with_retry_success():
"""Test executing successfully the create_fn callback."""
user = UserFactory()
document = create_tree_node_with_retry(
lambda: Document.add_root(
creator=user,
title="success",
)
)
assert isinstance(document, Document)
assert document.title == "success"
assert document.path is not None

View File

@@ -2,7 +2,7 @@
Unit tests for the filter_root_paths utility function.
"""
from core.utils import filter_descendants
from core.utils.paths import filter_descendants
def test_utils_filter_descendants_success():

View File

@@ -4,7 +4,8 @@ from django.utils import timezone
import pytest
from core import factories, utils
from core import factories
from core.utils.users import users_sharing_documents_with
pytestmark = pytest.mark.django_db
@@ -54,7 +55,7 @@ def test_utils_users_sharing_documents_with():
doc_3_pierre_2.created_at = yesterday
doc_3_pierre_2.save()
shared_map = utils.users_sharing_documents_with(user)
shared_map = users_sharing_documents_with(user)
assert shared_map == {
pierre_1.id: last_week,

View File

@@ -1,170 +0,0 @@
"""Utils for the core app."""
import base64
import logging
import re
import time
from collections import defaultdict
from django.core.cache import cache
from django.db import models as db
from django.db.models import Subquery
import pycrdt
from bs4 import BeautifulSoup
from core import enums, models
logger = logging.getLogger(__name__)
def get_value_by_pattern(data, pattern):
"""
Get all values from keys matching a regex pattern in a dictionary.
Args:
data (dict): Source dictionary to search
pattern (str): Regex pattern to match against keys
Returns:
list: List of values for all matching keys, empty list if no matches
Example:
>>> get_value_by_pattern({"title.fr": "Bonjour", "id": 1}, r"^title\\.")
["Bonjour"]
>>> get_value_by_pattern({"title.fr": "Bonjour", "title.en": "Hello"}, r"^title\\.")
["Bonjour", "Hello"]
"""
regex = re.compile(pattern)
return [value for key, value in data.items() if regex.match(key)]
def get_ancestor_to_descendants_map(paths, steplen):
"""
Given a list of document paths, return a mapping of ancestor_path -> set of descendant_paths.
Each path is assumed to use materialized path format with fixed-length segments.
Args:
paths (list of str): List of full document paths.
steplen (int): Length of each path segment.
Returns:
dict[str, set[str]]: Mapping from ancestor path to its descendant paths (including itself).
"""
ancestor_map = defaultdict(set)
for path in paths:
for i in range(steplen, len(path) + 1, steplen):
ancestor = path[:i]
ancestor_map[ancestor].add(path)
return ancestor_map
def filter_descendants(paths, root_paths, skip_sorting=False):
"""
Filters paths to keep only those that are descendants of any path in root_paths.
A path is considered a descendant of a root path if it starts with the root path.
If `skip_sorting` is not set to True, the function will sort both lists before
processing because both `paths` and `root_paths` need to be in lexicographic order
before going through the algorithm.
Args:
paths (iterable of str): List of paths to be filtered.
root_paths (iterable of str): List of paths to check as potential prefixes.
skip_sorting (bool): If True, assumes both `paths` and `root_paths` are already sorted.
Returns:
list of str: A list of sorted paths that are descendants of any path in `root_paths`.
"""
results = []
i = 0
n = len(root_paths)
if not skip_sorting:
paths.sort()
root_paths.sort()
for path in paths:
# Try to find a matching prefix in the sorted accessible paths
while i < n:
if path.startswith(root_paths[i]):
results.append(path)
break
if root_paths[i] < path:
i += 1
else:
# If paths[i] > path, no need to keep searching
break
return results
def base64_yjs_to_xml(base64_string):
"""Extract xml from base64 yjs document."""
decoded_bytes = base64.b64decode(base64_string)
# uint8_array = bytearray(decoded_bytes)
doc = pycrdt.Doc()
doc.apply_update(decoded_bytes)
return str(doc.get("document-store", type=pycrdt.XmlFragment))
def base64_yjs_to_text(base64_string):
"""Extract text from base64 yjs document."""
blocknote_structure = base64_yjs_to_xml(base64_string)
soup = BeautifulSoup(blocknote_structure, "lxml-xml")
return soup.get_text(separator=" ", strip=True)
def extract_attachments(content):
"""Helper method to extract media paths from a document's content."""
if not content:
return []
xml_content = base64_yjs_to_xml(content)
return re.findall(enums.MEDIA_STORAGE_URL_EXTRACT, xml_content)
def get_users_sharing_documents_with_cache_key(user):
"""Generate a unique cache key for each user."""
return f"users_sharing_documents_with_{user.id}"
def users_sharing_documents_with(user):
"""
Returns a map of users sharing documents with the given user,
sorted by last shared date.
"""
start_time = time.time()
cache_key = get_users_sharing_documents_with_cache_key(user)
cached_result = cache.get(cache_key)
if cached_result is not None:
elapsed = time.time() - start_time
logger.info(
"users_sharing_documents_with cache hit for user %s (took %.3fs)",
user.id,
elapsed,
)
return cached_result
user_docs_qs = models.DocumentAccess.objects.filter(user=user).values_list(
"document_id", flat=True
)
shared_qs = (
models.DocumentAccess.objects.filter(document_id__in=Subquery(user_docs_qs))
.exclude(user=user)
.values("user")
.annotate(last_shared=db.Max("created_at"))
)
result = {item["user"]: item["last_shared"] for item in shared_qs}
cache.set(cache_key, result, 86400) # Cache for 1 day
elapsed = time.time() - start_time
logger.info(
"users_sharing_documents_with cache miss for user %s (took %.3fs)",
user.id,
elapsed,
)
return result

View File

@@ -0,0 +1 @@
"""Core utilities package."""

View File

@@ -0,0 +1,24 @@
"""Dictionary utility functions."""
import re
def get_value_by_pattern(data, pattern):
"""
Get all values from keys matching a regex pattern in a dictionary.
Args:
data (dict): Source dictionary to search
pattern (str): Regex pattern to match against keys
Returns:
list: List of values for all matching keys, empty list if no matches
Example:
>>> get_value_by_pattern({"title.fr": "Bonjour", "id": 1}, r"^title\\.")
["Bonjour"]
>>> get_value_by_pattern({"title.fr": "Bonjour", "title.en": "Hello"}, r"^title\\.")
["Bonjour", "Hello"]
"""
regex = re.compile(pattern)
return [value for key, value in data.items() if regex.match(key)]

View File

@@ -0,0 +1,63 @@
"""Path and tree structure utilities."""
from collections import defaultdict
def get_ancestor_to_descendants_map(paths, steplen):
"""
Given a list of document paths, return a mapping of ancestor_path -> set of descendant_paths.
Each path is assumed to use materialized path format with fixed-length segments.
Args:
paths (list of str): List of full document paths.
steplen (int): Length of each path segment.
Returns:
dict[str, set[str]]: Mapping from ancestor path to its descendant paths (including itself).
"""
ancestor_map = defaultdict(set)
for path in paths:
for i in range(steplen, len(path) + 1, steplen):
ancestor = path[:i]
ancestor_map[ancestor].add(path)
return ancestor_map
def filter_descendants(paths, root_paths, skip_sorting=False):
"""
Filters paths to keep only those that are descendants of any path in root_paths.
A path is considered a descendant of a root path if it starts with the root path.
If `skip_sorting` is not set to True, the function will sort both lists before
processing because both `paths` and `root_paths` need to be in lexicographic order
before going through the algorithm.
Args:
paths (iterable of str): List of paths to be filtered.
root_paths (iterable of str): List of paths to check as potential prefixes.
skip_sorting (bool): If True, assumes both `paths` and `root_paths` are already sorted.
Returns:
list of str: A list of sorted paths that are descendants of any path in `root_paths`.
"""
results = []
i = 0
n = len(root_paths)
if not skip_sorting:
paths.sort()
root_paths.sort()
for path in paths:
# Try to find a matching prefix in the sorted accessible paths
while i < n:
if path.startswith(root_paths[i]):
results.append(path)
break
if root_paths[i] < path:
i += 1
else:
# If paths[i] > path, no need to keep searching
break
return results

View File

@@ -0,0 +1,56 @@
"""Treebeard path collision handling utilities."""
import logging
from django.conf import settings
from django.core.exceptions import ValidationError as DjangoValidationError
from django.db import IntegrityError, transaction
logger = logging.getLogger(__name__)
def _is_tree_path_collision(exc):
"""Return True when `exc` is caused by a Document.path uniqueness conflict.
Treebeard computes the materialized path by reading the current siblings;
under concurrency two callers may compute the same value. Depending on
timing this surfaces either as:
- `django.core.exceptions.ValidationError` raised by `full_clean()` /
`validate_unique()` before the INSERT (BaseModel.save calls full_clean),
- or `IntegrityError` from the database unique index when the validate
step misses the conflict.
"""
if isinstance(exc, DjangoValidationError):
message_dict = getattr(exc, "message_dict", None)
if message_dict is not None:
return "path" in message_dict
return "path" in str(exc).lower()
# search in the IntegrityError exception
return "impress_document_path_key" in str(exc).lower()
def create_tree_node_with_retry(create_fn):
"""Run `create_fn` in a fresh atomic block, retrying on path collisions.
The Document.path field carries a unique constraint, which is the source of
truth that prevents duplicate paths. On collision we let the failed
transaction roll back, and call `create_fn` again so treebeard recomputes
the path from the latest state.
"""
max_attempts = settings.TREEBEARD_PATH_COMPUTE_RETRY_MAX_ATTEMPTS
for attempt in range(max_attempts):
try:
with transaction.atomic():
return create_fn()
except (IntegrityError, DjangoValidationError) as exc:
if not _is_tree_path_collision(exc) or attempt == max_attempts - 1:
raise
logger.info(
"tree path collision on attempt %d/%d, retrying",
attempt + 1,
max_attempts,
)
raise RuntimeError("create_tree_node_with_retry exited without result")

View File

@@ -0,0 +1,55 @@
"""User sharing cache utilities."""
import logging
import time
from django.core.cache import cache
from django.db import models as db
from django.db.models import Subquery
from core import models
logger = logging.getLogger(__name__)
def get_users_sharing_documents_with_cache_key(user):
"""Generate a unique cache key for each user."""
return f"users_sharing_documents_with_{user.id}"
def users_sharing_documents_with(user):
"""
Returns a map of users sharing documents with the given user,
sorted by last shared date.
"""
start_time = time.time()
cache_key = get_users_sharing_documents_with_cache_key(user)
cached_result = cache.get(cache_key)
if cached_result is not None:
elapsed = time.time() - start_time
logger.info(
"users_sharing_documents_with cache hit for user %s (took %.3fs)",
user.id,
elapsed,
)
return cached_result
user_docs_qs = models.DocumentAccess.objects.filter(user=user).values_list(
"document_id", flat=True
)
shared_qs = (
models.DocumentAccess.objects.filter(document_id__in=Subquery(user_docs_qs))
.exclude(user=user)
.values("user")
.annotate(last_shared=db.Max("created_at"))
)
result = {item["user"]: item["last_shared"] for item in shared_qs}
cache.set(cache_key, result, 86400) # Cache for 1 day
elapsed = time.time() - start_time
logger.info(
"users_sharing_documents_with cache miss for user %s (took %.3fs)",
user.id,
elapsed,
)
return result

View File

@@ -0,0 +1,36 @@
"""Yjs document conversion utilities."""
import base64
import re
import pycrdt
from bs4 import BeautifulSoup
from core import enums
def base64_yjs_to_xml(base64_string):
"""Extract xml from base64 yjs document."""
decoded_bytes = base64.b64decode(base64_string)
doc = pycrdt.Doc()
doc.apply_update(decoded_bytes)
return str(doc.get("document-store", type=pycrdt.XmlFragment))
def base64_yjs_to_text(base64_string):
"""Extract text from base64 yjs document."""
blocknote_structure = base64_yjs_to_xml(base64_string)
soup = BeautifulSoup(blocknote_structure, "lxml-xml")
return soup.get_text(separator=" ", strip=True)
def extract_attachments(content):
"""Helper method to extract media paths from a document's content."""
if not content:
return []
xml_content = base64_yjs_to_xml(content)
return re.findall(enums.MEDIA_STORAGE_URL_EXTRACT, xml_content)

View File

@@ -1081,6 +1081,12 @@ class Base(Configuration):
60 * 60 * 24, environ_name="CONTENT_METADATA_CACHE_TIMEOUT", environ_prefix=None
)
TREEBEARD_PATH_COMPUTE_RETRY_MAX_ATTEMPTS = values.IntegerValue(
10,
environ_name="TREEBEARD_PATH_COMPUTE_RETRY_MAX_ATTEMPTS",
environ_prefix=None,
)
# pylint: disable=invalid-name
@property
def ENVIRONMENT(self):