mirror of
https://github.com/trycua/computer.git
synced 2026-01-07 14:00:04 -06:00
Merge pull request #504 from trycua/feat/api_key_overrides
[Agent] Add support for overriding api_base and api_url kwargs
This commit is contained in:
@@ -185,7 +185,9 @@ class ComputerAgent:
|
||||
max_trajectory_budget: Optional[float | dict] = None,
|
||||
telemetry_enabled: Optional[bool] = True,
|
||||
trust_remote_code: Optional[bool] = False,
|
||||
**kwargs,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
**additional_generation_kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize ComputerAgent.
|
||||
@@ -205,7 +207,9 @@ class ComputerAgent:
|
||||
max_trajectory_budget: If set, adds BudgetManagerCallback to track usage costs and stop when budget is exceeded
|
||||
telemetry_enabled: If set, adds TelemetryCallback to track anonymized usage data. Enabled by default.
|
||||
trust_remote_code: If set, trust remote code when loading local models. Disabled by default.
|
||||
**kwargs: Additional arguments passed to the agent loop
|
||||
api_key: Optional API key override for the model provider
|
||||
api_base: Optional API base URL override for the model provider
|
||||
**additional_generation_kwargs: Additional arguments passed to the model provider
|
||||
"""
|
||||
# If the loop is "human/human", we need to prefix a grounding model fallback
|
||||
if model in ["human/human", "human"]:
|
||||
@@ -223,8 +227,10 @@ class ComputerAgent:
|
||||
self.screenshot_delay = screenshot_delay
|
||||
self.use_prompt_caching = use_prompt_caching
|
||||
self.telemetry_enabled = telemetry_enabled
|
||||
self.kwargs = kwargs
|
||||
self.kwargs = additional_generation_kwargs
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
|
||||
# == Add built-in callbacks ==
|
||||
|
||||
@@ -593,7 +599,12 @@ class ComputerAgent:
|
||||
# ============================================================================
|
||||
|
||||
async def run(
|
||||
self, messages: Messages, stream: bool = False, **kwargs
|
||||
self,
|
||||
messages: Messages,
|
||||
stream: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
**additional_generation_kwargs,
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
Run the agent with the given messages using Computer protocol handler pattern.
|
||||
@@ -601,7 +612,9 @@ class ComputerAgent:
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
stream: Whether to stream the response
|
||||
**kwargs: Additional arguments
|
||||
api_key: Optional API key override for the model provider
|
||||
api_base: Optional API base URL override for the model provider
|
||||
**additional_generation_kwargs: Additional arguments passed to the model provider
|
||||
|
||||
Returns:
|
||||
AsyncGenerator that yields response chunks
|
||||
@@ -617,8 +630,12 @@ class ComputerAgent:
|
||||
|
||||
await self._initialize_computers()
|
||||
|
||||
# Merge kwargs
|
||||
merged_kwargs = {**self.kwargs, **kwargs}
|
||||
# Merge kwargs and thread api credentials (run overrides constructor)
|
||||
merged_kwargs = {**self.kwargs, **additional_generation_kwargs}
|
||||
if (api_key is not None) or (self.api_key is not None):
|
||||
merged_kwargs["api_key"] = api_key if api_key is not None else self.api_key
|
||||
if (api_base is not None) or (self.api_base is not None):
|
||||
merged_kwargs["api_base"] = api_base if api_base is not None else self.api_base
|
||||
|
||||
old_items = self._process_input(messages)
|
||||
new_items = []
|
||||
@@ -728,8 +745,14 @@ class ComputerAgent:
|
||||
if not self.computer_handler:
|
||||
raise ValueError("Computer tool or image_b64 is required for predict_click")
|
||||
image_b64 = await self.computer_handler.screenshot()
|
||||
# Pass along api credentials if available
|
||||
click_kwargs: Dict[str, Any] = {}
|
||||
if self.api_key is not None:
|
||||
click_kwargs["api_key"] = self.api_key
|
||||
if self.api_base is not None:
|
||||
click_kwargs["api_base"] = self.api_base
|
||||
return await self.agent_loop.predict_click(
|
||||
model=self.model, image_b64=image_b64, instruction=instruction
|
||||
model=self.model, image_b64=image_b64, instruction=instruction, **click_kwargs
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -297,6 +297,20 @@ Examples:
|
||||
help="Maximum number of retries for the LLM API calls",
|
||||
)
|
||||
|
||||
# Provider override credentials
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
dest="api_key",
|
||||
type=str,
|
||||
help="API key override for the model provider (passed to ComputerAgent)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api-base",
|
||||
dest="api_base",
|
||||
type=str,
|
||||
help="API base URL override for the model provider (passed to ComputerAgent)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Check for required environment variables
|
||||
@@ -380,6 +394,12 @@ Examples:
|
||||
"max_retries": args.max_retries,
|
||||
}
|
||||
|
||||
# Thread API credentials to agent if provided
|
||||
if args.api_key:
|
||||
agent_kwargs["api_key"] = args.api_key
|
||||
if args.api_base:
|
||||
agent_kwargs["api_base"] = args.api_base
|
||||
|
||||
if args.images > 0:
|
||||
agent_kwargs["only_n_most_recent_images"] = args.images
|
||||
|
||||
|
||||
@@ -1615,6 +1615,11 @@ Task: Click {instruction}. Output ONLY a click action on the target element.""",
|
||||
"max_tokens": 100, # Keep response short for click prediction
|
||||
"headers": {"anthropic-beta": tool_config["beta_flag"]},
|
||||
}
|
||||
# Thread optional API params
|
||||
if "api_key" in kwargs and kwargs.get("api_key") is not None:
|
||||
api_kwargs["api_key"] = kwargs.get("api_key")
|
||||
if "api_base" in kwargs and kwargs.get("api_base") is not None:
|
||||
api_kwargs["api_base"] = kwargs.get("api_base")
|
||||
|
||||
# Use liteLLM acompletion
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
@@ -24,7 +24,7 @@ class AsyncAgentConfig(Protocol):
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs,
|
||||
**generation_config,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Predict the next step based on input items.
|
||||
@@ -40,7 +40,9 @@ class AsyncAgentConfig(Protocol):
|
||||
_on_api_end: Callback for API end
|
||||
_on_usage: Callback for usage tracking
|
||||
_on_screenshot: Callback for screenshot events
|
||||
**kwargs: Additional arguments
|
||||
**generation_config: Additional arguments to pass to the model provider
|
||||
- api_key: Optional API key for the provider
|
||||
- api_base: Optional API base URL for the provider
|
||||
|
||||
Returns:
|
||||
Dictionary with "output" (output items) and "usage" array
|
||||
@@ -49,7 +51,7 @@ class AsyncAgentConfig(Protocol):
|
||||
|
||||
@abstractmethod
|
||||
async def predict_click(
|
||||
self, model: str, image_b64: str, instruction: str
|
||||
self, model: str, image_b64: str, instruction: str, **generation_config
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates based on image and instruction.
|
||||
@@ -58,6 +60,9 @@ class AsyncAgentConfig(Protocol):
|
||||
model: Model name to use
|
||||
image_b64: Base64 encoded image
|
||||
instruction: Instruction for where to click
|
||||
**generation_config: Additional arguments to pass to the model provider
|
||||
- api_key: Optional API key for the provider
|
||||
- api_base: Optional API base URL for the provider
|
||||
|
||||
Returns:
|
||||
None or tuple with (x, y) coordinates
|
||||
|
||||
@@ -762,6 +762,7 @@ class Glm4vConfig(AsyncAgentConfig):
|
||||
# "skip_special_tokens": False,
|
||||
# }
|
||||
}
|
||||
api_kwargs.update({k: v for k, v in (kwargs or {}).items()})
|
||||
|
||||
# Add API callbacks
|
||||
if _on_api_start:
|
||||
@@ -852,6 +853,7 @@ Where x,y are coordinates normalized to 0-999 range."""
|
||||
"skip_special_tokens": False,
|
||||
},
|
||||
}
|
||||
api_kwargs.update({k: v for k, v in (kwargs or {}).items()})
|
||||
|
||||
# Call liteLLM
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
@@ -140,7 +140,7 @@ class OpenAIComputerUseConfig:
|
||||
return output_dict
|
||||
|
||||
async def predict_click(
|
||||
self, model: str, image_b64: str, instruction: str
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates based on image and instruction.
|
||||
@@ -208,6 +208,7 @@ Task: Click {instruction}. Output ONLY a click action on the target element.""",
|
||||
"reasoning": {"summary": "concise"},
|
||||
"truncation": "auto",
|
||||
"max_tokens": 200, # Keep response short for click prediction
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Use liteLLM responses
|
||||
|
||||
@@ -773,7 +773,7 @@ class UITARSConfig:
|
||||
return agent_response
|
||||
|
||||
async def predict_click(
|
||||
self, model: str, image_b64: str, instruction: str
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates based on image and instruction.
|
||||
@@ -819,6 +819,7 @@ class UITARSConfig:
|
||||
"temperature": 0.0,
|
||||
"do_sample": False,
|
||||
}
|
||||
api_kwargs.update({k: v for k, v in (kwargs or {}).items()})
|
||||
|
||||
# Call liteLLM with UITARS model
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user