mirror of
https://github.com/suitenumerique/docs.git
synced 2026-04-25 17:15:01 +02:00
✨(backend) allow to use new ai feature using mistral sdk
We give the possibility, for the new ai feature, to choose between using the OpenAI or Mistral sdk. For instances having access to the mistral infrastructure, using it is mor appropriated than using the openai compatible chat model.
This commit is contained in:
@@ -6,6 +6,10 @@ and this project adheres to
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
|
||||
- ✨(backend) allow to use new ai feature using mistral sdk
|
||||
|
||||
## [v4.8.6] - 2026-04-08
|
||||
|
||||
### Added
|
||||
|
||||
14
docs/env.md
14
docs/env.md
@@ -9,14 +9,16 @@ These are the environment variables you can set for the `impress-backend` contai
|
||||
| Option | Description | default |
|
||||
| ----------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------- |
|
||||
| AI_ALLOW_REACH_FROM | Users that can use AI must be this level. options are "public", "authenticated", "restricted" | authenticated |
|
||||
| AI_API_KEY | AI key to be used for AI Base url | |
|
||||
| AI_BASE_URL | OpenAI compatible AI base url | |
|
||||
| AI_BOT | Information to give to the frontend about the AI bot | { "name": "Docs AI", "color": "#8bc6ff" }
|
||||
| OPENAI_SDK_API_KEY | AI key to be used by the OpenAI python SDK | |
|
||||
| OPENAI_SDK_BASE_URL | OpenAI compatible AI base url | |
|
||||
| MISTRAL_SDK_API_KEY | AI key to be used by the Mistral python SDK /!\ Mistral sdk can be used only in async mode with uvicorn /!\ | |
|
||||
| MISTRAL_SDK_BASE_URL | Mistral compatible AI base url | |
|
||||
| AI_BOT | Information to give to the frontend about the AI bot | { "name": "Docs AI", "color": "#8bc6ff" } |
|
||||
| AI_FEATURE_ENABLED | Enable AI options | false |
|
||||
| AI_FEATURE_BLOCKNOTE_ENABLED | Enable Blocknote AI options | false |
|
||||
| AI_FEATURE_LEGACY_ENABLED | Enable legacyAI options | true |
|
||||
| AI_FEATURE_BLOCKNOTE_ENABLED | Enable Blocknote AI options | false |
|
||||
| AI_FEATURE_LEGACY_ENABLED | Enable legacyAI options | true |
|
||||
| AI_MODEL | AI Model to use | |
|
||||
| AI_VERCEL_SDK_VERSION | The vercel AI SDK version used | 6 |
|
||||
| AI_VERCEL_SDK_VERSION | The vercel AI SDK version used | 6 |
|
||||
| ALLOW_LOGOUT_GET_METHOD | Allow get logout method | true |
|
||||
| API_USERS_LIST_LIMIT | Limit on API users | 5 |
|
||||
| API_USERS_LIST_THROTTLE_RATE_BURST | Throttle rate for api on burst | 30/minute |
|
||||
|
||||
@@ -71,14 +71,6 @@ OIDC_RS_ALLOWED_AUDIENCES=""
|
||||
# User reconciliation
|
||||
USER_RECONCILIATION_FORM_URL=http://localhost:3000
|
||||
|
||||
# AI
|
||||
AI_FEATURE_ENABLED=true
|
||||
AI_FEATURE_BLOCKNOTE_ENABLED=true
|
||||
AI_FEATURE_LEGACY_ENABLED=true
|
||||
AI_BASE_URL=https://openaiendpoint.com
|
||||
AI_API_KEY=password
|
||||
AI_MODEL=llama
|
||||
|
||||
# Collaboration
|
||||
COLLABORATION_API_URL=http://y-provider-development:4444/collaboration/api/
|
||||
COLLABORATION_BACKEND_BASE_URL=http://app-dev:8000
|
||||
|
||||
@@ -7,6 +7,7 @@ import os
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from functools import cache
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
from django.conf import settings
|
||||
@@ -15,7 +16,9 @@ from django.core.exceptions import ImproperlyConfigured
|
||||
from langfuse import get_client
|
||||
from langfuse.openai import OpenAI as OpenAI_Langfuse
|
||||
from pydantic_ai import Agent, DeferredToolRequests
|
||||
from pydantic_ai.models.mistral import MistralModel
|
||||
from pydantic_ai.models.openai import OpenAIChatModel
|
||||
from pydantic_ai.providers.mistral import MistralProvider
|
||||
from pydantic_ai.providers.openai import OpenAIProvider
|
||||
from pydantic_ai.tools import ToolDefinition
|
||||
from pydantic_ai.toolsets.external import ExternalToolset
|
||||
@@ -143,22 +146,59 @@ def convert_async_generator_to_sync(async_gen: AsyncIterator[str]) -> Iterator[s
|
||||
thread.join()
|
||||
|
||||
|
||||
@cache
|
||||
def configure_pydantic_model_provider() -> OpenAIChatModel | MistralModel:
|
||||
"""Configure a pydantic Model and return it."""
|
||||
if (
|
||||
settings.OPENAI_SDK_API_KEY
|
||||
and settings.OPENAI_SDK_BASE_URL
|
||||
and settings.AI_MODEL
|
||||
):
|
||||
return OpenAIChatModel(
|
||||
settings.AI_MODEL,
|
||||
provider=OpenAIProvider(
|
||||
api_key=settings.OPENAI_SDK_API_KEY,
|
||||
base_url=settings.OPENAI_SDK_BASE_URL,
|
||||
),
|
||||
)
|
||||
|
||||
if (
|
||||
settings.MISTRAL_SDK_API_KEY
|
||||
and settings.MISTRAL_SDK_BASE_URL
|
||||
and settings.AI_MODEL
|
||||
):
|
||||
return MistralModel(
|
||||
settings.AI_MODEL,
|
||||
provider=MistralProvider(
|
||||
api_key=settings.MISTRAL_SDK_API_KEY,
|
||||
base_url=settings.MISTRAL_SDK_BASE_URL,
|
||||
),
|
||||
)
|
||||
|
||||
raise ImproperlyConfigured("AI configuration not set")
|
||||
|
||||
|
||||
@cache
|
||||
def configure_legacy_openai_client():
|
||||
"""Configure the open ai sdk client for the legacy AI feature."""
|
||||
if (
|
||||
settings.OPENAI_SDK_BASE_URL is None
|
||||
or settings.OPENAI_SDK_API_KEY is None
|
||||
or settings.AI_MODEL is None
|
||||
):
|
||||
raise ImproperlyConfigured("AI configuration not set")
|
||||
return OpenAI(
|
||||
base_url=settings.OPENAI_SDK_BASE_URL, api_key=settings.OPENAI_SDK_API_KEY
|
||||
)
|
||||
|
||||
|
||||
class AIService:
|
||||
"""Service class for AI-related operations."""
|
||||
|
||||
def __init__(self):
|
||||
"""Ensure that the AI configuration is set properly."""
|
||||
if (
|
||||
settings.AI_BASE_URL is None
|
||||
or settings.AI_API_KEY is None
|
||||
or settings.AI_MODEL is None
|
||||
):
|
||||
raise ImproperlyConfigured("AI configuration not set")
|
||||
self.client = OpenAI(base_url=settings.AI_BASE_URL, api_key=settings.AI_API_KEY)
|
||||
|
||||
def call_ai_api(self, system_content, text):
|
||||
"""Helper method to call the OpenAI API and process the response."""
|
||||
response = self.client.chat.completions.create(
|
||||
client = configure_legacy_openai_client()
|
||||
response = client.chat.completions.create(
|
||||
model=settings.AI_MODEL,
|
||||
messages=[
|
||||
{"role": "system", "content": system_content},
|
||||
@@ -324,13 +364,9 @@ class AIService:
|
||||
langfuse.auth_check()
|
||||
Agent.instrument_all()
|
||||
|
||||
model = OpenAIChatModel(
|
||||
settings.AI_MODEL,
|
||||
provider=OpenAIProvider(
|
||||
base_url=settings.AI_BASE_URL, api_key=settings.AI_API_KEY
|
||||
),
|
||||
agent = Agent(
|
||||
configure_pydantic_model_provider(), instrument=instrument_enabled
|
||||
)
|
||||
agent = Agent(model, instrument=instrument_enabled)
|
||||
|
||||
accept = request.META.get("HTTP_ACCEPT", SSE_CONTENT_TYPE)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import pytest
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from core import factories
|
||||
from core.services.ai_services import configure_pydantic_model_provider
|
||||
from core.tests.conftest import TEAM, USER, VIA
|
||||
|
||||
pytestmark = pytest.mark.django_db
|
||||
@@ -20,13 +21,15 @@ pytestmark = pytest.mark.django_db
|
||||
def ai_settings(settings):
|
||||
"""Fixture to set AI settings."""
|
||||
settings.AI_MODEL = "llama"
|
||||
settings.AI_BASE_URL = "http://localhost-ai:12345/"
|
||||
settings.AI_API_KEY = "test-key"
|
||||
settings.OPENAI_SDK_BASE_URL = "http://localhost-ai:12345/"
|
||||
settings.OPENAI_SDK_API_KEY = "test-key"
|
||||
settings.AI_FEATURE_ENABLED = True
|
||||
settings.AI_FEATURE_BLOCKNOTE_ENABLED = True
|
||||
settings.AI_FEATURE_LEGACY_ENABLED = True
|
||||
settings.LANGFUSE_PUBLIC_KEY = None
|
||||
settings.AI_VERCEL_SDK_VERSION = 6
|
||||
yield
|
||||
configure_pydantic_model_provider.cache_clear()
|
||||
|
||||
|
||||
@override_settings(
|
||||
|
||||
@@ -2,47 +2,61 @@
|
||||
Test AI transform API endpoint for users in impress's core app.
|
||||
"""
|
||||
|
||||
import random
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.test import override_settings
|
||||
|
||||
import pytest
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from core import factories
|
||||
from core.services.ai_services import configure_legacy_openai_client
|
||||
from core.tests.conftest import TEAM, USER, VIA
|
||||
|
||||
pytestmark = pytest.mark.django_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ai_settings():
|
||||
def ai_settings(settings):
|
||||
"""Fixture to set AI settings."""
|
||||
with override_settings(
|
||||
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="llama"
|
||||
):
|
||||
yield
|
||||
settings.OPENAI_SDK_BASE_URL = "http://example.com"
|
||||
settings.OPENAI_SDK_API_KEY = "test-key"
|
||||
settings.AI_MODEL = "llama"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_openai_client_config():
|
||||
"""Clear the _configure_legacy_openai_client cache"""
|
||||
yield
|
||||
configure_legacy_openai_client.cache_clear()
|
||||
|
||||
|
||||
@override_settings(
|
||||
AI_ALLOW_REACH_FROM=random.choice(["public", "authenticated", "restricted"])
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"reach, role",
|
||||
"reach, role, ai_allow_reach_from",
|
||||
[
|
||||
("restricted", "reader"),
|
||||
("restricted", "editor"),
|
||||
("authenticated", "reader"),
|
||||
("authenticated", "editor"),
|
||||
("public", "reader"),
|
||||
("restricted", "reader", "public"),
|
||||
("restricted", "reader", "authenticated"),
|
||||
("restricted", "reader", "restricted"),
|
||||
("restricted", "editor", "public"),
|
||||
("restricted", "editor", "authenticated"),
|
||||
("restricted", "editor", "restrictied"),
|
||||
("authenticated", "reader", "public"),
|
||||
("authenticated", "reader", "authenticated"),
|
||||
("authenticated", "reader", "restricted"),
|
||||
("authenticated", "editor", "public"),
|
||||
("authenticated", "editor", "authenticated"),
|
||||
("authenticated", "editor", "restricted"),
|
||||
("public", "reader", "public"),
|
||||
("public", "reader", "authenticated"),
|
||||
("public", "reader", "restricted"),
|
||||
],
|
||||
)
|
||||
def test_api_documents_ai_transform_anonymous_forbidden(reach, role):
|
||||
def test_api_documents_ai_transform_anonymous_forbidden(
|
||||
reach, role, ai_allow_reach_from, settings
|
||||
):
|
||||
"""
|
||||
Anonymous users should not be able to request AI transform if the link reach
|
||||
and role don't allow it.
|
||||
"""
|
||||
settings.AI_ALLOW_REACH_FROM = ai_allow_reach_from
|
||||
document = factories.DocumentFactory(link_reach=reach, link_role=role)
|
||||
|
||||
url = f"/api/v1.0/documents/{document.id!s}/ai-transform/"
|
||||
@@ -54,14 +68,14 @@ def test_api_documents_ai_transform_anonymous_forbidden(reach, role):
|
||||
}
|
||||
|
||||
|
||||
@override_settings(AI_ALLOW_REACH_FROM="public")
|
||||
@pytest.mark.usefixtures("ai_settings")
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_transform_anonymous_success(mock_create):
|
||||
def test_api_documents_ai_transform_anonymous_success(mock_create, settings):
|
||||
"""
|
||||
Anonymous users should be able to request AI transform to a document
|
||||
if the link reach and role permit it.
|
||||
"""
|
||||
settings.AI_ALLOW_REACH_FROM = "public"
|
||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||
|
||||
mock_create.return_value = MagicMock(
|
||||
@@ -88,14 +102,17 @@ def test_api_documents_ai_transform_anonymous_success(mock_create):
|
||||
)
|
||||
|
||||
|
||||
@override_settings(AI_ALLOW_REACH_FROM=random.choice(["authenticated", "restricted"]))
|
||||
@pytest.mark.usefixtures("ai_settings")
|
||||
@pytest.mark.parametrize("ai_allow_reach_from", ["authenticated", "restricted"])
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_transform_anonymous_limited_by_setting(mock_create):
|
||||
def test_api_documents_ai_transform_anonymous_limited_by_setting(
|
||||
mock_create, ai_allow_reach_from, settings
|
||||
):
|
||||
"""
|
||||
Anonymous users should be able to request AI transform to a document
|
||||
if the link reach and role permit it.
|
||||
"""
|
||||
settings.AI_ALLOW_REACH_FROM = ai_allow_reach_from
|
||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||
|
||||
answer = '{"answer": "Salut"}'
|
||||
@@ -176,8 +193,8 @@ def test_api_documents_ai_transform_authenticated_success(mock_create, reach, ro
|
||||
"role": "system",
|
||||
"content": (
|
||||
"Answer the prompt using markdown formatting for structure and emphasis. "
|
||||
"Return the content directly without wrapping it in code blocks or markdown delimiters. "
|
||||
"Preserve the language and markdown formatting. "
|
||||
"Return the content directly without wrapping it in code blocks or markdown "
|
||||
"delimiters. Preserve the language and markdown formatting. "
|
||||
"Do not provide any other information. "
|
||||
"Preserve the language."
|
||||
),
|
||||
@@ -253,8 +270,8 @@ def test_api_documents_ai_transform_success(mock_create, via, role, mock_user_te
|
||||
"role": "system",
|
||||
"content": (
|
||||
"Answer the prompt using markdown formatting for structure and emphasis. "
|
||||
"Return the content directly without wrapping it in code blocks or markdown delimiters. "
|
||||
"Preserve the language and markdown formatting. "
|
||||
"Return the content directly without wrapping it in code blocks or markdown "
|
||||
"delimiters. Preserve the language and markdown formatting. "
|
||||
"Do not provide any other information. "
|
||||
"Preserve the language."
|
||||
),
|
||||
@@ -296,14 +313,14 @@ def test_api_documents_ai_transform_invalid_action():
|
||||
assert response.json() == {"action": ['"invalid" is not a valid choice.']}
|
||||
|
||||
|
||||
@override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
|
||||
@pytest.mark.usefixtures("ai_settings")
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_transform_throttling_document(mock_create):
|
||||
def test_api_documents_ai_transform_throttling_document(mock_create, settings):
|
||||
"""
|
||||
Throttling per document should be triggered on the AI transform endpoint.
|
||||
For full throttle class test see: `test_api_utils_ai_document_rate_throttles`
|
||||
"""
|
||||
settings.AI_DOCUMENT_RATE_THROTTLE_RATES = {"minute": 3, "hour": 6, "day": 10}
|
||||
client = APIClient()
|
||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||
|
||||
@@ -329,14 +346,14 @@ def test_api_documents_ai_transform_throttling_document(mock_create):
|
||||
}
|
||||
|
||||
|
||||
@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
|
||||
@pytest.mark.usefixtures("ai_settings")
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_transform_throttling_user(mock_create):
|
||||
def test_api_documents_ai_transform_throttling_user(mock_create, settings):
|
||||
"""
|
||||
Throttling per user should be triggered on the AI transform endpoint.
|
||||
For full throttle class test see: `test_api_utils_ai_user_rate_throttles`
|
||||
"""
|
||||
settings.AI_USER_RATE_THROTTLE_RATES = {"minute": 3, "hour": 6, "day": 10}
|
||||
user = factories.UserFactory()
|
||||
client = APIClient()
|
||||
client.force_login(user)
|
||||
|
||||
@@ -2,27 +2,31 @@
|
||||
Test AI translate API endpoint for users in impress's core app.
|
||||
"""
|
||||
|
||||
import random
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.test import override_settings
|
||||
|
||||
import pytest
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from core import factories
|
||||
from core.services.ai_services import configure_legacy_openai_client
|
||||
from core.tests.conftest import TEAM, USER, VIA
|
||||
|
||||
pytestmark = pytest.mark.django_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ai_settings():
|
||||
def ai_settings(settings):
|
||||
"""Fixture to set AI settings."""
|
||||
with override_settings(
|
||||
AI_BASE_URL="http://example.com", AI_API_KEY="test-key", AI_MODEL="llama"
|
||||
):
|
||||
yield
|
||||
settings.OPENAI_SDK_BASE_URL = "http://example.com"
|
||||
settings.OPENAI_SDK_API_KEY = "test-key"
|
||||
settings.AI_MODEL = "llama"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_openai_client_config():
|
||||
"clear the configure_legacy_openai_client cache"
|
||||
yield
|
||||
configure_legacy_openai_client.cache_clear()
|
||||
|
||||
|
||||
def test_api_documents_ai_translate_viewset_options_metadata():
|
||||
@@ -45,24 +49,34 @@ def test_api_documents_ai_translate_viewset_options_metadata():
|
||||
}
|
||||
|
||||
|
||||
@override_settings(
|
||||
AI_ALLOW_REACH_FROM=random.choice(["public", "authenticated", "restricted"])
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"reach, role",
|
||||
"reach, role, ai_allow_reach_from",
|
||||
[
|
||||
("restricted", "reader"),
|
||||
("restricted", "editor"),
|
||||
("authenticated", "reader"),
|
||||
("authenticated", "editor"),
|
||||
("public", "reader"),
|
||||
("restricted", "reader", "public"),
|
||||
("restricted", "reader", "authenticated"),
|
||||
("restricted", "reader", "restricted"),
|
||||
("restricted", "editor", "public"),
|
||||
("restricted", "editor", "authenticated"),
|
||||
("restricted", "editor", "restrictied"),
|
||||
("authenticated", "reader", "public"),
|
||||
("authenticated", "reader", "authenticated"),
|
||||
("authenticated", "reader", "restricted"),
|
||||
("authenticated", "editor", "public"),
|
||||
("authenticated", "editor", "authenticated"),
|
||||
("authenticated", "editor", "restricted"),
|
||||
("public", "reader", "public"),
|
||||
("public", "reader", "authenticated"),
|
||||
("public", "reader", "restricted"),
|
||||
],
|
||||
)
|
||||
def test_api_documents_ai_translate_anonymous_forbidden(reach, role):
|
||||
def test_api_documents_ai_translate_anonymous_forbidden(
|
||||
reach, role, ai_allow_reach_from, settings
|
||||
):
|
||||
"""
|
||||
Anonymous users should not be able to request AI translate if the link reach
|
||||
and role don't allow it.
|
||||
"""
|
||||
settings.AI_ALLOW_REACH_FROM = ai_allow_reach_from
|
||||
document = factories.DocumentFactory(link_reach=reach, link_role=role)
|
||||
|
||||
url = f"/api/v1.0/documents/{document.id!s}/ai-translate/"
|
||||
@@ -74,14 +88,14 @@ def test_api_documents_ai_translate_anonymous_forbidden(reach, role):
|
||||
}
|
||||
|
||||
|
||||
@override_settings(AI_ALLOW_REACH_FROM="public")
|
||||
@pytest.mark.usefixtures("ai_settings")
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_translate_anonymous_success(mock_create):
|
||||
def test_api_documents_ai_translate_anonymous_success(mock_create, settings):
|
||||
"""
|
||||
Anonymous users should be able to request AI translate to a document
|
||||
if the link reach and role permit it.
|
||||
"""
|
||||
settings.AI_ALLOW_REACH_FROM = "public"
|
||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||
|
||||
mock_create.return_value = MagicMock(
|
||||
@@ -110,14 +124,17 @@ def test_api_documents_ai_translate_anonymous_success(mock_create):
|
||||
)
|
||||
|
||||
|
||||
@override_settings(AI_ALLOW_REACH_FROM=random.choice(["authenticated", "restricted"]))
|
||||
@pytest.mark.usefixtures("ai_settings")
|
||||
@pytest.mark.parametrize("ai_allow_reach_from", ["authenticated", "restricted"])
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_translate_anonymous_limited_by_setting(mock_create):
|
||||
def test_api_documents_ai_translate_anonymous_limited_by_setting(
|
||||
mock_create, ai_allow_reach_from, settings
|
||||
):
|
||||
"""
|
||||
Anonymous users should be able to request AI translate to a document
|
||||
if the link reach and role permit it.
|
||||
"""
|
||||
settings.AI_ALLOW_REACH_FROM = ai_allow_reach_from
|
||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||
|
||||
answer = '{"answer": "Salut"}'
|
||||
@@ -318,14 +335,14 @@ def test_api_documents_ai_translate_invalid_action():
|
||||
assert response.json() == {"language": ['"invalid" is not a valid choice.']}
|
||||
|
||||
|
||||
@override_settings(AI_DOCUMENT_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
|
||||
@pytest.mark.usefixtures("ai_settings")
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_translate_throttling_document(mock_create):
|
||||
def test_api_documents_ai_translate_throttling_document(mock_create, settings):
|
||||
"""
|
||||
Throttling per document should be triggered on the AI translate endpoint.
|
||||
For full throttle class test see: `test_api_utils_ai_document_rate_throttles`
|
||||
"""
|
||||
settings.AI_DOCUMENT_RATE_THROTTLE_RATES = {"minute": 3, "hour": 6, "day": 10}
|
||||
client = APIClient()
|
||||
document = factories.DocumentFactory(link_reach="public", link_role="editor")
|
||||
|
||||
@@ -351,14 +368,14 @@ def test_api_documents_ai_translate_throttling_document(mock_create):
|
||||
}
|
||||
|
||||
|
||||
@override_settings(AI_USER_RATE_THROTTLE_RATES={"minute": 3, "hour": 6, "day": 10})
|
||||
@pytest.mark.usefixtures("ai_settings")
|
||||
@patch("openai.resources.chat.completions.Completions.create")
|
||||
def test_api_documents_ai_translate_throttling_user(mock_create):
|
||||
def test_api_documents_ai_translate_throttling_user(mock_create, settings):
|
||||
"""
|
||||
Throttling per user should be triggered on the AI translate endpoint.
|
||||
For full throttle class test see: `test_api_utils_ai_user_rate_throttles`
|
||||
"""
|
||||
settings.AI_USER_RATE_THROTTLE_RATES = {"minute": 3, "hour": 6, "day": 10}
|
||||
user = factories.UserFactory()
|
||||
client = APIClient()
|
||||
client.force_login(user)
|
||||
|
||||
@@ -14,6 +14,7 @@ import pytest
|
||||
from rest_framework.test import APIClient
|
||||
|
||||
from core import factories, models
|
||||
from core.services.ai_services import configure_legacy_openai_client
|
||||
from core.tests.documents.test_api_documents_ai_proxy import ( # pylint: disable=unused-import
|
||||
ai_settings,
|
||||
)
|
||||
@@ -23,6 +24,13 @@ pytestmark = pytest.mark.django_db
|
||||
# pylint: disable=unused-argument
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_openai_client_config():
|
||||
"""Clear the configure_legacy_openai_client cache."""
|
||||
yield
|
||||
configure_legacy_openai_client.cache_clear()
|
||||
|
||||
|
||||
def test_external_api_documents_ai_transform_not_allowed(
|
||||
user_token, resource_server_backend, user_specific_sub
|
||||
):
|
||||
|
||||
@@ -10,12 +10,16 @@ from django.core.exceptions import ImproperlyConfigured
|
||||
from django.test.utils import override_settings
|
||||
|
||||
import pytest
|
||||
from openai import OpenAIError
|
||||
from openai import OpenAI, OpenAIError
|
||||
from pydantic_ai.models.mistral import MistralModel
|
||||
from pydantic_ai.models.openai import OpenAIChatModel
|
||||
from pydantic_ai.ui.vercel_ai.request_types import TextUIPart, UIMessage
|
||||
|
||||
from core.services.ai_services import (
|
||||
BLOCKNOTE_TOOL_STRICT_PROMPT,
|
||||
AIService,
|
||||
configure_legacy_openai_client,
|
||||
configure_pydantic_model_provider,
|
||||
convert_async_generator_to_sync,
|
||||
)
|
||||
|
||||
@@ -26,35 +30,90 @@ pytestmark = pytest.mark.django_db
|
||||
def ai_settings(settings):
|
||||
"""Fixture to set AI settings."""
|
||||
settings.AI_MODEL = "llama"
|
||||
settings.AI_BASE_URL = "http://example.com"
|
||||
settings.AI_API_KEY = "test-key"
|
||||
settings.OPENAI_SDK_BASE_URL = "http://example.com"
|
||||
settings.OPENAI_SDK_API_KEY = "test-key"
|
||||
settings.AI_FEATURE_ENABLED = True
|
||||
settings.AI_FEATURE_BLOCKNOTE_ENABLED = True
|
||||
settings.AI_FEATURE_LEGACY_ENABLED = True
|
||||
settings.LANGFUSE_PUBLIC_KEY = None
|
||||
settings.AI_VERCEL_SDK_VERSION = 6
|
||||
yield
|
||||
configure_pydantic_model_provider.cache_clear()
|
||||
configure_legacy_openai_client.cache_clear()
|
||||
|
||||
|
||||
# -- AIService.__init__ --
|
||||
# -- AIService configure sdk--
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"setting_name, setting_value",
|
||||
[
|
||||
("AI_BASE_URL", None),
|
||||
("AI_API_KEY", None),
|
||||
("OPENAI_SDK_BASE_URL", None),
|
||||
("OPENAI_SDK_API_KEY", None),
|
||||
("AI_MODEL", None),
|
||||
],
|
||||
)
|
||||
def test_services_ai_setting_missing(setting_name, setting_value, settings):
|
||||
"""Setting should be set"""
|
||||
def test_ai_services_configure_legacy_openai_sdk_missing(
|
||||
setting_name, setting_value, settings
|
||||
):
|
||||
"""
|
||||
An exception must be raised if an expected settings is missing to configure the openai sdk.
|
||||
"""
|
||||
setattr(settings, setting_name, setting_value)
|
||||
|
||||
with pytest.raises(
|
||||
ImproperlyConfigured,
|
||||
match="AI configuration not set",
|
||||
):
|
||||
AIService()
|
||||
configure_legacy_openai_client()
|
||||
|
||||
|
||||
def test_ai_services_configure_legacy_openai_sdk(settings):
|
||||
"""With all required settings an open ai sdk instance should be configured."""
|
||||
settings.AI_MODEL = "llama"
|
||||
settings.OPENAI_SDK_BASE_URL = "http://example.com"
|
||||
settings.OPENAI_SDK_API_KEY = "test-key"
|
||||
|
||||
openai_sdk = configure_legacy_openai_client()
|
||||
|
||||
assert isinstance(openai_sdk, OpenAI)
|
||||
|
||||
|
||||
def test_ai_services_configure_pydantic_ai_model_openai(settings):
|
||||
"""When openai sdk settings are configured it should return an OpenAiChatModel."""
|
||||
settings.AI_MODEL = "llama"
|
||||
settings.OPENAI_SDK_BASE_URL = "http://example.com"
|
||||
settings.OPENAI_SDK_API_KEY = "test-key"
|
||||
|
||||
pydantic_ai_model = configure_pydantic_model_provider()
|
||||
assert isinstance(pydantic_ai_model, OpenAIChatModel)
|
||||
|
||||
|
||||
def test_ai_services_configure_pydantic_ai_model_mistral(settings):
|
||||
"""When mistral sdk settings are configured is should return a MistralModel."""
|
||||
settings.AI_MODEL = "llama"
|
||||
settings.OPENAI_SDK_BASE_URL = None
|
||||
settings.OPENAI_SDK_API_KEY = None
|
||||
settings.MISTRAL_SDK_API_KEY = "mistreal-sdk-key"
|
||||
settings.MISTRAL_SDK_BASE_URL = "https://mistral.base-url.com"
|
||||
|
||||
pydantic_ai_model = configure_pydantic_model_provider()
|
||||
assert isinstance(pydantic_ai_model, MistralModel)
|
||||
|
||||
|
||||
def test_ai_services_configure_pydantic_ai_model_no_settings(settings):
|
||||
"""When no settings are configured for a ai sdk it should raises an exception."""
|
||||
settings.AI_MODEL = None
|
||||
settings.OPENAI_SDK_BASE_URL = None
|
||||
settings.OPENAI_SDK_API_KEY = None
|
||||
settings.MISTRAL_SDK_API_KEY = None
|
||||
settings.MISTRAL_SDK_BASE_URL = None
|
||||
|
||||
with pytest.raises(
|
||||
ImproperlyConfigured,
|
||||
match="AI configuration not set",
|
||||
):
|
||||
configure_pydantic_model_provider()
|
||||
|
||||
|
||||
# -- AIService.transform --
|
||||
|
||||
@@ -801,8 +801,30 @@ class Base(Configuration):
|
||||
environ_name="AI_ALLOW_REACH_FROM",
|
||||
environ_prefix=None,
|
||||
)
|
||||
AI_API_KEY = SecretFileValue(None, environ_name="AI_API_KEY", environ_prefix=None)
|
||||
AI_BASE_URL = values.Value(None, environ_name="AI_BASE_URL", environ_prefix=None)
|
||||
|
||||
MISTRAL_SDK_BASE_URL = values.Value(
|
||||
None, environ_name="MISTRAL_SDK_BASE_URL", environ_prefix=None
|
||||
)
|
||||
MISTRAL_SDK_API_KEY = SecretFileValue(
|
||||
None, environ_name="MISTRAL_SDK_API_KEY", environ_prefix=None
|
||||
)
|
||||
|
||||
OPENAI_SDK_API_KEY = SecretFileValue(
|
||||
default=SecretFileValue( # retrocompatibility
|
||||
None,
|
||||
environ_name="AI_API_KEY",
|
||||
environ_prefix=None,
|
||||
),
|
||||
environ_name="OPENAI_SDK_API_KEY",
|
||||
environ_prefix=None,
|
||||
)
|
||||
OPENAI_SDK_BASE_URL = values.Value(
|
||||
default=values.Value( # retrocompatibility
|
||||
None, environ_name="AI_BASE_URL", environ_prefix=None
|
||||
),
|
||||
environ_name="OPENAI_SDK_BASE_URL",
|
||||
environ_prefix=None,
|
||||
)
|
||||
AI_BOT = values.DictValue(
|
||||
default={
|
||||
"name": _("Docs AI"),
|
||||
@@ -1138,6 +1160,11 @@ class Base(Configuration):
|
||||
}
|
||||
)
|
||||
|
||||
if cls.OPENAI_SDK_API_KEY and cls.MISTRAL_SDK_API_KEY:
|
||||
raise ValueError(
|
||||
"Both OPENAI_SDK and MISTRAL_SDK parameters can not be set simultaneously."
|
||||
)
|
||||
|
||||
|
||||
class Build(Base):
|
||||
"""Settings used when the application is built.
|
||||
|
||||
@@ -52,6 +52,7 @@ dependencies = [
|
||||
"langfuse==3.11.2",
|
||||
"lxml==6.0.2",
|
||||
"markdown==3.10.2",
|
||||
"mistralai==1.12.4",
|
||||
"mozilla-django-oidc==5.0.2",
|
||||
"nested-multipart-parser==1.6.0",
|
||||
"openai==2.24.0",
|
||||
|
||||
Reference in New Issue
Block a user