🔒️(ssrf) factorize SSRF code, allow redirects in image proxy

Also use the SSRF code in IMAP imports.
This commit is contained in:
Sylvain Zimmer
2026-04-14 23:28:32 +02:00
parent b4da8a3af7
commit 2847089a6c
9 changed files with 613 additions and 301 deletions

View File

@@ -4,6 +4,7 @@ import logging
from django.http import HttpResponse
from django.utils.decorators import method_decorator
from django.utils.http import content_disposition_header
from django.views.decorators.csrf import csrf_exempt
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
@@ -177,8 +178,8 @@ class BlobViewSet(ViewSet):
)
# Add appropriate headers for download
response["Content-Disposition"] = (
f'attachment; filename="{attachment["name"]}"'
response["Content-Disposition"] = content_disposition_header(
True, attachment["name"]
)
response["Content-Length"] = attachment["size"]
# Enable browser caching for 30 days (inline images benefit from this)
@@ -207,7 +208,9 @@ class BlobViewSet(ViewSet):
)
# Add appropriate headers for download
response["Content-Disposition"] = f'attachment; filename="{filename}"'
response["Content-Disposition"] = content_disposition_header(
True, filename
)
response["Content-Length"] = blob.size
# Enable browser caching for 30 days (inline images benefit from this)
response["Cache-Control"] = "private, max-age=2592000"

View File

