diff --git a/src/backend/core/api/viewsets/blob.py b/src/backend/core/api/viewsets/blob.py index ca587b02..27097590 100644 --- a/src/backend/core/api/viewsets/blob.py +++ b/src/backend/core/api/viewsets/blob.py @@ -4,6 +4,7 @@ import logging from django.http import HttpResponse from django.utils.decorators import method_decorator +from django.utils.http import content_disposition_header from django.views.decorators.csrf import csrf_exempt from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema @@ -177,8 +178,8 @@ class BlobViewSet(ViewSet): ) # Add appropriate headers for download - response["Content-Disposition"] = ( - f'attachment; filename="{attachment["name"]}"' + response["Content-Disposition"] = content_disposition_header( + True, attachment["name"] ) response["Content-Length"] = attachment["size"] # Enable browser caching for 30 days (inline images benefit from this) @@ -207,7 +208,9 @@ class BlobViewSet(ViewSet): ) # Add appropriate headers for download - response["Content-Disposition"] = f'attachment; filename="{filename}"' + response["Content-Disposition"] = content_disposition_header( + True, filename + ) response["Content-Length"] = blob.size # Enable browser caching for 30 days (inline images benefit from this) response["Cache-Control"] = "private, max-age=2592000" diff --git a/src/backend/core/api/viewsets/image_proxy.py b/src/backend/core/api/viewsets/image_proxy.py index 0958175f..9812ee99 100644 --- a/src/backend/core/api/viewsets/image_proxy.py +++ b/src/backend/core/api/viewsets/image_proxy.py @@ -1,9 +1,7 @@ """API ViewSet for proxying external images.""" -import ipaddress import logging -import socket -from urllib.parse import ParseResult, unquote, urlparse, urlunparse +from urllib.parse import unquote from django.conf import settings from django.http import HttpResponse @@ -11,261 +9,17 @@ from django.http import HttpResponse import magic import requests from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema -from requests.adapters import HTTPAdapter from rest_framework import status as http_status from rest_framework.response import Response from rest_framework.viewsets import ViewSet from core import enums, models from core.api import permissions +from core.services.ssrf import SSRFSafeSession, SSRFValidationError logger = logging.getLogger(__name__) -class SSRFValidationError(Exception): - """Exception raised when URL validation fails due to SSRF protection.""" - - -class SSRFProtectedAdapter(HTTPAdapter): - """ - HTTPAdapter that connects to a pre-validated IP address while maintaining - proper TLS certificate verification against the original hostname. - - This prevents TOCTOU DNS rebinding attacks by: - 1. Connecting to the IP address that was validated (not re-resolving DNS) - 2. Verifying TLS certificates against the original hostname (for HTTPS) - 3. Setting the Host header correctly for virtual hosting - """ - - def __init__( - self, - dest_ip: str, - dest_port: int, - original_hostname: str, - original_scheme: str, - **kwargs, - ): - self.dest_ip = dest_ip - self.dest_port = dest_port - self.original_hostname = original_hostname - self.original_scheme = original_scheme - super().__init__(**kwargs) - - def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs): - """Initialize pool manager with TLS hostname verification settings.""" - if self.original_scheme == "https": - # Ensure TLS certificate is verified against the original hostname - # even though we're connecting to an IP address - pool_kwargs["assert_hostname"] = self.original_hostname - pool_kwargs["server_hostname"] = self.original_hostname - super().init_poolmanager(connections, maxsize, block, **pool_kwargs) - - def send( - self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None - ): - """Send request, rewriting URL to connect to the validated IP address.""" - parsed = urlparse(request.url) - - # Build URL with validated IP instead of hostname - # IPv6 addresses need brackets in URLs - if ":" in self.dest_ip: - ip_netloc = f"[{self.dest_ip}]:{self.dest_port}" - else: - ip_netloc = f"{self.dest_ip}:{self.dest_port}" - - # Reconstruct URL with IP address - request.url = urlunparse( - ( - parsed.scheme, - ip_netloc, - parsed.path, - parsed.params, - parsed.query, - parsed.fragment, - ) - ) - - # Set Host header to original hostname for virtual hosting - # Include port only if non-standard - if parsed.port and parsed.port not in (80, 443): - request.headers["Host"] = f"{self.original_hostname}:{parsed.port}" - else: - request.headers["Host"] = self.original_hostname - - return super().send( - request, - stream=stream, - timeout=timeout, - verify=verify, - cert=cert, - proxies=proxies, - ) - - -class SSRFSafeSession: - """ - HTTP Session with built-in SSRF protection. - - This class provides a safe way to make HTTP requests by: - 1. Validating URL scheme (only http/https allowed) - 2. Blocking direct IP addresses (legitimate services use domain names) - 3. Resolving hostnames and blocking private/internal IPs - 4. Pinning resolved IPs to prevent DNS rebinding attacks (TOCTOU) - - Usage: - try: - response = SSRFSafeSession().get("https://example.com/image.png", timeout=10) - except SSRFValidationError: - # URL was blocked for security reasons - pass - """ - - def _validate_url(self, parsed_url: ParseResult) -> list[str]: - """ - Validate that a URL is safe to fetch (SSRF protection). - - This function prevents Server-Side Request Forgery (SSRF) attacks by - validating URLs before making HTTP requests. It implements a defense-in-depth - approach: - - 1. Only allows http/https schemes - 2. Blocks all IP addresses (legitimate emails use domain names) - 3. Resolves hostnames and blocks if they resolve to private/internal IPs - (prevents DNS rebinding attacks where attacker-controlled DNS returns - 127.0.0.1 or internal IPs) - - Blocked addresses include: - - Any direct IP address (e.g., http://192.168.1.1/) - - Private IP ranges (RFC1918: 10.x.x.x, 172.16-31.x.x, 192.168.x.x) - - Loopback addresses (127.x.x.x, ::1) - - Link-local addresses (169.254.x.x, fe80::/10) - - Multicast and reserved addresses - - Cloud provider metadata endpoints (169.254.169.254, fd00:ec2::254) - - Args: - parsed_url: The parsed URL to validate - - Returns: - List of validated IP addresses that the hostname resolves to - - Raises: - SSRFValidationError: If the URL is unsafe - """ - # Only allow http and https schemes - if parsed_url.scheme not in {"http", "https"}: - raise SSRFValidationError("Invalid URL scheme (only http/https allowed)") - - # Require a hostname - if not parsed_url.hostname: - raise SSRFValidationError("Invalid URL (missing hostname)") - - # Block all IP addresses (legitimate services use domain names) - try: - ipaddress.ip_address(parsed_url.hostname) - raise SSRFValidationError( - "IP addresses are not allowed (domain name required)" - ) - except ValueError: - # Not an IP address, continue validation - pass - - # Resolve hostname to IP addresses - try: - addr_info = socket.getaddrinfo( - parsed_url.hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM - ) - except socket.gaierror as exc: - raise SSRFValidationError("Unable to resolve hostname") from exc - - # Check all resolved IP addresses - valid_ips = [] - for _, _, _, _, sockaddr in addr_info: - ip_str = sockaddr[0] - try: - ip_addr = ipaddress.ip_address(ip_str) - - if ip_addr.is_private: - raise SSRFValidationError("Domain resolves to private IP address") - - if ip_addr.is_loopback: - raise SSRFValidationError("Domain resolves to loopback address") - - if ip_addr.is_link_local: - raise SSRFValidationError("Domain resolves to link-local address") - - if ip_addr.is_multicast: - raise SSRFValidationError("Domain resolves to multicast address") - - if ip_addr.is_reserved: - raise SSRFValidationError("Domain resolves to reserved address") - - # Block known cloud metadata IPs - if ip_str in ("169.254.169.254", "fd00:ec2::254"): - raise SSRFValidationError( - "Domain resolves to cloud metadata endpoint" - ) - - valid_ips.append(ip_str) - - except ValueError as exc: - raise SSRFValidationError("Invalid IP address in DNS response") from exc - - if not valid_ips: - raise SSRFValidationError("No valid IP addresses found") - - return valid_ips - - def get(self, url: str, timeout: int, **kwargs) -> requests.Response: - """ - Perform a safe HTTP GET request with SSRF protection and IP pinning. - - This method: - 1. Parses and validates the URL - 2. Resolves DNS and validates all returned IPs - 3. Creates a requests Session with a custom HTTPAdapter that: - - Connects directly to the validated IP (preventing DNS rebinding) - - Maintains proper TLS certificate verification against the hostname - - Sets the Host header correctly for virtual hosting - - Args: - url: The URL to fetch - timeout: Request timeout in seconds - **kwargs: Additional arguments passed to requests.Session.get() - - Returns: - requests.Response object - - Raises: - SSRFValidationError: If the URL fails security validation - requests.RequestException: If the HTTP request fails - """ - parsed_url = urlparse(url) - valid_ips = self._validate_url(parsed_url) - - # Determine the port (explicit or default based on scheme) - if parsed_url.port: - port = parsed_url.port - elif parsed_url.scheme == "http": - port = 80 - else: - port = 443 - - # Create a session with our SSRF-protected adapter that pins to the validated IP - session = requests.Session() - adapter = SSRFProtectedAdapter( - dest_ip=valid_ips[0], - dest_port=port, - original_hostname=parsed_url.hostname, - original_scheme=parsed_url.scheme, - ) - - # Mount the adapter for both http and https schemes - session.mount("http://", adapter) - session.mount("https://", adapter) - - return session.get(url, timeout=timeout, **kwargs) - - class ImageProxySuspiciousResponse(HttpResponse): """ Response for suspicious content that has been blocked by our image proxy. @@ -356,7 +110,6 @@ class ImageProxyViewSet(ViewSet): timeout=10, stream=True, headers={"User-Agent": "Messages-ImageProxy/1.0"}, - allow_redirects=False, ) response.raise_for_status() diff --git a/src/backend/core/api/viewsets/inbound/widget.py b/src/backend/core/api/viewsets/inbound/widget.py index c8261751..6f03f92f 100644 --- a/src/backend/core/api/viewsets/inbound/widget.py +++ b/src/backend/core/api/viewsets/inbound/widget.py @@ -35,9 +35,9 @@ class WidgetAuthentication(BaseAuthentication): if not channel_id: raise AuthenticationFailed("Missing channel_id") - # API key authentication for check endpoint + # Only allow widget-type channels try: - channel = models.Channel.objects.get(id=channel_id) + channel = models.Channel.objects.get(id=channel_id, type="widget") except models.Channel.DoesNotExist as e: raise AuthenticationFailed("Invalid channel_id") from e diff --git a/src/backend/core/services/importer/imap.py b/src/backend/core/services/importer/imap.py index 017bcac9..cf3457b0 100644 --- a/src/backend/core/services/importer/imap.py +++ b/src/backend/core/services/importer/imap.py @@ -22,6 +22,7 @@ from celery.utils.log import get_task_logger from core.mda.inbound import deliver_inbound_message from core.mda.rfc5322 import parse_email_message +from core.services.ssrf import SSRFValidationError, validate_hostname logger = get_task_logger(__name__) @@ -62,6 +63,21 @@ def decode_imap_utf7(s): return re.sub(r"&([^-]*)-", decode_match, s) +def _validate_imap_host(server: str) -> None: + """Validate that the IMAP server hostname is not a private/internal address. + + Wraps the shared SSRF validator but allows public IP literals, which are + legitimate addresses for customer-supplied IMAP servers. + + Raises: + ValueError: If the hostname resolves to a blocked IP address. + """ + try: + validate_hostname(server, allow_ip_literal=True) + except SSRFValidationError as exc: + raise ValueError(f"IMAP server {server} is not allowed: {exc}") from exc + + class IMAPConnectionManager: """Context manager for IMAP connections with proper cleanup.""" @@ -76,6 +92,9 @@ class IMAPConnectionManager: self.connection = None def __enter__(self): + # Validate the server hostname to prevent SSRF + _validate_imap_host(self.server) + # Port 143 typically uses STARTTLS, port 993 uses SSL direct # If use_ssl=True and port is 143, use STARTTLS instead of SSL direct use_starttls = self.use_ssl and self.port == 143 diff --git a/src/backend/core/services/ssrf.py b/src/backend/core/services/ssrf.py new file mode 100644 index 00000000..5f23322a --- /dev/null +++ b/src/backend/core/services/ssrf.py @@ -0,0 +1,245 @@ +"""Server-Side Request Forgery (SSRF) protections. + +Shared across features that take user-supplied network destinations (image +proxy, IMAP import, etc.). Provides a hostname/IP validator plus an HTTP +session with IP pinning to defeat DNS-rebinding (TOCTOU) attacks. +""" + +import ipaddress +import socket +from urllib.parse import urljoin, urlparse, urlunparse + +import requests +from requests.adapters import HTTPAdapter + +CLOUD_METADATA_IPS = frozenset({"169.254.169.254", "fd00:ec2::254"}) + +MAX_REDIRECTS = 5 +REDIRECT_STATUS_CODES = frozenset({301, 302, 303, 307, 308}) + + +class SSRFValidationError(Exception): + """Raised when a URL or hostname fails SSRF validation.""" + + +def _check_ip(ip_addr: ipaddress._BaseAddress, hostname: str) -> None: + # Check specific categories before is_private: in Python's ipaddress + # module, loopback/link-local/etc. are subsets of is_private, so checking + # is_private first would mask the more informative error. + if str(ip_addr) in CLOUD_METADATA_IPS: + raise SSRFValidationError(f"{hostname} resolves to cloud metadata endpoint") + if ip_addr.is_loopback: + raise SSRFValidationError(f"{hostname} resolves to loopback address") + if ip_addr.is_link_local: + raise SSRFValidationError(f"{hostname} resolves to link-local address") + if ip_addr.is_multicast: + raise SSRFValidationError(f"{hostname} resolves to multicast address") + if ip_addr.is_reserved: + raise SSRFValidationError(f"{hostname} resolves to reserved address") + if ip_addr.is_private: + raise SSRFValidationError(f"{hostname} resolves to private IP address") + + +def validate_hostname(hostname: str, *, allow_ip_literal: bool = False) -> list[str]: + """Resolve hostname and reject private/internal/metadata addresses. + + Args: + hostname: A hostname or, when allow_ip_literal=True, an IP literal. + allow_ip_literal: If False (default), IP literals are rejected outright — + legitimate services use domain names. If True, public IP literals + are accepted (used for IMAP where customers may supply raw IPs). + + Returns: + List of validated IP addresses the hostname resolves to. + + Raises: + SSRFValidationError: If the hostname/IP resolves to a blocked address. + """ + if not hostname: + raise SSRFValidationError("Invalid hostname (missing)") + + try: + ip = ipaddress.ip_address(hostname) + except ValueError: + ip = None + + if ip is not None: + if not allow_ip_literal: + raise SSRFValidationError( + "IP addresses are not allowed (domain name required)" + ) + _check_ip(ip, hostname) + return [str(ip)] + + try: + addr_info = socket.getaddrinfo( + hostname, None, socket.AF_UNSPEC, socket.SOCK_STREAM + ) + except socket.gaierror as exc: + raise SSRFValidationError("Unable to resolve hostname") from exc + + valid_ips: list[str] = [] + for _, _, _, _, sockaddr in addr_info: + ip_str = sockaddr[0] + try: + ip_addr = ipaddress.ip_address(ip_str) + except ValueError as exc: + raise SSRFValidationError("Invalid IP address in DNS response") from exc + _check_ip(ip_addr, hostname) + valid_ips.append(ip_str) + + if not valid_ips: + raise SSRFValidationError("No valid IP addresses found") + + return valid_ips + + +class SSRFProtectedAdapter(HTTPAdapter): + """HTTPAdapter that pins the connection to a pre-validated IP. + + Prevents TOCTOU DNS rebinding by: + 1. Connecting to the IP address that was validated (no re-resolving DNS). + 2. Verifying TLS certificates against the original hostname (for HTTPS). + 3. Setting the Host header correctly for virtual hosting. + """ + + def __init__( + self, + dest_ip: str, + dest_port: int, + original_hostname: str, + original_scheme: str, + **kwargs, + ): + self.dest_ip = dest_ip + self.dest_port = dest_port + self.original_hostname = original_hostname + self.original_scheme = original_scheme + super().__init__(**kwargs) + + def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs): + if self.original_scheme == "https": + pool_kwargs["assert_hostname"] = self.original_hostname + pool_kwargs["server_hostname"] = self.original_hostname + super().init_poolmanager(connections, maxsize, block, **pool_kwargs) + + def send( + self, request, stream=False, timeout=None, verify=True, cert=None, proxies=None + ): + parsed = urlparse(request.url) + + if ":" in self.dest_ip: + ip_netloc = f"[{self.dest_ip}]:{self.dest_port}" + else: + ip_netloc = f"{self.dest_ip}:{self.dest_port}" + + request.url = urlunparse( + ( + parsed.scheme, + ip_netloc, + parsed.path, + parsed.params, + parsed.query, + parsed.fragment, + ) + ) + + if parsed.port and parsed.port not in (80, 443): + request.headers["Host"] = f"{self.original_hostname}:{parsed.port}" + else: + request.headers["Host"] = self.original_hostname + + return super().send( + request, + stream=stream, + timeout=timeout, + verify=verify, + cert=cert, + proxies=proxies, + ) + + +class SSRFSafeSession: + """HTTP Session with built-in SSRF protection. + + 1. Validates URL scheme (only http/https allowed). + 2. Blocks direct IP addresses (legitimate services use domain names). + 3. Resolves hostnames and blocks private/internal IPs. + 4. Pins resolved IPs to prevent DNS rebinding attacks (TOCTOU). + + Usage: + try: + response = SSRFSafeSession().get("https://example.com/image.png", timeout=10) + except SSRFValidationError: + # URL was blocked for security reasons + pass + """ + + def _validate_and_unpack(self, url: str) -> tuple[str, str, str, int]: + """Validate a URL and return (validated_ip, hostname, scheme, port). + + Raises: + SSRFValidationError: If the URL is unsafe. + """ + parsed = urlparse(url) + if parsed.scheme not in {"http", "https"}: + raise SSRFValidationError("Invalid URL scheme (only http/https allowed)") + if not parsed.hostname: + raise SSRFValidationError("Invalid URL (missing hostname)") + + valid_ips = validate_hostname(parsed.hostname, allow_ip_literal=False) + + if parsed.port: + port = parsed.port + elif parsed.scheme == "http": + port = 80 + else: + port = 443 + + return valid_ips[0], parsed.hostname, parsed.scheme, port + + def get(self, url: str, timeout: int, **kwargs) -> requests.Response: + """Perform a safe HTTP GET with per-hop SSRF validation on redirects. + + Redirects are followed manually up to MAX_REDIRECTS hops. Each Location + URL is re-validated from scratch, so an attacker-controlled server + cannot redirect to an internal address or a different private target + on a later hop. + """ + # We always handle redirects ourselves — strip any caller override so + # the underlying requests session never follows a redirect unchecked. + kwargs.pop("allow_redirects", None) + + current_url = url + for _ in range(MAX_REDIRECTS + 1): + validated_ip, hostname, scheme, port = self._validate_and_unpack( + current_url + ) + + session = requests.Session() + adapter = SSRFProtectedAdapter( + dest_ip=validated_ip, + dest_port=port, + original_hostname=hostname, + original_scheme=scheme, + ) + session.mount("http://", adapter) + session.mount("https://", adapter) + + response = session.get( + current_url, timeout=timeout, allow_redirects=False, **kwargs + ) + + if response.status_code not in REDIRECT_STATUS_CODES: + return response + + location = response.headers.get("Location") + if not location: + # Redirect without a Location — hand the response back unchanged. + return response + + next_url = urljoin(current_url, location) + response.close() + current_url = next_url + + raise SSRFValidationError(f"Too many redirects (max {MAX_REDIRECTS})") diff --git a/src/backend/core/tests/api/test_image_proxy.py b/src/backend/core/tests/api/test_image_proxy.py index a3f171e9..0b204c9a 100644 --- a/src/backend/core/tests/api/test_image_proxy.py +++ b/src/backend/core/tests/api/test_image_proxy.py @@ -191,7 +191,7 @@ class TestImageProxyViewSet: "https://localhost:8080/image.jpg", ], ) - @patch("core.api.viewsets.image_proxy.socket.getaddrinfo") + @patch("core.services.ssrf.socket.getaddrinfo") def test_api_image_proxy_localhost_hostname_blocked( self, mock_getaddrinfo, api_client, user_mailbox, test_url ): @@ -210,7 +210,7 @@ class TestImageProxyViewSet: assert response["Content-Type"] == "image/svg+xml" @override_settings(IMAGE_PROXY_ENABLED=True) - @patch("core.api.viewsets.image_proxy.socket.getaddrinfo") + @patch("core.services.ssrf.socket.getaddrinfo") def test_api_image_proxy_domain_resolving_to_private_ip_blocked( self, mock_getaddrinfo, api_client, user_mailbox ): @@ -229,7 +229,7 @@ class TestImageProxyViewSet: assert response["Content-Type"] == "image/svg+xml" @override_settings(IMAGE_PROXY_ENABLED=True) - @patch("core.api.viewsets.image_proxy.socket.getaddrinfo") + @patch("core.services.ssrf.socket.getaddrinfo") def test_api_image_proxy_unresolvable_hostname( self, mock_getaddrinfo, api_client, user_mailbox ): @@ -249,7 +249,7 @@ class TestImageProxyViewSet: # Content-Type Validation Tests @override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10) - @patch("core.api.viewsets.image_proxy.socket.getaddrinfo") + @patch("core.services.ssrf.socket.getaddrinfo") @patch("core.api.viewsets.image_proxy.SSRFSafeSession.get") def test_api_image_proxy_non_image_content_type_blocked( self, mock_get, mock_getaddrinfo, api_client, user_mailbox @@ -273,7 +273,7 @@ class TestImageProxyViewSet: assert response["Content-Type"] == "image/svg+xml" @override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10) - @patch("core.api.viewsets.image_proxy.socket.getaddrinfo") + @patch("core.services.ssrf.socket.getaddrinfo") @patch("core.api.viewsets.image_proxy.SSRFSafeSession.get") @patch("magic.from_buffer") def test_api_image_proxy_content_not_actually_image( @@ -301,7 +301,7 @@ class TestImageProxyViewSet: assert response["Content-Type"] == "image/svg+xml" @override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10) - @patch("core.api.viewsets.image_proxy.socket.getaddrinfo") + @patch("core.services.ssrf.socket.getaddrinfo") @patch("core.api.viewsets.image_proxy.SSRFSafeSession.get") @patch("magic.from_buffer") @pytest.mark.parametrize( @@ -342,7 +342,7 @@ class TestImageProxyViewSet: # Size Limit Tests @override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=1) - @patch("core.api.viewsets.image_proxy.socket.getaddrinfo") + @patch("core.services.ssrf.socket.getaddrinfo") @patch("core.api.viewsets.image_proxy.SSRFSafeSession.get") def test_api_image_proxy_image_too_large_via_content_length( self, mock_get, mock_getaddrinfo, api_client, user_mailbox @@ -367,7 +367,7 @@ class TestImageProxyViewSet: assert "Image too large" in str(response.data) @override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=4096) - @patch("core.api.viewsets.image_proxy.socket.getaddrinfo") + @patch("core.services.ssrf.socket.getaddrinfo") @patch("core.api.viewsets.image_proxy.SSRFSafeSession.get") @patch("magic.from_buffer") def test_api_image_proxy_image_too_large_actual_content( @@ -395,39 +395,13 @@ class TestImageProxyViewSet: assert response.status_code == status.HTTP_413_REQUEST_ENTITY_TOO_LARGE assert "Image too large" in str(response.data) - # Redirect Tests - - @override_settings(IMAGE_PROXY_ENABLED=True) - @patch("core.api.viewsets.image_proxy.socket.getaddrinfo") - @patch("core.api.viewsets.image_proxy.SSRFSafeSession.get") - def test_api_image_proxy_redirects_blocked( - self, mock_get, mock_getaddrinfo, api_client, user_mailbox - ): - """Test that redirects are not followed (SSRF protection).""" - # Mock DNS resolution - mock_getaddrinfo.return_value = [(2, 1, 6, "", ("1.2.3.4", 80))] - - client, _ = api_client - url = self._get_image_proxy_url( - user_mailbox.id, "http://example.com/redirect.jpg" - ) - - mock_get.return_value = self._mock_requests_response( - content=b"", content_type="text/plain", status_code=302 - ) - - # Call the endpoint - client.get(url) - - # Verify that allow_redirects=False was passed to requests.get - mock_get.assert_called_once() - call_kwargs = mock_get.call_args[1] - assert call_kwargs["allow_redirects"] is False + # Redirect handling is covered end-to-end in core/tests/services/test_ssrf.py: + # SSRFSafeSession follows redirects internally, re-validating every hop. # Success Cases @override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_CACHE_TTL=3600) - @patch("core.api.viewsets.image_proxy.socket.getaddrinfo") + @patch("core.services.ssrf.socket.getaddrinfo") @patch("core.api.viewsets.image_proxy.SSRFSafeSession.get") @patch("magic.from_buffer") def test_api_image_proxy_successfully_jpg_image( @@ -459,7 +433,7 @@ class TestImageProxyViewSet: assert response["Cache-Control"] == "public, max-age=3600" @override_settings(IMAGE_PROXY_ENABLED=True) - @patch("core.api.viewsets.image_proxy.socket.getaddrinfo") + @patch("core.services.ssrf.socket.getaddrinfo") @patch("core.api.viewsets.image_proxy.SSRFSafeSession.get") @patch("magic.from_buffer") def test_api_image_proxy_successfully_png_image( @@ -488,7 +462,7 @@ class TestImageProxyViewSet: assert response["Content-Type"] == "image/png" @override_settings(IMAGE_PROXY_ENABLED=True) - @patch("core.api.viewsets.image_proxy.socket.getaddrinfo") + @patch("core.services.ssrf.socket.getaddrinfo") @patch("core.api.viewsets.image_proxy.SSRFSafeSession.get") @patch("magic.from_buffer") def test_api_image_proxy_successfully_with_octet_stream_content_type( @@ -517,7 +491,7 @@ class TestImageProxyViewSet: assert response["Content-Type"] == "image/jpeg" @override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_CACHE_TTL=3600) - @patch("core.api.viewsets.image_proxy.socket.getaddrinfo") + @patch("core.services.ssrf.socket.getaddrinfo") @patch("core.api.viewsets.image_proxy.SSRFSafeSession.get") @patch("magic.from_buffer") def test_api_image_proxy_no_content( @@ -547,7 +521,7 @@ class TestImageProxyViewSet: assert response["Content-Type"] == "image/svg+xml" @override_settings(IMAGE_PROXY_ENABLED=True) - @patch("core.api.viewsets.image_proxy.socket.getaddrinfo") + @patch("core.services.ssrf.socket.getaddrinfo") @patch("core.api.viewsets.image_proxy.SSRFSafeSession.get") @patch("magic.from_buffer") def test_api_image_proxy_url_with_special_characters( @@ -578,7 +552,7 @@ class TestImageProxyViewSet: assert mock_get.call_args[0][0] == test_url @override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_CACHE_TTL=3600) - @patch("core.api.viewsets.image_proxy.socket.getaddrinfo") + @patch("core.services.ssrf.socket.getaddrinfo") @patch("core.api.viewsets.image_proxy.SSRFSafeSession.get") @patch("magic.from_buffer") def test_api_image_proxy_successfully_secure_headers( @@ -617,7 +591,7 @@ class TestImageProxyViewSet: # Error Handling Tests @override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10) - @patch("core.api.viewsets.image_proxy.socket.getaddrinfo") + @patch("core.services.ssrf.socket.getaddrinfo") @patch("core.api.viewsets.image_proxy.SSRFSafeSession.get") def test_api_image_proxy_network_timeout( self, mock_get, mock_getaddrinfo, api_client, user_mailbox @@ -638,7 +612,7 @@ class TestImageProxyViewSet: assert "Failed to fetch image" in str(response.data) @override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10) - @patch("core.api.viewsets.image_proxy.socket.getaddrinfo") + @patch("core.services.ssrf.socket.getaddrinfo") @patch("core.api.viewsets.image_proxy.SSRFSafeSession.get") def test_api_image_proxy_connection_error( self, mock_get, mock_getaddrinfo, api_client, user_mailbox @@ -659,7 +633,7 @@ class TestImageProxyViewSet: assert "Failed to fetch image" in str(response.data) @override_settings(IMAGE_PROXY_ENABLED=True, IMAGE_PROXY_MAX_SIZE=10) - @patch("core.api.viewsets.image_proxy.socket.getaddrinfo") + @patch("core.services.ssrf.socket.getaddrinfo") @patch("core.api.viewsets.image_proxy.SSRFSafeSession.get") def test_api_image_proxy_http_error_404( self, mock_get, mock_getaddrinfo, api_client, user_mailbox @@ -687,7 +661,7 @@ class TestImageProxyViewSet: assert "Failed to fetch image" in str(response.data) @override_settings(IMAGE_PROXY_ENABLED=True) - @patch("core.api.viewsets.image_proxy.socket.getaddrinfo") + @patch("core.services.ssrf.socket.getaddrinfo") @patch("core.api.viewsets.image_proxy.SSRFSafeSession.get") @patch("magic.from_buffer") def test_api_image_proxy_invalid_content_length_header( @@ -716,7 +690,7 @@ class TestImageProxyViewSet: assert response.status_code == status.HTTP_200_OK @override_settings(IMAGE_PROXY_ENABLED=True) - @patch("core.api.viewsets.image_proxy.socket.getaddrinfo") + @patch("core.services.ssrf.socket.getaddrinfo") @patch("core.api.viewsets.image_proxy.SSRFSafeSession.get") @patch("magic.from_buffer") def test_api_image_proxy_missing_content_length_header( diff --git a/src/backend/core/tests/api/test_messages_import.py b/src/backend/core/tests/api/test_messages_import.py index 326ec6d5..0a6a1561 100644 --- a/src/backend/core/tests/api/test_messages_import.py +++ b/src/backend/core/tests/api/test_messages_import.py @@ -2,6 +2,7 @@ # pylint: disable=redefined-outer-name, unused-argument, no-value-for-parameter, too-many-lines import datetime +import socket from unittest.mock import patch from django.core.files.storage import storages @@ -19,6 +20,25 @@ from core.services.importer.mbox_tasks import process_mbox_file_task pytestmark = pytest.mark.django_db + +@pytest.fixture(autouse=True) +def _mock_ssrf_dns(): + """Short-circuit SSRF DNS validation for IMAP import tests. + + The IMAP endpoint validates the server hostname via + ``core.services.ssrf.validate_hostname``; tests use unresolvable fixtures + like ``imap.example.com`` so we return a public IP to reach the mocked + IMAP task code. + """ + with patch( + "core.services.ssrf.socket.getaddrinfo", + return_value=[ + (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("93.184.216.34", 0)) + ], + ): + yield + + IMPORT_FILE_URL = "/api/v1.0/import/file/" IMPORT_IMAP_URL = "/api/v1.0/import/imap/" diff --git a/src/backend/core/tests/importer/conftest.py b/src/backend/core/tests/importer/conftest.py new file mode 100644 index 00000000..d1ea9998 --- /dev/null +++ b/src/backend/core/tests/importer/conftest.py @@ -0,0 +1,24 @@ +"""Shared fixtures for importer tests.""" + +import socket +from unittest import mock + +import pytest + + +@pytest.fixture(autouse=True) +def _mock_ssrf_dns(): + """Short-circuit SSRF DNS validation for IMAP tests. + + The IMAP import path validates the server hostname via + ``core.services.ssrf.validate_hostname`` which calls ``socket.getaddrinfo``. + Test fixtures use unresolvable hostnames like ``imap.example.com``, so we + return a valid public IP here to let tests reach the mocked IMAP code. + """ + with mock.patch( + "core.services.ssrf.socket.getaddrinfo", + return_value=[ + (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("93.184.216.34", 0)) + ], + ): + yield diff --git a/src/backend/core/tests/services/test_ssrf.py b/src/backend/core/tests/services/test_ssrf.py new file mode 100644 index 00000000..a745ef1c --- /dev/null +++ b/src/backend/core/tests/services/test_ssrf.py @@ -0,0 +1,274 @@ +"""Unit tests for core.services.ssrf — focus on redirect handling. + +Hostname/IP validation itself is covered via the image proxy integration tests +(see core/tests/api/test_image_proxy.py). This file targets the per-hop +validation loop in SSRFSafeSession.get, which is the main SSRF-correctness +surface for redirect responses. +""" + +import socket +from unittest.mock import MagicMock, patch + +import pytest + +from core.services.ssrf import ( + MAX_REDIRECTS, + SSRFSafeSession, + SSRFValidationError, +) + +PUBLIC_IP = "93.184.216.34" +PRIVATE_IP = "192.168.1.1" + + +def _addrinfo(ip: str): + return [(socket.AF_INET, socket.SOCK_STREAM, 6, "", (ip, 0))] + + +def _mock_response(status_code: int = 200, location: str | None = None): + """Build a mock requests.Response-like object.""" + resp = MagicMock() + resp.status_code = status_code + resp.headers = {} + if location is not None: + resp.headers["Location"] = location + return resp + + +class TestSSRFSafeSessionRedirects: + """Redirect-handling contract for SSRFSafeSession.get.""" + + @patch("core.services.ssrf.requests.Session.get") + @patch("core.services.ssrf.socket.getaddrinfo") + def test_no_redirect_returns_response_directly(self, mock_dns, mock_get): + """A 200 response is returned without any extra request.""" + mock_dns.return_value = _addrinfo(PUBLIC_IP) + mock_get.return_value = _mock_response(200) + + response = SSRFSafeSession().get("https://legit.com/img.png", timeout=10) + + assert response.status_code == 200 + assert mock_get.call_count == 1 + + @patch("core.services.ssrf.requests.Session.get") + @patch("core.services.ssrf.socket.getaddrinfo") + def test_follows_redirect_to_safe_url(self, mock_dns, mock_get): + """A single redirect is followed to a validated destination.""" + mock_dns.return_value = _addrinfo(PUBLIC_IP) + mock_get.side_effect = [ + _mock_response(302, location="https://cdn.legit.com/img.png"), + _mock_response(200), + ] + + response = SSRFSafeSession().get("https://legit.com/img.png", timeout=10) + + assert response.status_code == 200 + assert mock_get.call_count == 2 + + @pytest.mark.parametrize("redirect_status", [301, 302, 303, 307, 308]) + @patch("core.services.ssrf.requests.Session.get") + @patch("core.services.ssrf.socket.getaddrinfo") + def test_all_redirect_statuses_followed(self, mock_dns, mock_get, redirect_status): + """All standard redirect status codes trigger a new hop.""" + mock_dns.return_value = _addrinfo(PUBLIC_IP) + mock_get.side_effect = [ + _mock_response(redirect_status, location="https://b.com/img.png"), + _mock_response(200), + ] + + response = SSRFSafeSession().get("https://a.com/img.png", timeout=10) + + assert response.status_code == 200 + assert mock_get.call_count == 2 + + @patch("core.services.ssrf.requests.Session.get") + @patch("core.services.ssrf.socket.getaddrinfo") + def test_follows_multiple_hops(self, mock_dns, mock_get): + """Chained redirects are followed up to MAX_REDIRECTS.""" + mock_dns.return_value = _addrinfo(PUBLIC_IP) + mock_get.side_effect = [ + _mock_response(302, location="https://b.com/"), + _mock_response(302, location="https://c.com/"), + _mock_response(302, location="https://d.com/"), + _mock_response(200), + ] + + response = SSRFSafeSession().get("https://a.com/", timeout=10) + + assert response.status_code == 200 + assert mock_get.call_count == 4 + + @patch("core.services.ssrf.requests.Session.get") + @patch("core.services.ssrf.socket.getaddrinfo") + def test_blocks_redirect_to_private_ip(self, mock_dns, mock_get): + """A Location resolving to a private IP is rejected mid-chain.""" + + def dns_side_effect(host, *_args, **_kwargs): + if host == "legit.com": + return _addrinfo(PUBLIC_IP) + if host == "internal.evil.com": + return _addrinfo(PRIVATE_IP) + raise AssertionError(f"unexpected DNS lookup: {host}") + + mock_dns.side_effect = dns_side_effect + mock_get.return_value = _mock_response( + 302, location="https://internal.evil.com/pwn" + ) + + with pytest.raises(SSRFValidationError, match="private IP"): + SSRFSafeSession().get("https://legit.com/img.png", timeout=10) + + # Only the first hop should have been issued; the second was blocked + # before the request was made. + assert mock_get.call_count == 1 + + @patch("core.services.ssrf.requests.Session.get") + @patch("core.services.ssrf.socket.getaddrinfo") + def test_blocks_redirect_to_loopback(self, mock_dns, mock_get): + """Redirect pointing at loopback (DNS-rebinding style) is rejected.""" + + def dns_side_effect(host, *_args, **_kwargs): + if host == "legit.com": + return _addrinfo(PUBLIC_IP) + if host == "rebind.evil.com": + return _addrinfo("127.0.0.1") + raise AssertionError(f"unexpected DNS lookup: {host}") + + mock_dns.side_effect = dns_side_effect + mock_get.return_value = _mock_response(302, location="https://rebind.evil.com/") + + with pytest.raises(SSRFValidationError, match="loopback"): + SSRFSafeSession().get("https://legit.com/img.png", timeout=10) + + @patch("core.services.ssrf.requests.Session.get") + @patch("core.services.ssrf.socket.getaddrinfo") + def test_blocks_redirect_to_cloud_metadata(self, mock_dns, mock_get): + """Redirect pointing at cloud metadata endpoint is rejected.""" + + def dns_side_effect(host, *_args, **_kwargs): + if host == "legit.com": + return _addrinfo(PUBLIC_IP) + if host == "meta.evil.com": + return _addrinfo("169.254.169.254") + raise AssertionError(f"unexpected DNS lookup: {host}") + + mock_dns.side_effect = dns_side_effect + mock_get.return_value = _mock_response( + 302, location="https://meta.evil.com/latest/meta-data/" + ) + + with pytest.raises(SSRFValidationError, match="metadata"): + SSRFSafeSession().get("https://legit.com/img.png", timeout=10) + + @patch("core.services.ssrf.requests.Session.get") + @patch("core.services.ssrf.socket.getaddrinfo") + def test_blocks_redirect_to_non_http_scheme(self, mock_dns, mock_get): + """Redirect to file:// or other schemes is rejected.""" + mock_dns.return_value = _addrinfo(PUBLIC_IP) + mock_get.return_value = _mock_response(302, location="file:///etc/passwd") + + with pytest.raises(SSRFValidationError, match="scheme"): + SSRFSafeSession().get("https://legit.com/", timeout=10) + + @patch("core.services.ssrf.requests.Session.get") + @patch("core.services.ssrf.socket.getaddrinfo") + def test_blocks_redirect_to_ip_literal(self, mock_dns, mock_get): + """Redirect whose Location is a raw IP is rejected (domains only).""" + mock_dns.return_value = _addrinfo(PUBLIC_IP) + mock_get.return_value = _mock_response(302, location="http://203.0.113.5/stuff") + + with pytest.raises(SSRFValidationError, match="IP addresses are not allowed"): + SSRFSafeSession().get("https://legit.com/", timeout=10) + + @patch("core.services.ssrf.requests.Session.get") + @patch("core.services.ssrf.socket.getaddrinfo") + def test_blocks_protocol_relative_redirect_to_private(self, mock_dns, mock_get): + """//host/path Location resolves against current scheme and is validated.""" + + def dns_side_effect(host, *_args, **_kwargs): + if host == "legit.com": + return _addrinfo(PUBLIC_IP) + if host == "internal": + return _addrinfo(PRIVATE_IP) + raise AssertionError(f"unexpected DNS lookup: {host}") + + mock_dns.side_effect = dns_side_effect + mock_get.return_value = _mock_response(302, location="//internal/admin") + + with pytest.raises(SSRFValidationError, match="private IP"): + SSRFSafeSession().get("https://legit.com/", timeout=10) + + @patch("core.services.ssrf.requests.Session.get") + @patch("core.services.ssrf.socket.getaddrinfo") + def test_follows_relative_redirect(self, mock_dns, mock_get): + """Relative Location is resolved against the current URL and validated.""" + mock_dns.return_value = _addrinfo(PUBLIC_IP) + mock_get.side_effect = [ + _mock_response(302, location="/new-path"), + _mock_response(200), + ] + + response = SSRFSafeSession().get("https://legit.com/old", timeout=10) + + assert response.status_code == 200 + # Second request should be made to https://legit.com/new-path. + second_call_url = mock_get.call_args_list[1].args[0] + assert second_call_url == "https://legit.com/new-path" + + @patch("core.services.ssrf.requests.Session.get") + @patch("core.services.ssrf.socket.getaddrinfo") + def test_blocks_too_many_redirects(self, mock_dns, mock_get): + """A redirect loop longer than MAX_REDIRECTS raises.""" + mock_dns.return_value = _addrinfo(PUBLIC_IP) + mock_get.return_value = _mock_response(302, location="https://legit.com/loop") + + with pytest.raises(SSRFValidationError, match="Too many redirects"): + SSRFSafeSession().get("https://legit.com/", timeout=10) + + # Initial hop + MAX_REDIRECTS extra hops, then error. + assert mock_get.call_count == MAX_REDIRECTS + 1 + + @patch("core.services.ssrf.requests.Session.get") + @patch("core.services.ssrf.socket.getaddrinfo") + def test_redirect_without_location_returned_as_is(self, mock_dns, mock_get): + """A redirect status with no Location header is returned unchanged.""" + mock_dns.return_value = _addrinfo(PUBLIC_IP) + mock_get.return_value = _mock_response(302, location=None) + + response = SSRFSafeSession().get("https://legit.com/", timeout=10) + + assert response.status_code == 302 + assert mock_get.call_count == 1 + + @patch("core.services.ssrf.requests.Session.get") + @patch("core.services.ssrf.socket.getaddrinfo") + def test_caller_allow_redirects_true_is_ignored(self, mock_dns, mock_get): + """Caller cannot opt out of per-hop validation by passing allow_redirects=True.""" + mock_dns.return_value = _addrinfo(PUBLIC_IP) + mock_get.side_effect = [ + _mock_response(302, location="https://cdn.legit.com/"), + _mock_response(200), + ] + + response = SSRFSafeSession().get( + "https://legit.com/", timeout=10, allow_redirects=True + ) + + assert response.status_code == 200 + # Each underlying Session.get must be called with allow_redirects=False. + for call in mock_get.call_args_list: + assert call.kwargs.get("allow_redirects") is False + + @patch("core.services.ssrf.requests.Session.get") + @patch("core.services.ssrf.socket.getaddrinfo") + def test_intermediate_response_closed(self, mock_dns, mock_get): + """Intermediate redirect responses are .close()d to release streams.""" + mock_dns.return_value = _addrinfo(PUBLIC_IP) + intermediate = _mock_response(302, location="https://cdn.legit.com/") + final = _mock_response(200) + mock_get.side_effect = [intermediate, final] + + SSRFSafeSession().get("https://legit.com/", timeout=10, stream=True) + + intermediate.close.assert_called_once() + final.close.assert_not_called()