Automaticlly detect function calling method to fix structured output for ChatOpenAI and upgrade to langchain-openai>=0.3.0

This commit is contained in:
magmueller
2025-01-19 11:33:20 -08:00
parent 8c090f2821
commit 41ebcddab6
3 changed files with 87 additions and 65 deletions

View File

@@ -10,6 +10,7 @@ from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
ToolMessage,
)
from langchain_openai import ChatOpenAI
@@ -34,7 +35,6 @@ class MessageManager:
include_attributes: list[str] = [],
max_error_length: int = 400,
max_actions_per_step: int = 10,
tool_call_in_content: bool = True,
):
self.llm = llm
self.system_prompt_class = system_prompt_class
@@ -55,7 +55,7 @@ class MessageManager:
self._add_message_with_tokens(system_message)
self.system_prompt = system_message
self.tool_call_in_content = tool_call_in_content
self.tool_id = 1
tool_calls = [
{
'name': 'AgentOutput',
@@ -63,28 +63,26 @@ class MessageManager:
'current_state': {
'evaluation_previous_goal': 'Unknown - No previous actions to evaluate.',
'memory': '',
'next_goal': 'Obtain task from user',
'next_goal': 'Start browser',
},
'action': [],
},
'id': '',
'id': str(self.tool_id),
'type': 'tool_call',
}
]
if self.tool_call_in_content:
# openai throws error if tool_calls are not responded -> move to content
example_tool_call = AIMessage(
content=f'{tool_calls}',
tool_calls=[],
)
else:
example_tool_call = AIMessage(
content=f'',
tool_calls=tool_calls,
)
example_tool_call = AIMessage(
content=f'',
tool_calls=tool_calls,
)
self._add_message_with_tokens(example_tool_call)
tool_message = ToolMessage(
content=f'Browser started',
tool_call_id=str(self.tool_id),
)
self._add_message_with_tokens(tool_message)
self.tool_id += 1
task_message = self.task_instructions(task)
self._add_message_with_tokens(task_message)
@@ -138,22 +136,18 @@ class MessageManager:
{
'name': 'AgentOutput',
'args': model_output.model_dump(mode='json', exclude_unset=True),
'id': '',
'id': str(self.tool_id),
'type': 'tool_call',
}
]
if self.tool_call_in_content:
msg = AIMessage(
content=f'{tool_calls}',
tool_calls=[],
)
else:
msg = AIMessage(
content='',
tool_calls=tool_calls,
)
msg = AIMessage(
content='',
tool_calls=tool_calls,
)
self._add_message_with_tokens(msg)
# empty tool response
def get_messages(self) -> List[BaseMessage]:
"""Get current message list, potentially trimmed to max tokens"""

View File

