import contextvars import os import socket from collections.abc import Callable from http.server import BaseHTTPRequestHandler from http.server import HTTPServer as BaseHTTPServer from ipaddress import IPv6Address, ip_address from threading import Thread, current_thread from typing import TYPE_CHECKING, Any, cast from django.db import DatabaseError, close_old_connections, connections from dramatiq.actor import Actor from dramatiq.broker import Broker from dramatiq.common import current_millis from dramatiq.message import Message from dramatiq.middleware.middleware import Middleware from structlog.stdlib import get_logger from django_dramatiq_postgres.conf import Conf from django_dramatiq_postgres.models import TaskBase, TaskState if TYPE_CHECKING: from django_dramatiq_postgres.broker import PostgresBroker class HTTPServerThread(Thread): """Base class for a thread which runs an HTTP Server. Mainly used for typing the `server` instance variable.""" server: HTTPServer | None = None class HTTPServer(BaseHTTPServer): def server_bind(self) -> None: self.socket.close() host, port = self.server_address[:2] host = cast(str, host) if host == "0.0.0.0" and socket.has_dualstack_ipv6(): # nosec host = "::" # nosec # Strip IPv6 brackets if host.startswith("[") and host.endswith("]"): host = host[1:-1] self.server_address = (host, port) self.address_family = ( socket.AF_INET6 if socket.has_dualstack_ipv6() and isinstance(ip_address(host), IPv6Address) else socket.AF_INET ) self.socket = socket.create_server( self.server_address, family=self.address_family, dualstack_ipv6=self.address_family == socket.AF_INET6, ) self.server_name = socket.getfqdn(host) self.server_port = port class DbConnectionMiddleware(Middleware): def _close_old_connections(self, *args: Any, **kwargs: Any) -> None: if Conf().test: return close_old_connections() before_process_message = _close_old_connections after_process_message = _close_old_connections def _close_connections(self, *args: Any, **kwargs: Any) -> None: connections.close_all() before_consumer_thread_shutdown = _close_connections before_worker_thread_shutdown = _close_connections before_worker_shutdown = _close_connections class TaskStateBeforeMiddleware(Middleware): def before_process_message(self, broker: PostgresBroker, message: Message[Any]) -> None: broker.query_set.filter( message_id=message.message_id, queue_name=message.queue_name, state=TaskState.CONSUMED, ).update( state=TaskState.PREPROCESS, ) class TaskStateAfterMiddleware(Middleware): def before_process_message(self, broker: PostgresBroker, message: Message[Any]) -> None: broker.query_set.filter( message_id=message.message_id, queue_name=message.queue_name, state=TaskState.PREPROCESS, ).update( state=TaskState.RUNNING, ) def after_skip_message(self, broker: PostgresBroker, message: Message[Any]) -> None: broker.query_set.filter( message_id=message.message_id, queue_name=message.queue_name, state=TaskState.RUNNING, ).update( state=TaskState.POSTPROCESS, ) def after_process_message( self, broker: PostgresBroker, message: Message[Any], *, result: Any | None = None, exception: Exception | None = None, ) -> None: self.after_skip_message(broker, message) class FullyQualifiedActorName(Middleware): def before_declare_actor(self, broker: Broker, actor: Actor[Any, Any]) -> None: actor.actor_name = f"{actor.fn.__module__}.{actor.fn.__name__}" class CurrentTaskNotFound(Exception): """ Not current task found. Did you call get_task outside a running task? """ class CurrentTask(Middleware): def __init__(self) -> None: self.logger = get_logger(__name__, type(self)) # This is a list of tasks, so that in tests, when a task calls another task, this acts as a pile _TASKS: contextvars.ContextVar[list[TaskBase] | None] = contextvars.ContextVar( "_TASKS", default=None, ) @classmethod def get_task(cls) -> TaskBase: task = cls._TASKS.get() if not task: raise CurrentTaskNotFound() return task[-1] def before_process_message(self, broker: Broker, message: Message[Any]) -> None: tasks = self._TASKS.get() if tasks is None: tasks = [] tasks.append(message.options["task"]) self._TASKS.set(tasks) def after_process_message( self, broker: Broker, message: Message[Any], *, result: Any | None = None, exception: Exception | None = None, ) -> None: tasks: list[TaskBase] | None = self._TASKS.get() if tasks is None or len(tasks) == 0: return task = tasks[-1] fields_to_exclude = { "message_id", "queue_name", "actor_name", "message", "state", "mtime", "retries", "eta", "result", "result_expiry", } fields_to_update = [ f.name for f in task._meta.get_fields() if f.name not in fields_to_exclude and f.concrete and not f.auto_created and not f.many_to_many ] if fields_to_update: try: task.save(update_fields=fields_to_update) except DatabaseError: pass self._TASKS.set(tasks[:-1]) def after_skip_message(self, broker: Broker, message: Message[Any]) -> None: self.after_process_message(broker, message) class _MetricsHandler(BaseHTTPRequestHandler): def do_GET(self) -> None: from prometheus_client import ( CONTENT_TYPE_LATEST, CollectorRegistry, generate_latest, multiprocess, ) registry = CollectorRegistry() multiprocess.MultiProcessCollector(registry) # type: ignore[no-untyped-call] output = generate_latest(registry) self.send_response(200) self.send_header("Content-Type", CONTENT_TYPE_LATEST) self.end_headers() self.wfile.write(output) def log_message(self, format: str, *args: Any) -> None: logger = get_logger(__name__, type(self)) logger.debug(format, *args) class MetricsMiddleware(Middleware): handler_class: type[BaseHTTPRequestHandler] = _MetricsHandler def __init__( self, prefix: str, labels: list[str] | None = None, ) -> None: super().__init__() self.prefix = prefix self.labels: list[str] = labels if labels is not None else ["queue_name", "actor_name"] self.delayed_messages: set[str] = set() self.message_start_times: dict[str, int] = {} @property def forks(self) -> list[Callable[[], None]]: from django_dramatiq_postgres.forks import worker_metrics return [worker_metrics] def before_worker_boot(self, broker: Broker, worker: Any) -> None: from prometheus_client import Counter, Gauge, Histogram self.total_messages = Counter( f"{self.prefix}_tasks_total", "The total number of tasks processed.", self.labels, ) self.total_errored_messages = Counter( f"{self.prefix}_tasks_errors_total", "The total number of errored tasks.", self.labels, ) self.total_retried_messages = Counter( f"{self.prefix}_tasks_retries_total", "The total number of retried tasks.", self.labels, ) self.total_rejected_messages = Counter( f"{self.prefix}_tasks_rejected_total", "The total number of dead-lettered tasks.", self.labels, ) self.in_progress_messages = Gauge( f"{self.prefix}_tasks_in_progress", "The number of tasks in progress.", self.labels, multiprocess_mode="livesum", ) self.in_progress_delayed_messages = Gauge( f"{self.prefix}_tasks_delayed_in_progress", "The number of delayed tasks in memory.", self.labels, ) self.messages_durations = Histogram( f"{self.prefix}_tasks_duration_milliseconds", "The time spent processing tasks.", self.labels, buckets=( 5, 10, 25, 50, 75, 100, 250, 500, 750, 1_000, 2_500, 5_000, 7_500, 10_000, 30_000, 60_000, 600_000, 900_000, 1_800_000, 3_600_000, float("inf"), ), ) def after_worker_shutdown(self, broker: Broker, worker: Any) -> None: from prometheus_client import multiprocess # TODO: worker_id multiprocess.mark_process_dead(os.getpid()) # type: ignore[no-untyped-call] def _make_labels(self, message: Message[Any]) -> list[str]: return [message.queue_name, message.actor_name] def after_nack(self, broker: Broker, message: Message[Any]) -> None: self.total_rejected_messages.labels(*self._make_labels(message)).inc() def after_enqueue(self, broker: Broker, message: Message[Any], delay: int) -> None: if "retries" in message.options: self.total_retried_messages.labels(*self._make_labels(message)).inc() def before_delay_message(self, broker: Broker, message: Message[Any]) -> None: self.delayed_messages.add(message.message_id) self.in_progress_delayed_messages.labels(*self._make_labels(message)).inc() def before_process_message(self, broker: Broker, message: Message[Any]) -> None: labels = self._make_labels(message) if message.message_id in self.delayed_messages: self.delayed_messages.remove(message.message_id) self.in_progress_delayed_messages.labels(*labels).dec() self.in_progress_messages.labels(*labels).inc() self.message_start_times[message.message_id] = current_millis() # type: ignore[no-untyped-call] def after_process_message( self, broker: Broker, message: Message[Any], *, result: Any | None = None, exception: Exception | None = None, ) -> None: labels = self._make_labels(message) message_start_time = self.message_start_times.pop(message.message_id, current_millis()) # type: ignore[no-untyped-call] message_duration = current_millis() - message_start_time # type: ignore[no-untyped-call] self.messages_durations.labels(*labels).observe(message_duration) self.in_progress_messages.labels(*labels).dec() self.total_messages.labels(*labels).inc() if exception is not None: self.total_errored_messages.labels(*labels).inc() after_skip_message = after_process_message @classmethod def run(cls, addr: str, port: int) -> None: try: server = HTTPServer((addr, port), cls.handler_class) thread = cast(HTTPServerThread, current_thread()) thread.server = server server.serve_forever() except OSError: get_logger(__name__, type(MetricsMiddleware)).warning( "Port is already in use, not starting metrics server" )