@@ -1,9 +1,7 @@
"""API ViewSet for proxying external images."""
import ipaddress
import logging
import socket
from urllib.parse import ParseResult, unquote, urlparse, urlunparse
from urllib.parse import unquote
from django.conf import settings
from django.http import HttpResponse
@@ -11,261 +9,17 @@ from django.http import HttpResponse
import magic
import requests
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
from requests.adapters import HTTPAdapter
from rest_framework import status as http_status
from rest_framework.response import Response
from rest_framework.viewsets import ViewSet
from core import enums, models
from core.api import permissions
from core.services.ssrf import SSRFSafeSession, SSRFValidationError
logger = logging.getLogger(__name__)
class SSRFValidationError(Exception):
"""Exception raised when URL validation fails due to SSRF protection."""
class SSRFProtectedAdapter(HTTPAdapter):
"""
HTTPAdapter that connects to a pre-validated IP address while maintaining
proper TLS certificate verification against the original hostname.
This prevents TOCTOU DNS rebinding attacks by:
1. Connecting to the IP address that was validated (not re-resolving DNS)
2. Verifying TLS certificates against the original hostname (for HTTPS)
3. Setting the Host header correctly for virtual hosting
"""
def __init__(
self,
dest_ip: str,
dest_port: int,
original_hostname: str,
original_scheme: str,
**kwargs,
):
self.dest_ip = dest_ip
self.dest_port = dest_port
self.original_hostname = original_hostname
self.original_scheme = original_scheme
super().__init__(**kwargs)
def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
"""Initialize pool manager with TLS hostname verification settings."""
if self.original_scheme == "https":
# Ensure TLS certificate is verified against the original hostname
# even though we're connecting to an IP address
pool_kwargs["assert_hostname"] = self.original_hostname
pool_kwargs["server_hostname"] = self.original_hostname
super().init_poolmanager(connections, maxsize, block, **pool_kwargs)
def send(
self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None
):
"""Send request, rewriting URL to connect to the validated IP address."""
parsed = urlparse(request.url)
# Build URL with validated IP instead of hostname
# IPv6 addresses need brackets in URLs
if ":" in self.dest_ip:
ip_netloc = f"[{self.dest_ip}]:{self.dest_port}"
else:
ip_netloc = f"{self.dest_ip}:{self.dest_port}"
# Reconstruct URL with IP address
request.url = urlunparse(
(
parsed.scheme,
ip_netloc,
parsed.path,
parsed.params,
parsed.query,
parsed.fragment,
)
)
# Set Host header to original hostname for virtual hosting
# Include port only if non-standard
if parsed.port and parsed.port not in (80, 443):
request.headers["Host"] = f"{self.original_hostname}:{parsed.port}"
else:
request.headers["Host"] = self.original_hostname
return super().send(
request,
stream=stream,
timeout=timeout,
verify=verify,
cert=cert,
proxies=proxies,
)
class SSRFSafeSession:
"""
HTTP Session with built-in SSRF protection.
This class provides a safe way to make HTTP requests by:
1. Validating URL scheme (only http/https allowed)
2. Blocking direct IP addresses (legitimate services use domain names)
3. Resolving hostnames and blocking private/internal IPs
4. Pinning resolved IPs to prevent DNS rebinding attacks (TOCTOU)
Usage:
try:
response = SSRFSafeSession().get("https://example.com/image.png", timeout=10)
except SSRFValidationError:
# URL was blocked for security reasons
pass
"""
def _validate_url(self, parsed_url: ParseResult) -> list[str]:
"""
Validate that a URL is safe to fetch (SSRF protection).
This function prevents Server-Side Request Forgery (SSRF) attacks by
validating URLs before making HTTP requests. It implements a defense-in-depth
approach:
1. Only allows http/https schemes
2. Blocks all IP addresses (legitimate emails use domain names)
3. Resolves hostnames and blocks if they resolve to private/internal IPs
(prevents DNS rebinding attacks where attacker-controlled DNS returns
127.0.0.1 or internal IPs)
Blocked addresses include:
- Any direct IP address (e.g., http://192.168.1.1/)
- Private IP ranges (RFC1918: 10.x.x.x, 172.16-31.x.x, 192.168.x.x)
- Loopback addresses (127.x.x.x, ::1)
- Link-local addresses (169.254.x.x, fe80::/10)
- Multicast and reserved addresses
- Cloud provider metadata endpoints (169.254.169.254, fd00:ec2::254)
Args:
parsed_url: The parsed URL to validate
Returns:
List of validated IP addresses that the hostname resolves to
Raises:
SSRFValidationError: If the URL is unsafe
"""
# Only allow http and https schemes
if parsed_url.scheme not in {"http", "https"}:
raise SSRFValidationError("Invalid URL scheme (only http/https allowed)")
# Require a hostname
if not parsed_url.hostname:
raise SSRFValidationError("Invalid URL (missing hostname)")
# Block all IP addresses (legitimate services use domain names)
try:
ipaddress.ip_address(parsed_url.hostname)
raise SSRFValidationError(
"IP addresses are not allowed (domain name required)"
)
except ValueError:
# Not an IP address, continue validation
pass
# Resolve hostname to IP addresses
try:
addr_info = socket.getaddrinfo(
parsed_url.hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM
)
except socket.gaierror as exc:
raise SSRFValidationError("Unable to resolve hostname") from exc
# Check all resolved IP addresses
valid_ips = []
for _, _, _, _, sockaddr in addr_info:
ip_str = sockaddr[0]
try:
ip_addr = ipaddress.ip_address(ip_str)
if ip_addr.is_private:
raise SSRFValidationError("Domain resolves to private IP address")
if ip_addr.is_loopback:
raise SSRFValidationError("Domain resolves to loopback address")
if ip_addr.is_link_local:
raise SSRFValidationError("Domain resolves to link-local address")
if ip_addr.is_multicast:
raise SSRFValidationError("Domain resolves to multicast address")
if ip_addr.is_reserved:
raise SSRFValidationError("Domain resolves to reserved address")
# Block known cloud metadata IPs
if ip_str in ("169.254.169.254", "fd00:ec2::254"):
raise SSRFValidationError(
"Domain resolves to cloud metadata endpoint"
)
valid_ips.append(ip_str)
except ValueError as exc:
raise SSRFValidationError("Invalid IP address in DNS response") from exc
if not valid_ips:
raise SSRFValidationError("No valid IP addresses found")
return valid_ips
def get(self, url: str, timeout: int, **kwargs) -> requests.Response:
"""
Perform a safe HTTP GET request with SSRF protection and IP pinning.
This method:
1. Parses and validates the URL
2. Resolves DNS and validates all returned IPs
3. Creates a requests Session with a custom HTTPAdapter that:
- Connects directly to the validated IP (preventing DNS rebinding)
- Maintains proper TLS certificate verification against the hostname
- Sets the Host header correctly for virtual hosting
Args:
url: The URL to fetch
timeout: Request timeout in seconds
**kwargs: Additional arguments passed to requests.Session.get()
Returns:
requests.Response object
Raises:
SSRFValidationError: If the URL fails security validation
requests.RequestException: If the HTTP request fails
"""
parsed_url = urlparse(url)
valid_ips = self._validate_url(parsed_url)
# Determine the port (explicit or default based on scheme)
if parsed_url.port:
port = parsed_url.port
elif parsed_url.scheme == "http":
port = 80
else:
port = 443
# Create a session with our SSRF-protected adapter that pins to the validated IP
session = requests.Session()
adapter = SSRFProtectedAdapter(
dest_ip=valid_ips[0],
dest_port=port,
original_hostname=parsed_url.hostname,
original_scheme=parsed_url.scheme,
)
# Mount the adapter for both http and https schemes
session.mount("http://", adapter)
session.mount("https://", adapter)
return session.get(url, timeout=timeout, **kwargs)
class ImageProxySuspiciousResponse(HttpResponse):
"""
Response for suspicious content that has been blocked by our image proxy.
@@ -356,7 +110,6 @@ class ImageProxyViewSet(ViewSet):
timeout=10,
stream=True,
headers={"User-Agent": "Messages-ImageProxy/1.0"},
allow_redirects=False,
)
response.raise_for_status()

View File

@@ -35,9 +35,9 @@ class WidgetAuthentication(BaseAuthentication):
if not channel_id:
raise AuthenticationFailed("Missing channel_id")
# API key authentication for check endpoint
# Only allow widget-type channels
try:
channel = models.Channel.objects.get(id=channel_id)
channel = models.Channel.objects.get(id=channel_id, type="widget")
except models.Channel.DoesNotExist as e:
raise AuthenticationFailed("Invalid channel_id") from e

View File

@@ -22,6 +22,7 @@ from celery.utils.log import get_task_logger
from core.mda.inbound import deliver_inbound_message
from core.mda.rfc5322 import parse_email_message
from core.services.ssrf import SSRFValidationError, validate_hostname
logger = get_task_logger(__name__)
@@ -62,6 +63,21 @@ def decode_imap_utf7(s):
return re.sub(r"&([^-]*)-", decode_match, s)
def _validate_imap_host(server: str) -> None:
"""Validate that the IMAP server hostname is not a private/internal address.
Wraps the shared SSRF validator but allows public IP literals, which are
legitimate addresses for customer-supplied IMAP servers.
Raises:
ValueError: If the hostname resolves to a blocked IP address.
"""
try:
validate_hostname(server, allow_ip_literal=True)
except SSRFValidationError as exc:
raise ValueError(f"IMAP server {server} is not allowed: {exc}") from exc
class IMAPConnectionManager:
"""Context manager for IMAP connections with proper cleanup."""
@@ -76,6 +92,9 @@ class IMAPConnectionManager:
self.connection = None
def __enter__(self):
# Validate the server hostname to prevent SSRF
_validate_imap_host(self.server)
# Port 143 typically uses STARTTLS, port 993 uses SSL direct
# If use_ssl=True and port is 143, use STARTTLS instead of SSL direct
use_starttls = self.use_ssl and self.port == 143

View File

@@ -0,0 +1,245 @@
"""Server-Side Request Forgery (SSRF) protections.
Shared across features that take user-supplied network destinations (image
proxy, IMAP import, etc.). Provides a hostname/IP validator plus an HTTP
session with IP pinning to defeat DNS-rebinding (TOCTOU) attacks.
"""
import ipaddress
import socket
from urllib.parse import urljoin, urlparse, urlunparse
import requests
from requests.adapters import HTTPAdapter
CLOUD_METADATA_IPS = frozenset({"169.254.169.254", "fd00:ec2::254"})
MAX_REDIRECTS = 5
REDIRECT_STATUS_CODES = frozenset({301, 302, 303, 307, 308})
class SSRFValidationError(Exception):
"""Raised when a URL or hostname fails SSRF validation."""
def _check_ip(ip_addr: ipaddress._BaseAddress, hostname: str) -> None:
# Check specific categories before is_private: in Python's ipaddress
# module, loopback/link-local/etc. are subsets of is_private, so checking
# is_private first would mask the more informative error.
if str(ip_addr) in CLOUD_METADATA_IPS:
raise SSRFValidationError(f"{hostname} resolves to cloud metadata endpoint")
if ip_addr.is_loopback:
raise SSRFValidationError(f"{hostname} resolves to loopback address")
if ip_addr.is_link_local:
raise SSRFValidationError(f"{hostname} resolves to link-local address")
if ip_addr.is_multicast:
raise SSRFValidationError(f"{hostname} resolves to multicast address")
if ip_addr.is_reserved:
raise SSRFValidationError(f"{hostname} resolves to reserved address")
if ip_addr.is_private:
raise SSRFValidationError(f"{hostname} resolves to private IP address")
def validate_hostname(hostname: str, *, allow_ip_literal: bool = False) -> list[str]:
"""Resolve hostname and reject private/internal/metadata addresses.
Args:
hostname: A hostname or, when allow_ip_literal=True, an IP literal.
allow_ip_literal: If False (default), IP literals are rejected outright —
legitimate services use domain names. If True, public IP literals
are accepted (used for IMAP where customers may supply raw IPs).
Returns:
List of validated IP addresses the hostname resolves to.
Raises:
SSRFValidationError: If the hostname/IP resolves to a blocked address.
"""
if not hostname:
raise SSRFValidationError("Invalid hostname (missing)")
try:
ip = ipaddress.ip_address(hostname)
except ValueError:
ip = None
if ip is not None:
if not allow_ip_literal:
raise SSRFValidationError(
"IP addresses are not allowed (domain name required)"
)
_check_ip(ip, hostname)
return [str(ip)]
try:
addr_info = socket.getaddrinfo(
hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM
)
except socket.gaierror as exc:
raise SSRFValidationError("Unable to resolve hostname") from exc
valid_ips: list[str] = []
for _, _, _, _, sockaddr in addr_info:
ip_str = sockaddr[0]
try:
ip_addr = ipaddress.ip_address(ip_str)
except ValueError as exc:
raise SSRFValidationError("Invalid IP address in DNS response") from exc
_check_ip(ip_addr, hostname)
valid_ips.append(ip_str)
if not valid_ips:
raise SSRFValidationError("No valid IP addresses found")
return valid_ips
class SSRFProtectedAdapter(HTTPAdapter):
"""HTTPAdapter that pins the connection to a pre-validated IP.
Prevents TOCTOU DNS rebinding by:
1. Connecting to the IP address that was validated (no re-resolving DNS).
2. Verifying TLS certificates against the original hostname (for HTTPS).
3. Setting the Host header correctly for virtual hosting.
"""
def __init__(
self,
dest_ip: str,
dest_port: int,
original_hostname: str,
original_scheme: str,
**kwargs,
):
self.dest_ip = dest_ip
self.dest_port = dest_port
self.original_hostname = original_hostname
self.original_scheme = original_scheme
super().__init__(**kwargs)
def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
if self.original_scheme == "https":
pool_kwargs["assert_hostname"] = self.original_hostname
pool_kwargs["server_hostname"] = self.original_hostname
super().init_poolmanager(connections, maxsize, block, **pool_kwargs)
def send(
self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None
):
parsed = urlparse(request.url)
if ":" in self.dest_ip:
ip_netloc = f"[{self.dest_ip}]:{self.dest_port}"
else:
ip_netloc = f"{self.dest_ip}:{self.dest_port}"
request.url = urlunparse(
(
parsed.scheme,
ip_netloc,
parsed.path,
parsed.params,
parsed.query,
parsed.fragment,
)
)
if parsed.port and parsed.port not in (80, 443):
request.headers["Host"] = f"{self.original_hostname}:{parsed.port}"
else:
request.headers["Host"] = self.original_hostname
return super().send(
request,
stream=stream,
timeout=timeout,
verify=verify,
cert=cert,
proxies=proxies,
)
class SSRFSafeSession:
"""HTTP Session with built-in SSRF protection.
1. Validates URL scheme (only http/https allowed).
2. Blocks direct IP addresses (legitimate services use domain names).
3. Resolves hostnames and blocks private/internal IPs.
4. Pins resolved IPs to prevent DNS rebinding attacks (TOCTOU).
Usage:
try:
response = SSRFSafeSession().get("https://example.com/image.png", timeout=10)
except SSRFValidationError:
# URL was blocked for security reasons
pass
"""
def _validate_and_unpack(self, url: str) -> tuple[str, str, str, int]:
"""Validate a URL and return (validated_ip, hostname, scheme, port).
Raises:
SSRFValidationError: If the URL is unsafe.
"""
parsed = urlparse(url)
if parsed.scheme not in {"http", "https"}:
raise SSRFValidationError("Invalid URL scheme (only http/https allowed)")
if not parsed.hostname:
raise SSRFValidationError("Invalid URL (missing hostname)")
valid_ips = validate_hostname(parsed.hostname, allow_ip_literal=False)
if parsed.port:
port = parsed.port
elif parsed.scheme == "http":
port = 80
else:
port = 443
return valid_ips[0], parsed.hostname, parsed.scheme, port
def get(self, url: str, timeout: int, **kwargs) -> requests.Response:
"""Perform a safe HTTP GET with per-hop SSRF validation on redirects.
Redirects are followed manually up to MAX_REDIRECTS hops. Each Location
URL is re-validated from scratch, so an attacker-controlled server
cannot redirect to an internal address or a different private target
on a later hop.
"""
# We always handle redirects ourselves — strip any caller override so
# the underlying requests session never follows a redirect unchecked.
kwargs.pop("allow_redirects", None)
current_url = url
for _ in range(MAX_REDIRECTS + 1):
validated_ip, hostname, scheme, port = self._validate_and_unpack(
current_url
)
session = requests.Session()
adapter = SSRFProtectedAdapter(
dest_ip=validated_ip,
dest_port=port,
original_hostname=hostname,
original_scheme=scheme,
)
session.mount("http://", adapter)
session.mount("https://", adapter)
response = session.get(
current_url, timeout=timeout, allow_redirects=False, **kwargs
)
if response.status_code not in REDIRECT_STATUS_CODES:
return response
location = response.headers.get("Location")
if not location:
# Redirect without a Location — hand the response back unchanged.
return response
next_url = urljoin(current_url, location)
response.close()
current_url = next_url
raise SSRFValidationError(f"Too many redirects (max {MAX_REDIRECTS})")

View File

@@ -191,7 +191,7 @@ class TestImageProxyViewSet:
"https://localhost:8080/image.jpg",
],
)
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
@patch("core.services.ssrf.socket.getaddrinfo")
def test_api_image_proxy_localhost_hostname_blocked(
self, mock_getaddrinfo, api_client, user_mailbox, test_url
):
@@ -210,7 +210,7 @@ class TestImageProxyViewSet:
assert response["Content-Type"] == "image/svg+xml"
@override_settings(IMAGE_PROXY_ENABLED=True)
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
@patch("core.services.ssrf.socket.getaddrinfo")
def test_api_image_proxy_domain_resolving_to_private_ip_blocked(
self, mock_getaddrinfo, api_client, user_mailbox
):
@@ -229,7 +229,7 @@ class TestImageProxyViewSet:
assert response["Content-Type"] == "image/svg+xml"
@override_settings(IMAGE_PROXY_ENABLED=True)
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
@patch("core.services.ssrf.socket.getaddrinfo")
def test_api_image_proxy_unresolvable_hostname(
self, mock_getaddrinfo, api_client, user_mailbox
):
@@ -249,7 +249,7 @@ class TestImageProxyViewSet:
# Content-Type Validation Tests
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
@patch("core.services.ssrf.socket.getaddrinfo")
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
def test_api_image_proxy_non_image_content_type_blocked(
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
@@ -273,7 +273,7 @@ class TestImageProxyViewSet:
assert response["Content-Type"] == "image/svg+xml"
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
@patch("core.services.ssrf.socket.getaddrinfo")
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
@patch("magic.from_buffer")
def test_api_image_proxy_content_not_actually_image(
@@ -301,7 +301,7 @@ class TestImageProxyViewSet:
assert response["Content-Type"] == "image/svg+xml"
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
@patch("core.services.ssrf.socket.getaddrinfo")
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
@patch("magic.from_buffer")
@pytest.mark.parametrize(
@@ -342,7 +342,7 @@ class TestImageProxyViewSet:
# Size Limit Tests
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=1)
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
@patch("core.services.ssrf.socket.getaddrinfo")
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
def test_api_image_proxy_image_too_large_via_content_length(
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
@@ -367,7 +367,7 @@ class TestImageProxyViewSet:
assert "Image too large" in str(response.data)
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=4096)
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
@patch("core.services.ssrf.socket.getaddrinfo")
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
@patch("magic.from_buffer")
def test_api_image_proxy_image_too_large_actual_content(
@@ -395,39 +395,13 @@ class TestImageProxyViewSet:
assert response.status_code == status.HTTP_413_REQUEST_ENTITY_TOO_LARGE
assert "Image too large" in str(response.data)
# Redirect Tests
@override_settings(IMAGE_PROXY_ENABLED=True)
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
def test_api_image_proxy_redirects_blocked(
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
):
"""Test that redirects are not followed (SSRF protection)."""
# Mock DNS resolution
mock_getaddrinfo.return_value = [(2, 1, 6, "", ("1.2.3.4", 80))]
client, _ = api_client
url = self._get_image_proxy_url(
user_mailbox.id, "http://example.com/redirect.jpg"
)
mock_get.return_value = self._mock_requests_response(
content=b"", content_type="text/plain", status_code=302
)
# Call the endpoint
client.get(url)
# Verify that allow_redirects=False was passed to requests.get
mock_get.assert_called_once()
call_kwargs = mock_get.call_args[1]
assert call_kwargs["allow_redirects"] is False
# Redirect handling is covered end-to-end in core/tests/services/test_ssrf.py:
# SSRFSafeSession follows redirects internally, re-validating every hop.
# Success Cases
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_CACHE_TTL=3600)
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
@patch("core.services.ssrf.socket.getaddrinfo")
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
@patch("magic.from_buffer")
def test_api_image_proxy_successfully_jpg_image(
@@ -459,7 +433,7 @@ class TestImageProxyViewSet:
assert response["Cache-Control"] == "public, max-age=3600"
@override_settings(IMAGE_PROXY_ENABLED=True)
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
@patch("core.services.ssrf.socket.getaddrinfo")
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
@patch("magic.from_buffer")
def test_api_image_proxy_successfully_png_image(
@@ -488,7 +462,7 @@ class TestImageProxyViewSet:
assert response["Content-Type"] == "image/png"
@override_settings(IMAGE_PROXY_ENABLED=True)
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
@patch("core.services.ssrf.socket.getaddrinfo")
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
@patch("magic.from_buffer")
def test_api_image_proxy_successfully_with_octet_stream_content_type(
@@ -517,7 +491,7 @@ class TestImageProxyViewSet:
assert response["Content-Type"] == "image/jpeg"
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_CACHE_TTL=3600)
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
@patch("core.services.ssrf.socket.getaddrinfo")
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
@patch("magic.from_buffer")
def test_api_image_proxy_no_content(
@@ -547,7 +521,7 @@ class TestImageProxyViewSet:
assert response["Content-Type"] == "image/svg+xml"
@override_settings(IMAGE_PROXY_ENABLED=True)
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
@patch("core.services.ssrf.socket.getaddrinfo")
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
@patch("magic.from_buffer")
def test_api_image_proxy_url_with_special_characters(
@@ -578,7 +552,7 @@ class TestImageProxyViewSet:
assert mock_get.call_args[0][0] == test_url
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_CACHE_TTL=3600)
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
@patch("core.services.ssrf.socket.getaddrinfo")
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
@patch("magic.from_buffer")
def test_api_image_proxy_successfully_secure_headers(
@@ -617,7 +591,7 @@ class TestImageProxyViewSet:
# Error Handling Tests
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
@patch("core.services.ssrf.socket.getaddrinfo")
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
def test_api_image_proxy_network_timeout(
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
@@ -638,7 +612,7 @@ class TestImageProxyViewSet:
assert "Failed to fetch image" in str(response.data)
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
@patch("core.services.ssrf.socket.getaddrinfo")
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
def test_api_image_proxy_connection_error(
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
@@ -659,7 +633,7 @@ class TestImageProxyViewSet:
assert "Failed to fetch image" in str(response.data)
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
@patch("core.services.ssrf.socket.getaddrinfo")
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
def test_api_image_proxy_http_error_404(
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
@@ -687,7 +661,7 @@ class TestImageProxyViewSet:
assert "Failed to fetch image" in str(response.data)
@override_settings(IMAGE_PROXY_ENABLED=True)
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
@patch("core.services.ssrf.socket.getaddrinfo")
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
@patch("magic.from_buffer")
def test_api_image_proxy_invalid_content_length_header(
@@ -716,7 +690,7 @@ class TestImageProxyViewSet:
assert response.status_code == status.HTTP_200_OK
@override_settings(IMAGE_PROXY_ENABLED=True)
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
@patch("core.services.ssrf.socket.getaddrinfo")
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
@patch("magic.from_buffer")
def test_api_image_proxy_missing_content_length_header(

View File

@@ -2,6 +2,7 @@
# pylint: disable=redefined-outer-name, unused-argument, no-value-for-parameter, too-many-lines
import datetime
import socket
from unittest.mock import patch
from django.core.files.storage import storages
@@ -19,6 +20,25 @@ from core.services.importer.mbox_tasks import process_mbox_file_task
pytestmark = pytest.mark.django_db
@pytest.fixture(autouse=True)
def _mock_ssrf_dns():
"""Short-circuit SSRF DNS validation for IMAP import tests.
The IMAP endpoint validates the server hostname via
``core.services.ssrf.validate_hostname``; tests use unresolvable fixtures
like ``imap.example.com`` so we return a public IP to reach the mocked
IMAP task code.
"""
with patch(
"core.services.ssrf.socket.getaddrinfo",
return_value=[
(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("93.184.216.34", 0))
],
):
yield
IMPORT_FILE_URL = "/api/v1.0/import/file/"
IMPORT_IMAP_URL = "/api/v1.0/import/imap/"

View File

@@ -0,0 +1,24 @@
"""Shared fixtures for importer tests."""
import socket
from unittest import mock
import pytest
@pytest.fixture(autouse=True)
def _mock_ssrf_dns():
"""Short-circuit SSRF DNS validation for IMAP tests.
The IMAP import path validates the server hostname via
``core.services.ssrf.validate_hostname`` which calls ``socket.getaddrinfo``.
Test fixtures use unresolvable hostnames like ``imap.example.com``, so we
return a valid public IP here to let tests reach the mocked IMAP code.
"""
with mock.patch(
"core.services.ssrf.socket.getaddrinfo",
return_value=[
(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("93.184.216.34", 0))
],
):
yield

View File

@@ -0,0 +1,274 @@
"""Unit tests for core.services.ssrf — focus on redirect handling.
Hostname/IP validation itself is covered via the image proxy integration tests
(see core/tests/api/test_image_proxy.py). This file targets the per-hop
validation loop in SSRFSafeSession.get, which is the main SSRF-correctness
surface for redirect responses.
"""
import socket
from unittest.mock import MagicMock, patch
import pytest
from core.services.ssrf import (
MAX_REDIRECTS,
SSRFSafeSession,
SSRFValidationError,
)
PUBLIC_IP = "93.184.216.34"
PRIVATE_IP = "192.168.1.1"
def _addrinfo(ip: str):
return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (ip, 0))]
def _mock_response(status_code: int = 200, location: str | None = None):
"""Build a mock requests.Response-like object."""
resp = MagicMock()
resp.status_code = status_code
resp.headers = {}
if location is not None:
resp.headers["Location"] = location
return resp
class TestSSRFSafeSessionRedirects:
"""Redirect-handling contract for SSRFSafeSession.get."""
@patch("core.services.ssrf.requests.Session.get")
@patch("core.services.ssrf.socket.getaddrinfo")
def test_no_redirect_returns_response_directly(self, mock_dns, mock_get):
"""A 200 response is returned without any extra request."""
mock_dns.return_value = _addrinfo(PUBLIC_IP)
mock_get.return_value = _mock_response(200)
response = SSRFSafeSession().get("https://legit.com/img.png", timeout=10)
assert response.status_code == 200
assert mock_get.call_count == 1
@patch("core.services.ssrf.requests.Session.get")
@patch("core.services.ssrf.socket.getaddrinfo")
def test_follows_redirect_to_safe_url(self, mock_dns, mock_get):
"""A single redirect is followed to a validated destination."""
mock_dns.return_value = _addrinfo(PUBLIC_IP)
mock_get.side_effect = [
_mock_response(302, location="https://cdn.legit.com/img.png"),
_mock_response(200),
]
response = SSRFSafeSession().get("https://legit.com/img.png", timeout=10)
assert response.status_code == 200
assert mock_get.call_count == 2
@pytest.mark.parametrize("redirect_status", [301, 302, 303, 307, 308])
@patch("core.services.ssrf.requests.Session.get")
@patch("core.services.ssrf.socket.getaddrinfo")
def test_all_redirect_statuses_followed(self, mock_dns, mock_get, redirect_status):
"""All standard redirect status codes trigger a new hop."""
mock_dns.return_value = _addrinfo(PUBLIC_IP)
mock_get.side_effect = [
_mock_response(redirect_status, location="https://b.com/img.png"),
_mock_response(200),
]
response = SSRFSafeSession().get("https://a.com/img.png", timeout=10)
assert response.status_code == 200
assert mock_get.call_count == 2
@patch("core.services.ssrf.requests.Session.get")
@patch("core.services.ssrf.socket.getaddrinfo")
def test_follows_multiple_hops(self, mock_dns, mock_get):
"""Chained redirects are followed up to MAX_REDIRECTS."""
mock_dns.return_value = _addrinfo(PUBLIC_IP)
mock_get.side_effect = [
_mock_response(302, location="https://b.com/"),
_mock_response(302, location="https://c.com/"),
_mock_response(302, location="https://d.com/"),
_mock_response(200),
]
response = SSRFSafeSession().get("https://a.com/", timeout=10)
assert response.status_code == 200
assert mock_get.call_count == 4
@patch("core.services.ssrf.requests.Session.get")
@patch("core.services.ssrf.socket.getaddrinfo")
def test_blocks_redirect_to_private_ip(self, mock_dns, mock_get):
"""A Location resolving to a private IP is rejected mid-chain."""
def dns_side_effect(host, *_args, **_kwargs):
if host == "legit.com":
return _addrinfo(PUBLIC_IP)
if host == "internal.evil.com":
return _addrinfo(PRIVATE_IP)
raise AssertionError(f"unexpected DNS lookup: {host}")
mock_dns.side_effect = dns_side_effect
mock_get.return_value = _mock_response(
302, location="https://internal.evil.com/pwn"
)
with pytest.raises(SSRFValidationError, match="private IP"):
SSRFSafeSession().get("https://legit.com/img.png", timeout=10)
# Only the first hop should have been issued; the second was blocked
# before the request was made.
assert mock_get.call_count == 1
@patch("core.services.ssrf.requests.Session.get")
@patch("core.services.ssrf.socket.getaddrinfo")
def test_blocks_redirect_to_loopback(self, mock_dns, mock_get):
"""Redirect pointing at loopback (DNS-rebinding style) is rejected."""
def dns_side_effect(host, *_args, **_kwargs):
if host == "legit.com":
return _addrinfo(PUBLIC_IP)
if host == "rebind.evil.com":
return _addrinfo("127.0.0.1")
raise AssertionError(f"unexpected DNS lookup: {host}")
mock_dns.side_effect = dns_side_effect
mock_get.return_value = _mock_response(302, location="https://rebind.evil.com/")
with pytest.raises(SSRFValidationError, match="loopback"):
SSRFSafeSession().get("https://legit.com/img.png", timeout=10)
@patch("core.services.ssrf.requests.Session.get")
@patch("core.services.ssrf.socket.getaddrinfo")
def test_blocks_redirect_to_cloud_metadata(self, mock_dns, mock_get):
"""Redirect pointing at cloud metadata endpoint is rejected."""
def dns_side_effect(host, *_args, **_kwargs):
if host == "legit.com":
return _addrinfo(PUBLIC_IP)
if host == "meta.evil.com":
return _addrinfo("169.254.169.254")
raise AssertionError(f"unexpected DNS lookup: {host}")
mock_dns.side_effect = dns_side_effect
mock_get.return_value = _mock_response(
302, location="https://meta.evil.com/latest/meta-data/"
)
with pytest.raises(SSRFValidationError, match="metadata"):
SSRFSafeSession().get("https://legit.com/img.png", timeout=10)
@patch("core.services.ssrf.requests.Session.get")
@patch("core.services.ssrf.socket.getaddrinfo")
def test_blocks_redirect_to_non_http_scheme(self, mock_dns, mock_get):
"""Redirect to file:// or other schemes is rejected."""
mock_dns.return_value = _addrinfo(PUBLIC_IP)
mock_get.return_value = _mock_response(302, location="file:///etc/passwd")
with pytest.raises(SSRFValidationError, match="scheme"):
SSRFSafeSession().get("https://legit.com/", timeout=10)
@patch("core.services.ssrf.requests.Session.get")
@patch("core.services.ssrf.socket.getaddrinfo")
def test_blocks_redirect_to_ip_literal(self, mock_dns, mock_get):
"""Redirect whose Location is a raw IP is rejected (domains only)."""
mock_dns.return_value = _addrinfo(PUBLIC_IP)
mock_get.return_value = _mock_response(302, location="http://203.0.113.5/stuff")
with pytest.raises(SSRFValidationError, match="IP addresses are not allowed"):
SSRFSafeSession().get("https://legit.com/", timeout=10)
@patch("core.services.ssrf.requests.Session.get")
@patch("core.services.ssrf.socket.getaddrinfo")
def test_blocks_protocol_relative_redirect_to_private(self, mock_dns, mock_get):
"""//host/path Location resolves against current scheme and is validated."""
def dns_side_effect(host, *_args, **_kwargs):
if host == "legit.com":
return _addrinfo(PUBLIC_IP)
if host == "internal":
return _addrinfo(PRIVATE_IP)
raise AssertionError(f"unexpected DNS lookup: {host}")
mock_dns.side_effect = dns_side_effect
mock_get.return_value = _mock_response(302, location="//internal/admin")
with pytest.raises(SSRFValidationError, match="private IP"):
SSRFSafeSession().get("https://legit.com/", timeout=10)
@patch("core.services.ssrf.requests.Session.get")
@patch("core.services.ssrf.socket.getaddrinfo")
def test_follows_relative_redirect(self, mock_dns, mock_get):
"""Relative Location is resolved against the current URL and validated."""
mock_dns.return_value = _addrinfo(PUBLIC_IP)
mock_get.side_effect = [
_mock_response(302, location="/new-path"),
_mock_response(200),
]
response = SSRFSafeSession().get("https://legit.com/old", timeout=10)
assert response.status_code == 200
# Second request should be made to https://legit.com/new-path.
second_call_url = mock_get.call_args_list[1].args[0]
assert second_call_url == "https://legit.com/new-path"
@patch("core.services.ssrf.requests.Session.get")
@patch("core.services.ssrf.socket.getaddrinfo")
def test_blocks_too_many_redirects(self, mock_dns, mock_get):
"""A redirect loop longer than MAX_REDIRECTS raises."""
mock_dns.return_value = _addrinfo(PUBLIC_IP)
mock_get.return_value = _mock_response(302, location="https://legit.com/loop")
with pytest.raises(SSRFValidationError, match="Too many redirects"):
SSRFSafeSession().get("https://legit.com/", timeout=10)
# Initial hop + MAX_REDIRECTS extra hops, then error.
assert mock_get.call_count == MAX_REDIRECTS + 1
@patch("core.services.ssrf.requests.Session.get")
@patch("core.services.ssrf.socket.getaddrinfo")
def test_redirect_without_location_returned_as_is(self, mock_dns, mock_get):
"""A redirect status with no Location header is returned unchanged."""
mock_dns.return_value = _addrinfo(PUBLIC_IP)
mock_get.return_value = _mock_response(302, location=None)
response = SSRFSafeSession().get("https://legit.com/", timeout=10)
assert response.status_code == 302
assert mock_get.call_count == 1
@patch("core.services.ssrf.requests.Session.get")
@patch("core.services.ssrf.socket.getaddrinfo")
def test_caller_allow_redirects_true_is_ignored(self, mock_dns, mock_get):
"""Caller cannot opt out of per-hop validation by passing allow_redirects=True."""
mock_dns.return_value = _addrinfo(PUBLIC_IP)
mock_get.side_effect = [
_mock_response(302, location="https://cdn.legit.com/"),
_mock_response(200),
]
response = SSRFSafeSession().get(
"https://legit.com/", timeout=10, allow_redirects=True
)
assert response.status_code == 200
# Each underlying Session.get must be called with allow_redirects=False.
for call in mock_get.call_args_list:
assert call.kwargs.get("allow_redirects") is False
@patch("core.services.ssrf.requests.Session.get")
@patch("core.services.ssrf.socket.getaddrinfo")
def test_intermediate_response_closed(self, mock_dns, mock_get):
"""Intermediate redirect responses are .close()d to release streams."""
mock_dns.return_value = _addrinfo(PUBLIC_IP)
intermediate = _mock_response(302, location="https://cdn.legit.com/")
final = _mock_response(200)
mock_get.side_effect = [intermediate, final]
SSRFSafeSession().get("https://legit.com/", timeout=10, stream=True)
intermediate.close.assert_called_once()
final.close.assert_not_called()