Files
authentik/packages/django-channels-postgres/django_channels_postgres/layer.py

576 lines
22 KiB
Python

import asyncio
import functools
import types
import zlib
from base64 import b64decode
from contextlib import AbstractAsyncContextManager
from datetime import UTC, datetime, timedelta
from re import Pattern
from typing import Any, cast
from uuid import uuid4
import msgpack
from channels.layers import BaseChannelLayer
from django.db import DEFAULT_DB_ALIAS, connections
from django.utils.timezone import now
from psycopg import AsyncConnection, Notify, sql
from psycopg.conninfo import make_conninfo
from psycopg.errors import Error as PsycopgError
from psycopg_pool import AsyncConnectionPool
from structlog.stdlib import get_logger
from django_channels_postgres.models import NOTIFY_CHANNEL, GroupChannel, Message
LOGGER = get_logger()
GROUP_CHANNEL_TABLE = GroupChannel._meta.db_table
MESSAGE_TABLE = Message._meta.db_table
async def _async_proxy(
obj: PostgresChannelLayerLoopProxy,
name: str,
*args: Any,
**kwargs: Any,
) -> Any:
# Must be defined as a function and not a method due to
# https://bugs.python.org/issue38364
layer = obj._get_layer(allow_sync=False)
return await getattr(layer, name)(*args, **kwargs)
def _wrap_close(proxy: PostgresChannelLayerLoopProxy, loop: asyncio.AbstractEventLoop) -> None:
original_impl = loop.close
def _wrapper(self: asyncio.AbstractEventLoop, *args: Any, **kwargs: Any) -> None:
if loop in proxy._layers:
layer = proxy._layers[loop]
del proxy._layers[loop]
loop.run_until_complete(layer.flush())
self.close = original_impl # type: ignore[method-assign]
return self.close(*args, **kwargs)
loop.close = types.MethodType(_wrapper, loop) # type: ignore[method-assign]
class PostgresChannelLayerLoopProxy:
def __init__(
self,
*args: Any,
**kwargs: Any,
) -> None:
self._args = args
self._kwargs = kwargs
self._kwargs["channel_layer"] = self
self._layers: dict[asyncio.AbstractEventLoop | None, PostgresChannelLoopLayer] = {}
def __getattr__(self, name: str) -> Any:
if name in (
"new_channel",
"send",
"receive",
"group_add",
"group_discard",
"group_send",
"flush",
):
return functools.partial(_async_proxy, self, name)
else:
return getattr(self._get_layer(allow_sync=True), name)
def serialize(self, message: dict[str, Any]) -> bytes:
"""Serializes message to a byte string."""
m = cast(bytes, msgpack.packb(message, use_bin_type=True))
c = zlib.compress(m, 6)
return c
def deserialize(self, message: bytes) -> dict[str, Any]:
"""Deserializes from a byte string."""
m = zlib.decompress(message)
return cast(dict[str, Any], msgpack.unpackb(m, raw=False))
def _get_layer(self, allow_sync: bool) -> PostgresChannelLoopLayer:
try:
loop = asyncio.get_running_loop()
except RuntimeError as exc:
if allow_sync:
# No loop configured, we will only allow sync APIs
loop = None
else:
raise exc
try:
layer = self._layers[loop]
except KeyError:
layer = PostgresChannelLoopLayer(*self._args, **self._kwargs)
self._layers[loop] = layer
if loop is not None:
_wrap_close(self, loop)
return layer
PostgresChannelLayer = PostgresChannelLayerLoopProxy
class PostgresChannelLoopLayer(BaseChannelLayer):
"""
Postgres channel layer.
It uses the NOTIFY/LISTEN functionality of postgres to broadcast messages
It also makes use of an internal message table to overcome the
8000bytes limit of Postgres' NOTIFY messages.
Which is a far cry from the channels standard of 1MB
This table has a trigger that sends out the `NOTIFY` signal.
Using a database also means messages are durable and will always be
available to consumers (as long as they're not expired).
"""
def __init__(
self,
channel_layer: PostgresChannelLayerLoopProxy,
prefix: str = "asgi",
expiry: int = 60,
group_expiry: int = 86400,
capacity: int = 100,
channel_capacity: dict[Pattern[str] | str, int] | None = None,
using: str = DEFAULT_DB_ALIAS,
) -> None:
super().__init__(expiry=expiry, capacity=capacity, channel_capacity=channel_capacity)
self.group_expiry = group_expiry
self.prefix = prefix
assert isinstance(self.prefix, str), "Prefix must be unicode" # nosec
self.channel_layer = channel_layer
self.using = using
self._pool_lock = asyncio.Lock()
# Each consumer gets its own *specific* channel, created with the `new_channel()` method.
# This dict maps `channel_name` to a queue of messages for that channel.
self.channels: dict[str, asyncio.Queue[tuple[str, bytes | None]]] = {}
self._pool: AsyncConnectionPool | None = None
self.receiver = PostgresChannelLayerReceiver(self.using, self)
def make_conninfo(self) -> str:
db_params = connections[self.using].get_connection_params()
# Prevent psycopg from using the custom synchronous cursor factory from django
db_params.pop("cursor_factory")
db_params.pop("context")
return make_conninfo(conninfo="", **db_params, connect_timeout=10)
async def connection(self) -> AbstractAsyncContextManager[AsyncConnection]:
if self._pool is None:
async with self._pool_lock:
async def _configure_connection(conn: AsyncConnection) -> None:
await conn.set_autocommit(True)
conn.prepare_threshold = 0 # All statements should be prepared
conn.prepared_max = None # No limit on the number of prepared statements
self._pool = AsyncConnectionPool(
conninfo=self.make_conninfo(),
open=False,
configure=_configure_connection,
min_size=1,
max_size=4,
)
await self._pool.open(wait=True)
return self._pool.connection()
async def _subscribe_to_channel(self, channel: str) -> None:
self.channels[channel] = asyncio.Queue()
await self.receiver.subscribe(channel)
extensions = ["groups", "flush"]
### Channel layer API ###
async def send(self, channel: str, message: dict[str, Any]) -> None:
"""
Send a message onto a (general or specific) channel.
"""
# Typecheck
assert isinstance(message, dict), "message is not a dict" # nosec
assert self.require_valid_channel_name(channel), "Channel name not valid" # nosec
# Make sure the message does not contain reserved keys
assert "__asgi_channel__" not in message # nosec
async with await self.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
sql.SQL("""
INSERT INTO {table}
({id}, {channel}, {message}, {expires})
VALUES (%s, %s, %s, %s)
""").format(
table=sql.Identifier(MESSAGE_TABLE),
id=sql.Identifier("id"),
channel=sql.Identifier("channel"),
message=sql.Identifier("message"),
expires=sql.Identifier("expires"),
),
(
uuid4(),
channel,
self.channel_layer.serialize(message),
now() + timedelta(seconds=self.expiry),
),
)
async def new_channel(self, prefix: str = "specific") -> str:
"""
Returns a new channel name that can be used by something in our
process as a specific channel.
"""
channel = f"{self.prefix}.{prefix}.{uuid4().hex}"
await self._subscribe_to_channel(channel)
return channel
async def receive(self, channel: str) -> dict[str, Any]:
"""
Receive the first message that arrives on the channel.
If more than one coroutine waits on the same channel, the first waiter
will be given the message when it arrives.
"""
if channel not in self.channels:
await self._subscribe_to_channel(channel)
q = self.channels[channel]
try:
while True:
message_id, message = await q.get()
async with await self.connection() as conn:
async with conn.cursor() as cursor:
if message is None:
await cursor.execute(
sql.SQL("""
DELETE
FROM {table}
WHERE {table}.{id} = %s
RETURNING {table}.{message}
""").format(
table=sql.Identifier(MESSAGE_TABLE),
id=sql.Identifier("id"),
message=sql.Identifier("message"),
),
(message_id,),
)
row = await cursor.fetchone()
if row is None:
continue
message = row[0]
else:
await cursor.execute(
sql.SQL("""
DELETE
FROM {table}
WHERE {table}.{id} = %s
""").format(
table=sql.Identifier(MESSAGE_TABLE),
id=sql.Identifier("id"),
),
(message_id,),
)
break
except asyncio.CancelledError, TimeoutError, GeneratorExit:
# We assume here that the reason we are cancelled is because the consumer
# is exiting, therefore we need to cleanup by unsubscribe below. Indeed,
# currently the way that Django Channels works, this is a safe assumption.
# In the future, Django Channels could change to call a *new* method that
# would serve as the antithesis of `new_channel()`; this new method might
# be named `delete_channel()`. If that were the case, we would do the
# following cleanup from that new `delete_channel()` method, but, since
# that's not how Django Channels works (yet), we do the cleanup below:
if channel in self.channels:
del self.channels[channel]
try:
await self.receiver.unsubscribe(channel)
except BaseException as exc: # noqa: BLE001
LOGGER.warning("Unexpected exception while cleaning-up channel", exc=exc)
# We don't re-raise here because we want the CancelledError to be the one
# re-raised
raise
return self.channel_layer.deserialize(message)
# ==============================================================
# Groups extension
# ==============================================================
async def group_add(self, group: str, channel: str) -> None:
"""
Adds the channel name to a group.
"""
# Check the inputs
assert self.require_valid_group_name(group), "Group name not valid" # nosec
assert self.require_valid_channel_name(channel), "Channel name not valid" # nosec
group_key = self._group_key(group)
async with await self.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
sql.SQL("""
INSERT INTO {table}
({id}, {group_key}, {channel}, {expires})
VALUES (%s, %s, %s, %s)
""").format(
table=sql.Identifier(GROUP_CHANNEL_TABLE),
id=sql.Identifier("id"),
group_key=sql.Identifier("group_key"),
channel=sql.Identifier("channel"),
expires=sql.Identifier("expires"),
),
(
uuid4(),
group_key,
channel,
now() + timedelta(seconds=self.group_expiry),
),
)
async def group_discard(self, group: str, channel: str) -> None:
"""
Removes the channel from the named group if it is in the group;
does nothing otherwise (does not error)
"""
# Check the inputs
assert self.require_valid_group_name(group), "Group name not valid" # nosec
assert self.require_valid_channel_name(channel), "Channel name not valid" # nosec
group_key = self._group_key(group)
async with await self.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
sql.SQL("""
DELETE
FROM {table}
WHERE {table}.{group_key} = %s
AND {table}.{channel} = %s
""").format(
table=sql.Identifier(GROUP_CHANNEL_TABLE),
group_key=sql.Identifier("group_key"),
channel=sql.Identifier("channel"),
),
(group_key, channel),
)
async def group_send(self, group: str, message: dict[str, Any]) -> None:
"""
Sends a message to the entire group.
"""
assert self.require_valid_group_name(group), "Group name not valid" # nosec
group_key = self._group_key(group)
serialized_message = self.channel_layer.serialize(message)
async with await self.connection() as conn:
async with conn.cursor() as cursor:
await cursor.execute(
sql.SQL("""
SELECT DISTINCT {table}.{channel}
FROM {table}
WHERE {table}.{group_key} = %s
""").format(
table=sql.Identifier(GROUP_CHANNEL_TABLE),
channel=sql.Identifier("channel"),
group_key=sql.Identifier("group_key"),
),
(group_key,),
)
channels = [row[0] for row in await cursor.fetchall()]
messages = [
(uuid4(), channel, serialized_message, now() + timedelta(seconds=self.expiry))
for channel in channels
]
async with conn.cursor() as cursor:
await cursor.executemany(
sql.SQL("""
INSERT INTO {table}
({id}, {channel}, {message}, {expires})
VALUES (%s, %s, %s, %s)
""").format(
table=sql.Identifier(MESSAGE_TABLE),
id=sql.Identifier("id"),
channel=sql.Identifier("channel"),
message=sql.Identifier("message"),
expires=sql.Identifier("expires"),
),
messages,
)
def group_send_blocking(self, group: str, message: dict[str, Any]) -> None:
"""
Sends a message to the entire group, blocking version.
"""
assert self.require_valid_group_name(group), "Group name not valid" # nosec
group_key = self._group_key(group)
serialized_message = self.channel_layer.serialize(message)
with connections[self.using].cursor() as cursor:
cursor.execute(
sql.SQL("""
SELECT DISTINCT {table}.{channel}
FROM {table}
WHERE {table}.{group_key} = %s
""").format(
table=sql.Identifier(GROUP_CHANNEL_TABLE),
channel=sql.Identifier("channel"),
group_key=sql.Identifier("group_key"),
),
(group_key,),
)
channels = [row[0] for row in cursor.fetchall()]
messages = [
(uuid4(), channel, serialized_message, now() + timedelta(seconds=self.expiry))
for channel in channels
]
with connections[self.using].cursor() as cursor:
cursor.executemany(
sql.SQL("""
INSERT INTO {table}
({id}, {channel}, {message}, {expires})
VALUES (%s, %s, %s, %s)
""").format(
table=sql.Identifier(MESSAGE_TABLE),
id=sql.Identifier("id"),
channel=sql.Identifier("channel"),
message=sql.Identifier("message"),
expires=sql.Identifier("expires"),
),
messages,
)
def _group_key(self, group: str) -> str:
"""
Common function to make the storage key for the group.
"""
return f"{self.prefix}.group.{group}"
### Flush extension ###
async def flush(self) -> None:
"""
Deletes all messages and groups.
"""
self.channels = {}
await self.receiver.flush()
class PostgresChannelLayerReceiver:
def __init__(self, using: str, channel_layer: PostgresChannelLoopLayer) -> None:
self.using = using
self.channel_layer = channel_layer
self._subscribed_to: set[str] = set()
self._lock = asyncio.Lock()
self._receive_task: asyncio.Task[None] | None = None
async def subscribe(self, channel: str) -> None:
async with self._lock:
if channel not in self._subscribed_to:
self._ensure_receiver()
self._subscribed_to.add(channel)
async def unsubscribe(self, channel: str) -> None:
async with self._lock:
if channel in self._subscribed_to:
self._ensure_receiver()
self._subscribed_to.remove(channel)
async def flush(self) -> None:
async with self._lock:
if self._receive_task is not None:
self._receive_task.cancel()
try:
await self._receive_task
except asyncio.CancelledError:
pass
self._receive_task = None
self._subscribed_to = set()
async def _do_receiving(self) -> None:
while True:
try:
async with await AsyncConnection.connect(
conninfo=self.channel_layer.make_conninfo(),
autocommit=True,
) as conn:
await self._process_backlog(conn)
await conn.execute(
sql.SQL("LISTEN {channel}").format(channel=sql.Identifier(NOTIFY_CHANNEL))
)
while True:
async for notify in conn.notifies(timeout=30):
await self._receive_notify(notify)
except asyncio.CancelledError, TimeoutError, GeneratorExit:
raise
except PsycopgError as exc:
LOGGER.warning("Postgres connection is not healthy", exc=exc)
except BaseException as exc: # noqa: BLE001
LOGGER.warning("Unexpected exception in receive task", exc=exc, exc_info=True)
await asyncio.sleep(1)
async def _process_backlog(self, conn: AsyncConnection) -> None:
if not self._subscribed_to:
return
async with conn.cursor() as cursor:
await cursor.execute(
sql.SQL("""
DELETE
FROM {table}
WHERE {table}.{channel} = ANY(%s)
AND {table}.{expires} >= %s
RETURNING {table}.{id}, {table}.{channel}, {table}.{message}
""").format(
table=sql.Identifier(MESSAGE_TABLE),
id=sql.Identifier("id"),
channel=sql.Identifier("channel"),
expires=sql.Identifier("expires"),
message=sql.Identifier("message"),
),
(list(self._subscribed_to), now()),
)
async for row in cursor:
message_id, channel, message = row
self._receive_message(channel, message_id, message)
async def _receive_notify(self, notify: Notify) -> None:
payload = notify.payload
split_payload = payload.split(":")
message: bytes | None = None
match len(split_payload):
case 4:
message_id, channel, timestamp, base64_message = split_payload
if channel not in self._subscribed_to:
return
expires = datetime.fromtimestamp(float(timestamp), tz=UTC)
if expires < now():
return
message = b64decode(base64_message)
case 3:
message_id, channel, timestamp = split_payload
if channel not in self._subscribed_to:
return
expires = datetime.fromtimestamp(float(timestamp), tz=UTC)
if expires < now():
return
message = None
case _:
return
self._receive_message(channel, message_id, message)
def _receive_message(self, channel: str, message_id: str, message: bytes | None) -> None:
if (q := self.channel_layer.channels.get(channel)) is not None:
q.put_nowait((message_id, message))
def _ensure_receiver(self) -> None:
if self._receive_task is None:
self._receive_task = asyncio.ensure_future(self._do_receiving())