Files
authentik/packages/django-postgres-cache/django_postgres_cache/backend.py

142 lines
4.7 KiB
Python

import base64
import pickle # nosec
from datetime import UTC, datetime
from typing import Any
from django.conf import settings
from django.core.cache.backends.base import DEFAULT_TIMEOUT
from django.core.cache.backends.db import DatabaseCache as BaseDatabaseCache
from django.db import DatabaseError
from django.db.utils import ProgrammingError
from django.utils.module_loading import import_string
from django.utils.timezone import now
from psqlextra.types import ConflictAction
from django_postgres_cache.models import CacheEntry
class DatabaseCache(BaseDatabaseCache):
def __init__(self, table: str, params: dict[str, Any]) -> None:
super().__init__(table, params)
self.reverse_key_func = import_string(params["REVERSE_KEY_FUNCTION"])
self._table = CacheEntry._meta.db_table
self.cache_model_class = CacheEntry
def _cull(self, *args: Any, **kwargs: Any) -> None:
"""Stubbed out cull method as we cull in a background task"""
pass
def get(self, key: str, default: Any | None = None, version: int | None = None) -> Any:
try:
return super().get(key, default=default, version=version)
except ProgrammingError:
return default
def keys(self, keys_pattern: str, version: int | None = None) -> list[str]:
try:
return self._keys(keys_pattern, version=version)
except ProgrammingError:
return []
def _keys(self, keys_pattern: str, version: int | None = None) -> list[str]:
keys_pattern = self.make_key(keys_pattern.replace("*", ".*"), version=version)
return [
self.reverse_key_func(key)
for key in CacheEntry.objects.filter(cache_key__regex=keys_pattern).values_list(
"cache_key", flat=True
)
]
def ttl(self, key: str, version: int | None = None) -> int | None:
"""Get TTL left for a given key and version"""
key = self.make_and_validate_key(key, version=version)
entry = CacheEntry.objects.filter(cache_key=key).first()
if not entry:
return None
return int((entry.expires - now()).total_seconds())
def _base_set_expiry(self, timeout: float | None) -> datetime:
timeout = self.get_backend_timeout(timeout)
if timeout is None:
exp = datetime.max
else:
tz = UTC if settings.USE_TZ else None
exp = datetime.fromtimestamp(timeout, tz=tz)
exp.replace(microsecond=0)
return exp
def _base_set_data(
self,
key: Any,
value: Any,
timeout: float | None,
version: int | None = None,
) -> tuple[str, str, datetime]:
key = self.make_and_validate_key(key, version=version)
pickled = pickle.dumps(value, self.pickle_protocol)
# The DB column is expecting a string, so make sure the value is a
# string, not bytes. Refs #19274.
b64encoded = base64.b64encode(pickled).decode("latin1")
return (key, b64encoded, self._base_set_expiry(timeout))
def touch(
self,
key: Any,
timeout: float | None = DEFAULT_TIMEOUT,
version: int | None = None,
) -> bool:
key = self.make_and_validate_key(key, version=version)
expiry = self._base_set_expiry(timeout)
try:
count = CacheEntry.objects.filter(cache_key=key).update(expires=expiry)
return bool(count != 0)
except DatabaseError:
return False
def add(
self,
key: Any,
value: Any,
timeout: float | None = DEFAULT_TIMEOUT,
version: int | None = None,
) -> bool:
key, value, expiry = self._base_set_data(key, value, timeout, version)
try:
CacheEntry.objects.on_conflict(
["cache_key"],
ConflictAction.UPDATE,
update_values=dict(
expires=expiry,
),
).insert(
cache_key=key,
value=value,
expires=expiry,
)
# We don't know if the row already existed, we just return True for success
return True
except DatabaseError:
return False
def set(
self,
key: Any,
value: Any,
timeout: float | None = DEFAULT_TIMEOUT,
version: int | None = None,
) -> None:
key, value, expiry = self._base_set_data(key, value, timeout, version)
CacheEntry.objects.on_conflict(
["cache_key"],
ConflictAction.UPDATE,
).insert(
cache_key=key,
value=value,
expires=expiry,
)
def clear(self) -> None:
CacheEntry.objects.truncate()