diff --git a/libs/python/agent/agent/adapters/__init__.py b/libs/python/agent/agent/adapters/__init__.py index 1f07a9fc..cded1d34 100644 --- a/libs/python/agent/agent/adapters/__init__.py +++ b/libs/python/agent/agent/adapters/__init__.py @@ -2,6 +2,7 @@ Adapters package for agent - Custom LLM adapters for LiteLLM """ +from .cua_adapter import CUAAdapter from .huggingfacelocal_adapter import HuggingFaceLocalAdapter from .human_adapter import HumanAdapter from .mlxvlm_adapter import MLXVLMAdapter @@ -10,4 +11,5 @@ __all__ = [ "HuggingFaceLocalAdapter", "HumanAdapter", "MLXVLMAdapter", + "CUAAdapter", ] diff --git a/libs/python/agent/agent/adapters/cua_adapter.py b/libs/python/agent/agent/adapters/cua_adapter.py new file mode 100644 index 00000000..76e13977 --- /dev/null +++ b/libs/python/agent/agent/adapters/cua_adapter.py @@ -0,0 +1,73 @@ +import os +from typing import Any, AsyncIterator, Iterator + +from litellm import acompletion, completion +from litellm.llms.custom_llm import CustomLLM +from litellm.types.utils import GenericStreamingChunk, ModelResponse + + +class CUAAdapter(CustomLLM): + def __init__(self, base_url: str | None = None, api_key: str | None = None, **_: Any): + super().__init__() + self.base_url = base_url or os.environ.get("CUA_BASE_URL") or "https://inference.cua.ai/v1" + self.api_key = api_key or os.environ.get("CUA_INFERENCE_API_KEY") or os.environ.get("CUA_API_KEY") + + def _normalize_model(self, model: str) -> str: + # Accept either "cua/" or raw "" + return model.split("/", 1)[1] if model and model.startswith("cua/") else model + + def completion(self, *args, **kwargs) -> ModelResponse: + params = dict(kwargs) + inner_model = self._normalize_model(params.get("model", "")) + params.update( + { + "model": f"openai/{inner_model}", + "api_base": self.base_url, + "api_key": self.api_key, + "stream": False, + } + ) + return completion(**params) # type: ignore + + async def acompletion(self, *args, **kwargs) -> ModelResponse: + params = dict(kwargs) + inner_model = self._normalize_model(params.get("model", "")) + params.update( + { + "model": f"openai/{inner_model}", + "api_base": self.base_url, + "api_key": self.api_key, + "stream": False, + } + ) + return await acompletion(**params) # type: ignore + + def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]: + params = dict(kwargs) + inner_model = self._normalize_model(params.get("model", "")) + params.update( + { + "model": f"openai/{inner_model}", + "api_base": self.base_url, + "api_key": self.api_key, + "stream": True, + } + ) + # Yield chunks directly from LiteLLM's streaming generator + for chunk in completion(**params): # type: ignore + yield chunk # type: ignore + + async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: + params = dict(kwargs) + inner_model = self._normalize_model(params.get("model", "")) + params.update( + { + "model": f"openai/{inner_model}", + "api_base": self.base_url, + "api_key": self.api_key, + "stream": True, + } + ) + stream = await acompletion(**params) # type: ignore + async for chunk in stream: # type: ignore + yield chunk # type: ignore diff --git a/libs/python/agent/agent/agent.py b/libs/python/agent/agent/agent.py index f85c513c..42f04a00 100644 --- a/libs/python/agent/agent/agent.py +++ b/libs/python/agent/agent/agent.py @@ -23,11 +23,7 @@ import litellm import litellm.utils from litellm.responses.utils import Usage -from .adapters import ( - HuggingFaceLocalAdapter, - HumanAdapter, - MLXVLMAdapter, -) +from .adapters import CUAAdapter, HuggingFaceLocalAdapter, HumanAdapter, MLXVLMAdapter from .callbacks import ( BudgetManagerCallback, ImageRetentionCallback, @@ -278,10 +274,12 @@ class ComputerAgent: ) human_adapter = HumanAdapter() mlx_adapter = MLXVLMAdapter() + cua_adapter = CUAAdapter() litellm.custom_provider_map = [ {"provider": "huggingface-local", "custom_handler": hf_adapter}, {"provider": "human", "custom_handler": human_adapter}, {"provider": "mlx", "custom_handler": mlx_adapter}, + {"provider": "cua", "custom_handler": cua_adapter}, ] litellm.suppress_debug_info = True