mirror of
https://github.com/suitenumerique/messages.git
synced 2026-04-25 17:15:21 +02:00
🔒️(ssrf) factorize SSRF code, allow redirects in image proxy
Also use the SSRF code in IMAP imports.
This commit is contained in:
@@ -4,6 +4,7 @@ import logging
|
||||
|
||||
from django.http import HttpResponse
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.utils.http import content_disposition_header
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
|
||||
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
||||
@@ -177,8 +178,8 @@ class BlobViewSet(ViewSet):
|
||||
)
|
||||
|
||||
# Add appropriate headers for download
|
||||
response["Content-Disposition"] = (
|
||||
f'attachment; filename="{attachment["name"]}"'
|
||||
response["Content-Disposition"] = content_disposition_header(
|
||||
True, attachment["name"]
|
||||
)
|
||||
response["Content-Length"] = attachment["size"]
|
||||
# Enable browser caching for 30 days (inline images benefit from this)
|
||||
@@ -207,7 +208,9 @@ class BlobViewSet(ViewSet):
|
||||
)
|
||||
|
||||
# Add appropriate headers for download
|
||||
response["Content-Disposition"] = f'attachment; filename="{filename}"'
|
||||
response["Content-Disposition"] = content_disposition_header(
|
||||
True, filename
|
||||
)
|
||||
response["Content-Length"] = blob.size
|
||||
# Enable browser caching for 30 days (inline images benefit from this)
|
||||
response["Cache-Control"] = "private, max-age=2592000"
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
"""API ViewSet for proxying external images."""
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import socket
|
||||
from urllib.parse import ParseResult, unquote, urlparse, urlunparse
|
||||
from urllib.parse import unquote
|
||||
|
||||
from django.conf import settings
|
||||
from django.http import HttpResponse
|
||||
@@ -11,261 +9,17 @@ from django.http import HttpResponse
|
||||
import magic
|
||||
import requests
|
||||
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
||||
from requests.adapters import HTTPAdapter
|
||||
from rest_framework import status as http_status
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.viewsets import ViewSet
|
||||
|
||||
from core import enums, models
|
||||
from core.api import permissions
|
||||
from core.services.ssrf import SSRFSafeSession, SSRFValidationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SSRFValidationError(Exception):
|
||||
"""Exception raised when URL validation fails due to SSRF protection."""
|
||||
|
||||
|
||||
class SSRFProtectedAdapter(HTTPAdapter):
|
||||
"""
|
||||
HTTPAdapter that connects to a pre-validated IP address while maintaining
|
||||
proper TLS certificate verification against the original hostname.
|
||||
|
||||
This prevents TOCTOU DNS rebinding attacks by:
|
||||
1. Connecting to the IP address that was validated (not re-resolving DNS)
|
||||
2. Verifying TLS certificates against the original hostname (for HTTPS)
|
||||
3. Setting the Host header correctly for virtual hosting
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dest_ip: str,
|
||||
dest_port: int,
|
||||
original_hostname: str,
|
||||
original_scheme: str,
|
||||
**kwargs,
|
||||
):
|
||||
self.dest_ip = dest_ip
|
||||
self.dest_port = dest_port
|
||||
self.original_hostname = original_hostname
|
||||
self.original_scheme = original_scheme
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
|
||||
"""Initialize pool manager with TLS hostname verification settings."""
|
||||
if self.original_scheme == "https":
|
||||
# Ensure TLS certificate is verified against the original hostname
|
||||
# even though we're connecting to an IP address
|
||||
pool_kwargs["assert_hostname"] = self.original_hostname
|
||||
pool_kwargs["server_hostname"] = self.original_hostname
|
||||
super().init_poolmanager(connections, maxsize, block, **pool_kwargs)
|
||||
|
||||
def send(
|
||||
self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None
|
||||
):
|
||||
"""Send request, rewriting URL to connect to the validated IP address."""
|
||||
parsed = urlparse(request.url)
|
||||
|
||||
# Build URL with validated IP instead of hostname
|
||||
# IPv6 addresses need brackets in URLs
|
||||
if ":" in self.dest_ip:
|
||||
ip_netloc = f"[{self.dest_ip}]:{self.dest_port}"
|
||||
else:
|
||||
ip_netloc = f"{self.dest_ip}:{self.dest_port}"
|
||||
|
||||
# Reconstruct URL with IP address
|
||||
request.url = urlunparse(
|
||||
(
|
||||
parsed.scheme,
|
||||
ip_netloc,
|
||||
parsed.path,
|
||||
parsed.params,
|
||||
parsed.query,
|
||||
parsed.fragment,
|
||||
)
|
||||
)
|
||||
|
||||
# Set Host header to original hostname for virtual hosting
|
||||
# Include port only if non-standard
|
||||
if parsed.port and parsed.port not in (80, 443):
|
||||
request.headers["Host"] = f"{self.original_hostname}:{parsed.port}"
|
||||
else:
|
||||
request.headers["Host"] = self.original_hostname
|
||||
|
||||
return super().send(
|
||||
request,
|
||||
stream=stream,
|
||||
timeout=timeout,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
proxies=proxies,
|
||||
)
|
||||
|
||||
|
||||
class SSRFSafeSession:
|
||||
"""
|
||||
HTTP Session with built-in SSRF protection.
|
||||
|
||||
This class provides a safe way to make HTTP requests by:
|
||||
1. Validating URL scheme (only http/https allowed)
|
||||
2. Blocking direct IP addresses (legitimate services use domain names)
|
||||
3. Resolving hostnames and blocking private/internal IPs
|
||||
4. Pinning resolved IPs to prevent DNS rebinding attacks (TOCTOU)
|
||||
|
||||
Usage:
|
||||
try:
|
||||
response = SSRFSafeSession().get("https://example.com/image.png", timeout=10)
|
||||
except SSRFValidationError:
|
||||
# URL was blocked for security reasons
|
||||
pass
|
||||
"""
|
||||
|
||||
def _validate_url(self, parsed_url: ParseResult) -> list[str]:
|
||||
"""
|
||||
Validate that a URL is safe to fetch (SSRF protection).
|
||||
|
||||
This function prevents Server-Side Request Forgery (SSRF) attacks by
|
||||
validating URLs before making HTTP requests. It implements a defense-in-depth
|
||||
approach:
|
||||
|
||||
1. Only allows http/https schemes
|
||||
2. Blocks all IP addresses (legitimate emails use domain names)
|
||||
3. Resolves hostnames and blocks if they resolve to private/internal IPs
|
||||
(prevents DNS rebinding attacks where attacker-controlled DNS returns
|
||||
127.0.0.1 or internal IPs)
|
||||
|
||||
Blocked addresses include:
|
||||
- Any direct IP address (e.g., http://192.168.1.1/)
|
||||
- Private IP ranges (RFC1918: 10.x.x.x, 172.16-31.x.x, 192.168.x.x)
|
||||
- Loopback addresses (127.x.x.x, ::1)
|
||||
- Link-local addresses (169.254.x.x, fe80::/10)
|
||||
- Multicast and reserved addresses
|
||||
- Cloud provider metadata endpoints (169.254.169.254, fd00:ec2::254)
|
||||
|
||||
Args:
|
||||
parsed_url: The parsed URL to validate
|
||||
|
||||
Returns:
|
||||
List of validated IP addresses that the hostname resolves to
|
||||
|
||||
Raises:
|
||||
SSRFValidationError: If the URL is unsafe
|
||||
"""
|
||||
# Only allow http and https schemes
|
||||
if parsed_url.scheme not in {"http", "https"}:
|
||||
raise SSRFValidationError("Invalid URL scheme (only http/https allowed)")
|
||||
|
||||
# Require a hostname
|
||||
if not parsed_url.hostname:
|
||||
raise SSRFValidationError("Invalid URL (missing hostname)")
|
||||
|
||||
# Block all IP addresses (legitimate services use domain names)
|
||||
try:
|
||||
ipaddress.ip_address(parsed_url.hostname)
|
||||
raise SSRFValidationError(
|
||||
"IP addresses are not allowed (domain name required)"
|
||||
)
|
||||
except ValueError:
|
||||
# Not an IP address, continue validation
|
||||
pass
|
||||
|
||||
# Resolve hostname to IP addresses
|
||||
try:
|
||||
addr_info = socket.getaddrinfo(
|
||||
parsed_url.hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM
|
||||
)
|
||||
except socket.gaierror as exc:
|
||||
raise SSRFValidationError("Unable to resolve hostname") from exc
|
||||
|
||||
# Check all resolved IP addresses
|
||||
valid_ips = []
|
||||
for _, _, _, _, sockaddr in addr_info:
|
||||
ip_str = sockaddr[0]
|
||||
try:
|
||||
ip_addr = ipaddress.ip_address(ip_str)
|
||||
|
||||
if ip_addr.is_private:
|
||||
raise SSRFValidationError("Domain resolves to private IP address")
|
||||
|
||||
if ip_addr.is_loopback:
|
||||
raise SSRFValidationError("Domain resolves to loopback address")
|
||||
|
||||
if ip_addr.is_link_local:
|
||||
raise SSRFValidationError("Domain resolves to link-local address")
|
||||
|
||||
if ip_addr.is_multicast:
|
||||
raise SSRFValidationError("Domain resolves to multicast address")
|
||||
|
||||
if ip_addr.is_reserved:
|
||||
raise SSRFValidationError("Domain resolves to reserved address")
|
||||
|
||||
# Block known cloud metadata IPs
|
||||
if ip_str in ("169.254.169.254", "fd00:ec2::254"):
|
||||
raise SSRFValidationError(
|
||||
"Domain resolves to cloud metadata endpoint"
|
||||
)
|
||||
|
||||
valid_ips.append(ip_str)
|
||||
|
||||
except ValueError as exc:
|
||||
raise SSRFValidationError("Invalid IP address in DNS response") from exc
|
||||
|
||||
if not valid_ips:
|
||||
raise SSRFValidationError("No valid IP addresses found")
|
||||
|
||||
return valid_ips
|
||||
|
||||
def get(self, url: str, timeout: int, **kwargs) -> requests.Response:
|
||||
"""
|
||||
Perform a safe HTTP GET request with SSRF protection and IP pinning.
|
||||
|
||||
This method:
|
||||
1. Parses and validates the URL
|
||||
2. Resolves DNS and validates all returned IPs
|
||||
3. Creates a requests Session with a custom HTTPAdapter that:
|
||||
- Connects directly to the validated IP (preventing DNS rebinding)
|
||||
- Maintains proper TLS certificate verification against the hostname
|
||||
- Sets the Host header correctly for virtual hosting
|
||||
|
||||
Args:
|
||||
url: The URL to fetch
|
||||
timeout: Request timeout in seconds
|
||||
**kwargs: Additional arguments passed to requests.Session.get()
|
||||
|
||||
Returns:
|
||||
requests.Response object
|
||||
|
||||
Raises:
|
||||
SSRFValidationError: If the URL fails security validation
|
||||
requests.RequestException: If the HTTP request fails
|
||||
"""
|
||||
parsed_url = urlparse(url)
|
||||
valid_ips = self._validate_url(parsed_url)
|
||||
|
||||
# Determine the port (explicit or default based on scheme)
|
||||
if parsed_url.port:
|
||||
port = parsed_url.port
|
||||
elif parsed_url.scheme == "http":
|
||||
port = 80
|
||||
else:
|
||||
port = 443
|
||||
|
||||
# Create a session with our SSRF-protected adapter that pins to the validated IP
|
||||
session = requests.Session()
|
||||
adapter = SSRFProtectedAdapter(
|
||||
dest_ip=valid_ips[0],
|
||||
dest_port=port,
|
||||
original_hostname=parsed_url.hostname,
|
||||
original_scheme=parsed_url.scheme,
|
||||
)
|
||||
|
||||
# Mount the adapter for both http and https schemes
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
|
||||
return session.get(url, timeout=timeout, **kwargs)
|
||||
|
||||
|
||||
class ImageProxySuspiciousResponse(HttpResponse):
|
||||
"""
|
||||
Response for suspicious content that has been blocked by our image proxy.
|
||||
@@ -356,7 +110,6 @@ class ImageProxyViewSet(ViewSet):
|
||||
timeout=10,
|
||||
stream=True,
|
||||
headers={"User-Agent": "Messages-ImageProxy/1.0"},
|
||||
allow_redirects=False,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
@@ -35,9 +35,9 @@ class WidgetAuthentication(BaseAuthentication):
|
||||
if not channel_id:
|
||||
raise AuthenticationFailed("Missing channel_id")
|
||||
|
||||
# API key authentication for check endpoint
|
||||
# Only allow widget-type channels
|
||||
try:
|
||||
channel = models.Channel.objects.get(id=channel_id)
|
||||
channel = models.Channel.objects.get(id=channel_id, type="widget")
|
||||
except models.Channel.DoesNotExist as e:
|
||||
raise AuthenticationFailed("Invalid channel_id") from e
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from celery.utils.log import get_task_logger
|
||||
|
||||
from core.mda.inbound import deliver_inbound_message
|
||||
from core.mda.rfc5322 import parse_email_message
|
||||
from core.services.ssrf import SSRFValidationError, validate_hostname
|
||||
|
||||
logger = get_task_logger(__name__)
|
||||
|
||||
@@ -62,6 +63,21 @@ def decode_imap_utf7(s):
|
||||
return re.sub(r"&([^-]*)-", decode_match, s)
|
||||
|
||||
|
||||
def _validate_imap_host(server: str) -> None:
|
||||
"""Validate that the IMAP server hostname is not a private/internal address.
|
||||
|
||||
Wraps the shared SSRF validator but allows public IP literals, which are
|
||||
legitimate addresses for customer-supplied IMAP servers.
|
||||
|
||||
Raises:
|
||||
ValueError: If the hostname resolves to a blocked IP address.
|
||||
"""
|
||||
try:
|
||||
validate_hostname(server, allow_ip_literal=True)
|
||||
except SSRFValidationError as exc:
|
||||
raise ValueError(f"IMAP server {server} is not allowed: {exc}") from exc
|
||||
|
||||
|
||||
class IMAPConnectionManager:
|
||||
"""Context manager for IMAP connections with proper cleanup."""
|
||||
|
||||
@@ -76,6 +92,9 @@ class IMAPConnectionManager:
|
||||
self.connection = None
|
||||
|
||||
def __enter__(self):
|
||||
# Validate the server hostname to prevent SSRF
|
||||
_validate_imap_host(self.server)
|
||||
|
||||
# Port 143 typically uses STARTTLS, port 993 uses SSL direct
|
||||
# If use_ssl=True and port is 143, use STARTTLS instead of SSL direct
|
||||
use_starttls = self.use_ssl and self.port == 143
|
||||
|
||||
245
src/backend/core/services/ssrf.py
Normal file
245
src/backend/core/services/ssrf.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Server-Side Request Forgery (SSRF) protections.
|
||||
|
||||
Shared across features that take user-supplied network destinations (image
|
||||
proxy, IMAP import, etc.). Provides a hostname/IP validator plus an HTTP
|
||||
session with IP pinning to defeat DNS-rebinding (TOCTOU) attacks.
|
||||
"""
|
||||
|
||||
import ipaddress
|
||||
import socket
|
||||
from urllib.parse import urljoin, urlparse, urlunparse
|
||||
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
|
||||
CLOUD_METADATA_IPS = frozenset({"169.254.169.254", "fd00:ec2::254"})
|
||||
|
||||
MAX_REDIRECTS = 5
|
||||
REDIRECT_STATUS_CODES = frozenset({301, 302, 303, 307, 308})
|
||||
|
||||
|
||||
class SSRFValidationError(Exception):
|
||||
"""Raised when a URL or hostname fails SSRF validation."""
|
||||
|
||||
|
||||
def _check_ip(ip_addr: ipaddress._BaseAddress, hostname: str) -> None:
|
||||
# Check specific categories before is_private: in Python's ipaddress
|
||||
# module, loopback/link-local/etc. are subsets of is_private, so checking
|
||||
# is_private first would mask the more informative error.
|
||||
if str(ip_addr) in CLOUD_METADATA_IPS:
|
||||
raise SSRFValidationError(f"{hostname} resolves to cloud metadata endpoint")
|
||||
if ip_addr.is_loopback:
|
||||
raise SSRFValidationError(f"{hostname} resolves to loopback address")
|
||||
if ip_addr.is_link_local:
|
||||
raise SSRFValidationError(f"{hostname} resolves to link-local address")
|
||||
if ip_addr.is_multicast:
|
||||
raise SSRFValidationError(f"{hostname} resolves to multicast address")
|
||||
if ip_addr.is_reserved:
|
||||
raise SSRFValidationError(f"{hostname} resolves to reserved address")
|
||||
if ip_addr.is_private:
|
||||
raise SSRFValidationError(f"{hostname} resolves to private IP address")
|
||||
|
||||
|
||||
def validate_hostname(hostname: str, *, allow_ip_literal: bool = False) -> list[str]:
|
||||
"""Resolve hostname and reject private/internal/metadata addresses.
|
||||
|
||||
Args:
|
||||
hostname: A hostname or, when allow_ip_literal=True, an IP literal.
|
||||
allow_ip_literal: If False (default), IP literals are rejected outright —
|
||||
legitimate services use domain names. If True, public IP literals
|
||||
are accepted (used for IMAP where customers may supply raw IPs).
|
||||
|
||||
Returns:
|
||||
List of validated IP addresses the hostname resolves to.
|
||||
|
||||
Raises:
|
||||
SSRFValidationError: If the hostname/IP resolves to a blocked address.
|
||||
"""
|
||||
if not hostname:
|
||||
raise SSRFValidationError("Invalid hostname (missing)")
|
||||
|
||||
try:
|
||||
ip = ipaddress.ip_address(hostname)
|
||||
except ValueError:
|
||||
ip = None
|
||||
|
||||
if ip is not None:
|
||||
if not allow_ip_literal:
|
||||
raise SSRFValidationError(
|
||||
"IP addresses are not allowed (domain name required)"
|
||||
)
|
||||
_check_ip(ip, hostname)
|
||||
return [str(ip)]
|
||||
|
||||
try:
|
||||
addr_info = socket.getaddrinfo(
|
||||
hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM
|
||||
)
|
||||
except socket.gaierror as exc:
|
||||
raise SSRFValidationError("Unable to resolve hostname") from exc
|
||||
|
||||
valid_ips: list[str] = []
|
||||
for _, _, _, _, sockaddr in addr_info:
|
||||
ip_str = sockaddr[0]
|
||||
try:
|
||||
ip_addr = ipaddress.ip_address(ip_str)
|
||||
except ValueError as exc:
|
||||
raise SSRFValidationError("Invalid IP address in DNS response") from exc
|
||||
_check_ip(ip_addr, hostname)
|
||||
valid_ips.append(ip_str)
|
||||
|
||||
if not valid_ips:
|
||||
raise SSRFValidationError("No valid IP addresses found")
|
||||
|
||||
return valid_ips
|
||||
|
||||
|
||||
class SSRFProtectedAdapter(HTTPAdapter):
|
||||
"""HTTPAdapter that pins the connection to a pre-validated IP.
|
||||
|
||||
Prevents TOCTOU DNS rebinding by:
|
||||
1. Connecting to the IP address that was validated (no re-resolving DNS).
|
||||
2. Verifying TLS certificates against the original hostname (for HTTPS).
|
||||
3. Setting the Host header correctly for virtual hosting.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dest_ip: str,
|
||||
dest_port: int,
|
||||
original_hostname: str,
|
||||
original_scheme: str,
|
||||
**kwargs,
|
||||
):
|
||||
self.dest_ip = dest_ip
|
||||
self.dest_port = dest_port
|
||||
self.original_hostname = original_hostname
|
||||
self.original_scheme = original_scheme
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
|
||||
if self.original_scheme == "https":
|
||||
pool_kwargs["assert_hostname"] = self.original_hostname
|
||||
pool_kwargs["server_hostname"] = self.original_hostname
|
||||
super().init_poolmanager(connections, maxsize, block, **pool_kwargs)
|
||||
|
||||
def send(
|
||||
self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None
|
||||
):
|
||||
parsed = urlparse(request.url)
|
||||
|
||||
if ":" in self.dest_ip:
|
||||
ip_netloc = f"[{self.dest_ip}]:{self.dest_port}"
|
||||
else:
|
||||
ip_netloc = f"{self.dest_ip}:{self.dest_port}"
|
||||
|
||||
request.url = urlunparse(
|
||||
(
|
||||
parsed.scheme,
|
||||
ip_netloc,
|
||||
parsed.path,
|
||||
parsed.params,
|
||||
parsed.query,
|
||||
parsed.fragment,
|
||||
)
|
||||
)
|
||||
|
||||
if parsed.port and parsed.port not in (80, 443):
|
||||
request.headers["Host"] = f"{self.original_hostname}:{parsed.port}"
|
||||
else:
|
||||
request.headers["Host"] = self.original_hostname
|
||||
|
||||
return super().send(
|
||||
request,
|
||||
stream=stream,
|
||||
timeout=timeout,
|
||||
verify=verify,
|
||||
cert=cert,
|
||||
proxies=proxies,
|
||||
)
|
||||
|
||||
|
||||
class SSRFSafeSession:
|
||||
"""HTTP Session with built-in SSRF protection.
|
||||
|
||||
1. Validates URL scheme (only http/https allowed).
|
||||
2. Blocks direct IP addresses (legitimate services use domain names).
|
||||
3. Resolves hostnames and blocks private/internal IPs.
|
||||
4. Pins resolved IPs to prevent DNS rebinding attacks (TOCTOU).
|
||||
|
||||
Usage:
|
||||
try:
|
||||
response = SSRFSafeSession().get("https://example.com/image.png", timeout=10)
|
||||
except SSRFValidationError:
|
||||
# URL was blocked for security reasons
|
||||
pass
|
||||
"""
|
||||
|
||||
def _validate_and_unpack(self, url: str) -> tuple[str, str, str, int]:
|
||||
"""Validate a URL and return (validated_ip, hostname, scheme, port).
|
||||
|
||||
Raises:
|
||||
SSRFValidationError: If the URL is unsafe.
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in {"http", "https"}:
|
||||
raise SSRFValidationError("Invalid URL scheme (only http/https allowed)")
|
||||
if not parsed.hostname:
|
||||
raise SSRFValidationError("Invalid URL (missing hostname)")
|
||||
|
||||
valid_ips = validate_hostname(parsed.hostname, allow_ip_literal=False)
|
||||
|
||||
if parsed.port:
|
||||
port = parsed.port
|
||||
elif parsed.scheme == "http":
|
||||
port = 80
|
||||
else:
|
||||
port = 443
|
||||
|
||||
return valid_ips[0], parsed.hostname, parsed.scheme, port
|
||||
|
||||
def get(self, url: str, timeout: int, **kwargs) -> requests.Response:
|
||||
"""Perform a safe HTTP GET with per-hop SSRF validation on redirects.
|
||||
|
||||
Redirects are followed manually up to MAX_REDIRECTS hops. Each Location
|
||||
URL is re-validated from scratch, so an attacker-controlled server
|
||||
cannot redirect to an internal address or a different private target
|
||||
on a later hop.
|
||||
"""
|
||||
# We always handle redirects ourselves — strip any caller override so
|
||||
# the underlying requests session never follows a redirect unchecked.
|
||||
kwargs.pop("allow_redirects", None)
|
||||
|
||||
current_url = url
|
||||
for _ in range(MAX_REDIRECTS + 1):
|
||||
validated_ip, hostname, scheme, port = self._validate_and_unpack(
|
||||
current_url
|
||||
)
|
||||
|
||||
session = requests.Session()
|
||||
adapter = SSRFProtectedAdapter(
|
||||
dest_ip=validated_ip,
|
||||
dest_port=port,
|
||||
original_hostname=hostname,
|
||||
original_scheme=scheme,
|
||||
)
|
||||
session.mount("http://", adapter)
|
||||
session.mount("https://", adapter)
|
||||
|
||||
response = session.get(
|
||||
current_url, timeout=timeout, allow_redirects=False, **kwargs
|
||||
)
|
||||
|
||||
if response.status_code not in REDIRECT_STATUS_CODES:
|
||||
return response
|
||||
|
||||
location = response.headers.get("Location")
|
||||
if not location:
|
||||
# Redirect without a Location — hand the response back unchanged.
|
||||
return response
|
||||
|
||||
next_url = urljoin(current_url, location)
|
||||
response.close()
|
||||
current_url = next_url
|
||||
|
||||
raise SSRFValidationError(f"Too many redirects (max {MAX_REDIRECTS})")
|
||||
@@ -191,7 +191,7 @@ class TestImageProxyViewSet:
|
||||
"https://localhost:8080/image.jpg",
|
||||
],
|
||||
)
|
||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
def test_api_image_proxy_localhost_hostname_blocked(
|
||||
self, mock_getaddrinfo, api_client, user_mailbox, test_url
|
||||
):
|
||||
@@ -210,7 +210,7 @@ class TestImageProxyViewSet:
|
||||
assert response["Content-Type"] == "image/svg+xml"
|
||||
|
||||
@override_settings(IMAGE_PROXY_ENABLED=True)
|
||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
def test_api_image_proxy_domain_resolving_to_private_ip_blocked(
|
||||
self, mock_getaddrinfo, api_client, user_mailbox
|
||||
):
|
||||
@@ -229,7 +229,7 @@ class TestImageProxyViewSet:
|
||||
assert response["Content-Type"] == "image/svg+xml"
|
||||
|
||||
@override_settings(IMAGE_PROXY_ENABLED=True)
|
||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
def test_api_image_proxy_unresolvable_hostname(
|
||||
self, mock_getaddrinfo, api_client, user_mailbox
|
||||
):
|
||||
@@ -249,7 +249,7 @@ class TestImageProxyViewSet:
|
||||
# Content-Type Validation Tests
|
||||
|
||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
|
||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||
def test_api_image_proxy_non_image_content_type_blocked(
|
||||
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
|
||||
@@ -273,7 +273,7 @@ class TestImageProxyViewSet:
|
||||
assert response["Content-Type"] == "image/svg+xml"
|
||||
|
||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
|
||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||
@patch("magic.from_buffer")
|
||||
def test_api_image_proxy_content_not_actually_image(
|
||||
@@ -301,7 +301,7 @@ class TestImageProxyViewSet:
|
||||
assert response["Content-Type"] == "image/svg+xml"
|
||||
|
||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
|
||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||
@patch("magic.from_buffer")
|
||||
@pytest.mark.parametrize(
|
||||
@@ -342,7 +342,7 @@ class TestImageProxyViewSet:
|
||||
# Size Limit Tests
|
||||
|
||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=1)
|
||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||
def test_api_image_proxy_image_too_large_via_content_length(
|
||||
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
|
||||
@@ -367,7 +367,7 @@ class TestImageProxyViewSet:
|
||||
assert "Image too large" in str(response.data)
|
||||
|
||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=4096)
|
||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||
@patch("magic.from_buffer")
|
||||
def test_api_image_proxy_image_too_large_actual_content(
|
||||
@@ -395,39 +395,13 @@ class TestImageProxyViewSet:
|
||||
assert response.status_code == status.HTTP_413_REQUEST_ENTITY_TOO_LARGE
|
||||
assert "Image too large" in str(response.data)
|
||||
|
||||
# Redirect Tests
|
||||
|
||||
@override_settings(IMAGE_PROXY_ENABLED=True)
|
||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||
def test_api_image_proxy_redirects_blocked(
|
||||
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
|
||||
):
|
||||
"""Test that redirects are not followed (SSRF protection)."""
|
||||
# Mock DNS resolution
|
||||
mock_getaddrinfo.return_value = [(2, 1, 6, "", ("1.2.3.4", 80))]
|
||||
|
||||
client, _ = api_client
|
||||
url = self._get_image_proxy_url(
|
||||
user_mailbox.id, "http://example.com/redirect.jpg"
|
||||
)
|
||||
|
||||
mock_get.return_value = self._mock_requests_response(
|
||||
content=b"", content_type="text/plain", status_code=302
|
||||
)
|
||||
|
||||
# Call the endpoint
|
||||
client.get(url)
|
||||
|
||||
# Verify that allow_redirects=False was passed to requests.get
|
||||
mock_get.assert_called_once()
|
||||
call_kwargs = mock_get.call_args[1]
|
||||
assert call_kwargs["allow_redirects"] is False
|
||||
# Redirect handling is covered end-to-end in core/tests/services/test_ssrf.py:
|
||||
# SSRFSafeSession follows redirects internally, re-validating every hop.
|
||||
|
||||
# Success Cases
|
||||
|
||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_CACHE_TTL=3600)
|
||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||
@patch("magic.from_buffer")
|
||||
def test_api_image_proxy_successfully_jpg_image(
|
||||
@@ -459,7 +433,7 @@ class TestImageProxyViewSet:
|
||||
assert response["Cache-Control"] == "public, max-age=3600"
|
||||
|
||||
@override_settings(IMAGE_PROXY_ENABLED=True)
|
||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||
@patch("magic.from_buffer")
|
||||
def test_api_image_proxy_successfully_png_image(
|
||||
@@ -488,7 +462,7 @@ class TestImageProxyViewSet:
|
||||
assert response["Content-Type"] == "image/png"
|
||||
|
||||
@override_settings(IMAGE_PROXY_ENABLED=True)
|
||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||
@patch("magic.from_buffer")
|
||||
def test_api_image_proxy_successfully_with_octet_stream_content_type(
|
||||
@@ -517,7 +491,7 @@ class TestImageProxyViewSet:
|
||||
assert response["Content-Type"] == "image/jpeg"
|
||||
|
||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_CACHE_TTL=3600)
|
||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||
@patch("magic.from_buffer")
|
||||
def test_api_image_proxy_no_content(
|
||||
@@ -547,7 +521,7 @@ class TestImageProxyViewSet:
|
||||
assert response["Content-Type"] == "image/svg+xml"
|
||||
|
||||
@override_settings(IMAGE_PROXY_ENABLED=True)
|
||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||
@patch("magic.from_buffer")
|
||||
def test_api_image_proxy_url_with_special_characters(
|
||||
@@ -578,7 +552,7 @@ class TestImageProxyViewSet:
|
||||
assert mock_get.call_args[0][0] == test_url
|
||||
|
||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_CACHE_TTL=3600)
|
||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||
@patch("magic.from_buffer")
|
||||
def test_api_image_proxy_successfully_secure_headers(
|
||||
@@ -617,7 +591,7 @@ class TestImageProxyViewSet:
|
||||
# Error Handling Tests
|
||||
|
||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
|
||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||
def test_api_image_proxy_network_timeout(
|
||||
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
|
||||
@@ -638,7 +612,7 @@ class TestImageProxyViewSet:
|
||||
assert "Failed to fetch image" in str(response.data)
|
||||
|
||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
|
||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||
def test_api_image_proxy_connection_error(
|
||||
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
|
||||
@@ -659,7 +633,7 @@ class TestImageProxyViewSet:
|
||||
assert "Failed to fetch image" in str(response.data)
|
||||
|
||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
|
||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||
def test_api_image_proxy_http_error_404(
|
||||
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
|
||||
@@ -687,7 +661,7 @@ class TestImageProxyViewSet:
|
||||
assert "Failed to fetch image" in str(response.data)
|
||||
|
||||
@override_settings(IMAGE_PROXY_ENABLED=True)
|
||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||
@patch("magic.from_buffer")
|
||||
def test_api_image_proxy_invalid_content_length_header(
|
||||
@@ -716,7 +690,7 @@ class TestImageProxyViewSet:
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
|
||||
@override_settings(IMAGE_PROXY_ENABLED=True)
|
||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||
@patch("magic.from_buffer")
|
||||
def test_api_image_proxy_missing_content_length_header(
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# pylint: disable=redefined-outer-name, unused-argument, no-value-for-parameter, too-many-lines
|
||||
|
||||
import datetime
|
||||
import socket
|
||||
from unittest.mock import patch
|
||||
|
||||
from django.core.files.storage import storages
|
||||
@@ -19,6 +20,25 @@ from core.services.importer.mbox_tasks import process_mbox_file_task
|
||||
|
||||
pytestmark = pytest.mark.django_db
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_ssrf_dns():
|
||||
"""Short-circuit SSRF DNS validation for IMAP import tests.
|
||||
|
||||
The IMAP endpoint validates the server hostname via
|
||||
``core.services.ssrf.validate_hostname``; tests use unresolvable fixtures
|
||||
like ``imap.example.com`` so we return a public IP to reach the mocked
|
||||
IMAP task code.
|
||||
"""
|
||||
with patch(
|
||||
"core.services.ssrf.socket.getaddrinfo",
|
||||
return_value=[
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("93.184.216.34", 0))
|
||||
],
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
IMPORT_FILE_URL = "/api/v1.0/import/file/"
|
||||
IMPORT_IMAP_URL = "/api/v1.0/import/imap/"
|
||||
|
||||
|
||||
24
src/backend/core/tests/importer/conftest.py
Normal file
24
src/backend/core/tests/importer/conftest.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Shared fixtures for importer tests."""
|
||||
|
||||
import socket
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_ssrf_dns():
|
||||
"""Short-circuit SSRF DNS validation for IMAP tests.
|
||||
|
||||
The IMAP import path validates the server hostname via
|
||||
``core.services.ssrf.validate_hostname`` which calls ``socket.getaddrinfo``.
|
||||
Test fixtures use unresolvable hostnames like ``imap.example.com``, so we
|
||||
return a valid public IP here to let tests reach the mocked IMAP code.
|
||||
"""
|
||||
with mock.patch(
|
||||
"core.services.ssrf.socket.getaddrinfo",
|
||||
return_value=[
|
||||
(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("93.184.216.34", 0))
|
||||
],
|
||||
):
|
||||
yield
|
||||
274
src/backend/core/tests/services/test_ssrf.py
Normal file
274
src/backend/core/tests/services/test_ssrf.py
Normal file
@@ -0,0 +1,274 @@
|
||||
"""Unit tests for core.services.ssrf — focus on redirect handling.
|
||||
|
||||
Hostname/IP validation itself is covered via the image proxy integration tests
|
||||
(see core/tests/api/test_image_proxy.py). This file targets the per-hop
|
||||
validation loop in SSRFSafeSession.get, which is the main SSRF-correctness
|
||||
surface for redirect responses.
|
||||
"""
|
||||
|
||||
import socket
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.services.ssrf import (
|
||||
MAX_REDIRECTS,
|
||||
SSRFSafeSession,
|
||||
SSRFValidationError,
|
||||
)
|
||||
|
||||
PUBLIC_IP = "93.184.216.34"
|
||||
PRIVATE_IP = "192.168.1.1"
|
||||
|
||||
|
||||
def _addrinfo(ip: str):
|
||||
return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (ip, 0))]
|
||||
|
||||
|
||||
def _mock_response(status_code: int = 200, location: str | None = None):
|
||||
"""Build a mock requests.Response-like object."""
|
||||
resp = MagicMock()
|
||||
resp.status_code = status_code
|
||||
resp.headers = {}
|
||||
if location is not None:
|
||||
resp.headers["Location"] = location
|
||||
return resp
|
||||
|
||||
|
||||
class TestSSRFSafeSessionRedirects:
|
||||
"""Redirect-handling contract for SSRFSafeSession.get."""
|
||||
|
||||
@patch("core.services.ssrf.requests.Session.get")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
def test_no_redirect_returns_response_directly(self, mock_dns, mock_get):
|
||||
"""A 200 response is returned without any extra request."""
|
||||
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||
mock_get.return_value = _mock_response(200)
|
||||
|
||||
response = SSRFSafeSession().get("https://legit.com/img.png", timeout=10)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
@patch("core.services.ssrf.requests.Session.get")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
def test_follows_redirect_to_safe_url(self, mock_dns, mock_get):
|
||||
"""A single redirect is followed to a validated destination."""
|
||||
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||
mock_get.side_effect = [
|
||||
_mock_response(302, location="https://cdn.legit.com/img.png"),
|
||||
_mock_response(200),
|
||||
]
|
||||
|
||||
response = SSRFSafeSession().get("https://legit.com/img.png", timeout=10)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
@pytest.mark.parametrize("redirect_status", [301, 302, 303, 307, 308])
|
||||
@patch("core.services.ssrf.requests.Session.get")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
def test_all_redirect_statuses_followed(self, mock_dns, mock_get, redirect_status):
|
||||
"""All standard redirect status codes trigger a new hop."""
|
||||
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||
mock_get.side_effect = [
|
||||
_mock_response(redirect_status, location="https://b.com/img.png"),
|
||||
_mock_response(200),
|
||||
]
|
||||
|
||||
response = SSRFSafeSession().get("https://a.com/img.png", timeout=10)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert mock_get.call_count == 2
|
||||
|
||||
@patch("core.services.ssrf.requests.Session.get")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
def test_follows_multiple_hops(self, mock_dns, mock_get):
|
||||
"""Chained redirects are followed up to MAX_REDIRECTS."""
|
||||
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||
mock_get.side_effect = [
|
||||
_mock_response(302, location="https://b.com/"),
|
||||
_mock_response(302, location="https://c.com/"),
|
||||
_mock_response(302, location="https://d.com/"),
|
||||
_mock_response(200),
|
||||
]
|
||||
|
||||
response = SSRFSafeSession().get("https://a.com/", timeout=10)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert mock_get.call_count == 4
|
||||
|
||||
@patch("core.services.ssrf.requests.Session.get")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
def test_blocks_redirect_to_private_ip(self, mock_dns, mock_get):
|
||||
"""A Location resolving to a private IP is rejected mid-chain."""
|
||||
|
||||
def dns_side_effect(host, *_args, **_kwargs):
|
||||
if host == "legit.com":
|
||||
return _addrinfo(PUBLIC_IP)
|
||||
if host == "internal.evil.com":
|
||||
return _addrinfo(PRIVATE_IP)
|
||||
raise AssertionError(f"unexpected DNS lookup: {host}")
|
||||
|
||||
mock_dns.side_effect = dns_side_effect
|
||||
mock_get.return_value = _mock_response(
|
||||
302, location="https://internal.evil.com/pwn"
|
||||
)
|
||||
|
||||
with pytest.raises(SSRFValidationError, match="private IP"):
|
||||
SSRFSafeSession().get("https://legit.com/img.png", timeout=10)
|
||||
|
||||
# Only the first hop should have been issued; the second was blocked
|
||||
# before the request was made.
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
@patch("core.services.ssrf.requests.Session.get")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
def test_blocks_redirect_to_loopback(self, mock_dns, mock_get):
|
||||
"""Redirect pointing at loopback (DNS-rebinding style) is rejected."""
|
||||
|
||||
def dns_side_effect(host, *_args, **_kwargs):
|
||||
if host == "legit.com":
|
||||
return _addrinfo(PUBLIC_IP)
|
||||
if host == "rebind.evil.com":
|
||||
return _addrinfo("127.0.0.1")
|
||||
raise AssertionError(f"unexpected DNS lookup: {host}")
|
||||
|
||||
mock_dns.side_effect = dns_side_effect
|
||||
mock_get.return_value = _mock_response(302, location="https://rebind.evil.com/")
|
||||
|
||||
with pytest.raises(SSRFValidationError, match="loopback"):
|
||||
SSRFSafeSession().get("https://legit.com/img.png", timeout=10)
|
||||
|
||||
@patch("core.services.ssrf.requests.Session.get")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
def test_blocks_redirect_to_cloud_metadata(self, mock_dns, mock_get):
|
||||
"""Redirect pointing at cloud metadata endpoint is rejected."""
|
||||
|
||||
def dns_side_effect(host, *_args, **_kwargs):
|
||||
if host == "legit.com":
|
||||
return _addrinfo(PUBLIC_IP)
|
||||
if host == "meta.evil.com":
|
||||
return _addrinfo("169.254.169.254")
|
||||
raise AssertionError(f"unexpected DNS lookup: {host}")
|
||||
|
||||
mock_dns.side_effect = dns_side_effect
|
||||
mock_get.return_value = _mock_response(
|
||||
302, location="https://meta.evil.com/latest/meta-data/"
|
||||
)
|
||||
|
||||
with pytest.raises(SSRFValidationError, match="metadata"):
|
||||
SSRFSafeSession().get("https://legit.com/img.png", timeout=10)
|
||||
|
||||
@patch("core.services.ssrf.requests.Session.get")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
def test_blocks_redirect_to_non_http_scheme(self, mock_dns, mock_get):
|
||||
"""Redirect to file:// or other schemes is rejected."""
|
||||
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||
mock_get.return_value = _mock_response(302, location="file:///etc/passwd")
|
||||
|
||||
with pytest.raises(SSRFValidationError, match="scheme"):
|
||||
SSRFSafeSession().get("https://legit.com/", timeout=10)
|
||||
|
||||
@patch("core.services.ssrf.requests.Session.get")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
def test_blocks_redirect_to_ip_literal(self, mock_dns, mock_get):
|
||||
"""Redirect whose Location is a raw IP is rejected (domains only)."""
|
||||
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||
mock_get.return_value = _mock_response(302, location="http://203.0.113.5/stuff")
|
||||
|
||||
with pytest.raises(SSRFValidationError, match="IP addresses are not allowed"):
|
||||
SSRFSafeSession().get("https://legit.com/", timeout=10)
|
||||
|
||||
@patch("core.services.ssrf.requests.Session.get")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
def test_blocks_protocol_relative_redirect_to_private(self, mock_dns, mock_get):
|
||||
"""//host/path Location resolves against current scheme and is validated."""
|
||||
|
||||
def dns_side_effect(host, *_args, **_kwargs):
|
||||
if host == "legit.com":
|
||||
return _addrinfo(PUBLIC_IP)
|
||||
if host == "internal":
|
||||
return _addrinfo(PRIVATE_IP)
|
||||
raise AssertionError(f"unexpected DNS lookup: {host}")
|
||||
|
||||
mock_dns.side_effect = dns_side_effect
|
||||
mock_get.return_value = _mock_response(302, location="//internal/admin")
|
||||
|
||||
with pytest.raises(SSRFValidationError, match="private IP"):
|
||||
SSRFSafeSession().get("https://legit.com/", timeout=10)
|
||||
|
||||
@patch("core.services.ssrf.requests.Session.get")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
def test_follows_relative_redirect(self, mock_dns, mock_get):
|
||||
"""Relative Location is resolved against the current URL and validated."""
|
||||
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||
mock_get.side_effect = [
|
||||
_mock_response(302, location="/new-path"),
|
||||
_mock_response(200),
|
||||
]
|
||||
|
||||
response = SSRFSafeSession().get("https://legit.com/old", timeout=10)
|
||||
|
||||
assert response.status_code == 200
|
||||
# Second request should be made to https://legit.com/new-path.
|
||||
second_call_url = mock_get.call_args_list[1].args[0]
|
||||
assert second_call_url == "https://legit.com/new-path"
|
||||
|
||||
@patch("core.services.ssrf.requests.Session.get")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
def test_blocks_too_many_redirects(self, mock_dns, mock_get):
|
||||
"""A redirect loop longer than MAX_REDIRECTS raises."""
|
||||
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||
mock_get.return_value = _mock_response(302, location="https://legit.com/loop")
|
||||
|
||||
with pytest.raises(SSRFValidationError, match="Too many redirects"):
|
||||
SSRFSafeSession().get("https://legit.com/", timeout=10)
|
||||
|
||||
# Initial hop + MAX_REDIRECTS extra hops, then error.
|
||||
assert mock_get.call_count == MAX_REDIRECTS + 1
|
||||
|
||||
@patch("core.services.ssrf.requests.Session.get")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
def test_redirect_without_location_returned_as_is(self, mock_dns, mock_get):
|
||||
"""A redirect status with no Location header is returned unchanged."""
|
||||
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||
mock_get.return_value = _mock_response(302, location=None)
|
||||
|
||||
response = SSRFSafeSession().get("https://legit.com/", timeout=10)
|
||||
|
||||
assert response.status_code == 302
|
||||
assert mock_get.call_count == 1
|
||||
|
||||
@patch("core.services.ssrf.requests.Session.get")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
def test_caller_allow_redirects_true_is_ignored(self, mock_dns, mock_get):
|
||||
"""Caller cannot opt out of per-hop validation by passing allow_redirects=True."""
|
||||
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||
mock_get.side_effect = [
|
||||
_mock_response(302, location="https://cdn.legit.com/"),
|
||||
_mock_response(200),
|
||||
]
|
||||
|
||||
response = SSRFSafeSession().get(
|
||||
"https://legit.com/", timeout=10, allow_redirects=True
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
# Each underlying Session.get must be called with allow_redirects=False.
|
||||
for call in mock_get.call_args_list:
|
||||
assert call.kwargs.get("allow_redirects") is False
|
||||
|
||||
@patch("core.services.ssrf.requests.Session.get")
|
||||
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||
def test_intermediate_response_closed(self, mock_dns, mock_get):
|
||||
"""Intermediate redirect responses are .close()d to release streams."""
|
||||
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||
intermediate = _mock_response(302, location="https://cdn.legit.com/")
|
||||
final = _mock_response(200)
|
||||
mock_get.side_effect = [intermediate, final]
|
||||
|
||||
SSRFSafeSession().get("https://legit.com/", timeout=10, stream=True)
|
||||
|
||||
intermediate.close.assert_called_once()
|
||||
final.close.assert_not_called()
|
||||
Reference in New Issue
Block a user