Files
authentik/packages/django-dramatiq-postgres/django_dramatiq_postgres/middleware.py
2026-03-28 20:57:46 +01:00

368 lines
12 KiB
Python

import contextvars
import os
import socket
from collections.abc import Callable
from http.server import BaseHTTPRequestHandler
from http.server import HTTPServer as BaseHTTPServer
from ipaddress import IPv6Address, ip_address
from threading import Thread, current_thread
from typing import TYPE_CHECKING, Any, cast
from django.db import DatabaseError, close_old_connections, connections
from dramatiq.actor import Actor
from dramatiq.broker import Broker
from dramatiq.common import current_millis
from dramatiq.message import Message
from dramatiq.middleware.middleware import Middleware
from structlog.stdlib import get_logger
from django_dramatiq_postgres.conf import Conf
from django_dramatiq_postgres.models import TaskBase, TaskState
if TYPE_CHECKING:
from django_dramatiq_postgres.broker import PostgresBroker
class HTTPServerThread(Thread):
"""Base class for a thread which runs an HTTP Server. Mainly used for typing
the `server` instance variable."""
server: HTTPServer | None = None
class HTTPServer(BaseHTTPServer):
def server_bind(self) -> None:
self.socket.close()
host, port = self.server_address[:2]
host = cast(str, host)
if host == "0.0.0.0" and socket.has_dualstack_ipv6(): # nosec
host = "::" # nosec
# Strip IPv6 brackets
if host.startswith("[") and host.endswith("]"):
host = host[1:-1]
self.server_address = (host, port)
self.address_family = (
socket.AF_INET6
if socket.has_dualstack_ipv6() and isinstance(ip_address(host), IPv6Address)
else socket.AF_INET
)
self.socket = socket.create_server(
self.server_address,
family=self.address_family,
dualstack_ipv6=self.address_family == socket.AF_INET6,
)
self.server_name = socket.getfqdn(host)
self.server_port = port
class DbConnectionMiddleware(Middleware):
def _close_old_connections(self, *args: Any, **kwargs: Any) -> None:
if Conf().test:
return
close_old_connections()
before_process_message = _close_old_connections
after_process_message = _close_old_connections
def _close_connections(self, *args: Any, **kwargs: Any) -> None:
connections.close_all()
before_consumer_thread_shutdown = _close_connections
before_worker_thread_shutdown = _close_connections
before_worker_shutdown = _close_connections
class TaskStateBeforeMiddleware(Middleware):
def before_process_message(self, broker: PostgresBroker, message: Message[Any]) -> None:
broker.query_set.filter(
message_id=message.message_id,
queue_name=message.queue_name,
state=TaskState.CONSUMED,
).update(
state=TaskState.PREPROCESS,
)
class TaskStateAfterMiddleware(Middleware):
def before_process_message(self, broker: PostgresBroker, message: Message[Any]) -> None:
broker.query_set.filter(
message_id=message.message_id,
queue_name=message.queue_name,
state=TaskState.PREPROCESS,
).update(
state=TaskState.RUNNING,
)
def after_skip_message(self, broker: PostgresBroker, message: Message[Any]) -> None:
broker.query_set.filter(
message_id=message.message_id,
queue_name=message.queue_name,
state=TaskState.RUNNING,
).update(
state=TaskState.POSTPROCESS,
)
def after_process_message(
self,
broker: PostgresBroker,
message: Message[Any],
*,
result: Any | None = None,
exception: Exception | None = None,
) -> None:
self.after_skip_message(broker, message)
class FullyQualifiedActorName(Middleware):
def before_declare_actor(self, broker: Broker, actor: Actor[Any, Any]) -> None:
actor.actor_name = f"{actor.fn.__module__}.{actor.fn.__name__}"
class CurrentTaskNotFound(Exception):
"""
Not current task found. Did you call get_task outside a running task?
"""
class CurrentTask(Middleware):
def __init__(self) -> None:
self.logger = get_logger(__name__, type(self))
# This is a list of tasks, so that in tests, when a task calls another task, this acts as a pile
_TASKS: contextvars.ContextVar[list[TaskBase] | None] = contextvars.ContextVar(
"_TASKS",
default=None,
)
@classmethod
def get_task(cls) -> TaskBase:
task = cls._TASKS.get()
if not task:
raise CurrentTaskNotFound()
return task[-1]
def before_process_message(self, broker: Broker, message: Message[Any]) -> None:
tasks = self._TASKS.get()
if tasks is None:
tasks = []
tasks.append(message.options["task"])
self._TASKS.set(tasks)
def after_process_message(
self,
broker: Broker,
message: Message[Any],
*,
result: Any | None = None,
exception: Exception | None = None,
) -> None:
tasks: list[TaskBase] | None = self._TASKS.get()
if tasks is None or len(tasks) == 0:
return
task = tasks[-1]
fields_to_exclude = {
"message_id",
"queue_name",
"actor_name",
"message",
"state",
"mtime",
"retries",
"eta",
"result",
"result_expiry",
}
fields_to_update = [
f.name
for f in task._meta.get_fields()
if f.name not in fields_to_exclude
and f.concrete
and not f.auto_created
and not f.many_to_many
]
if fields_to_update:
try:
task.save(update_fields=fields_to_update)
except DatabaseError:
pass
self._TASKS.set(tasks[:-1])
def after_skip_message(self, broker: Broker, message: Message[Any]) -> None:
self.after_process_message(broker, message)
class _MetricsHandler(BaseHTTPRequestHandler):
def do_GET(self) -> None:
from prometheus_client import (
CONTENT_TYPE_LATEST,
CollectorRegistry,
generate_latest,
multiprocess,
)
registry = CollectorRegistry()
multiprocess.MultiProcessCollector(registry) # type: ignore[no-untyped-call]
output = generate_latest(registry)
self.send_response(200)
self.send_header("Content-Type", CONTENT_TYPE_LATEST)
self.end_headers()
self.wfile.write(output)
def log_message(self, format: str, *args: Any) -> None:
logger = get_logger(__name__, type(self))
logger.debug(format, *args)
class MetricsMiddleware(Middleware):
handler_class: type[BaseHTTPRequestHandler] = _MetricsHandler
def __init__(
self,
prefix: str,
labels: list[str] | None = None,
) -> None:
super().__init__()
self.prefix = prefix
self.labels: list[str] = labels if labels is not None else ["queue_name", "actor_name"]
self.delayed_messages: set[str] = set()
self.message_start_times: dict[str, int] = {}
@property
def forks(self) -> list[Callable[[], None]]:
from django_dramatiq_postgres.forks import worker_metrics
return [worker_metrics]
def before_worker_boot(self, broker: Broker, worker: Any) -> None:
from prometheus_client import Counter, Gauge, Histogram
self.total_messages = Counter(
f"{self.prefix}_tasks_total",
"The total number of tasks processed.",
self.labels,
)
self.total_errored_messages = Counter(
f"{self.prefix}_tasks_errors_total",
"The total number of errored tasks.",
self.labels,
)
self.total_retried_messages = Counter(
f"{self.prefix}_tasks_retries_total",
"The total number of retried tasks.",
self.labels,
)
self.total_rejected_messages = Counter(
f"{self.prefix}_tasks_rejected_total",
"The total number of dead-lettered tasks.",
self.labels,
)
self.in_progress_messages = Gauge(
f"{self.prefix}_tasks_in_progress",
"The number of tasks in progress.",
self.labels,
multiprocess_mode="livesum",
)
self.in_progress_delayed_messages = Gauge(
f"{self.prefix}_tasks_delayed_in_progress",
"The number of delayed tasks in memory.",
self.labels,
)
self.messages_durations = Histogram(
f"{self.prefix}_tasks_duration_milliseconds",
"The time spent processing tasks.",
self.labels,
buckets=(
5,
10,
25,
50,
75,
100,
250,
500,
750,
1_000,
2_500,
5_000,
7_500,
10_000,
30_000,
60_000,
600_000,
900_000,
1_800_000,
3_600_000,
float("inf"),
),
)
def after_worker_shutdown(self, broker: Broker, worker: Any) -> None:
from prometheus_client import multiprocess
# TODO: worker_id
multiprocess.mark_process_dead(os.getpid()) # type: ignore[no-untyped-call]
def _make_labels(self, message: Message[Any]) -> list[str]:
return [message.queue_name, message.actor_name]
def after_nack(self, broker: Broker, message: Message[Any]) -> None:
self.total_rejected_messages.labels(*self._make_labels(message)).inc()
def after_enqueue(self, broker: Broker, message: Message[Any], delay: int) -> None:
if "retries" in message.options:
self.total_retried_messages.labels(*self._make_labels(message)).inc()
def before_delay_message(self, broker: Broker, message: Message[Any]) -> None:
self.delayed_messages.add(message.message_id)
self.in_progress_delayed_messages.labels(*self._make_labels(message)).inc()
def before_process_message(self, broker: Broker, message: Message[Any]) -> None:
labels = self._make_labels(message)
if message.message_id in self.delayed_messages:
self.delayed_messages.remove(message.message_id)
self.in_progress_delayed_messages.labels(*labels).dec()
self.in_progress_messages.labels(*labels).inc()
self.message_start_times[message.message_id] = current_millis() # type: ignore[no-untyped-call]
def after_process_message(
self,
broker: Broker,
message: Message[Any],
*,
result: Any | None = None,
exception: Exception | None = None,
) -> None:
labels = self._make_labels(message)
message_start_time = self.message_start_times.pop(message.message_id, current_millis()) # type: ignore[no-untyped-call]
message_duration = current_millis() - message_start_time # type: ignore[no-untyped-call]
self.messages_durations.labels(*labels).observe(message_duration)
self.in_progress_messages.labels(*labels).dec()
self.total_messages.labels(*labels).inc()
if exception is not None:
self.total_errored_messages.labels(*labels).inc()
after_skip_message = after_process_message
@classmethod
def run(cls, addr: str, port: int) -> None:
try:
server = HTTPServer((addr, port), cls.handler_class)
thread = cast(HTTPServerThread, current_thread())
thread.server = server
server.serve_forever()
except OSError:
get_logger(__name__, type(MetricsMiddleware)).warning(
"Port is already in use, not starting metrics server"
)