diff --git a/CHANGELOG.md b/CHANGELOG.md index 7f0028184..b29105b82 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ and this project adheres to ## [Unreleased] +### Added + +- ✨(backend) allow to use new ai feature using mistral sdk + ## [v4.8.6] - 2026-04-08 ### Added diff --git a/docs/env.md b/docs/env.md index 2dad89ecf..03c6833b5 100644 --- a/docs/env.md +++ b/docs/env.md @@ -9,14 +9,16 @@ These are the environment variables you can set for the `impress-backend` contai | Option | Description | default | | ----------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------- | | AI_ALLOW_REACH_FROM | Users that can use AI must be this level. options are "public", "authenticated", "restricted" | authenticated | -| AI_API_KEY | AI key to be used for AI Base url | | -| AI_BASE_URL | OpenAI compatible AI base url | | -| AI_BOT | Information to give to the frontend about the AI bot | { "name": "Docs AI", "color": "#8bc6ff" } +| OPENAI_SDK_API_KEY | AI key to be used by the OpenAI python SDK | | +| OPENAI_SDK_BASE_URL | OpenAI compatible AI base url | | +| MISTRAL_SDK_API_KEY | AI key to be used by the Mistral python SDK /!\ Mistral sdk can be used only in async mode with uvicorn /!\ | | +| MISTRAL_SDK_BASE_URL | Mistral compatible AI base url | | +| AI_BOT | Information to give to the frontend about the AI bot | { "name": "Docs AI", "color": "#8bc6ff" } | | AI_FEATURE_ENABLED | Enable AI options | false | -| AI_FEATURE_BLOCKNOTE_ENABLED | Enable Blocknote AI options | false | -| AI_FEATURE_LEGACY_ENABLED | Enable legacyAI options | true | +| AI_FEATURE_BLOCKNOTE_ENABLED | Enable Blocknote AI options | false | +| AI_FEATURE_LEGACY_ENABLED | Enable legacyAI options | true | | AI_MODEL | AI Model to use | | -| AI_VERCEL_SDK_VERSION | The vercel AI SDK version used | 6 | +| AI_VERCEL_SDK_VERSION | The vercel AI SDK version used | 6 | | ALLOW_LOGOUT_GET_METHOD | Allow get logout method | true | | API_USERS_LIST_LIMIT | Limit on API users | 5 | | API_USERS_LIST_THROTTLE_RATE_BURST | Throttle rate for api on burst | 30/minute | diff --git a/env.d/development/common b/env.d/development/common index fb311964f..252f10518 100644 --- a/env.d/development/common +++ b/env.d/development/common @@ -71,14 +71,6 @@ OIDC_RS_ALLOWED_AUDIENCES="" # User reconciliation USER_RECONCILIATION_FORM_URL=http://localhost:3000 -# AI -AI_FEATURE_ENABLED=true -AI_FEATURE_BLOCKNOTE_ENABLED=true -AI_FEATURE_LEGACY_ENABLED=true -AI_BASE_URL=https://openaiendpoint.com -AI_API_KEY=password -AI_MODEL=llama - # Collaboration COLLABORATION_API_URL=http://y-provider-development:4444/collaboration/api/ COLLABORATION_BACKEND_BASE_URL=http://app-dev:8000 diff --git a/src/backend/core/services/ai_services.py b/src/backend/core/services/ai_services.py index 89e322ab1..0f6a59c68 100644 --- a/src/backend/core/services/ai_services.py +++ b/src/backend/core/services/ai_services.py @@ -7,6 +7,7 @@ import os import queue import threading from collections.abc import AsyncIterator, Iterator +from functools import cache from typing import Any, Dict, Union from django.conf import settings @@ -15,7 +16,9 @@ from django.core.exceptions import ImproperlyConfigured from langfuse import get_client from langfuse.openai import OpenAI as OpenAI_Langfuse from pydantic_ai import Agent, DeferredToolRequests +from pydantic_ai.models.mistral import MistralModel from pydantic_ai.models.openai import OpenAIChatModel +from pydantic_ai.providers.mistral import MistralProvider from pydantic_ai.providers.openai import OpenAIProvider from pydantic_ai.tools import ToolDefinition from pydantic_ai.toolsets.external import ExternalToolset @@ -143,22 +146,59 @@ def convert_async_generator_to_sync(async_gen: AsyncIterator[str]) -> Iterator[s thread.join() +@cache +def configure_pydantic_model_provider() -> OpenAIChatModel | MistralModel: + """Configure a pydantic Model and return it.""" + if ( + settings.OPENAI_SDK_API_KEY + and settings.OPENAI_SDK_BASE_URL + and settings.AI_MODEL + ): + return OpenAIChatModel( + settings.AI_MODEL, + provider=OpenAIProvider( + api_key=settings.OPENAI_SDK_API_KEY, + base_url=settings.OPENAI_SDK_BASE_URL, + ), + ) + + if ( + settings.MISTRAL_SDK_API_KEY + and settings.MISTRAL_SDK_BASE_URL + and settings.AI_MODEL + ): + return MistralModel( + settings.AI_MODEL, + provider=MistralProvider( + api_key=settings.MISTRAL_SDK_API_KEY, + base_url=settings.MISTRAL_SDK_BASE_URL, + ), + ) + + raise ImproperlyConfigured("AI configuration not set") + + +@cache +def configure_legacy_openai_client(): + """Configure the open ai sdk client for the legacy AI feature.""" + if ( + settings.OPENAI_SDK_BASE_URL is None + or settings.OPENAI_SDK_API_KEY is None + or settings.AI_MODEL is None + ): + raise ImproperlyConfigured("AI configuration not set") + return OpenAI( + base_url=settings.OPENAI_SDK_BASE_URL, api_key=settings.OPENAI_SDK_API_KEY + ) + + class AIService: """Service class for AI-related operations.""" - def __init__(self): - """Ensure that the AI configuration is set properly.""" - if ( - settings.AI_BASE_URL is None - or settings.AI_API_KEY is None - or settings.AI_MODEL is None - ): - raise ImproperlyConfigured("AI configuration not set") - self.client = OpenAI(base_url=settings.AI_BASE_URL, api_key=settings.AI_API_KEY) - def call_ai_api(self, system_content, text): """Helper method to call the OpenAI API and process the response.""" - response = self.client.chat.completions.create( + client = configure_legacy_openai_client() + response = client.chat.completions.create( model=settings.AI_MODEL, messages=[ {"role": "system", "content": system_content}, @@ -324,13 +364,9 @@ class AIService: langfuse.auth_check() Agent.instrument_all() - model = OpenAIChatModel( - settings.AI_MODEL, - provider=OpenAIProvider( - base_url=settings.AI_BASE_URL, api_key=settings.AI_API_KEY - ), + agent = Agent( + configure_pydantic_model_provider(), instrument=instrument_enabled ) - agent = Agent(model, instrument=instrument_enabled) accept = request.META.get("HTTP_ACCEPT", SSE_CONTENT_TYPE) diff --git a/src/backend/core/tests/documents/test_api_documents_ai_proxy.py b/src/backend/core/tests/documents/test_api_documents_ai_proxy.py index 443fa7626..ffda9d9ec 100644 --- a/src/backend/core/tests/documents/test_api_documents_ai_proxy.py +++ b/src/backend/core/tests/documents/test_api_documents_ai_proxy.py @@ -11,6 +11,7 @@ import pytest from rest_framework.test import APIClient from core import factories +from core.services.ai_services import configure_pydantic_model_provider from core.tests.conftest import TEAM, USER, VIA pytestmark = pytest.mark.django_db @@ -20,13 +21,15 @@ pytestmark = pytest.mark.django_db def ai_settings(settings): """Fixture to set AI settings.""" settings.AI_MODEL = "llama" - settings.AI_BASE_URL = "http://localhost-ai:12345/" - settings.AI_API_KEY = "test-key" + settings.OPENAI_SDK_BASE_URL = "http://localhost-ai:12345/" + settings.OPENAI_SDK_API_KEY = "test-key" settings.AI_FEATURE_ENABLED = True settings.AI_FEATURE_BLOCKNOTE_ENABLED = True settings.AI_FEATURE_LEGACY_ENABLED = True settings.LANGFUSE_PUBLIC_KEY = None settings.AI_VERCEL_SDK_VERSION = 6 + yield + configure_pydantic_model_provider.cache_clear() @override_settings( diff --git a/src/backend/core/tests/documents/test_api_documents_ai_transform.py b/src/backend/core/tests/documents/test_api_documents_ai_transform.py index d047839c1..5ada04c3c 100644 --- a/src/backend/core/tests/documents/test_api_documents_ai_transform.py +++ b/src/backend/core/tests/documents/test_api_documents_ai_transform.py @@ -2,47 +2,61 @@ Test AI transform API endpoint for users in impress's core app. """ -import random from unittest.mock import MagicMock, patch -from django.test import override_settings - import pytest from rest_framework.test import APIClient from core import factories +from core.services.ai_services import configure_legacy_openai_client from core.tests.conftest import TEAM, USER, VIA pytestmark = pytest.mark.django_db @pytest.fixture -def ai_settings(): +def ai_settings(settings): """Fixture to set AI settings.""" - with override_settings( - AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="llama" - ): - yield + settings.OPENAI_SDK_BASE_URL = "http://example.com" + settings.OPENAI_SDK_API_KEY = "test-key" + settings.AI_MODEL = "llama" + + +@pytest.fixture(autouse=True) +def clear_openai_client_config(): + """Clear the _configure_legacy_openai_client cache""" + yield + configure_legacy_openai_client.cache_clear() -@override_settings( - AI_ALLOW_REACH_FROM=random.choice(["public", "authenticated", "restricted"]) -) @pytest.mark.parametrize( - "reach, role", + "reach, role, ai_allow_reach_from", [ - ("restricted", "reader"), - ("restricted", "editor"), - ("authenticated", "reader"), - ("authenticated", "editor"), - ("public", "reader"), + ("restricted", "reader", "public"), + ("restricted", "reader", "authenticated"), + ("restricted", "reader", "restricted"), + ("restricted", "editor", "public"), + ("restricted", "editor", "authenticated"), + ("restricted", "editor", "restrictied"), + ("authenticated", "reader", "public"), + ("authenticated", "reader", "authenticated"), + ("authenticated", "reader", "restricted"), + ("authenticated", "editor", "public"), + ("authenticated", "editor", "authenticated"), + ("authenticated", "editor", "restricted"), + ("public", "reader", "public"), + ("public", "reader", "authenticated"), + ("public", "reader", "restricted"), ], ) -def test_api_documents_ai_transform_anonymous_forbidden(reach, role): +def test_api_documents_ai_transform_anonymous_forbidden( + reach, role, ai_allow_reach_from, settings +): """ Anonymous users should not be able to request AI transform if the link reach and role don't allow it. """ + settings.AI_ALLOW_REACH_FROM = ai_allow_reach_from document = factories.DocumentFactory(link_reach=reach, link_role=role) url = f"/api/v1.0/documents/{document.id!s}/ai-transform/" @@ -54,14 +68,14 @@ def test_api_documents_ai_transform_anonymous_forbidden(reach, role): } -@override_settings(AI_ALLOW_REACH_FROM="public") @pytest.mark.usefixtures("ai_settings") @patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_transform_anonymous_success(mock_create): +def test_api_documents_ai_transform_anonymous_success(mock_create, settings): """ Anonymous users should be able to request AI transform to a document if the link reach and role permit it. """ + settings.AI_ALLOW_REACH_FROM = "public" document = factories.DocumentFactory(link_reach="public", link_role="editor") mock_create.return_value = MagicMock( @@ -88,14 +102,17 @@ def test_api_documents_ai_transform_anonymous_success(mock_create): ) -@override_settings(AI_ALLOW_REACH_FROM=random.choice(["authenticated", "restricted"])) @pytest.mark.usefixtures("ai_settings") +@pytest.mark.parametrize("ai_allow_reach_from", ["authenticated", "restricted"]) @patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_transform_anonymous_limited_by_setting(mock_create): +def test_api_documents_ai_transform_anonymous_limited_by_setting( + mock_create, ai_allow_reach_from, settings +): """ Anonymous users should be able to request AI transform to a document if the link reach and role permit it. """ + settings.AI_ALLOW_REACH_FROM = ai_allow_reach_from document = factories.DocumentFactory(link_reach="public", link_role="editor") answer = '{"answer": "Salut"}' @@ -176,8 +193,8 @@ def test_api_documents_ai_transform_authenticated_success(mock_create, reach, ro "role": "system", "content": ( "Answer the prompt using markdown formatting for structure and emphasis. " - "Return the content directly without wrapping it in code blocks or markdown delimiters. " - "Preserve the language and markdown formatting. " + "Return the content directly without wrapping it in code blocks or markdown " + "delimiters. Preserve the language and markdown formatting. " "Do not provide any other information. " "Preserve the language." ), @@ -253,8 +270,8 @@ def test_api_documents_ai_transform_success(mock_create, via, role, mock_user_te "role": "system", "content": ( "Answer the prompt using markdown formatting for structure and emphasis. " - "Return the content directly without wrapping it in code blocks or markdown delimiters. " - "Preserve the language and markdown formatting. " + "Return the content directly without wrapping it in code blocks or markdown " + "delimiters. Preserve the language and markdown formatting. " "Do not provide any other information. " "Preserve the language." ), @@ -296,14 +313,14 @@ def test_api_documents_ai_transform_invalid_action(): assert response.json() == {"action": ['"invalid" is not a valid choice.']} -@override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10}) @pytest.mark.usefixtures("ai_settings") @patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_transform_throttling_document(mock_create): +def test_api_documents_ai_transform_throttling_document(mock_create, settings): """ Throttling per document should be triggered on the AI transform endpoint. For full throttle class test see: `test_api_utils_ai_document_rate_throttles` """ + settings.AI_DOCUMENT_RATE_THROTTLE_RATES = {"minute": 3, "hour": 6, "day": 10} client = APIClient() document = factories.DocumentFactory(link_reach="public", link_role="editor") @@ -329,14 +346,14 @@ def test_api_documents_ai_transform_throttling_document(mock_create): } -@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10}) @pytest.mark.usefixtures("ai_settings") @patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_transform_throttling_user(mock_create): +def test_api_documents_ai_transform_throttling_user(mock_create, settings): """ Throttling per user should be triggered on the AI transform endpoint. For full throttle class test see: `test_api_utils_ai_user_rate_throttles` """ + settings.AI_USER_RATE_THROTTLE_RATES = {"minute": 3, "hour": 6, "day": 10} user = factories.UserFactory() client = APIClient() client.force_login(user) diff --git a/src/backend/core/tests/documents/test_api_documents_ai_translate.py b/src/backend/core/tests/documents/test_api_documents_ai_translate.py index f0d7978c2..c800a588e 100644 --- a/src/backend/core/tests/documents/test_api_documents_ai_translate.py +++ b/src/backend/core/tests/documents/test_api_documents_ai_translate.py @@ -2,27 +2,31 @@ Test AI translate API endpoint for users in impress's core app. """ -import random from unittest.mock import MagicMock, patch -from django.test import override_settings - import pytest from rest_framework.test import APIClient from core import factories +from core.services.ai_services import configure_legacy_openai_client from core.tests.conftest import TEAM, USER, VIA pytestmark = pytest.mark.django_db @pytest.fixture -def ai_settings(): +def ai_settings(settings): """Fixture to set AI settings.""" - with override_settings( - AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="llama" - ): - yield + settings.OPENAI_SDK_BASE_URL = "http://example.com" + settings.OPENAI_SDK_API_KEY = "test-key" + settings.AI_MODEL = "llama" + + +@pytest.fixture(autouse=True) +def clear_openai_client_config(): + "clear the configure_legacy_openai_client cache" + yield + configure_legacy_openai_client.cache_clear() def test_api_documents_ai_translate_viewset_options_metadata(): @@ -45,24 +49,34 @@ def test_api_documents_ai_translate_viewset_options_metadata(): } -@override_settings( - AI_ALLOW_REACH_FROM=random.choice(["public", "authenticated", "restricted"]) -) @pytest.mark.parametrize( - "reach, role", + "reach, role, ai_allow_reach_from", [ - ("restricted", "reader"), - ("restricted", "editor"), - ("authenticated", "reader"), - ("authenticated", "editor"), - ("public", "reader"), + ("restricted", "reader", "public"), + ("restricted", "reader", "authenticated"), + ("restricted", "reader", "restricted"), + ("restricted", "editor", "public"), + ("restricted", "editor", "authenticated"), + ("restricted", "editor", "restrictied"), + ("authenticated", "reader", "public"), + ("authenticated", "reader", "authenticated"), + ("authenticated", "reader", "restricted"), + ("authenticated", "editor", "public"), + ("authenticated", "editor", "authenticated"), + ("authenticated", "editor", "restricted"), + ("public", "reader", "public"), + ("public", "reader", "authenticated"), + ("public", "reader", "restricted"), ], ) -def test_api_documents_ai_translate_anonymous_forbidden(reach, role): +def test_api_documents_ai_translate_anonymous_forbidden( + reach, role, ai_allow_reach_from, settings +): """ Anonymous users should not be able to request AI translate if the link reach and role don't allow it. """ + settings.AI_ALLOW_REACH_FROM = ai_allow_reach_from document = factories.DocumentFactory(link_reach=reach, link_role=role) url = f"/api/v1.0/documents/{document.id!s}/ai-translate/" @@ -74,14 +88,14 @@ def test_api_documents_ai_translate_anonymous_forbidden(reach, role): } -@override_settings(AI_ALLOW_REACH_FROM="public") @pytest.mark.usefixtures("ai_settings") @patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_translate_anonymous_success(mock_create): +def test_api_documents_ai_translate_anonymous_success(mock_create, settings): """ Anonymous users should be able to request AI translate to a document if the link reach and role permit it. """ + settings.AI_ALLOW_REACH_FROM = "public" document = factories.DocumentFactory(link_reach="public", link_role="editor") mock_create.return_value = MagicMock( @@ -110,14 +124,17 @@ def test_api_documents_ai_translate_anonymous_success(mock_create): ) -@override_settings(AI_ALLOW_REACH_FROM=random.choice(["authenticated", "restricted"])) @pytest.mark.usefixtures("ai_settings") +@pytest.mark.parametrize("ai_allow_reach_from", ["authenticated", "restricted"]) @patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_translate_anonymous_limited_by_setting(mock_create): +def test_api_documents_ai_translate_anonymous_limited_by_setting( + mock_create, ai_allow_reach_from, settings +): """ Anonymous users should be able to request AI translate to a document if the link reach and role permit it. """ + settings.AI_ALLOW_REACH_FROM = ai_allow_reach_from document = factories.DocumentFactory(link_reach="public", link_role="editor") answer = '{"answer": "Salut"}' @@ -318,14 +335,14 @@ def test_api_documents_ai_translate_invalid_action(): assert response.json() == {"language": ['"invalid" is not a valid choice.']} -@override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10}) @pytest.mark.usefixtures("ai_settings") @patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_translate_throttling_document(mock_create): +def test_api_documents_ai_translate_throttling_document(mock_create, settings): """ Throttling per document should be triggered on the AI translate endpoint. For full throttle class test see: `test_api_utils_ai_document_rate_throttles` """ + settings.AI_DOCUMENT_RATE_THROTTLE_RATES = {"minute": 3, "hour": 6, "day": 10} client = APIClient() document = factories.DocumentFactory(link_reach="public", link_role="editor") @@ -351,14 +368,14 @@ def test_api_documents_ai_translate_throttling_document(mock_create): } -@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10}) @pytest.mark.usefixtures("ai_settings") @patch("openai.resources.chat.completions.Completions.create") -def test_api_documents_ai_translate_throttling_user(mock_create): +def test_api_documents_ai_translate_throttling_user(mock_create, settings): """ Throttling per user should be triggered on the AI translate endpoint. For full throttle class test see: `test_api_utils_ai_user_rate_throttles` """ + settings.AI_USER_RATE_THROTTLE_RATES = {"minute": 3, "hour": 6, "day": 10} user = factories.UserFactory() client = APIClient() client.force_login(user) diff --git a/src/backend/core/tests/external_api/test_external_api_documents_ai.py b/src/backend/core/tests/external_api/test_external_api_documents_ai.py index 848be4f9b..92480fe81 100644 --- a/src/backend/core/tests/external_api/test_external_api_documents_ai.py +++ b/src/backend/core/tests/external_api/test_external_api_documents_ai.py @@ -14,6 +14,7 @@ import pytest from rest_framework.test import APIClient from core import factories, models +from core.services.ai_services import configure_legacy_openai_client from core.tests.documents.test_api_documents_ai_proxy import ( # pylint: disable=unused-import ai_settings, ) @@ -23,6 +24,13 @@ pytestmark = pytest.mark.django_db # pylint: disable=unused-argument +@pytest.fixture(autouse=True) +def clear_openai_client_config(): + """Clear the configure_legacy_openai_client cache.""" + yield + configure_legacy_openai_client.cache_clear() + + def test_external_api_documents_ai_transform_not_allowed( user_token, resource_server_backend, user_specific_sub ): diff --git a/src/backend/core/tests/test_services_ai_services.py b/src/backend/core/tests/test_services_ai_services.py index ff4e12d19..bcdba86c8 100644 --- a/src/backend/core/tests/test_services_ai_services.py +++ b/src/backend/core/tests/test_services_ai_services.py @@ -10,12 +10,16 @@ from django.core.exceptions import ImproperlyConfigured from django.test.utils import override_settings import pytest -from openai import OpenAIError +from openai import OpenAI, OpenAIError +from pydantic_ai.models.mistral import MistralModel +from pydantic_ai.models.openai import OpenAIChatModel from pydantic_ai.ui.vercel_ai.request_types import TextUIPart, UIMessage from core.services.ai_services import ( BLOCKNOTE_TOOL_STRICT_PROMPT, AIService, + configure_legacy_openai_client, + configure_pydantic_model_provider, convert_async_generator_to_sync, ) @@ -26,35 +30,90 @@ pytestmark = pytest.mark.django_db def ai_settings(settings): """Fixture to set AI settings.""" settings.AI_MODEL = "llama" - settings.AI_BASE_URL = "http://example.com" - settings.AI_API_KEY = "test-key" + settings.OPENAI_SDK_BASE_URL = "http://example.com" + settings.OPENAI_SDK_API_KEY = "test-key" settings.AI_FEATURE_ENABLED = True settings.AI_FEATURE_BLOCKNOTE_ENABLED = True settings.AI_FEATURE_LEGACY_ENABLED = True settings.LANGFUSE_PUBLIC_KEY = None settings.AI_VERCEL_SDK_VERSION = 6 + yield + configure_pydantic_model_provider.cache_clear() + configure_legacy_openai_client.cache_clear() -# -- AIService.__init__ -- +# -- AIService configure sdk-- @pytest.mark.parametrize( "setting_name, setting_value", [ - ("AI_BASE_URL", None), - ("AI_API_KEY", None), + ("OPENAI_SDK_BASE_URL", None), + ("OPENAI_SDK_API_KEY", None), ("AI_MODEL", None), ], ) -def test_services_ai_setting_missing(setting_name, setting_value, settings): - """Setting should be set""" +def test_ai_services_configure_legacy_openai_sdk_missing( + setting_name, setting_value, settings +): + """ + An exception must be raised if an expected settings is missing to configure the openai sdk. + """ setattr(settings, setting_name, setting_value) with pytest.raises( ImproperlyConfigured, match="AI configuration not set", ): - AIService() + configure_legacy_openai_client() + + +def test_ai_services_configure_legacy_openai_sdk(settings): + """With all required settings an open ai sdk instance should be configured.""" + settings.AI_MODEL = "llama" + settings.OPENAI_SDK_BASE_URL = "http://example.com" + settings.OPENAI_SDK_API_KEY = "test-key" + + openai_sdk = configure_legacy_openai_client() + + assert isinstance(openai_sdk, OpenAI) + + +def test_ai_services_configure_pydantic_ai_model_openai(settings): + """When openai sdk settings are configured it should return an OpenAiChatModel.""" + settings.AI_MODEL = "llama" + settings.OPENAI_SDK_BASE_URL = "http://example.com" + settings.OPENAI_SDK_API_KEY = "test-key" + + pydantic_ai_model = configure_pydantic_model_provider() + assert isinstance(pydantic_ai_model, OpenAIChatModel) + + +def test_ai_services_configure_pydantic_ai_model_mistral(settings): + """When mistral sdk settings are configured is should return a MistralModel.""" + settings.AI_MODEL = "llama" + settings.OPENAI_SDK_BASE_URL = None + settings.OPENAI_SDK_API_KEY = None + settings.MISTRAL_SDK_API_KEY = "mistreal-sdk-key" + settings.MISTRAL_SDK_BASE_URL = "https://mistral.base-url.com" + + pydantic_ai_model = configure_pydantic_model_provider() + assert isinstance(pydantic_ai_model, MistralModel) + + +def test_ai_services_configure_pydantic_ai_model_no_settings(settings): + """When no settings are configured for a ai sdk it should raises an exception.""" + settings.AI_MODEL = None + settings.OPENAI_SDK_BASE_URL = None + settings.OPENAI_SDK_API_KEY = None + settings.MISTRAL_SDK_API_KEY = None + settings.MISTRAL_SDK_BASE_URL = None + + with pytest.raises( + ImproperlyConfigured, + match="AI configuration not set", + ): + configure_pydantic_model_provider() # -- AIService.transform -- diff --git a/src/backend/impress/settings.py b/src/backend/impress/settings.py index aea0b45d4..5962522ef 100755 --- a/src/backend/impress/settings.py +++ b/src/backend/impress/settings.py @@ -801,8 +801,30 @@ class Base(Configuration): environ_name="AI_ALLOW_REACH_FROM", environ_prefix=None, ) - AI_API_KEY = SecretFileValue(None, environ_name="AI_API_KEY", environ_prefix=None) - AI_BASE_URL = values.Value(None, environ_name="AI_BASE_URL", environ_prefix=None) + + MISTRAL_SDK_BASE_URL = values.Value( + None, environ_name="MISTRAL_SDK_BASE_URL", environ_prefix=None + ) + MISTRAL_SDK_API_KEY = SecretFileValue( + None, environ_name="MISTRAL_SDK_API_KEY", environ_prefix=None + ) + + OPENAI_SDK_API_KEY = SecretFileValue( + default=SecretFileValue( # retrocompatibility + None, + environ_name="AI_API_KEY", + environ_prefix=None, + ), + environ_name="OPENAI_SDK_API_KEY", + environ_prefix=None, + ) + OPENAI_SDK_BASE_URL = values.Value( + default=values.Value( # retrocompatibility + None, environ_name="AI_BASE_URL", environ_prefix=None + ), + environ_name="OPENAI_SDK_BASE_URL", + environ_prefix=None, + ) AI_BOT = values.DictValue( default={ "name": _("Docs AI"), @@ -1138,6 +1160,11 @@ class Base(Configuration): } ) + if cls.OPENAI_SDK_API_KEY and cls.MISTRAL_SDK_API_KEY: + raise ValueError( + "Both OPENAI_SDK and MISTRAL_SDK parameters can not be set simultaneously." + ) + class Build(Base): """Settings used when the application is built. diff --git a/src/backend/pyproject.toml b/src/backend/pyproject.toml index d6613069f..33a68cd86 100644 --- a/src/backend/pyproject.toml +++ b/src/backend/pyproject.toml @@ -52,6 +52,7 @@ dependencies = [ "langfuse==3.11.2", "lxml==6.0.2", "markdown==3.10.2", + "mistralai==1.12.4", "mozilla-django-oidc==5.0.2", "nested-multipart-parser==1.6.0", "openai==2.24.0",