diff --git a/frontend/src/utils/chat/agent.js b/frontend/src/utils/chat/agent.js index 9f3bb479d..8e33e0967 100644 --- a/frontend/src/utils/chat/agent.js +++ b/frontend/src/utils/chat/agent.js @@ -158,6 +158,13 @@ export default function handleSocketResponse(socket, event, setChatHistory) { ); } + if (type === "chatId") { + if (!data.content.chatId) return prev; + return prev.map((msg) => + msg.uuid === uuid ? { ...msg, chatId: data.content.chatId } : msg + ); + } + if (type === "textResponseChunk") { return prev .map((msg) => diff --git a/server/models/workspaceChats.js b/server/models/workspaceChats.js index e48807be7..494c761f7 100644 --- a/server/models/workspaceChats.js +++ b/server/models/workspaceChats.js @@ -315,6 +315,45 @@ const WorkspaceChats = { return { chats: null, message: error.message }; } }, + upsert: async function ( + chatId = null, + data = { + workspaceId: null, + prompt: null, + response: {}, + user: null, + threadId: null, + include: true, + apiSessionId: null, + } + ) { + try { + const payload = { + workspaceId: data.workspaceId, + response: safeJSONStringify(data.response), + user_id: data.user?.id || null, + thread_id: data.threadId, + api_session_id: data.apiSessionId, + include: data.include, + }; + + const { chat } = await prisma.workspace_chats.upsert({ + where: { + id: Number(chatId), + user_id: data.user?.id || null, + }, + // On updates, we already have the prompt so we don't need to set it again. + update: { ...payload, lastUpdatedAt: new Date() }, + + // On creates, we need to set the prompt or else record will fail. + create: { ...payload, prompt: data.prompt }, + }); + return { chat, message: null }; + } catch (error) { + console.error(error.message); + return { chat: null, message: error.message }; + } + }, }; module.exports = { WorkspaceChats }; diff --git a/server/utils/agents/aibitat/index.js b/server/utils/agents/aibitat/index.js index c372740f5..edfef8b9b 100644 --- a/server/utils/agents/aibitat/index.js +++ b/server/utils/agents/aibitat/index.js @@ -33,7 +33,7 @@ class AIbitat { defaultInterrupt; maxRounds; _chats; - + _trackedChatId = null; agents = new Map(); channels = new Map(); functions = new Map(); @@ -114,6 +114,44 @@ class AIbitat { return this; } + /** + * Register a new chat ID for tracking for a given conversation exchange + * @param {number} chatId - The ID of the chat to register. + */ + registerChatId(chatId = null) { + if (!chatId) return; + this._trackedChatId = Number(chatId); + } + + /** + * Get the tracked chat ID for a given conversation exchange + * @returns {number|null} The ID of the chat to register. + */ + get trackedChatId() { + return this._trackedChatId ?? null; + } + + /** + * Clear the tracked chat ID for a given conversation exchange + */ + clearTrackedChatId() { + this._trackedChatId = null; + } + + /** + * Emit the tracked chat ID to the frontend via the websocket + * plugin (assumed to be attached). + * @param {string} [uuid] - The message UUID to associate with this chatId + */ + emitChatId(uuid = null) { + if (!this.trackedChatId || !uuid) return null; + this.socket?.send?.("reportStreamEvent", { + type: "chatId", + uuid, + chatId: this.trackedChatId, + }); + } + /** * Add citation(s) to be reported when the response is finalized. * Citations are buffered and flushed with the correct message UUID. @@ -916,6 +954,7 @@ https://docs.anythingllm.com/agent/intelligent-tool-selection metrics: provider.getUsage(), }); this?.flushCitations?.(directOutputUUID); + this?.emitChatId?.(directOutputUUID); return result; } @@ -957,6 +996,7 @@ https://docs.anythingllm.com/agent/intelligent-tool-selection metrics: provider.getUsage(), }); this?.flushCitations?.(responseUuid); + this?.emitChatId?.(responseUuid); return completionStream?.textResponse; } @@ -1096,6 +1136,7 @@ https://docs.anythingllm.com/agent/intelligent-tool-selection metrics: provider.getUsage(), }); this?.flushCitations?.(msgUUID); + this?.emitChatId?.(msgUUID); return completion?.textResponse; } diff --git a/server/utils/agents/aibitat/plugins/chat-history.js b/server/utils/agents/aibitat/plugins/chat-history.js index 99b0ca641..647aef1b7 100644 --- a/server/utils/agents/aibitat/plugins/chat-history.js +++ b/server/utils/agents/aibitat/plugins/chat-history.js @@ -12,6 +12,46 @@ const chatHistory = { return { name: this.name, setup: function (aibitat) { + // pre-register a workspace chat ID to secure it in the DB + aibitat.onMessage(async (message) => { + if (message.from !== "USER") return; + + /** + * If we don't have a tracked chat ID, we need to create a new one so we can upsert the response later. + * Normally, if this was a totally fresh chat from the user, we can assume that the message from the socket is + * the message we want to store for the prompt. However, if this is a regeneration of a previous message and that message + * called tools the history could include intermediate messages so need to search backwards to find the most recent user message + * as that is actually the prompt. + */ + if (!aibitat.trackedChatId) { + let userMessage = message.content; + if (userMessage.startsWith("@agent:")) { + const lastUserMsgIndex = aibitat._chats.findLastIndex( + (c) => c.from === "USER" && !c.content.startsWith("@agent:") + ); + + // When regenerating a message, we need to use the last user message as the prompt. + // Also prune the chats array to only include the messages before target prompt to re-run + // or else tool call results from the previous run will be included in the history and the model will not re-call tools + // that previously worked for the to-be-regenerated prompt. + if (lastUserMsgIndex !== -1) { + userMessage = aibitat._chats[lastUserMsgIndex].content; + aibitat._chats = aibitat._chats.slice(0, lastUserMsgIndex + 1); + } + } + + const { chat } = await WorkspaceChats.new({ + workspaceId: Number(aibitat.handlerProps.invocation.workspace_id), + user: { id: aibitat.handlerProps.invocation.user_id || null }, + threadId: aibitat.handlerProps.invocation.thread_id || null, + include: false, + prompt: userMessage, + response: {}, + }); + if (chat) aibitat.registerChatId(chat.id); + } + }); + aibitat.onMessage(async () => { try { const lastResponses = aibitat.chats.slice(-2); @@ -54,7 +94,7 @@ const chatHistory = { const metrics = aibitat.provider?.getUsage?.() ?? {}; const citations = aibitat._pendingCitations ?? []; const outputs = aibitat._pendingOutputs ?? []; - await WorkspaceChats.new({ + await WorkspaceChats.upsert(aibitat.trackedChatId, { workspaceId: Number(invocation.workspace_id), prompt, response: { @@ -67,9 +107,9 @@ const chatHistory = { }, user: { id: invocation?.user_id || null }, threadId: invocation?.thread_id || null, + include: true, }); - aibitat.clearCitations?.(); - aibitat._pendingOutputs = []; + this._cleanup(aibitat); }, _storeSpecial: async function ( aibitat, @@ -80,7 +120,7 @@ const chatHistory = { const citations = aibitat._pendingCitations ?? []; const outputs = aibitat._pendingOutputs ?? []; const existingSources = options?.sources ?? []; - await WorkspaceChats.new({ + await WorkspaceChats.upsert(aibitat.trackedChatId, { workspaceId: Number(invocation.workspace_id), prompt, response: { @@ -97,10 +137,16 @@ const chatHistory = { }, user: { id: invocation?.user_id || null }, threadId: invocation?.thread_id || null, + include: true, }); + options?.postSave(); + this._cleanup(aibitat); + }, + + _cleanup: function (aibitat) { aibitat.clearCitations?.(); aibitat._pendingOutputs = []; - options?.postSave(); + aibitat.clearTrackedChatId(); }, }; },