mirror of
https://github.com/browser-use/browser-use
synced 2026-05-06 17:52:15 +02:00
Automaticlly detect function calling method to fix structured output for ChatOpenAI and upgrade to langchain-openai>=0.3.0
This commit is contained in:
@@ -10,6 +10,7 @@ from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
HumanMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
@@ -34,7 +35,6 @@ class MessageManager:
|
||||
include_attributes: list[str] = [],
|
||||
max_error_length: int = 400,
|
||||
max_actions_per_step: int = 10,
|
||||
tool_call_in_content: bool = True,
|
||||
):
|
||||
self.llm = llm
|
||||
self.system_prompt_class = system_prompt_class
|
||||
@@ -55,7 +55,7 @@ class MessageManager:
|
||||
|
||||
self._add_message_with_tokens(system_message)
|
||||
self.system_prompt = system_message
|
||||
self.tool_call_in_content = tool_call_in_content
|
||||
self.tool_id = 1
|
||||
tool_calls = [
|
||||
{
|
||||
'name': 'AgentOutput',
|
||||
@@ -63,28 +63,26 @@ class MessageManager:
|
||||
'current_state': {
|
||||
'evaluation_previous_goal': 'Unknown - No previous actions to evaluate.',
|
||||
'memory': '',
|
||||
'next_goal': 'Obtain task from user',
|
||||
'next_goal': 'Start browser',
|
||||
},
|
||||
'action': [],
|
||||
},
|
||||
'id': '',
|
||||
'id': str(self.tool_id),
|
||||
'type': 'tool_call',
|
||||
}
|
||||
]
|
||||
if self.tool_call_in_content:
|
||||
# openai throws error if tool_calls are not responded -> move to content
|
||||
example_tool_call = AIMessage(
|
||||
content=f'{tool_calls}',
|
||||
tool_calls=[],
|
||||
)
|
||||
else:
|
||||
example_tool_call = AIMessage(
|
||||
content=f'',
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
example_tool_call = AIMessage(
|
||||
content=f'',
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
self._add_message_with_tokens(example_tool_call)
|
||||
|
||||
tool_message = ToolMessage(
|
||||
content=f'Browser started',
|
||||
tool_call_id=str(self.tool_id),
|
||||
)
|
||||
self._add_message_with_tokens(tool_message)
|
||||
self.tool_id += 1
|
||||
task_message = self.task_instructions(task)
|
||||
self._add_message_with_tokens(task_message)
|
||||
|
||||
@@ -138,22 +136,18 @@ class MessageManager:
|
||||
{
|
||||
'name': 'AgentOutput',
|
||||
'args': model_output.model_dump(mode='json', exclude_unset=True),
|
||||
'id': '',
|
||||
'id': str(self.tool_id),
|
||||
'type': 'tool_call',
|
||||
}
|
||||
]
|
||||
if self.tool_call_in_content:
|
||||
msg = AIMessage(
|
||||
content=f'{tool_calls}',
|
||||
tool_calls=[],
|
||||
)
|
||||
else:
|
||||
msg = AIMessage(
|
||||
content='',
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
msg = AIMessage(
|
||||
content='',
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
self._add_message_with_tokens(msg)
|
||||
# empty tool response
|
||||
|
||||
def get_messages(self) -> List[BaseMessage]:
|
||||
"""Get current message list, potentially trimmed to max tokens"""
|
||||
|
||||
@@ -67,7 +67,7 @@ class Agent:
|
||||
controller: Controller = Controller(),
|
||||
use_vision: bool = True,
|
||||
save_conversation_path: Optional[str] = None,
|
||||
save_conversation_path_encoding: Optional[str] = "utf-8",
|
||||
save_conversation_path_encoding: Optional[str] = 'utf-8',
|
||||
max_failures: int = 3,
|
||||
retry_delay: int = 10,
|
||||
system_prompt_class: Type[SystemPrompt] = SystemPrompt,
|
||||
@@ -88,7 +88,7 @@ class Agent:
|
||||
],
|
||||
max_error_length: int = 400,
|
||||
max_actions_per_step: int = 10,
|
||||
tool_call_in_content: bool = True,
|
||||
tool_calling_method: Optional[str] = 'auto',
|
||||
):
|
||||
self.agent_id = str(uuid.uuid4()) # unique identifier for the agent
|
||||
|
||||
@@ -131,9 +131,12 @@ class Agent:
|
||||
|
||||
# Action and output models setup
|
||||
self._setup_action_models()
|
||||
|
||||
self._set_version_and_source()
|
||||
self.max_input_tokens = max_input_tokens
|
||||
self.tool_call_in_content = tool_call_in_content
|
||||
|
||||
self._set_model_names()
|
||||
|
||||
self.tool_calling_method = self.set_tool_calling_method(tool_calling_method)
|
||||
|
||||
self.message_manager = MessageManager(
|
||||
llm=self.llm,
|
||||
@@ -144,7 +147,6 @@ class Agent:
|
||||
include_attributes=self.include_attributes,
|
||||
max_error_length=self.max_error_length,
|
||||
max_actions_per_step=self.max_actions_per_step,
|
||||
tool_call_in_content=tool_call_in_content,
|
||||
)
|
||||
|
||||
# Tracking variables
|
||||
@@ -158,6 +160,36 @@ class Agent:
|
||||
if save_conversation_path:
|
||||
logger.info(f'Saving conversation to {save_conversation_path}')
|
||||
|
||||
def _set_version_and_source(self) -> None:
|
||||
try:
|
||||
import pkg_resources
|
||||
|
||||
version = pkg_resources.get_distribution('browser-use').version
|
||||
source = 'pip'
|
||||
except Exception:
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
version = (
|
||||
subprocess.check_output(['git', 'describe', '--tags']).decode('utf-8').strip()
|
||||
)
|
||||
source = 'git'
|
||||
except Exception:
|
||||
version = 'unknown'
|
||||
source = 'unknown'
|
||||
logger.debug(f'Version: {version}, Source: {source}')
|
||||
self.version = version
|
||||
self.source = source
|
||||
|
||||
def _set_model_names(self) -> None:
|
||||
self.chat_model_library = self.llm.__class__.__name__
|
||||
if hasattr(self.llm, 'model_name'):
|
||||
self.model_name = self.llm.model_name # type: ignore
|
||||
elif hasattr(self.llm, 'model'):
|
||||
self.model_name = self.llm.model # type: ignore
|
||||
else:
|
||||
self.model_name = 'Unknown'
|
||||
|
||||
def _setup_action_models(self) -> None:
|
||||
"""Setup dynamic action models from controller's registry"""
|
||||
# Get the dynamic action model from controller's registry
|
||||
@@ -165,6 +197,17 @@ class Agent:
|
||||
# Create output model with the dynamic actions
|
||||
self.AgentOutput = AgentOutput.type_with_custom_actions(self.ActionModel)
|
||||
|
||||
def set_tool_calling_method(self, tool_calling_method: Optional[str]) -> Optional[str]:
|
||||
if tool_calling_method == 'auto':
|
||||
if self.chat_model_library == 'ChatGoogleGenerativeAI':
|
||||
return None
|
||||
elif self.chat_model_library == 'ChatOpenAI':
|
||||
return 'function_calling'
|
||||
elif self.chat_model_library == 'AzureChatOpenAI':
|
||||
return 'function_calling'
|
||||
else:
|
||||
return None
|
||||
|
||||
@time_execution_async('--step')
|
||||
async def step(self, step_info: Optional[AgentStepInfo] = None) -> None:
|
||||
"""Execute one step of the task"""
|
||||
@@ -284,8 +327,13 @@ class Agent:
|
||||
@time_execution_async('--get_next_action')
|
||||
async def get_next_action(self, input_messages: list[BaseMessage]) -> AgentOutput:
|
||||
"""Get next action from LLM based on current state"""
|
||||
if self.tool_calling_method is None:
|
||||
structured_llm = self.llm.with_structured_output(self.AgentOutput, include_raw=True)
|
||||
else:
|
||||
structured_llm = self.llm.with_structured_output(
|
||||
self.AgentOutput, include_raw=True, method=self.tool_calling_method
|
||||
)
|
||||
|
||||
structured_llm = self.llm.with_structured_output(self.AgentOutput, include_raw=True)
|
||||
response: dict[str, Any] = await structured_llm.ainvoke(input_messages) # type: ignore
|
||||
|
||||
parsed: AgentOutput = response['parsed']
|
||||
@@ -324,7 +372,11 @@ class Agent:
|
||||
# create folders if not exists
|
||||
os.makedirs(os.path.dirname(self.save_conversation_path), exist_ok=True)
|
||||
|
||||
with open(self.save_conversation_path + f'_{self.n_steps}.txt', 'w', encoding=self.save_conversation_path_encoding) as f:
|
||||
with open(
|
||||
self.save_conversation_path + f'_{self.n_steps}.txt',
|
||||
'w',
|
||||
encoding=self.save_conversation_path_encoding,
|
||||
) as f:
|
||||
self._write_messages_to_file(f, input_messages)
|
||||
self._write_response_to_file(f, response)
|
||||
|
||||
@@ -354,41 +406,17 @@ class Agent:
|
||||
def _log_agent_run(self) -> None:
|
||||
"""Log the agent run"""
|
||||
logger.info(f'🚀 Starting task: {self.task}')
|
||||
# model_name is eiter model or model_name
|
||||
if hasattr(self.llm, 'model_name'):
|
||||
model_name = self.llm.model_name # type: ignore
|
||||
elif hasattr(self.llm, 'model'):
|
||||
model_name = self.llm.model # type: ignore
|
||||
else:
|
||||
model_name = 'Unknown'
|
||||
|
||||
try:
|
||||
import pkg_resources
|
||||
|
||||
version = pkg_resources.get_distribution('browser-use').version
|
||||
source = 'pip'
|
||||
except Exception:
|
||||
try:
|
||||
import subprocess
|
||||
|
||||
version = (
|
||||
subprocess.check_output(['git', 'describe', '--tags']).decode('utf-8').strip()
|
||||
)
|
||||
source = 'git'
|
||||
except Exception:
|
||||
version = 'unknown'
|
||||
source = 'unknown'
|
||||
logger.debug(f'Version: {version}, Source: {source}')
|
||||
logger.debug(f'Version: {self.version}, Source: {self.source}')
|
||||
self.telemetry.capture(
|
||||
AgentRunTelemetryEvent(
|
||||
agent_id=self.agent_id,
|
||||
use_vision=self.use_vision,
|
||||
tool_call_in_content=self.tool_call_in_content,
|
||||
task=self.task,
|
||||
model_name=model_name,
|
||||
chat_model_library=self.llm.__class__.__name__,
|
||||
version=version,
|
||||
source=source,
|
||||
model_name=self.model_name,
|
||||
chat_model_library=self.chat_model_library,
|
||||
version=self.version,
|
||||
source=self.source,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -803,7 +831,7 @@ class Agent:
|
||||
margin: int,
|
||||
logo: Optional[Image.Image] = None,
|
||||
display_step: bool = True,
|
||||
text_color: tuple[int,int,int,int] = (255, 255, 255, 255),
|
||||
text_color: tuple[int, int, int, int] = (255, 255, 255, 255),
|
||||
text_box_color: tuple[int, int, int, int] = (0, 0, 0, 255),
|
||||
) -> Image.Image:
|
||||
"""Add step number and goal overlay to an image."""
|
||||
|
||||
@@ -17,7 +17,7 @@ dependencies = [
|
||||
"beautifulsoup4>=4.12.3",
|
||||
"httpx>=0.27.2",
|
||||
"langchain>=0.3.14",
|
||||
"langchain-openai==0.2.14",
|
||||
"langchain-openai>=0.3.0",
|
||||
"langchain-anthropic>=0.3.1",
|
||||
"langchain-fireworks>=0.2.5",
|
||||
"langchain-aws>=0.2.10",
|
||||
|
||||
Reference in New Issue
Block a user