@@ -67,7 +67,7 @@ class Agent:
controller: Controller = Controller(),
use_vision: bool = True,
save_conversation_path: Optional[str] = None,
save_conversation_path_encoding: Optional[str] = "utf-8",
save_conversation_path_encoding: Optional[str] = 'utf-8',
max_failures: int = 3,
retry_delay: int = 10,
system_prompt_class: Type[SystemPrompt] = SystemPrompt,
@@ -88,7 +88,7 @@ class Agent:
],
max_error_length: int = 400,
max_actions_per_step: int = 10,
tool_call_in_content: bool = True,
tool_calling_method: Optional[str] = 'auto',
):
self.agent_id = str(uuid.uuid4()) # unique identifier for the agent
@@ -131,9 +131,12 @@ class Agent:
# Action and output models setup
self._setup_action_models()
self._set_version_and_source()
self.max_input_tokens = max_input_tokens
self.tool_call_in_content = tool_call_in_content
self._set_model_names()
self.tool_calling_method = self.set_tool_calling_method(tool_calling_method)
self.message_manager = MessageManager(
llm=self.llm,
@@ -144,7 +147,6 @@ class Agent:
include_attributes=self.include_attributes,
max_error_length=self.max_error_length,
max_actions_per_step=self.max_actions_per_step,
tool_call_in_content=tool_call_in_content,
)
# Tracking variables
@@ -158,6 +160,36 @@ class Agent:
if save_conversation_path:
logger.info(f'Saving conversation to {save_conversation_path}')
def _set_version_and_source(self) -> None:
try:
import pkg_resources
version = pkg_resources.get_distribution('browser-use').version
source = 'pip'
except Exception:
try:
import subprocess
version = (
subprocess.check_output(['git', 'describe', '--tags']).decode('utf-8').strip()
)
source = 'git'
except Exception:
version = 'unknown'
source = 'unknown'
logger.debug(f'Version: {version}, Source: {source}')
self.version = version
self.source = source
def _set_model_names(self) -> None:
self.chat_model_library = self.llm.__class__.__name__
if hasattr(self.llm, 'model_name'):
self.model_name = self.llm.model_name # type: ignore
elif hasattr(self.llm, 'model'):
self.model_name = self.llm.model # type: ignore
else:
self.model_name = 'Unknown'
def _setup_action_models(self) -> None:
"""Setup dynamic action models from controller's registry"""
# Get the dynamic action model from controller's registry
@@ -165,6 +197,17 @@ class Agent:
# Create output model with the dynamic actions
self.AgentOutput = AgentOutput.type_with_custom_actions(self.ActionModel)
def set_tool_calling_method(self, tool_calling_method: Optional[str]) -> Optional[str]:
if tool_calling_method == 'auto':
if self.chat_model_library == 'ChatGoogleGenerativeAI':
return None
elif self.chat_model_library == 'ChatOpenAI':
return 'function_calling'
elif self.chat_model_library == 'AzureChatOpenAI':
return 'function_calling'
else:
return None
@time_execution_async('--step')
async def step(self, step_info: Optional[AgentStepInfo] = None) -> None:
"""Execute one step of the task"""
@@ -284,8 +327,13 @@ class Agent:
@time_execution_async('--get_next_action')
async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput:
"""Get next action from LLM based on current state"""
if self.tool_calling_method is None:
structured_llm = self.llm.with_structured_output(self.AgentOutput, include_raw=True)
else:
structured_llm = self.llm.with_structured_output(
self.AgentOutput, include_raw=True, method=self.tool_calling_method
)
structured_llm = self.llm.with_structured_output(self.AgentOutput, include_raw=True)
response: dict[str, Any] = await structured_llm.ainvoke(input_messages) # type: ignore
parsed: AgentOutput = response['parsed']
@@ -324,7 +372,11 @@ class Agent:
# create folders if not exists
os.makedirs(os.path.dirname(self.save_conversation_path), exist_ok=True)
with open(self.save_conversation_path + f'_{self.n_steps}.txt', 'w', encoding=self.save_conversation_path_encoding) as f:
with open(
self.save_conversation_path + f'_{self.n_steps}.txt',
'w',
encoding=self.save_conversation_path_encoding,
) as f:
self._write_messages_to_file(f, input_messages)
self._write_response_to_file(f, response)
@@ -354,41 +406,17 @@ class Agent:
def _log_agent_run(self) -> None:
"""Log the agent run"""
logger.info(f'🚀 Starting task: {self.task}')
# model_name is eiter model or model_name
if hasattr(self.llm, 'model_name'):
model_name = self.llm.model_name # type: ignore
elif hasattr(self.llm, 'model'):
model_name = self.llm.model # type: ignore
else:
model_name = 'Unknown'
try:
import pkg_resources
version = pkg_resources.get_distribution('browser-use').version
source = 'pip'
except Exception:
try:
import subprocess
version = (
subprocess.check_output(['git', 'describe', '--tags']).decode('utf-8').strip()
)
source = 'git'
except Exception:
version = 'unknown'
source = 'unknown'
logger.debug(f'Version: {version}, Source: {source}')
logger.debug(f'Version: {self.version}, Source: {self.source}')
self.telemetry.capture(
AgentRunTelemetryEvent(
agent_id=self.agent_id,
use_vision=self.use_vision,
tool_call_in_content=self.tool_call_in_content,
task=self.task,
model_name=model_name,
chat_model_library=self.llm.__class__.__name__,
version=version,
source=source,
model_name=self.model_name,
chat_model_library=self.chat_model_library,
version=self.version,
source=self.source,
)
)
@@ -803,7 +831,7 @@ class Agent:
margin: int,
logo: Optional[Image.Image] = None,
display_step: bool = True,
text_color: tuple[int,int,int,int] = (255, 255, 255, 255),
text_color: tuple[int, int, int, int] = (255, 255, 255, 255),
text_box_color: tuple[int, int, int, int] = (0, 0, 0, 255),
) -> Image.Image:
"""Add step number and goal overlay to an image."""

View File

@@ -17,7 +17,7 @@ dependencies = [
"beautifulsoup4>=4.12.3",
"httpx>=0.27.2",
"langchain>=0.3.14",
"langchain-openai==0.2.14",
"langchain-openai>=0.3.0",
"langchain-anthropic>=0.3.1",
"langchain-fireworks>=0.2.5",
"langchain-aws>=0.2.10",