mirror of
https://github.com/suitenumerique/messages.git
synced 2026-04-25 17:15:21 +02:00
🔒️(ssrf) factorize SSRF code, allow redirects in image proxy (#631)
Also use the SSRF code in IMAP imports.
This commit is contained in:
@@ -4,6 +4,7 @@ import logging
|
|||||||
|
|
||||||
from django.http import HttpResponse
|
from django.http import HttpResponse
|
||||||
from django.utils.decorators import method_decorator
|
from django.utils.decorators import method_decorator
|
||||||
|
from django.utils.http import content_disposition_header
|
||||||
from django.views.decorators.csrf import csrf_exempt
|
from django.views.decorators.csrf import csrf_exempt
|
||||||
|
|
||||||
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
||||||
@@ -177,8 +178,8 @@ class BlobViewSet(ViewSet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Add appropriate headers for download
|
# Add appropriate headers for download
|
||||||
response["Content-Disposition"] = (
|
response["Content-Disposition"] = content_disposition_header(
|
||||||
f'attachment; filename="{attachment["name"]}"'
|
True, attachment["name"]
|
||||||
)
|
)
|
||||||
response["Content-Length"] = attachment["size"]
|
response["Content-Length"] = attachment["size"]
|
||||||
# Enable browser caching for 30 days (inline images benefit from this)
|
# Enable browser caching for 30 days (inline images benefit from this)
|
||||||
@@ -207,7 +208,9 @@ class BlobViewSet(ViewSet):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Add appropriate headers for download
|
# Add appropriate headers for download
|
||||||
response["Content-Disposition"] = f'attachment; filename="{filename}"'
|
response["Content-Disposition"] = content_disposition_header(
|
||||||
|
True, filename
|
||||||
|
)
|
||||||
response["Content-Length"] = blob.size
|
response["Content-Length"] = blob.size
|
||||||
# Enable browser caching for 30 days (inline images benefit from this)
|
# Enable browser caching for 30 days (inline images benefit from this)
|
||||||
response["Cache-Control"] = "private, max-age=2592000"
|
response["Cache-Control"] = "private, max-age=2592000"
|
||||||
|
|||||||
@@ -1,9 +1,7 @@
|
|||||||
"""API ViewSet for proxying external images."""
|
"""API ViewSet for proxying external images."""
|
||||||
|
|
||||||
import ipaddress
|
|
||||||
import logging
|
import logging
|
||||||
import socket
|
from urllib.parse import unquote
|
||||||
from urllib.parse import ParseResult, unquote, urlparse, urlunparse
|
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.http import HttpResponse
|
from django.http import HttpResponse
|
||||||
@@ -11,261 +9,17 @@ from django.http import HttpResponse
|
|||||||
import magic
|
import magic
|
||||||
import requests
|
import requests
|
||||||
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
||||||
from requests.adapters import HTTPAdapter
|
|
||||||
from rest_framework import status as http_status
|
from rest_framework import status as http_status
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.viewsets import ViewSet
|
from rest_framework.viewsets import ViewSet
|
||||||
|
|
||||||
from core import enums, models
|
from core import enums, models
|
||||||
from core.api import permissions
|
from core.api import permissions
|
||||||
|
from core.services.ssrf import SSRFSafeSession, SSRFValidationError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SSRFValidationError(Exception):
|
|
||||||
"""Exception raised when URL validation fails due to SSRF protection."""
|
|
||||||
|
|
||||||
|
|
||||||
class SSRFProtectedAdapter(HTTPAdapter):
|
|
||||||
"""
|
|
||||||
HTTPAdapter that connects to a pre-validated IP address while maintaining
|
|
||||||
proper TLS certificate verification against the original hostname.
|
|
||||||
|
|
||||||
This prevents TOCTOU DNS rebinding attacks by:
|
|
||||||
1. Connecting to the IP address that was validated (not re-resolving DNS)
|
|
||||||
2. Verifying TLS certificates against the original hostname (for HTTPS)
|
|
||||||
3. Setting the Host header correctly for virtual hosting
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dest_ip: str,
|
|
||||||
dest_port: int,
|
|
||||||
original_hostname: str,
|
|
||||||
original_scheme: str,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.dest_ip = dest_ip
|
|
||||||
self.dest_port = dest_port
|
|
||||||
self.original_hostname = original_hostname
|
|
||||||
self.original_scheme = original_scheme
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
|
|
||||||
"""Initialize pool manager with TLS hostname verification settings."""
|
|
||||||
if self.original_scheme == "https":
|
|
||||||
# Ensure TLS certificate is verified against the original hostname
|
|
||||||
# even though we're connecting to an IP address
|
|
||||||
pool_kwargs["assert_hostname"] = self.original_hostname
|
|
||||||
pool_kwargs["server_hostname"] = self.original_hostname
|
|
||||||
super().init_poolmanager(connections, maxsize, block, **pool_kwargs)
|
|
||||||
|
|
||||||
def send(
|
|
||||||
self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None
|
|
||||||
):
|
|
||||||
"""Send request, rewriting URL to connect to the validated IP address."""
|
|
||||||
parsed = urlparse(request.url)
|
|
||||||
|
|
||||||
# Build URL with validated IP instead of hostname
|
|
||||||
# IPv6 addresses need brackets in URLs
|
|
||||||
if ":" in self.dest_ip:
|
|
||||||
ip_netloc = f"[{self.dest_ip}]:{self.dest_port}"
|
|
||||||
else:
|
|
||||||
ip_netloc = f"{self.dest_ip}:{self.dest_port}"
|
|
||||||
|
|
||||||
# Reconstruct URL with IP address
|
|
||||||
request.url = urlunparse(
|
|
||||||
(
|
|
||||||
parsed.scheme,
|
|
||||||
ip_netloc,
|
|
||||||
parsed.path,
|
|
||||||
parsed.params,
|
|
||||||
parsed.query,
|
|
||||||
parsed.fragment,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set Host header to original hostname for virtual hosting
|
|
||||||
# Include port only if non-standard
|
|
||||||
if parsed.port and parsed.port not in (80, 443):
|
|
||||||
request.headers["Host"] = f"{self.original_hostname}:{parsed.port}"
|
|
||||||
else:
|
|
||||||
request.headers["Host"] = self.original_hostname
|
|
||||||
|
|
||||||
return super().send(
|
|
||||||
request,
|
|
||||||
stream=stream,
|
|
||||||
timeout=timeout,
|
|
||||||
verify=verify,
|
|
||||||
cert=cert,
|
|
||||||
proxies=proxies,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SSRFSafeSession:
|
|
||||||
"""
|
|
||||||
HTTP Session with built-in SSRF protection.
|
|
||||||
|
|
||||||
This class provides a safe way to make HTTP requests by:
|
|
||||||
1. Validating URL scheme (only http/https allowed)
|
|
||||||
2. Blocking direct IP addresses (legitimate services use domain names)
|
|
||||||
3. Resolving hostnames and blocking private/internal IPs
|
|
||||||
4. Pinning resolved IPs to prevent DNS rebinding attacks (TOCTOU)
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
try:
|
|
||||||
response = SSRFSafeSession().get("https://example.com/image.png", timeout=10)
|
|
||||||
except SSRFValidationError:
|
|
||||||
# URL was blocked for security reasons
|
|
||||||
pass
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _validate_url(self, parsed_url: ParseResult) -> list[str]:
|
|
||||||
"""
|
|
||||||
Validate that a URL is safe to fetch (SSRF protection).
|
|
||||||
|
|
||||||
This function prevents Server-Side Request Forgery (SSRF) attacks by
|
|
||||||
validating URLs before making HTTP requests. It implements a defense-in-depth
|
|
||||||
approach:
|
|
||||||
|
|
||||||
1. Only allows http/https schemes
|
|
||||||
2. Blocks all IP addresses (legitimate emails use domain names)
|
|
||||||
3. Resolves hostnames and blocks if they resolve to private/internal IPs
|
|
||||||
(prevents DNS rebinding attacks where attacker-controlled DNS returns
|
|
||||||
127.0.0.1 or internal IPs)
|
|
||||||
|
|
||||||
Blocked addresses include:
|
|
||||||
- Any direct IP address (e.g., http://192.168.1.1/)
|
|
||||||
- Private IP ranges (RFC1918: 10.x.x.x, 172.16-31.x.x, 192.168.x.x)
|
|
||||||
- Loopback addresses (127.x.x.x, ::1)
|
|
||||||
- Link-local addresses (169.254.x.x, fe80::/10)
|
|
||||||
- Multicast and reserved addresses
|
|
||||||
- Cloud provider metadata endpoints (169.254.169.254, fd00:ec2::254)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
parsed_url: The parsed URL to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of validated IP addresses that the hostname resolves to
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
SSRFValidationError: If the URL is unsafe
|
|
||||||
"""
|
|
||||||
# Only allow http and https schemes
|
|
||||||
if parsed_url.scheme not in {"http", "https"}:
|
|
||||||
raise SSRFValidationError("Invalid URL scheme (only http/https allowed)")
|
|
||||||
|
|
||||||
# Require a hostname
|
|
||||||
if not parsed_url.hostname:
|
|
||||||
raise SSRFValidationError("Invalid URL (missing hostname)")
|
|
||||||
|
|
||||||
# Block all IP addresses (legitimate services use domain names)
|
|
||||||
try:
|
|
||||||
ipaddress.ip_address(parsed_url.hostname)
|
|
||||||
raise SSRFValidationError(
|
|
||||||
"IP addresses are not allowed (domain name required)"
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
# Not an IP address, continue validation
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Resolve hostname to IP addresses
|
|
||||||
try:
|
|
||||||
addr_info = socket.getaddrinfo(
|
|
||||||
parsed_url.hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM
|
|
||||||
)
|
|
||||||
except socket.gaierror as exc:
|
|
||||||
raise SSRFValidationError("Unable to resolve hostname") from exc
|
|
||||||
|
|
||||||
# Check all resolved IP addresses
|
|
||||||
valid_ips = []
|
|
||||||
for _, _, _, _, sockaddr in addr_info:
|
|
||||||
ip_str = sockaddr[0]
|
|
||||||
try:
|
|
||||||
ip_addr = ipaddress.ip_address(ip_str)
|
|
||||||
|
|
||||||
if ip_addr.is_private:
|
|
||||||
raise SSRFValidationError("Domain resolves to private IP address")
|
|
||||||
|
|
||||||
if ip_addr.is_loopback:
|
|
||||||
raise SSRFValidationError("Domain resolves to loopback address")
|
|
||||||
|
|
||||||
if ip_addr.is_link_local:
|
|
||||||
raise SSRFValidationError("Domain resolves to link-local address")
|
|
||||||
|
|
||||||
if ip_addr.is_multicast:
|
|
||||||
raise SSRFValidationError("Domain resolves to multicast address")
|
|
||||||
|
|
||||||
if ip_addr.is_reserved:
|
|
||||||
raise SSRFValidationError("Domain resolves to reserved address")
|
|
||||||
|
|
||||||
# Block known cloud metadata IPs
|
|
||||||
if ip_str in ("169.254.169.254", "fd00:ec2::254"):
|
|
||||||
raise SSRFValidationError(
|
|
||||||
"Domain resolves to cloud metadata endpoint"
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_ips.append(ip_str)
|
|
||||||
|
|
||||||
except ValueError as exc:
|
|
||||||
raise SSRFValidationError("Invalid IP address in DNS response") from exc
|
|
||||||
|
|
||||||
if not valid_ips:
|
|
||||||
raise SSRFValidationError("No valid IP addresses found")
|
|
||||||
|
|
||||||
return valid_ips
|
|
||||||
|
|
||||||
def get(self, url: str, timeout: int, **kwargs) -> requests.Response:
|
|
||||||
"""
|
|
||||||
Perform a safe HTTP GET request with SSRF protection and IP pinning.
|
|
||||||
|
|
||||||
This method:
|
|
||||||
1. Parses and validates the URL
|
|
||||||
2. Resolves DNS and validates all returned IPs
|
|
||||||
3. Creates a requests Session with a custom HTTPAdapter that:
|
|
||||||
- Connects directly to the validated IP (preventing DNS rebinding)
|
|
||||||
- Maintains proper TLS certificate verification against the hostname
|
|
||||||
- Sets the Host header correctly for virtual hosting
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url: The URL to fetch
|
|
||||||
timeout: Request timeout in seconds
|
|
||||||
**kwargs: Additional arguments passed to requests.Session.get()
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
requests.Response object
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
SSRFValidationError: If the URL fails security validation
|
|
||||||
requests.RequestException: If the HTTP request fails
|
|
||||||
"""
|
|
||||||
parsed_url = urlparse(url)
|
|
||||||
valid_ips = self._validate_url(parsed_url)
|
|
||||||
|
|
||||||
# Determine the port (explicit or default based on scheme)
|
|
||||||
if parsed_url.port:
|
|
||||||
port = parsed_url.port
|
|
||||||
elif parsed_url.scheme == "http":
|
|
||||||
port = 80
|
|
||||||
else:
|
|
||||||
port = 443
|
|
||||||
|
|
||||||
# Create a session with our SSRF-protected adapter that pins to the validated IP
|
|
||||||
session = requests.Session()
|
|
||||||
adapter = SSRFProtectedAdapter(
|
|
||||||
dest_ip=valid_ips[0],
|
|
||||||
dest_port=port,
|
|
||||||
original_hostname=parsed_url.hostname,
|
|
||||||
original_scheme=parsed_url.scheme,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mount the adapter for both http and https schemes
|
|
||||||
session.mount("http://", adapter)
|
|
||||||
session.mount("https://", adapter)
|
|
||||||
|
|
||||||
return session.get(url, timeout=timeout, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageProxySuspiciousResponse(HttpResponse):
|
class ImageProxySuspiciousResponse(HttpResponse):
|
||||||
"""
|
"""
|
||||||
Response for suspicious content that has been blocked by our image proxy.
|
Response for suspicious content that has been blocked by our image proxy.
|
||||||
@@ -356,7 +110,6 @@ class ImageProxyViewSet(ViewSet):
|
|||||||
timeout=10,
|
timeout=10,
|
||||||
stream=True,
|
stream=True,
|
||||||
headers={"User-Agent": "Messages-ImageProxy/1.0"},
|
headers={"User-Agent": "Messages-ImageProxy/1.0"},
|
||||||
allow_redirects=False,
|
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
|
||||||
|
|||||||
@@ -35,9 +35,9 @@ class WidgetAuthentication(BaseAuthentication):
|
|||||||
if not channel_id:
|
if not channel_id:
|
||||||
raise AuthenticationFailed("Missing channel_id")
|
raise AuthenticationFailed("Missing channel_id")
|
||||||
|
|
||||||
# API key authentication for check endpoint
|
# Only allow widget-type channels
|
||||||
try:
|
try:
|
||||||
channel = models.Channel.objects.get(id=channel_id)
|
channel = models.Channel.objects.get(id=channel_id, type="widget")
|
||||||
except models.Channel.DoesNotExist as e:
|
except models.Channel.DoesNotExist as e:
|
||||||
raise AuthenticationFailed("Invalid channel_id") from e
|
raise AuthenticationFailed("Invalid channel_id") from e
|
||||||
|
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from celery.utils.log import get_task_logger
|
|||||||
|
|
||||||
from core.mda.inbound import deliver_inbound_message
|
from core.mda.inbound import deliver_inbound_message
|
||||||
from core.mda.rfc5322 import parse_email_message
|
from core.mda.rfc5322 import parse_email_message
|
||||||
|
from core.services.ssrf import SSRFValidationError, validate_hostname
|
||||||
|
|
||||||
logger = get_task_logger(__name__)
|
logger = get_task_logger(__name__)
|
||||||
|
|
||||||
@@ -62,6 +63,21 @@ def decode_imap_utf7(s):
|
|||||||
return re.sub(r"&([^-]*)-", decode_match, s)
|
return re.sub(r"&([^-]*)-", decode_match, s)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_imap_host(server: str) -> None:
|
||||||
|
"""Validate that the IMAP server hostname is not a private/internal address.
|
||||||
|
|
||||||
|
Wraps the shared SSRF validator but allows public IP literals, which are
|
||||||
|
legitimate addresses for customer-supplied IMAP servers.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the hostname resolves to a blocked IP address.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
validate_hostname(server, allow_ip_literal=True)
|
||||||
|
except SSRFValidationError as exc:
|
||||||
|
raise ValueError(f"IMAP server {server} is not allowed: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
class IMAPConnectionManager:
|
class IMAPConnectionManager:
|
||||||
"""Context manager for IMAP connections with proper cleanup."""
|
"""Context manager for IMAP connections with proper cleanup."""
|
||||||
|
|
||||||
@@ -76,6 +92,9 @@ class IMAPConnectionManager:
|
|||||||
self.connection = None
|
self.connection = None
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
# Validate the server hostname to prevent SSRF
|
||||||
|
_validate_imap_host(self.server)
|
||||||
|
|
||||||
# Port 143 typically uses STARTTLS, port 993 uses SSL direct
|
# Port 143 typically uses STARTTLS, port 993 uses SSL direct
|
||||||
# If use_ssl=True and port is 143, use STARTTLS instead of SSL direct
|
# If use_ssl=True and port is 143, use STARTTLS instead of SSL direct
|
||||||
use_starttls = self.use_ssl and self.port == 143
|
use_starttls = self.use_ssl and self.port == 143
|
||||||
|
|||||||
245
src/backend/core/services/ssrf.py
Normal file
245
src/backend/core/services/ssrf.py
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
"""Server-Side Request Forgery (SSRF) protections.
|
||||||
|
|
||||||
|
Shared across features that take user-supplied network destinations (image
|
||||||
|
proxy, IMAP import, etc.). Provides a hostname/IP validator plus an HTTP
|
||||||
|
session with IP pinning to defeat DNS-rebinding (TOCTOU) attacks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
|
import socket
|
||||||
|
from urllib.parse import urljoin, urlparse, urlunparse
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from requests.adapters import HTTPAdapter
|
||||||
|
|
||||||
|
CLOUD_METADATA_IPS = frozenset({"169.254.169.254", "fd00:ec2::254"})
|
||||||
|
|
||||||
|
MAX_REDIRECTS = 5
|
||||||
|
REDIRECT_STATUS_CODES = frozenset({301, 302, 303, 307, 308})
|
||||||
|
|
||||||
|
|
||||||
|
class SSRFValidationError(Exception):
|
||||||
|
"""Raised when a URL or hostname fails SSRF validation."""
|
||||||
|
|
||||||
|
|
||||||
|
def _check_ip(ip_addr: ipaddress._BaseAddress, hostname: str) -> None:
|
||||||
|
# Check specific categories before is_private: in Python's ipaddress
|
||||||
|
# module, loopback/link-local/etc. are subsets of is_private, so checking
|
||||||
|
# is_private first would mask the more informative error.
|
||||||
|
if str(ip_addr) in CLOUD_METADATA_IPS:
|
||||||
|
raise SSRFValidationError(f"{hostname} resolves to cloud metadata endpoint")
|
||||||
|
if ip_addr.is_loopback:
|
||||||
|
raise SSRFValidationError(f"{hostname} resolves to loopback address")
|
||||||
|
if ip_addr.is_link_local:
|
||||||
|
raise SSRFValidationError(f"{hostname} resolves to link-local address")
|
||||||
|
if ip_addr.is_multicast:
|
||||||
|
raise SSRFValidationError(f"{hostname} resolves to multicast address")
|
||||||
|
if ip_addr.is_reserved:
|
||||||
|
raise SSRFValidationError(f"{hostname} resolves to reserved address")
|
||||||
|
if ip_addr.is_private:
|
||||||
|
raise SSRFValidationError(f"{hostname} resolves to private IP address")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_hostname(hostname: str, *, allow_ip_literal: bool = False) -> list[str]:
|
||||||
|
"""Resolve hostname and reject private/internal/metadata addresses.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hostname: A hostname or, when allow_ip_literal=True, an IP literal.
|
||||||
|
allow_ip_literal: If False (default), IP literals are rejected outright —
|
||||||
|
legitimate services use domain names. If True, public IP literals
|
||||||
|
are accepted (used for IMAP where customers may supply raw IPs).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of validated IP addresses the hostname resolves to.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SSRFValidationError: If the hostname/IP resolves to a blocked address.
|
||||||
|
"""
|
||||||
|
if not hostname:
|
||||||
|
raise SSRFValidationError("Invalid hostname (missing)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
ip = ipaddress.ip_address(hostname)
|
||||||
|
except ValueError:
|
||||||
|
ip = None
|
||||||
|
|
||||||
|
if ip is not None:
|
||||||
|
if not allow_ip_literal:
|
||||||
|
raise SSRFValidationError(
|
||||||
|
"IP addresses are not allowed (domain name required)"
|
||||||
|
)
|
||||||
|
_check_ip(ip, hostname)
|
||||||
|
return [str(ip)]
|
||||||
|
|
||||||
|
try:
|
||||||
|
addr_info = socket.getaddrinfo(
|
||||||
|
hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM
|
||||||
|
)
|
||||||
|
except socket.gaierror as exc:
|
||||||
|
raise SSRFValidationError("Unable to resolve hostname") from exc
|
||||||
|
|
||||||
|
valid_ips: list[str] = []
|
||||||
|
for _, _, _, _, sockaddr in addr_info:
|
||||||
|
ip_str = sockaddr[0]
|
||||||
|
try:
|
||||||
|
ip_addr = ipaddress.ip_address(ip_str)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise SSRFValidationError("Invalid IP address in DNS response") from exc
|
||||||
|
_check_ip(ip_addr, hostname)
|
||||||
|
valid_ips.append(ip_str)
|
||||||
|
|
||||||
|
if not valid_ips:
|
||||||
|
raise SSRFValidationError("No valid IP addresses found")
|
||||||
|
|
||||||
|
return valid_ips
|
||||||
|
|
||||||
|
|
||||||
|
class SSRFProtectedAdapter(HTTPAdapter):
|
||||||
|
"""HTTPAdapter that pins the connection to a pre-validated IP.
|
||||||
|
|
||||||
|
Prevents TOCTOU DNS rebinding by:
|
||||||
|
1. Connecting to the IP address that was validated (no re-resolving DNS).
|
||||||
|
2. Verifying TLS certificates against the original hostname (for HTTPS).
|
||||||
|
3. Setting the Host header correctly for virtual hosting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dest_ip: str,
|
||||||
|
dest_port: int,
|
||||||
|
original_hostname: str,
|
||||||
|
original_scheme: str,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.dest_ip = dest_ip
|
||||||
|
self.dest_port = dest_port
|
||||||
|
self.original_hostname = original_hostname
|
||||||
|
self.original_scheme = original_scheme
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
|
||||||
|
if self.original_scheme == "https":
|
||||||
|
pool_kwargs["assert_hostname"] = self.original_hostname
|
||||||
|
pool_kwargs["server_hostname"] = self.original_hostname
|
||||||
|
super().init_poolmanager(connections, maxsize, block, **pool_kwargs)
|
||||||
|
|
||||||
|
def send(
|
||||||
|
self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None
|
||||||
|
):
|
||||||
|
parsed = urlparse(request.url)
|
||||||
|
|
||||||
|
if ":" in self.dest_ip:
|
||||||
|
ip_netloc = f"[{self.dest_ip}]:{self.dest_port}"
|
||||||
|
else:
|
||||||
|
ip_netloc = f"{self.dest_ip}:{self.dest_port}"
|
||||||
|
|
||||||
|
request.url = urlunparse(
|
||||||
|
(
|
||||||
|
parsed.scheme,
|
||||||
|
ip_netloc,
|
||||||
|
parsed.path,
|
||||||
|
parsed.params,
|
||||||
|
parsed.query,
|
||||||
|
parsed.fragment,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if parsed.port and parsed.port not in (80, 443):
|
||||||
|
request.headers["Host"] = f"{self.original_hostname}:{parsed.port}"
|
||||||
|
else:
|
||||||
|
request.headers["Host"] = self.original_hostname
|
||||||
|
|
||||||
|
return super().send(
|
||||||
|
request,
|
||||||
|
stream=stream,
|
||||||
|
timeout=timeout,
|
||||||
|
verify=verify,
|
||||||
|
cert=cert,
|
||||||
|
proxies=proxies,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SSRFSafeSession:
|
||||||
|
"""HTTP Session with built-in SSRF protection.
|
||||||
|
|
||||||
|
1. Validates URL scheme (only http/https allowed).
|
||||||
|
2. Blocks direct IP addresses (legitimate services use domain names).
|
||||||
|
3. Resolves hostnames and blocks private/internal IPs.
|
||||||
|
4. Pins resolved IPs to prevent DNS rebinding attacks (TOCTOU).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
try:
|
||||||
|
response = SSRFSafeSession().get("https://example.com/image.png", timeout=10)
|
||||||
|
except SSRFValidationError:
|
||||||
|
# URL was blocked for security reasons
|
||||||
|
pass
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _validate_and_unpack(self, url: str) -> tuple[str, str, str, int]:
|
||||||
|
"""Validate a URL and return (validated_ip, hostname, scheme, port).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SSRFValidationError: If the URL is unsafe.
|
||||||
|
"""
|
||||||
|
parsed = urlparse(url)
|
||||||
|
if parsed.scheme not in {"http", "https"}:
|
||||||
|
raise SSRFValidationError("Invalid URL scheme (only http/https allowed)")
|
||||||
|
if not parsed.hostname:
|
||||||
|
raise SSRFValidationError("Invalid URL (missing hostname)")
|
||||||
|
|
||||||
|
valid_ips = validate_hostname(parsed.hostname, allow_ip_literal=False)
|
||||||
|
|
||||||
|
if parsed.port:
|
||||||
|
port = parsed.port
|
||||||
|
elif parsed.scheme == "http":
|
||||||
|
port = 80
|
||||||
|
else:
|
||||||
|
port = 443
|
||||||
|
|
||||||
|
return valid_ips[0], parsed.hostname, parsed.scheme, port
|
||||||
|
|
||||||
|
def get(self, url: str, timeout: int, **kwargs) -> requests.Response:
|
||||||
|
"""Perform a safe HTTP GET with per-hop SSRF validation on redirects.
|
||||||
|
|
||||||
|
Redirects are followed manually up to MAX_REDIRECTS hops. Each Location
|
||||||
|
URL is re-validated from scratch, so an attacker-controlled server
|
||||||
|
cannot redirect to an internal address or a different private target
|
||||||
|
on a later hop.
|
||||||
|
"""
|
||||||
|
# We always handle redirects ourselves — strip any caller override so
|
||||||
|
# the underlying requests session never follows a redirect unchecked.
|
||||||
|
kwargs.pop("allow_redirects", None)
|
||||||
|
|
||||||
|
current_url = url
|
||||||
|
for _ in range(MAX_REDIRECTS + 1):
|
||||||
|
validated_ip, hostname, scheme, port = self._validate_and_unpack(
|
||||||
|
current_url
|
||||||
|
)
|
||||||
|
|
||||||
|
session = requests.Session()
|
||||||
|
adapter = SSRFProtectedAdapter(
|
||||||
|
dest_ip=validated_ip,
|
||||||
|
dest_port=port,
|
||||||
|
original_hostname=hostname,
|
||||||
|
original_scheme=scheme,
|
||||||
|
)
|
||||||
|
session.mount("http://", adapter)
|
||||||
|
session.mount("https://", adapter)
|
||||||
|
|
||||||
|
response = session.get(
|
||||||
|
current_url, timeout=timeout, allow_redirects=False, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code not in REDIRECT_STATUS_CODES:
|
||||||
|
return response
|
||||||
|
|
||||||
|
location = response.headers.get("Location")
|
||||||
|
if not location:
|
||||||
|
# Redirect without a Location — hand the response back unchanged.
|
||||||
|
return response
|
||||||
|
|
||||||
|
next_url = urljoin(current_url, location)
|
||||||
|
response.close()
|
||||||
|
current_url = next_url
|
||||||
|
|
||||||
|
raise SSRFValidationError(f"Too many redirects (max {MAX_REDIRECTS})")
|
||||||
@@ -191,7 +191,7 @@ class TestImageProxyViewSet:
|
|||||||
"https://localhost:8080/image.jpg",
|
"https://localhost:8080/image.jpg",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
def test_api_image_proxy_localhost_hostname_blocked(
|
def test_api_image_proxy_localhost_hostname_blocked(
|
||||||
self, mock_getaddrinfo, api_client, user_mailbox, test_url
|
self, mock_getaddrinfo, api_client, user_mailbox, test_url
|
||||||
):
|
):
|
||||||
@@ -210,7 +210,7 @@ class TestImageProxyViewSet:
|
|||||||
assert response["Content-Type"] == "image/svg+xml"
|
assert response["Content-Type"] == "image/svg+xml"
|
||||||
|
|
||||||
@override_settings(IMAGE_PROXY_ENABLED=True)
|
@override_settings(IMAGE_PROXY_ENABLED=True)
|
||||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
def test_api_image_proxy_domain_resolving_to_private_ip_blocked(
|
def test_api_image_proxy_domain_resolving_to_private_ip_blocked(
|
||||||
self, mock_getaddrinfo, api_client, user_mailbox
|
self, mock_getaddrinfo, api_client, user_mailbox
|
||||||
):
|
):
|
||||||
@@ -229,7 +229,7 @@ class TestImageProxyViewSet:
|
|||||||
assert response["Content-Type"] == "image/svg+xml"
|
assert response["Content-Type"] == "image/svg+xml"
|
||||||
|
|
||||||
@override_settings(IMAGE_PROXY_ENABLED=True)
|
@override_settings(IMAGE_PROXY_ENABLED=True)
|
||||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
def test_api_image_proxy_unresolvable_hostname(
|
def test_api_image_proxy_unresolvable_hostname(
|
||||||
self, mock_getaddrinfo, api_client, user_mailbox
|
self, mock_getaddrinfo, api_client, user_mailbox
|
||||||
):
|
):
|
||||||
@@ -249,7 +249,7 @@ class TestImageProxyViewSet:
|
|||||||
# Content-Type Validation Tests
|
# Content-Type Validation Tests
|
||||||
|
|
||||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
|
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
|
||||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||||
def test_api_image_proxy_non_image_content_type_blocked(
|
def test_api_image_proxy_non_image_content_type_blocked(
|
||||||
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
|
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
|
||||||
@@ -273,7 +273,7 @@ class TestImageProxyViewSet:
|
|||||||
assert response["Content-Type"] == "image/svg+xml"
|
assert response["Content-Type"] == "image/svg+xml"
|
||||||
|
|
||||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
|
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
|
||||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||||
@patch("magic.from_buffer")
|
@patch("magic.from_buffer")
|
||||||
def test_api_image_proxy_content_not_actually_image(
|
def test_api_image_proxy_content_not_actually_image(
|
||||||
@@ -301,7 +301,7 @@ class TestImageProxyViewSet:
|
|||||||
assert response["Content-Type"] == "image/svg+xml"
|
assert response["Content-Type"] == "image/svg+xml"
|
||||||
|
|
||||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
|
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
|
||||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||||
@patch("magic.from_buffer")
|
@patch("magic.from_buffer")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@@ -342,7 +342,7 @@ class TestImageProxyViewSet:
|
|||||||
# Size Limit Tests
|
# Size Limit Tests
|
||||||
|
|
||||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=1)
|
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=1)
|
||||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||||
def test_api_image_proxy_image_too_large_via_content_length(
|
def test_api_image_proxy_image_too_large_via_content_length(
|
||||||
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
|
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
|
||||||
@@ -367,7 +367,7 @@ class TestImageProxyViewSet:
|
|||||||
assert "Image too large" in str(response.data)
|
assert "Image too large" in str(response.data)
|
||||||
|
|
||||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=4096)
|
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=4096)
|
||||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||||
@patch("magic.from_buffer")
|
@patch("magic.from_buffer")
|
||||||
def test_api_image_proxy_image_too_large_actual_content(
|
def test_api_image_proxy_image_too_large_actual_content(
|
||||||
@@ -395,39 +395,13 @@ class TestImageProxyViewSet:
|
|||||||
assert response.status_code == status.HTTP_413_REQUEST_ENTITY_TOO_LARGE
|
assert response.status_code == status.HTTP_413_REQUEST_ENTITY_TOO_LARGE
|
||||||
assert "Image too large" in str(response.data)
|
assert "Image too large" in str(response.data)
|
||||||
|
|
||||||
# Redirect Tests
|
# Redirect handling is covered end-to-end in core/tests/services/test_ssrf.py:
|
||||||
|
# SSRFSafeSession follows redirects internally, re-validating every hop.
|
||||||
@override_settings(IMAGE_PROXY_ENABLED=True)
|
|
||||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
|
||||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
|
||||||
def test_api_image_proxy_redirects_blocked(
|
|
||||||
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
|
|
||||||
):
|
|
||||||
"""Test that redirects are not followed (SSRF protection)."""
|
|
||||||
# Mock DNS resolution
|
|
||||||
mock_getaddrinfo.return_value = [(2, 1, 6, "", ("1.2.3.4", 80))]
|
|
||||||
|
|
||||||
client, _ = api_client
|
|
||||||
url = self._get_image_proxy_url(
|
|
||||||
user_mailbox.id, "http://example.com/redirect.jpg"
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_get.return_value = self._mock_requests_response(
|
|
||||||
content=b"", content_type="text/plain", status_code=302
|
|
||||||
)
|
|
||||||
|
|
||||||
# Call the endpoint
|
|
||||||
client.get(url)
|
|
||||||
|
|
||||||
# Verify that allow_redirects=False was passed to requests.get
|
|
||||||
mock_get.assert_called_once()
|
|
||||||
call_kwargs = mock_get.call_args[1]
|
|
||||||
assert call_kwargs["allow_redirects"] is False
|
|
||||||
|
|
||||||
# Success Cases
|
# Success Cases
|
||||||
|
|
||||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_CACHE_TTL=3600)
|
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_CACHE_TTL=3600)
|
||||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||||
@patch("magic.from_buffer")
|
@patch("magic.from_buffer")
|
||||||
def test_api_image_proxy_successfully_jpg_image(
|
def test_api_image_proxy_successfully_jpg_image(
|
||||||
@@ -459,7 +433,7 @@ class TestImageProxyViewSet:
|
|||||||
assert response["Cache-Control"] == "public, max-age=3600"
|
assert response["Cache-Control"] == "public, max-age=3600"
|
||||||
|
|
||||||
@override_settings(IMAGE_PROXY_ENABLED=True)
|
@override_settings(IMAGE_PROXY_ENABLED=True)
|
||||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||||
@patch("magic.from_buffer")
|
@patch("magic.from_buffer")
|
||||||
def test_api_image_proxy_successfully_png_image(
|
def test_api_image_proxy_successfully_png_image(
|
||||||
@@ -488,7 +462,7 @@ class TestImageProxyViewSet:
|
|||||||
assert response["Content-Type"] == "image/png"
|
assert response["Content-Type"] == "image/png"
|
||||||
|
|
||||||
@override_settings(IMAGE_PROXY_ENABLED=True)
|
@override_settings(IMAGE_PROXY_ENABLED=True)
|
||||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||||
@patch("magic.from_buffer")
|
@patch("magic.from_buffer")
|
||||||
def test_api_image_proxy_successfully_with_octet_stream_content_type(
|
def test_api_image_proxy_successfully_with_octet_stream_content_type(
|
||||||
@@ -517,7 +491,7 @@ class TestImageProxyViewSet:
|
|||||||
assert response["Content-Type"] == "image/jpeg"
|
assert response["Content-Type"] == "image/jpeg"
|
||||||
|
|
||||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_CACHE_TTL=3600)
|
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_CACHE_TTL=3600)
|
||||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||||
@patch("magic.from_buffer")
|
@patch("magic.from_buffer")
|
||||||
def test_api_image_proxy_no_content(
|
def test_api_image_proxy_no_content(
|
||||||
@@ -547,7 +521,7 @@ class TestImageProxyViewSet:
|
|||||||
assert response["Content-Type"] == "image/svg+xml"
|
assert response["Content-Type"] == "image/svg+xml"
|
||||||
|
|
||||||
@override_settings(IMAGE_PROXY_ENABLED=True)
|
@override_settings(IMAGE_PROXY_ENABLED=True)
|
||||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||||
@patch("magic.from_buffer")
|
@patch("magic.from_buffer")
|
||||||
def test_api_image_proxy_url_with_special_characters(
|
def test_api_image_proxy_url_with_special_characters(
|
||||||
@@ -578,7 +552,7 @@ class TestImageProxyViewSet:
|
|||||||
assert mock_get.call_args[0][0] == test_url
|
assert mock_get.call_args[0][0] == test_url
|
||||||
|
|
||||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_CACHE_TTL=3600)
|
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_CACHE_TTL=3600)
|
||||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||||
@patch("magic.from_buffer")
|
@patch("magic.from_buffer")
|
||||||
def test_api_image_proxy_successfully_secure_headers(
|
def test_api_image_proxy_successfully_secure_headers(
|
||||||
@@ -617,7 +591,7 @@ class TestImageProxyViewSet:
|
|||||||
# Error Handling Tests
|
# Error Handling Tests
|
||||||
|
|
||||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
|
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
|
||||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||||
def test_api_image_proxy_network_timeout(
|
def test_api_image_proxy_network_timeout(
|
||||||
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
|
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
|
||||||
@@ -638,7 +612,7 @@ class TestImageProxyViewSet:
|
|||||||
assert "Failed to fetch image" in str(response.data)
|
assert "Failed to fetch image" in str(response.data)
|
||||||
|
|
||||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
|
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
|
||||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||||
def test_api_image_proxy_connection_error(
|
def test_api_image_proxy_connection_error(
|
||||||
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
|
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
|
||||||
@@ -659,7 +633,7 @@ class TestImageProxyViewSet:
|
|||||||
assert "Failed to fetch image" in str(response.data)
|
assert "Failed to fetch image" in str(response.data)
|
||||||
|
|
||||||
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
|
@override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10)
|
||||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||||
def test_api_image_proxy_http_error_404(
|
def test_api_image_proxy_http_error_404(
|
||||||
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
|
self, mock_get, mock_getaddrinfo, api_client, user_mailbox
|
||||||
@@ -687,7 +661,7 @@ class TestImageProxyViewSet:
|
|||||||
assert "Failed to fetch image" in str(response.data)
|
assert "Failed to fetch image" in str(response.data)
|
||||||
|
|
||||||
@override_settings(IMAGE_PROXY_ENABLED=True)
|
@override_settings(IMAGE_PROXY_ENABLED=True)
|
||||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||||
@patch("magic.from_buffer")
|
@patch("magic.from_buffer")
|
||||||
def test_api_image_proxy_invalid_content_length_header(
|
def test_api_image_proxy_invalid_content_length_header(
|
||||||
@@ -716,7 +690,7 @@ class TestImageProxyViewSet:
|
|||||||
assert response.status_code == status.HTTP_200_OK
|
assert response.status_code == status.HTTP_200_OK
|
||||||
|
|
||||||
@override_settings(IMAGE_PROXY_ENABLED=True)
|
@override_settings(IMAGE_PROXY_ENABLED=True)
|
||||||
@patch("core.api.viewsets.image_proxy.socket.getaddrinfo")
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
@patch("core.api.viewsets.image_proxy.SSRFSafeSession.get")
|
||||||
@patch("magic.from_buffer")
|
@patch("magic.from_buffer")
|
||||||
def test_api_image_proxy_missing_content_length_header(
|
def test_api_image_proxy_missing_content_length_header(
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
# pylint: disable=redefined-outer-name, unused-argument, no-value-for-parameter, too-many-lines
|
# pylint: disable=redefined-outer-name, unused-argument, no-value-for-parameter, too-many-lines
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
|
import socket
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from django.core.files.storage import storages
|
from django.core.files.storage import storages
|
||||||
@@ -19,6 +20,25 @@ from core.services.importer.mbox_tasks import process_mbox_file_task
|
|||||||
|
|
||||||
pytestmark = pytest.mark.django_db
|
pytestmark = pytest.mark.django_db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _mock_ssrf_dns():
|
||||||
|
"""Short-circuit SSRF DNS validation for IMAP import tests.
|
||||||
|
|
||||||
|
The IMAP endpoint validates the server hostname via
|
||||||
|
``core.services.ssrf.validate_hostname``; tests use unresolvable fixtures
|
||||||
|
like ``imap.example.com`` so we return a public IP to reach the mocked
|
||||||
|
IMAP task code.
|
||||||
|
"""
|
||||||
|
with patch(
|
||||||
|
"core.services.ssrf.socket.getaddrinfo",
|
||||||
|
return_value=[
|
||||||
|
(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("93.184.216.34", 0))
|
||||||
|
],
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
IMPORT_FILE_URL = "/api/v1.0/import/file/"
|
IMPORT_FILE_URL = "/api/v1.0/import/file/"
|
||||||
IMPORT_IMAP_URL = "/api/v1.0/import/imap/"
|
IMPORT_IMAP_URL = "/api/v1.0/import/imap/"
|
||||||
|
|
||||||
|
|||||||
24
src/backend/core/tests/importer/conftest.py
Normal file
24
src/backend/core/tests/importer/conftest.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""Shared fixtures for importer tests."""
|
||||||
|
|
||||||
|
import socket
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _mock_ssrf_dns():
|
||||||
|
"""Short-circuit SSRF DNS validation for IMAP tests.
|
||||||
|
|
||||||
|
The IMAP import path validates the server hostname via
|
||||||
|
``core.services.ssrf.validate_hostname`` which calls ``socket.getaddrinfo``.
|
||||||
|
Test fixtures use unresolvable hostnames like ``imap.example.com``, so we
|
||||||
|
return a valid public IP here to let tests reach the mocked IMAP code.
|
||||||
|
"""
|
||||||
|
with mock.patch(
|
||||||
|
"core.services.ssrf.socket.getaddrinfo",
|
||||||
|
return_value=[
|
||||||
|
(socket.AF_INET, socket.SOCK_STREAM, 6, "", ("93.184.216.34", 0))
|
||||||
|
],
|
||||||
|
):
|
||||||
|
yield
|
||||||
274
src/backend/core/tests/services/test_ssrf.py
Normal file
274
src/backend/core/tests/services/test_ssrf.py
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
"""Unit tests for core.services.ssrf — focus on redirect handling.
|
||||||
|
|
||||||
|
Hostname/IP validation itself is covered via the image proxy integration tests
|
||||||
|
(see core/tests/api/test_image_proxy.py). This file targets the per-hop
|
||||||
|
validation loop in SSRFSafeSession.get, which is the main SSRF-correctness
|
||||||
|
surface for redirect responses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import socket
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from core.services.ssrf import (
|
||||||
|
MAX_REDIRECTS,
|
||||||
|
SSRFSafeSession,
|
||||||
|
SSRFValidationError,
|
||||||
|
)
|
||||||
|
|
||||||
|
PUBLIC_IP = "93.184.216.34"
|
||||||
|
PRIVATE_IP = "192.168.1.1"
|
||||||
|
|
||||||
|
|
||||||
|
def _addrinfo(ip: str):
|
||||||
|
return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (ip, 0))]
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_response(status_code: int = 200, location: str | None = None):
|
||||||
|
"""Build a mock requests.Response-like object."""
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.status_code = status_code
|
||||||
|
resp.headers = {}
|
||||||
|
if location is not None:
|
||||||
|
resp.headers["Location"] = location
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
class TestSSRFSafeSessionRedirects:
|
||||||
|
"""Redirect-handling contract for SSRFSafeSession.get."""
|
||||||
|
|
||||||
|
@patch("core.services.ssrf.requests.Session.get")
|
||||||
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
|
def test_no_redirect_returns_response_directly(self, mock_dns, mock_get):
|
||||||
|
"""A 200 response is returned without any extra request."""
|
||||||
|
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||||
|
mock_get.return_value = _mock_response(200)
|
||||||
|
|
||||||
|
response = SSRFSafeSession().get("https://legit.com/img.png", timeout=10)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert mock_get.call_count == 1
|
||||||
|
|
||||||
|
@patch("core.services.ssrf.requests.Session.get")
|
||||||
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
|
def test_follows_redirect_to_safe_url(self, mock_dns, mock_get):
|
||||||
|
"""A single redirect is followed to a validated destination."""
|
||||||
|
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||||
|
mock_get.side_effect = [
|
||||||
|
_mock_response(302, location="https://cdn.legit.com/img.png"),
|
||||||
|
_mock_response(200),
|
||||||
|
]
|
||||||
|
|
||||||
|
response = SSRFSafeSession().get("https://legit.com/img.png", timeout=10)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert mock_get.call_count == 2
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("redirect_status", [301, 302, 303, 307, 308])
|
||||||
|
@patch("core.services.ssrf.requests.Session.get")
|
||||||
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
|
def test_all_redirect_statuses_followed(self, mock_dns, mock_get, redirect_status):
|
||||||
|
"""All standard redirect status codes trigger a new hop."""
|
||||||
|
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||||
|
mock_get.side_effect = [
|
||||||
|
_mock_response(redirect_status, location="https://b.com/img.png"),
|
||||||
|
_mock_response(200),
|
||||||
|
]
|
||||||
|
|
||||||
|
response = SSRFSafeSession().get("https://a.com/img.png", timeout=10)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert mock_get.call_count == 2
|
||||||
|
|
||||||
|
@patch("core.services.ssrf.requests.Session.get")
|
||||||
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
|
def test_follows_multiple_hops(self, mock_dns, mock_get):
|
||||||
|
"""Chained redirects are followed up to MAX_REDIRECTS."""
|
||||||
|
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||||
|
mock_get.side_effect = [
|
||||||
|
_mock_response(302, location="https://b.com/"),
|
||||||
|
_mock_response(302, location="https://c.com/"),
|
||||||
|
_mock_response(302, location="https://d.com/"),
|
||||||
|
_mock_response(200),
|
||||||
|
]
|
||||||
|
|
||||||
|
response = SSRFSafeSession().get("https://a.com/", timeout=10)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert mock_get.call_count == 4
|
||||||
|
|
||||||
|
@patch("core.services.ssrf.requests.Session.get")
|
||||||
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
|
def test_blocks_redirect_to_private_ip(self, mock_dns, mock_get):
|
||||||
|
"""A Location resolving to a private IP is rejected mid-chain."""
|
||||||
|
|
||||||
|
def dns_side_effect(host, *_args, **_kwargs):
|
||||||
|
if host == "legit.com":
|
||||||
|
return _addrinfo(PUBLIC_IP)
|
||||||
|
if host == "internal.evil.com":
|
||||||
|
return _addrinfo(PRIVATE_IP)
|
||||||
|
raise AssertionError(f"unexpected DNS lookup: {host}")
|
||||||
|
|
||||||
|
mock_dns.side_effect = dns_side_effect
|
||||||
|
mock_get.return_value = _mock_response(
|
||||||
|
302, location="https://internal.evil.com/pwn"
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SSRFValidationError, match="private IP"):
|
||||||
|
SSRFSafeSession().get("https://legit.com/img.png", timeout=10)
|
||||||
|
|
||||||
|
# Only the first hop should have been issued; the second was blocked
|
||||||
|
# before the request was made.
|
||||||
|
assert mock_get.call_count == 1
|
||||||
|
|
||||||
|
@patch("core.services.ssrf.requests.Session.get")
|
||||||
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
|
def test_blocks_redirect_to_loopback(self, mock_dns, mock_get):
|
||||||
|
"""Redirect pointing at loopback (DNS-rebinding style) is rejected."""
|
||||||
|
|
||||||
|
def dns_side_effect(host, *_args, **_kwargs):
|
||||||
|
if host == "legit.com":
|
||||||
|
return _addrinfo(PUBLIC_IP)
|
||||||
|
if host == "rebind.evil.com":
|
||||||
|
return _addrinfo("127.0.0.1")
|
||||||
|
raise AssertionError(f"unexpected DNS lookup: {host}")
|
||||||
|
|
||||||
|
mock_dns.side_effect = dns_side_effect
|
||||||
|
mock_get.return_value = _mock_response(302, location="https://rebind.evil.com/")
|
||||||
|
|
||||||
|
with pytest.raises(SSRFValidationError, match="loopback"):
|
||||||
|
SSRFSafeSession().get("https://legit.com/img.png", timeout=10)
|
||||||
|
|
||||||
|
@patch("core.services.ssrf.requests.Session.get")
|
||||||
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
|
def test_blocks_redirect_to_cloud_metadata(self, mock_dns, mock_get):
|
||||||
|
"""Redirect pointing at cloud metadata endpoint is rejected."""
|
||||||
|
|
||||||
|
def dns_side_effect(host, *_args, **_kwargs):
|
||||||
|
if host == "legit.com":
|
||||||
|
return _addrinfo(PUBLIC_IP)
|
||||||
|
if host == "meta.evil.com":
|
||||||
|
return _addrinfo("169.254.169.254")
|
||||||
|
raise AssertionError(f"unexpected DNS lookup: {host}")
|
||||||
|
|
||||||
|
mock_dns.side_effect = dns_side_effect
|
||||||
|
mock_get.return_value = _mock_response(
|
||||||
|
302, location="https://meta.evil.com/latest/meta-data/"
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SSRFValidationError, match="metadata"):
|
||||||
|
SSRFSafeSession().get("https://legit.com/img.png", timeout=10)
|
||||||
|
|
||||||
|
@patch("core.services.ssrf.requests.Session.get")
|
||||||
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
|
def test_blocks_redirect_to_non_http_scheme(self, mock_dns, mock_get):
|
||||||
|
"""Redirect to file:// or other schemes is rejected."""
|
||||||
|
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||||
|
mock_get.return_value = _mock_response(302, location="file:///etc/passwd")
|
||||||
|
|
||||||
|
with pytest.raises(SSRFValidationError, match="scheme"):
|
||||||
|
SSRFSafeSession().get("https://legit.com/", timeout=10)
|
||||||
|
|
||||||
|
@patch("core.services.ssrf.requests.Session.get")
|
||||||
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
|
def test_blocks_redirect_to_ip_literal(self, mock_dns, mock_get):
|
||||||
|
"""Redirect whose Location is a raw IP is rejected (domains only)."""
|
||||||
|
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||||
|
mock_get.return_value = _mock_response(302, location="http://203.0.113.5/stuff")
|
||||||
|
|
||||||
|
with pytest.raises(SSRFValidationError, match="IP addresses are not allowed"):
|
||||||
|
SSRFSafeSession().get("https://legit.com/", timeout=10)
|
||||||
|
|
||||||
|
@patch("core.services.ssrf.requests.Session.get")
|
||||||
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
|
def test_blocks_protocol_relative_redirect_to_private(self, mock_dns, mock_get):
|
||||||
|
"""//host/path Location resolves against current scheme and is validated."""
|
||||||
|
|
||||||
|
def dns_side_effect(host, *_args, **_kwargs):
|
||||||
|
if host == "legit.com":
|
||||||
|
return _addrinfo(PUBLIC_IP)
|
||||||
|
if host == "internal":
|
||||||
|
return _addrinfo(PRIVATE_IP)
|
||||||
|
raise AssertionError(f"unexpected DNS lookup: {host}")
|
||||||
|
|
||||||
|
mock_dns.side_effect = dns_side_effect
|
||||||
|
mock_get.return_value = _mock_response(302, location="//internal/admin")
|
||||||
|
|
||||||
|
with pytest.raises(SSRFValidationError, match="private IP"):
|
||||||
|
SSRFSafeSession().get("https://legit.com/", timeout=10)
|
||||||
|
|
||||||
|
@patch("core.services.ssrf.requests.Session.get")
|
||||||
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
|
def test_follows_relative_redirect(self, mock_dns, mock_get):
|
||||||
|
"""Relative Location is resolved against the current URL and validated."""
|
||||||
|
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||||
|
mock_get.side_effect = [
|
||||||
|
_mock_response(302, location="/new-path"),
|
||||||
|
_mock_response(200),
|
||||||
|
]
|
||||||
|
|
||||||
|
response = SSRFSafeSession().get("https://legit.com/old", timeout=10)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
# Second request should be made to https://legit.com/new-path.
|
||||||
|
second_call_url = mock_get.call_args_list[1].args[0]
|
||||||
|
assert second_call_url == "https://legit.com/new-path"
|
||||||
|
|
||||||
|
@patch("core.services.ssrf.requests.Session.get")
|
||||||
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
|
def test_blocks_too_many_redirects(self, mock_dns, mock_get):
|
||||||
|
"""A redirect loop longer than MAX_REDIRECTS raises."""
|
||||||
|
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||||
|
mock_get.return_value = _mock_response(302, location="https://legit.com/loop")
|
||||||
|
|
||||||
|
with pytest.raises(SSRFValidationError, match="Too many redirects"):
|
||||||
|
SSRFSafeSession().get("https://legit.com/", timeout=10)
|
||||||
|
|
||||||
|
# Initial hop + MAX_REDIRECTS extra hops, then error.
|
||||||
|
assert mock_get.call_count == MAX_REDIRECTS + 1
|
||||||
|
|
||||||
|
@patch("core.services.ssrf.requests.Session.get")
|
||||||
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
|
def test_redirect_without_location_returned_as_is(self, mock_dns, mock_get):
|
||||||
|
"""A redirect status with no Location header is returned unchanged."""
|
||||||
|
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||||
|
mock_get.return_value = _mock_response(302, location=None)
|
||||||
|
|
||||||
|
response = SSRFSafeSession().get("https://legit.com/", timeout=10)
|
||||||
|
|
||||||
|
assert response.status_code == 302
|
||||||
|
assert mock_get.call_count == 1
|
||||||
|
|
||||||
|
@patch("core.services.ssrf.requests.Session.get")
|
||||||
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
|
def test_caller_allow_redirects_true_is_ignored(self, mock_dns, mock_get):
|
||||||
|
"""Caller cannot opt out of per-hop validation by passing allow_redirects=True."""
|
||||||
|
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||||
|
mock_get.side_effect = [
|
||||||
|
_mock_response(302, location="https://cdn.legit.com/"),
|
||||||
|
_mock_response(200),
|
||||||
|
]
|
||||||
|
|
||||||
|
response = SSRFSafeSession().get(
|
||||||
|
"https://legit.com/", timeout=10, allow_redirects=True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
# Each underlying Session.get must be called with allow_redirects=False.
|
||||||
|
for call in mock_get.call_args_list:
|
||||||
|
assert call.kwargs.get("allow_redirects") is False
|
||||||
|
|
||||||
|
@patch("core.services.ssrf.requests.Session.get")
|
||||||
|
@patch("core.services.ssrf.socket.getaddrinfo")
|
||||||
|
def test_intermediate_response_closed(self, mock_dns, mock_get):
|
||||||
|
"""Intermediate redirect responses are .close()d to release streams."""
|
||||||
|
mock_dns.return_value = _addrinfo(PUBLIC_IP)
|
||||||
|
intermediate = _mock_response(302, location="https://cdn.legit.com/")
|
||||||
|
final = _mock_response(200)
|
||||||
|
mock_get.side_effect = [intermediate, final]
|
||||||
|
|
||||||
|
SSRFSafeSession().get("https://legit.com/", timeout=10, stream=True)
|
||||||
|
|
||||||
|
intermediate.close.assert_called_once()
|
||||||
|
final.close.assert_not_called()
|
||||||
Reference in New Issue
Block a user