mirror of
https://github.com/trycua/lume.git
synced 2026-04-28 16:20:10 -05:00
fix(agent): validate CUA API key before inference call in CUAAdapter
This commit is contained in:
@@ -18,6 +18,26 @@ class CUAAdapter(CustomLLM):
|
||||
# Accept either "cua/<model>" or raw "<model>"
|
||||
return model.split("/", 1)[1] if model and model.startswith("cua/") else model
|
||||
|
||||
def _resolve_api_key(self, kwargs: dict | None = None) -> str:
|
||||
"""Resolve the CUA API key, raising a clear error if missing.
|
||||
|
||||
Checks kwargs (from ComputerAgent api_key param) then falls back
|
||||
to self.api_key (from CUA_API_KEY / CUA_INFERENCE_API_KEY env vars).
|
||||
|
||||
This validation must run before the inner litellm call because that
|
||||
call uses an anthropic/ or openai/ model prefix, which would cause
|
||||
litellm to fall back to ANTHROPIC_API_KEY from env — sending the
|
||||
wrong key to the CUA inference endpoint.
|
||||
"""
|
||||
resolved = (kwargs.get("api_key") if kwargs else None) or self.api_key
|
||||
if not resolved:
|
||||
raise ValueError(
|
||||
"No CUA API key provided for cua/ model inference. "
|
||||
"Please either set the CUA_API_KEY environment variable "
|
||||
"or pass api_key to ComputerAgent()."
|
||||
)
|
||||
return resolved
|
||||
|
||||
def completion(self, *args, **kwargs) -> ModelResponse:
|
||||
model = kwargs.get("model", "")
|
||||
api_base = kwargs.get("api_base") or self.base_url
|
||||
@@ -31,11 +51,13 @@ class CUAAdapter(CustomLLM):
|
||||
else:
|
||||
model = f"openai/{self._normalize_model(model)}"
|
||||
|
||||
api_key = self._resolve_api_key(kwargs)
|
||||
|
||||
params = {
|
||||
"model": model,
|
||||
"messages": kwargs.get("messages", []),
|
||||
"api_base": api_base,
|
||||
"api_key": kwargs.get("api_key") or self.api_key,
|
||||
"api_key": api_key,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
@@ -86,11 +108,13 @@ class CUAAdapter(CustomLLM):
|
||||
else:
|
||||
model = f"openai/{self._normalize_model(model)}"
|
||||
|
||||
api_key = self._resolve_api_key(kwargs)
|
||||
|
||||
params = {
|
||||
"model": model,
|
||||
"messages": kwargs.get("messages", []),
|
||||
"api_base": api_base,
|
||||
"api_key": kwargs.get("api_key") or self.api_key,
|
||||
"api_key": api_key,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
@@ -133,11 +157,12 @@ class CUAAdapter(CustomLLM):
|
||||
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||
params = dict(kwargs)
|
||||
inner_model = self._normalize_model(params.get("model", ""))
|
||||
api_key = self._resolve_api_key(kwargs)
|
||||
params.update(
|
||||
{
|
||||
"model": f"openai/{inner_model}",
|
||||
"api_base": self.base_url,
|
||||
"api_key": self.api_key,
|
||||
"api_key": api_key,
|
||||
"stream": True,
|
||||
}
|
||||
)
|
||||
@@ -148,11 +173,12 @@ class CUAAdapter(CustomLLM):
|
||||
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
||||
params = dict(kwargs)
|
||||
inner_model = self._normalize_model(params.get("model", ""))
|
||||
api_key = self._resolve_api_key(kwargs)
|
||||
params.update(
|
||||
{
|
||||
"model": f"openai/{inner_model}",
|
||||
"api_base": self.base_url,
|
||||
"api_key": self.api_key,
|
||||
"api_key": api_key,
|
||||
"stream": True,
|
||||
}
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user