mirror of
https://github.com/goauthentik/authentik
synced 2026-04-27 18:07:15 +02:00
576 lines
22 KiB
Python
576 lines
22 KiB
Python
import asyncio
|
|
import functools
|
|
import types
|
|
import zlib
|
|
from base64 import b64decode
|
|
from contextlib import AbstractAsyncContextManager
|
|
from datetime import UTC, datetime, timedelta
|
|
from re import Pattern
|
|
from typing import Any, cast
|
|
from uuid import uuid4
|
|
|
|
import msgpack
|
|
from channels.layers import BaseChannelLayer
|
|
from django.db import DEFAULT_DB_ALIAS, connections
|
|
from django.utils.timezone import now
|
|
from psycopg import AsyncConnection, Notify, sql
|
|
from psycopg.conninfo import make_conninfo
|
|
from psycopg.errors import Error as PsycopgError
|
|
from psycopg_pool import AsyncConnectionPool
|
|
from structlog.stdlib import get_logger
|
|
|
|
from django_channels_postgres.models import NOTIFY_CHANNEL, GroupChannel, Message
|
|
|
|
LOGGER = get_logger()
|
|
|
|
|
|
GROUP_CHANNEL_TABLE = GroupChannel._meta.db_table
|
|
MESSAGE_TABLE = Message._meta.db_table
|
|
|
|
|
|
async def _async_proxy(
|
|
obj: PostgresChannelLayerLoopProxy,
|
|
name: str,
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
# Must be defined as a function and not a method due to
|
|
# https://bugs.python.org/issue38364
|
|
layer = obj._get_layer(allow_sync=False)
|
|
return await getattr(layer, name)(*args, **kwargs)
|
|
|
|
|
|
def _wrap_close(proxy: PostgresChannelLayerLoopProxy, loop: asyncio.AbstractEventLoop) -> None:
|
|
original_impl = loop.close
|
|
|
|
def _wrapper(self: asyncio.AbstractEventLoop, *args: Any, **kwargs: Any) -> None:
|
|
if loop in proxy._layers:
|
|
layer = proxy._layers[loop]
|
|
del proxy._layers[loop]
|
|
loop.run_until_complete(layer.flush())
|
|
self.close = original_impl # type: ignore[method-assign]
|
|
return self.close(*args, **kwargs)
|
|
|
|
loop.close = types.MethodType(_wrapper, loop) # type: ignore[method-assign]
|
|
|
|
|
|
class PostgresChannelLayerLoopProxy:
|
|
def __init__(
|
|
self,
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
self._args = args
|
|
self._kwargs = kwargs
|
|
self._kwargs["channel_layer"] = self
|
|
self._layers: dict[asyncio.AbstractEventLoop | None, PostgresChannelLoopLayer] = {}
|
|
|
|
def __getattr__(self, name: str) -> Any:
|
|
if name in (
|
|
"new_channel",
|
|
"send",
|
|
"receive",
|
|
"group_add",
|
|
"group_discard",
|
|
"group_send",
|
|
"flush",
|
|
):
|
|
return functools.partial(_async_proxy, self, name)
|
|
else:
|
|
return getattr(self._get_layer(allow_sync=True), name)
|
|
|
|
def serialize(self, message: dict[str, Any]) -> bytes:
|
|
"""Serializes message to a byte string."""
|
|
m = cast(bytes, msgpack.packb(message, use_bin_type=True))
|
|
c = zlib.compress(m, 6)
|
|
return c
|
|
|
|
def deserialize(self, message: bytes) -> dict[str, Any]:
|
|
"""Deserializes from a byte string."""
|
|
m = zlib.decompress(message)
|
|
return cast(dict[str, Any], msgpack.unpackb(m, raw=False))
|
|
|
|
def _get_layer(self, allow_sync: bool) -> PostgresChannelLoopLayer:
|
|
try:
|
|
loop = asyncio.get_running_loop()
|
|
except RuntimeError as exc:
|
|
if allow_sync:
|
|
# No loop configured, we will only allow sync APIs
|
|
loop = None
|
|
else:
|
|
raise exc
|
|
|
|
try:
|
|
layer = self._layers[loop]
|
|
except KeyError:
|
|
layer = PostgresChannelLoopLayer(*self._args, **self._kwargs)
|
|
self._layers[loop] = layer
|
|
if loop is not None:
|
|
_wrap_close(self, loop)
|
|
|
|
return layer
|
|
|
|
|
|
PostgresChannelLayer = PostgresChannelLayerLoopProxy
|
|
|
|
|
|
class PostgresChannelLoopLayer(BaseChannelLayer):
|
|
"""
|
|
Postgres channel layer.
|
|
|
|
It uses the NOTIFY/LISTEN functionality of postgres to broadcast messages
|
|
|
|
It also makes use of an internal message table to overcome the
|
|
8000bytes limit of Postgres' NOTIFY messages.
|
|
Which is a far cry from the channels standard of 1MB
|
|
This table has a trigger that sends out the `NOTIFY` signal.
|
|
|
|
Using a database also means messages are durable and will always be
|
|
available to consumers (as long as they're not expired).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
channel_layer: PostgresChannelLayerLoopProxy,
|
|
prefix: str = "asgi",
|
|
expiry: int = 60,
|
|
group_expiry: int = 86400,
|
|
capacity: int = 100,
|
|
channel_capacity: dict[Pattern[str] | str, int] | None = None,
|
|
using: str = DEFAULT_DB_ALIAS,
|
|
) -> None:
|
|
super().__init__(expiry=expiry, capacity=capacity, channel_capacity=channel_capacity)
|
|
|
|
self.group_expiry = group_expiry
|
|
self.prefix = prefix
|
|
assert isinstance(self.prefix, str), "Prefix must be unicode" # nosec
|
|
self.channel_layer = channel_layer
|
|
self.using = using
|
|
self._pool_lock = asyncio.Lock()
|
|
|
|
# Each consumer gets its own *specific* channel, created with the `new_channel()` method.
|
|
# This dict maps `channel_name` to a queue of messages for that channel.
|
|
self.channels: dict[str, asyncio.Queue[tuple[str, bytes | None]]] = {}
|
|
|
|
self._pool: AsyncConnectionPool | None = None
|
|
self.receiver = PostgresChannelLayerReceiver(self.using, self)
|
|
|
|
def make_conninfo(self) -> str:
|
|
db_params = connections[self.using].get_connection_params()
|
|
# Prevent psycopg from using the custom synchronous cursor factory from django
|
|
db_params.pop("cursor_factory")
|
|
db_params.pop("context")
|
|
return make_conninfo(conninfo="", **db_params, connect_timeout=10)
|
|
|
|
async def connection(self) -> AbstractAsyncContextManager[AsyncConnection]:
|
|
if self._pool is None:
|
|
async with self._pool_lock:
|
|
|
|
async def _configure_connection(conn: AsyncConnection) -> None:
|
|
await conn.set_autocommit(True)
|
|
conn.prepare_threshold = 0 # All statements should be prepared
|
|
conn.prepared_max = None # No limit on the number of prepared statements
|
|
|
|
self._pool = AsyncConnectionPool(
|
|
conninfo=self.make_conninfo(),
|
|
open=False,
|
|
configure=_configure_connection,
|
|
min_size=1,
|
|
max_size=4,
|
|
)
|
|
await self._pool.open(wait=True)
|
|
return self._pool.connection()
|
|
|
|
async def _subscribe_to_channel(self, channel: str) -> None:
|
|
self.channels[channel] = asyncio.Queue()
|
|
await self.receiver.subscribe(channel)
|
|
|
|
extensions = ["groups", "flush"]
|
|
|
|
### Channel layer API ###
|
|
|
|
async def send(self, channel: str, message: dict[str, Any]) -> None:
|
|
"""
|
|
Send a message onto a (general or specific) channel.
|
|
"""
|
|
# Typecheck
|
|
assert isinstance(message, dict), "message is not a dict" # nosec
|
|
assert self.require_valid_channel_name(channel), "Channel name not valid" # nosec
|
|
# Make sure the message does not contain reserved keys
|
|
assert "__asgi_channel__" not in message # nosec
|
|
|
|
async with await self.connection() as conn:
|
|
async with conn.cursor() as cursor:
|
|
await cursor.execute(
|
|
sql.SQL("""
|
|
INSERT INTO {table}
|
|
({id}, {channel}, {message}, {expires})
|
|
VALUES (%s, %s, %s, %s)
|
|
""").format(
|
|
table=sql.Identifier(MESSAGE_TABLE),
|
|
id=sql.Identifier("id"),
|
|
channel=sql.Identifier("channel"),
|
|
message=sql.Identifier("message"),
|
|
expires=sql.Identifier("expires"),
|
|
),
|
|
(
|
|
uuid4(),
|
|
channel,
|
|
self.channel_layer.serialize(message),
|
|
now() + timedelta(seconds=self.expiry),
|
|
),
|
|
)
|
|
|
|
async def new_channel(self, prefix: str = "specific") -> str:
|
|
"""
|
|
Returns a new channel name that can be used by something in our
|
|
process as a specific channel.
|
|
"""
|
|
channel = f"{self.prefix}.{prefix}.{uuid4().hex}"
|
|
await self._subscribe_to_channel(channel)
|
|
return channel
|
|
|
|
async def receive(self, channel: str) -> dict[str, Any]:
|
|
"""
|
|
Receive the first message that arrives on the channel.
|
|
If more than one coroutine waits on the same channel, the first waiter
|
|
will be given the message when it arrives.
|
|
"""
|
|
if channel not in self.channels:
|
|
await self._subscribe_to_channel(channel)
|
|
|
|
q = self.channels[channel]
|
|
try:
|
|
while True:
|
|
message_id, message = await q.get()
|
|
async with await self.connection() as conn:
|
|
async with conn.cursor() as cursor:
|
|
if message is None:
|
|
await cursor.execute(
|
|
sql.SQL("""
|
|
DELETE
|
|
FROM {table}
|
|
WHERE {table}.{id} = %s
|
|
RETURNING {table}.{message}
|
|
""").format(
|
|
table=sql.Identifier(MESSAGE_TABLE),
|
|
id=sql.Identifier("id"),
|
|
message=sql.Identifier("message"),
|
|
),
|
|
(message_id,),
|
|
)
|
|
row = await cursor.fetchone()
|
|
if row is None:
|
|
continue
|
|
message = row[0]
|
|
else:
|
|
await cursor.execute(
|
|
sql.SQL("""
|
|
DELETE
|
|
FROM {table}
|
|
WHERE {table}.{id} = %s
|
|
""").format(
|
|
table=sql.Identifier(MESSAGE_TABLE),
|
|
id=sql.Identifier("id"),
|
|
),
|
|
(message_id,),
|
|
)
|
|
break
|
|
except asyncio.CancelledError, TimeoutError, GeneratorExit:
|
|
# We assume here that the reason we are cancelled is because the consumer
|
|
# is exiting, therefore we need to cleanup by unsubscribe below. Indeed,
|
|
# currently the way that Django Channels works, this is a safe assumption.
|
|
# In the future, Django Channels could change to call a *new* method that
|
|
# would serve as the antithesis of `new_channel()`; this new method might
|
|
# be named `delete_channel()`. If that were the case, we would do the
|
|
# following cleanup from that new `delete_channel()` method, but, since
|
|
# that's not how Django Channels works (yet), we do the cleanup below:
|
|
if channel in self.channels:
|
|
del self.channels[channel]
|
|
try:
|
|
await self.receiver.unsubscribe(channel)
|
|
except BaseException as exc: # noqa: BLE001
|
|
LOGGER.warning("Unexpected exception while cleaning-up channel", exc=exc)
|
|
# We don't re-raise here because we want the CancelledError to be the one
|
|
# re-raised
|
|
raise
|
|
return self.channel_layer.deserialize(message)
|
|
|
|
# ==============================================================
|
|
# Groups extension
|
|
# ==============================================================
|
|
|
|
async def group_add(self, group: str, channel: str) -> None:
|
|
"""
|
|
Adds the channel name to a group.
|
|
"""
|
|
# Check the inputs
|
|
assert self.require_valid_group_name(group), "Group name not valid" # nosec
|
|
assert self.require_valid_channel_name(channel), "Channel name not valid" # nosec
|
|
|
|
group_key = self._group_key(group)
|
|
|
|
async with await self.connection() as conn:
|
|
async with conn.cursor() as cursor:
|
|
await cursor.execute(
|
|
sql.SQL("""
|
|
INSERT INTO {table}
|
|
({id}, {group_key}, {channel}, {expires})
|
|
VALUES (%s, %s, %s, %s)
|
|
""").format(
|
|
table=sql.Identifier(GROUP_CHANNEL_TABLE),
|
|
id=sql.Identifier("id"),
|
|
group_key=sql.Identifier("group_key"),
|
|
channel=sql.Identifier("channel"),
|
|
expires=sql.Identifier("expires"),
|
|
),
|
|
(
|
|
uuid4(),
|
|
group_key,
|
|
channel,
|
|
now() + timedelta(seconds=self.group_expiry),
|
|
),
|
|
)
|
|
|
|
async def group_discard(self, group: str, channel: str) -> None:
|
|
"""
|
|
Removes the channel from the named group if it is in the group;
|
|
does nothing otherwise (does not error)
|
|
"""
|
|
# Check the inputs
|
|
assert self.require_valid_group_name(group), "Group name not valid" # nosec
|
|
assert self.require_valid_channel_name(channel), "Channel name not valid" # nosec
|
|
|
|
group_key = self._group_key(group)
|
|
|
|
async with await self.connection() as conn:
|
|
async with conn.cursor() as cursor:
|
|
await cursor.execute(
|
|
sql.SQL("""
|
|
DELETE
|
|
FROM {table}
|
|
WHERE {table}.{group_key} = %s
|
|
AND {table}.{channel} = %s
|
|
""").format(
|
|
table=sql.Identifier(GROUP_CHANNEL_TABLE),
|
|
group_key=sql.Identifier("group_key"),
|
|
channel=sql.Identifier("channel"),
|
|
),
|
|
(group_key, channel),
|
|
)
|
|
|
|
async def group_send(self, group: str, message: dict[str, Any]) -> None:
|
|
"""
|
|
Sends a message to the entire group.
|
|
"""
|
|
assert self.require_valid_group_name(group), "Group name not valid" # nosec
|
|
|
|
group_key = self._group_key(group)
|
|
|
|
serialized_message = self.channel_layer.serialize(message)
|
|
|
|
async with await self.connection() as conn:
|
|
async with conn.cursor() as cursor:
|
|
await cursor.execute(
|
|
sql.SQL("""
|
|
SELECT DISTINCT {table}.{channel}
|
|
FROM {table}
|
|
WHERE {table}.{group_key} = %s
|
|
""").format(
|
|
table=sql.Identifier(GROUP_CHANNEL_TABLE),
|
|
channel=sql.Identifier("channel"),
|
|
group_key=sql.Identifier("group_key"),
|
|
),
|
|
(group_key,),
|
|
)
|
|
channels = [row[0] for row in await cursor.fetchall()]
|
|
messages = [
|
|
(uuid4(), channel, serialized_message, now() + timedelta(seconds=self.expiry))
|
|
for channel in channels
|
|
]
|
|
async with conn.cursor() as cursor:
|
|
await cursor.executemany(
|
|
sql.SQL("""
|
|
INSERT INTO {table}
|
|
({id}, {channel}, {message}, {expires})
|
|
VALUES (%s, %s, %s, %s)
|
|
""").format(
|
|
table=sql.Identifier(MESSAGE_TABLE),
|
|
id=sql.Identifier("id"),
|
|
channel=sql.Identifier("channel"),
|
|
message=sql.Identifier("message"),
|
|
expires=sql.Identifier("expires"),
|
|
),
|
|
messages,
|
|
)
|
|
|
|
def group_send_blocking(self, group: str, message: dict[str, Any]) -> None:
|
|
"""
|
|
Sends a message to the entire group, blocking version.
|
|
"""
|
|
assert self.require_valid_group_name(group), "Group name not valid" # nosec
|
|
|
|
group_key = self._group_key(group)
|
|
|
|
serialized_message = self.channel_layer.serialize(message)
|
|
|
|
with connections[self.using].cursor() as cursor:
|
|
cursor.execute(
|
|
sql.SQL("""
|
|
SELECT DISTINCT {table}.{channel}
|
|
FROM {table}
|
|
WHERE {table}.{group_key} = %s
|
|
""").format(
|
|
table=sql.Identifier(GROUP_CHANNEL_TABLE),
|
|
channel=sql.Identifier("channel"),
|
|
group_key=sql.Identifier("group_key"),
|
|
),
|
|
(group_key,),
|
|
)
|
|
channels = [row[0] for row in cursor.fetchall()]
|
|
messages = [
|
|
(uuid4(), channel, serialized_message, now() + timedelta(seconds=self.expiry))
|
|
for channel in channels
|
|
]
|
|
with connections[self.using].cursor() as cursor:
|
|
cursor.executemany(
|
|
sql.SQL("""
|
|
INSERT INTO {table}
|
|
({id}, {channel}, {message}, {expires})
|
|
VALUES (%s, %s, %s, %s)
|
|
""").format(
|
|
table=sql.Identifier(MESSAGE_TABLE),
|
|
id=sql.Identifier("id"),
|
|
channel=sql.Identifier("channel"),
|
|
message=sql.Identifier("message"),
|
|
expires=sql.Identifier("expires"),
|
|
),
|
|
messages,
|
|
)
|
|
|
|
def _group_key(self, group: str) -> str:
|
|
"""
|
|
Common function to make the storage key for the group.
|
|
"""
|
|
return f"{self.prefix}.group.{group}"
|
|
|
|
### Flush extension ###
|
|
|
|
async def flush(self) -> None:
|
|
"""
|
|
Deletes all messages and groups.
|
|
"""
|
|
self.channels = {}
|
|
await self.receiver.flush()
|
|
|
|
|
|
class PostgresChannelLayerReceiver:
|
|
def __init__(self, using: str, channel_layer: PostgresChannelLoopLayer) -> None:
|
|
self.using = using
|
|
self.channel_layer = channel_layer
|
|
self._subscribed_to: set[str] = set()
|
|
self._lock = asyncio.Lock()
|
|
self._receive_task: asyncio.Task[None] | None = None
|
|
|
|
async def subscribe(self, channel: str) -> None:
|
|
async with self._lock:
|
|
if channel not in self._subscribed_to:
|
|
self._ensure_receiver()
|
|
self._subscribed_to.add(channel)
|
|
|
|
async def unsubscribe(self, channel: str) -> None:
|
|
async with self._lock:
|
|
if channel in self._subscribed_to:
|
|
self._ensure_receiver()
|
|
self._subscribed_to.remove(channel)
|
|
|
|
async def flush(self) -> None:
|
|
async with self._lock:
|
|
if self._receive_task is not None:
|
|
self._receive_task.cancel()
|
|
try:
|
|
await self._receive_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._receive_task = None
|
|
self._subscribed_to = set()
|
|
|
|
async def _do_receiving(self) -> None:
|
|
while True:
|
|
try:
|
|
async with await AsyncConnection.connect(
|
|
conninfo=self.channel_layer.make_conninfo(),
|
|
autocommit=True,
|
|
) as conn:
|
|
await self._process_backlog(conn)
|
|
await conn.execute(
|
|
sql.SQL("LISTEN {channel}").format(channel=sql.Identifier(NOTIFY_CHANNEL))
|
|
)
|
|
while True:
|
|
async for notify in conn.notifies(timeout=30):
|
|
await self._receive_notify(notify)
|
|
except asyncio.CancelledError, TimeoutError, GeneratorExit:
|
|
raise
|
|
except PsycopgError as exc:
|
|
LOGGER.warning("Postgres connection is not healthy", exc=exc)
|
|
except BaseException as exc: # noqa: BLE001
|
|
LOGGER.warning("Unexpected exception in receive task", exc=exc, exc_info=True)
|
|
await asyncio.sleep(1)
|
|
|
|
async def _process_backlog(self, conn: AsyncConnection) -> None:
|
|
if not self._subscribed_to:
|
|
return
|
|
async with conn.cursor() as cursor:
|
|
await cursor.execute(
|
|
sql.SQL("""
|
|
DELETE
|
|
FROM {table}
|
|
WHERE {table}.{channel} = ANY(%s)
|
|
AND {table}.{expires} >= %s
|
|
RETURNING {table}.{id}, {table}.{channel}, {table}.{message}
|
|
""").format(
|
|
table=sql.Identifier(MESSAGE_TABLE),
|
|
id=sql.Identifier("id"),
|
|
channel=sql.Identifier("channel"),
|
|
expires=sql.Identifier("expires"),
|
|
message=sql.Identifier("message"),
|
|
),
|
|
(list(self._subscribed_to), now()),
|
|
)
|
|
async for row in cursor:
|
|
message_id, channel, message = row
|
|
self._receive_message(channel, message_id, message)
|
|
|
|
async def _receive_notify(self, notify: Notify) -> None:
|
|
payload = notify.payload
|
|
split_payload = payload.split(":")
|
|
message: bytes | None = None
|
|
match len(split_payload):
|
|
case 4:
|
|
message_id, channel, timestamp, base64_message = split_payload
|
|
if channel not in self._subscribed_to:
|
|
return
|
|
expires = datetime.fromtimestamp(float(timestamp), tz=UTC)
|
|
if expires < now():
|
|
return
|
|
message = b64decode(base64_message)
|
|
case 3:
|
|
message_id, channel, timestamp = split_payload
|
|
if channel not in self._subscribed_to:
|
|
return
|
|
expires = datetime.fromtimestamp(float(timestamp), tz=UTC)
|
|
if expires < now():
|
|
return
|
|
message = None
|
|
case _:
|
|
return
|
|
self._receive_message(channel, message_id, message)
|
|
|
|
def _receive_message(self, channel: str, message_id: str, message: bytes | None) -> None:
|
|
if (q := self.channel_layer.channels.get(channel)) is not None:
|
|
q.put_nowait((message_id, message))
|
|
|
|
def _ensure_receiver(self) -> None:
|
|
if self._receive_task is None:
|
|
self._receive_task = asyncio.ensure_future(self._do_receiving())
|