fix(agent): validate CUA API key before inference call in CUAAdapter

This commit is contained in:
synacktra.work@gmail.com
2026-02-10 19:09:02 +05:30
parent 16983b9c8a
commit 9ce88b1bf9
@@ -18,6 +18,26 @@ class CUAAdapter(CustomLLM):
# Accept either "cua/<model>" or raw "<model>"
return model.split("/", 1)[1] if model and model.startswith("cua/") else model
def _resolve_api_key(self, kwargs: dict | None = None) -> str:
"""Resolve the CUA API key, raising a clear error if missing.
Checks kwargs (from ComputerAgent api_key param) then falls back
to self.api_key (from CUA_API_KEY / CUA_INFERENCE_API_KEY env vars).
This validation must run before the inner litellm call because that
call uses an anthropic/ or openai/ model prefix, which would cause
litellm to fall back to ANTHROPIC_API_KEY from env — sending the
wrong key to the CUA inference endpoint.
"""
resolved = (kwargs.get("api_key") if kwargs else None) or self.api_key
if not resolved:
raise ValueError(
"No CUA API key provided for cua/ model inference. "
"Please either set the CUA_API_KEY environment variable "
"or pass api_key to ComputerAgent()."
)
return resolved
def completion(self, *args, **kwargs) -> ModelResponse:
model = kwargs.get("model", "")
api_base = kwargs.get("api_base") or self.base_url
@@ -31,11 +51,13 @@ class CUAAdapter(CustomLLM):
else:
model = f"openai/{self._normalize_model(model)}"
api_key = self._resolve_api_key(kwargs)
params = {
"model": model,
"messages": kwargs.get("messages", []),
"api_base": api_base,
"api_key": kwargs.get("api_key") or self.api_key,
"api_key": api_key,
"stream": False,
}
@@ -86,11 +108,13 @@ class CUAAdapter(CustomLLM):
else:
model = f"openai/{self._normalize_model(model)}"
api_key = self._resolve_api_key(kwargs)
params = {
"model": model,
"messages": kwargs.get("messages", []),
"api_base": api_base,
"api_key": kwargs.get("api_key") or self.api_key,
"api_key": api_key,
"stream": False,
}
@@ -133,11 +157,12 @@ class CUAAdapter(CustomLLM):
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
params = dict(kwargs)
inner_model = self._normalize_model(params.get("model", ""))
api_key = self._resolve_api_key(kwargs)
params.update(
{
"model": f"openai/{inner_model}",
"api_base": self.base_url,
"api_key": self.api_key,
"api_key": api_key,
"stream": True,
}
)
@@ -148,11 +173,12 @@ class CUAAdapter(CustomLLM):
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
params = dict(kwargs)
inner_model = self._normalize_model(params.get("model", ""))
api_key = self._resolve_api_key(kwargs)
params.update(
{
"model": f"openai/{inner_model}",
"api_base": self.base_url,
"api_key": self.api_key,
"api_key": api_key,
"stream": True,
}
)