mirror of
https://github.com/goauthentik/authentik
synced 2026-04-28 10:28:22 +02:00
368 lines
12 KiB
Python
368 lines
12 KiB
Python
import contextvars
|
|
import os
|
|
import socket
|
|
from collections.abc import Callable
|
|
from http.server import BaseHTTPRequestHandler
|
|
from http.server import HTTPServer as BaseHTTPServer
|
|
from ipaddress import IPv6Address, ip_address
|
|
from threading import Thread, current_thread
|
|
from typing import TYPE_CHECKING, Any, cast
|
|
|
|
from django.db import DatabaseError, close_old_connections, connections
|
|
from dramatiq.actor import Actor
|
|
from dramatiq.broker import Broker
|
|
from dramatiq.common import current_millis
|
|
from dramatiq.message import Message
|
|
from dramatiq.middleware.middleware import Middleware
|
|
from structlog.stdlib import get_logger
|
|
|
|
from django_dramatiq_postgres.conf import Conf
|
|
from django_dramatiq_postgres.models import TaskBase, TaskState
|
|
|
|
if TYPE_CHECKING:
|
|
from django_dramatiq_postgres.broker import PostgresBroker
|
|
|
|
|
|
class HTTPServerThread(Thread):
|
|
"""Base class for a thread which runs an HTTP Server. Mainly used for typing
|
|
the `server` instance variable."""
|
|
|
|
server: HTTPServer | None = None
|
|
|
|
|
|
class HTTPServer(BaseHTTPServer):
|
|
def server_bind(self) -> None:
|
|
self.socket.close()
|
|
|
|
host, port = self.server_address[:2]
|
|
host = cast(str, host)
|
|
if host == "0.0.0.0" and socket.has_dualstack_ipv6(): # nosec
|
|
host = "::" # nosec
|
|
|
|
# Strip IPv6 brackets
|
|
if host.startswith("[") and host.endswith("]"):
|
|
host = host[1:-1]
|
|
|
|
self.server_address = (host, port)
|
|
|
|
self.address_family = (
|
|
socket.AF_INET6
|
|
if socket.has_dualstack_ipv6() and isinstance(ip_address(host), IPv6Address)
|
|
else socket.AF_INET
|
|
)
|
|
|
|
self.socket = socket.create_server(
|
|
self.server_address,
|
|
family=self.address_family,
|
|
dualstack_ipv6=self.address_family == socket.AF_INET6,
|
|
)
|
|
|
|
self.server_name = socket.getfqdn(host)
|
|
self.server_port = port
|
|
|
|
|
|
class DbConnectionMiddleware(Middleware):
|
|
def _close_old_connections(self, *args: Any, **kwargs: Any) -> None:
|
|
if Conf().test:
|
|
return
|
|
close_old_connections()
|
|
|
|
before_process_message = _close_old_connections
|
|
after_process_message = _close_old_connections
|
|
|
|
def _close_connections(self, *args: Any, **kwargs: Any) -> None:
|
|
connections.close_all()
|
|
|
|
before_consumer_thread_shutdown = _close_connections
|
|
before_worker_thread_shutdown = _close_connections
|
|
before_worker_shutdown = _close_connections
|
|
|
|
|
|
class TaskStateBeforeMiddleware(Middleware):
|
|
def before_process_message(self, broker: PostgresBroker, message: Message[Any]) -> None:
|
|
broker.query_set.filter(
|
|
message_id=message.message_id,
|
|
queue_name=message.queue_name,
|
|
state=TaskState.CONSUMED,
|
|
).update(
|
|
state=TaskState.PREPROCESS,
|
|
)
|
|
|
|
|
|
class TaskStateAfterMiddleware(Middleware):
|
|
def before_process_message(self, broker: PostgresBroker, message: Message[Any]) -> None:
|
|
broker.query_set.filter(
|
|
message_id=message.message_id,
|
|
queue_name=message.queue_name,
|
|
state=TaskState.PREPROCESS,
|
|
).update(
|
|
state=TaskState.RUNNING,
|
|
)
|
|
|
|
def after_skip_message(self, broker: PostgresBroker, message: Message[Any]) -> None:
|
|
broker.query_set.filter(
|
|
message_id=message.message_id,
|
|
queue_name=message.queue_name,
|
|
state=TaskState.RUNNING,
|
|
).update(
|
|
state=TaskState.POSTPROCESS,
|
|
)
|
|
|
|
def after_process_message(
|
|
self,
|
|
broker: PostgresBroker,
|
|
message: Message[Any],
|
|
*,
|
|
result: Any | None = None,
|
|
exception: Exception | None = None,
|
|
) -> None:
|
|
self.after_skip_message(broker, message)
|
|
|
|
|
|
class FullyQualifiedActorName(Middleware):
|
|
def before_declare_actor(self, broker: Broker, actor: Actor[Any, Any]) -> None:
|
|
actor.actor_name = f"{actor.fn.__module__}.{actor.fn.__name__}"
|
|
|
|
|
|
class CurrentTaskNotFound(Exception):
|
|
"""
|
|
Not current task found. Did you call get_task outside a running task?
|
|
"""
|
|
|
|
|
|
class CurrentTask(Middleware):
|
|
def __init__(self) -> None:
|
|
self.logger = get_logger(__name__, type(self))
|
|
|
|
# This is a list of tasks, so that in tests, when a task calls another task, this acts as a pile
|
|
_TASKS: contextvars.ContextVar[list[TaskBase] | None] = contextvars.ContextVar(
|
|
"_TASKS",
|
|
default=None,
|
|
)
|
|
|
|
@classmethod
|
|
def get_task(cls) -> TaskBase:
|
|
task = cls._TASKS.get()
|
|
if not task:
|
|
raise CurrentTaskNotFound()
|
|
return task[-1]
|
|
|
|
def before_process_message(self, broker: Broker, message: Message[Any]) -> None:
|
|
tasks = self._TASKS.get()
|
|
if tasks is None:
|
|
tasks = []
|
|
tasks.append(message.options["task"])
|
|
self._TASKS.set(tasks)
|
|
|
|
def after_process_message(
|
|
self,
|
|
broker: Broker,
|
|
message: Message[Any],
|
|
*,
|
|
result: Any | None = None,
|
|
exception: Exception | None = None,
|
|
) -> None:
|
|
tasks: list[TaskBase] | None = self._TASKS.get()
|
|
if tasks is None or len(tasks) == 0:
|
|
return
|
|
|
|
task = tasks[-1]
|
|
fields_to_exclude = {
|
|
"message_id",
|
|
"queue_name",
|
|
"actor_name",
|
|
"message",
|
|
"state",
|
|
"mtime",
|
|
"retries",
|
|
"eta",
|
|
"result",
|
|
"result_expiry",
|
|
}
|
|
fields_to_update = [
|
|
f.name
|
|
for f in task._meta.get_fields()
|
|
if f.name not in fields_to_exclude
|
|
and f.concrete
|
|
and not f.auto_created
|
|
and not f.many_to_many
|
|
]
|
|
if fields_to_update:
|
|
try:
|
|
task.save(update_fields=fields_to_update)
|
|
except DatabaseError:
|
|
pass
|
|
self._TASKS.set(tasks[:-1])
|
|
|
|
def after_skip_message(self, broker: Broker, message: Message[Any]) -> None:
|
|
self.after_process_message(broker, message)
|
|
|
|
|
|
class _MetricsHandler(BaseHTTPRequestHandler):
|
|
def do_GET(self) -> None:
|
|
from prometheus_client import (
|
|
CONTENT_TYPE_LATEST,
|
|
CollectorRegistry,
|
|
generate_latest,
|
|
multiprocess,
|
|
)
|
|
|
|
registry = CollectorRegistry()
|
|
multiprocess.MultiProcessCollector(registry) # type: ignore[no-untyped-call]
|
|
output = generate_latest(registry)
|
|
self.send_response(200)
|
|
self.send_header("Content-Type", CONTENT_TYPE_LATEST)
|
|
self.end_headers()
|
|
self.wfile.write(output)
|
|
|
|
def log_message(self, format: str, *args: Any) -> None:
|
|
logger = get_logger(__name__, type(self))
|
|
logger.debug(format, *args)
|
|
|
|
|
|
class MetricsMiddleware(Middleware):
|
|
handler_class: type[BaseHTTPRequestHandler] = _MetricsHandler
|
|
|
|
def __init__(
|
|
self,
|
|
prefix: str,
|
|
labels: list[str] | None = None,
|
|
) -> None:
|
|
super().__init__()
|
|
self.prefix = prefix
|
|
self.labels: list[str] = labels if labels is not None else ["queue_name", "actor_name"]
|
|
|
|
self.delayed_messages: set[str] = set()
|
|
self.message_start_times: dict[str, int] = {}
|
|
|
|
@property
|
|
def forks(self) -> list[Callable[[], None]]:
|
|
from django_dramatiq_postgres.forks import worker_metrics
|
|
|
|
return [worker_metrics]
|
|
|
|
def before_worker_boot(self, broker: Broker, worker: Any) -> None:
|
|
from prometheus_client import Counter, Gauge, Histogram
|
|
|
|
self.total_messages = Counter(
|
|
f"{self.prefix}_tasks_total",
|
|
"The total number of tasks processed.",
|
|
self.labels,
|
|
)
|
|
self.total_errored_messages = Counter(
|
|
f"{self.prefix}_tasks_errors_total",
|
|
"The total number of errored tasks.",
|
|
self.labels,
|
|
)
|
|
self.total_retried_messages = Counter(
|
|
f"{self.prefix}_tasks_retries_total",
|
|
"The total number of retried tasks.",
|
|
self.labels,
|
|
)
|
|
self.total_rejected_messages = Counter(
|
|
f"{self.prefix}_tasks_rejected_total",
|
|
"The total number of dead-lettered tasks.",
|
|
self.labels,
|
|
)
|
|
self.in_progress_messages = Gauge(
|
|
f"{self.prefix}_tasks_in_progress",
|
|
"The number of tasks in progress.",
|
|
self.labels,
|
|
multiprocess_mode="livesum",
|
|
)
|
|
self.in_progress_delayed_messages = Gauge(
|
|
f"{self.prefix}_tasks_delayed_in_progress",
|
|
"The number of delayed tasks in memory.",
|
|
self.labels,
|
|
)
|
|
self.messages_durations = Histogram(
|
|
f"{self.prefix}_tasks_duration_milliseconds",
|
|
"The time spent processing tasks.",
|
|
self.labels,
|
|
buckets=(
|
|
5,
|
|
10,
|
|
25,
|
|
50,
|
|
75,
|
|
100,
|
|
250,
|
|
500,
|
|
750,
|
|
1_000,
|
|
2_500,
|
|
5_000,
|
|
7_500,
|
|
10_000,
|
|
30_000,
|
|
60_000,
|
|
600_000,
|
|
900_000,
|
|
1_800_000,
|
|
3_600_000,
|
|
float("inf"),
|
|
),
|
|
)
|
|
|
|
def after_worker_shutdown(self, broker: Broker, worker: Any) -> None:
|
|
from prometheus_client import multiprocess
|
|
|
|
# TODO: worker_id
|
|
multiprocess.mark_process_dead(os.getpid()) # type: ignore[no-untyped-call]
|
|
|
|
def _make_labels(self, message: Message[Any]) -> list[str]:
|
|
return [message.queue_name, message.actor_name]
|
|
|
|
def after_nack(self, broker: Broker, message: Message[Any]) -> None:
|
|
self.total_rejected_messages.labels(*self._make_labels(message)).inc()
|
|
|
|
def after_enqueue(self, broker: Broker, message: Message[Any], delay: int) -> None:
|
|
if "retries" in message.options:
|
|
self.total_retried_messages.labels(*self._make_labels(message)).inc()
|
|
|
|
def before_delay_message(self, broker: Broker, message: Message[Any]) -> None:
|
|
self.delayed_messages.add(message.message_id)
|
|
self.in_progress_delayed_messages.labels(*self._make_labels(message)).inc()
|
|
|
|
def before_process_message(self, broker: Broker, message: Message[Any]) -> None:
|
|
labels = self._make_labels(message)
|
|
if message.message_id in self.delayed_messages:
|
|
self.delayed_messages.remove(message.message_id)
|
|
self.in_progress_delayed_messages.labels(*labels).dec()
|
|
|
|
self.in_progress_messages.labels(*labels).inc()
|
|
self.message_start_times[message.message_id] = current_millis() # type: ignore[no-untyped-call]
|
|
|
|
def after_process_message(
|
|
self,
|
|
broker: Broker,
|
|
message: Message[Any],
|
|
*,
|
|
result: Any | None = None,
|
|
exception: Exception | None = None,
|
|
) -> None:
|
|
labels = self._make_labels(message)
|
|
|
|
message_start_time = self.message_start_times.pop(message.message_id, current_millis()) # type: ignore[no-untyped-call]
|
|
message_duration = current_millis() - message_start_time # type: ignore[no-untyped-call]
|
|
self.messages_durations.labels(*labels).observe(message_duration)
|
|
|
|
self.in_progress_messages.labels(*labels).dec()
|
|
self.total_messages.labels(*labels).inc()
|
|
if exception is not None:
|
|
self.total_errored_messages.labels(*labels).inc()
|
|
|
|
after_skip_message = after_process_message
|
|
|
|
@classmethod
|
|
def run(cls, addr: str, port: int) -> None:
|
|
try:
|
|
server = HTTPServer((addr, port), cls.handler_class)
|
|
thread = cast(HTTPServerThread, current_thread())
|
|
thread.server = server
|
|
server.serve_forever()
|
|
except OSError:
|
|
get_logger(__name__, type(MetricsMiddleware)).warning(
|
|
"Port is already in use, not starting metrics server"
|
|
)
|