Add api_base and api_url kwarg and threading

This commit is contained in:
Dillon DuPont
2025-10-22 17:25:55 -07:00
parent 0d91fe6f38
commit f18103dc20
6 changed files with 33 additions and 8 deletions

View File

@@ -185,6 +185,8 @@ class ComputerAgent:
max_trajectory_budget: Optional[float | dict] = None,
telemetry_enabled: Optional[bool] = True,
trust_remote_code: Optional[bool] = False,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
**kwargs,
):
"""
@@ -225,6 +227,8 @@ class ComputerAgent:
self.telemetry_enabled = telemetry_enabled
self.kwargs = kwargs
self.trust_remote_code = trust_remote_code
self.api_key = api_key
self.api_base = api_base
# == Add built-in callbacks ==
@@ -593,7 +597,7 @@ class ComputerAgent:
# ============================================================================
async def run(
self, messages: Messages, stream: bool = False, **kwargs
self, messages: Messages, stream: bool = False, api_key: Optional[str] = None, api_base: Optional[str] = None, **kwargs
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Run the agent with the given messages using Computer protocol handler pattern.
@@ -617,8 +621,12 @@ class ComputerAgent:
await self._initialize_computers()
# Merge kwargs
# Merge kwargs and thread api credentials (run overrides constructor)
merged_kwargs = {**self.kwargs, **kwargs}
if (api_key is not None) or (self.api_key is not None):
merged_kwargs["api_key"] = api_key if api_key is not None else self.api_key
if (api_base is not None) or (self.api_base is not None):
merged_kwargs["api_base"] = api_base if api_base is not None else self.api_base
old_items = self._process_input(messages)
new_items = []
@@ -728,8 +736,14 @@ class ComputerAgent:
if not self.computer_handler:
raise ValueError("Computer tool or image_b64 is required for predict_click")
image_b64 = await self.computer_handler.screenshot()
# Pass along api credentials if available
click_kwargs: Dict[str, Any] = {}
if self.api_key is not None:
click_kwargs["api_key"] = self.api_key
if self.api_base is not None:
click_kwargs["api_base"] = self.api_base
return await self.agent_loop.predict_click(
model=self.model, image_b64=image_b64, instruction=instruction
model=self.model, image_b64=image_b64, instruction=instruction, **click_kwargs
)
return None

View File

@@ -1615,6 +1615,11 @@ Task: Click {instruction}. Output ONLY a click action on the target element.""",
"max_tokens": 100, # Keep response short for click prediction
"headers": {"anthropic-beta": tool_config["beta_flag"]},
}
# Thread optional API params
if "api_key" in kwargs and kwargs.get("api_key") is not None:
api_kwargs["api_key"] = kwargs.get("api_key")
if "api_base" in kwargs and kwargs.get("api_base") is not None:
api_kwargs["api_base"] = kwargs.get("api_base")
# Use liteLLM acompletion
response = await litellm.acompletion(**api_kwargs)

View File

@@ -24,7 +24,7 @@ class AsyncAgentConfig(Protocol):
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs,
**generation_kwargs,
) -> Dict[str, Any]:
"""
Predict the next step based on input items.
@@ -40,7 +40,9 @@ class AsyncAgentConfig(Protocol):
_on_api_end: Callback for API end
_on_usage: Callback for usage tracking
_on_screenshot: Callback for screenshot events
**kwargs: Additional arguments
**generation_kwargs: Additional arguments for generation
- api_key: Optional API key for the provider
- api_base: Optional API base URL for the provider
Returns:
Dictionary with "output" (output items) and "usage" array
@@ -49,7 +51,7 @@ class AsyncAgentConfig(Protocol):
@abstractmethod
async def predict_click(
self, model: str, image_b64: str, instruction: str
self, model: str, image_b64: str, instruction: str, **generation_config
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates based on image and instruction.

View File

@@ -762,6 +762,7 @@ class Glm4vConfig(AsyncAgentConfig):
# "skip_special_tokens": False,
# }
}
api_kwargs.update({k: v for k, v in (kwargs or {}).items()})
# Add API callbacks
if _on_api_start:
@@ -852,6 +853,7 @@ Where x,y are coordinates normalized to 0-999 range."""
"skip_special_tokens": False,
},
}
api_kwargs.update({k: v for k, v in (kwargs or {}).items()})
# Call liteLLM
response = await litellm.acompletion(**api_kwargs)

View File

@@ -140,7 +140,7 @@ class OpenAIComputerUseConfig:
return output_dict
async def predict_click(
self, model: str, image_b64: str, instruction: str
self, model: str, image_b64: str, instruction: str, **kwargs
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates based on image and instruction.
@@ -208,6 +208,7 @@ Task: Click {instruction}. Output ONLY a click action on the target element.""",
"reasoning": {"summary": "concise"},
"truncation": "auto",
"max_tokens": 200, # Keep response short for click prediction
**kwargs,
}
# Use liteLLM responses

View File

@@ -773,7 +773,7 @@ class UITARSConfig:
return agent_response
async def predict_click(
self, model: str, image_b64: str, instruction: str
self, model: str, image_b64: str, instruction: str, **kwargs
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates based on image and instruction.
@@ -819,6 +819,7 @@ class UITARSConfig:
"temperature": 0.0,
"do_sample": False,
}
api_kwargs.update({k: v for k, v in (kwargs or {}).items()})
# Call liteLLM with UITARS model
response = await litellm.acompletion(**api_kwargs)