mirror of
https://github.com/trycua/computer.git
synced 2025-12-31 10:29:59 -06:00
Merge pull request #506 from trycua/feat/inference-provider
Add "cua/" LLM provider
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
Adapters package for agent - Custom LLM adapters for LiteLLM
|
||||
"""
|
||||
|
||||
from .cua_adapter import CUAAdapter
|
||||
from .huggingfacelocal_adapter import HuggingFaceLocalAdapter
|
||||
from .human_adapter import HumanAdapter
|
||||
from .mlxvlm_adapter import MLXVLMAdapter
|
||||
@@ -10,4 +11,5 @@ __all__ = [
|
||||
"HuggingFaceLocalAdapter",
|
||||
"HumanAdapter",
|
||||
"MLXVLMAdapter",
|
||||
"CUAAdapter",
|
||||
]
|
||||
|
||||
73
libs/python/agent/agent/adapters/cua_adapter.py
Normal file
73
libs/python/agent/agent/adapters/cua_adapter.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import os
|
||||
from typing import Any, AsyncIterator, Iterator
|
||||
|
||||
from litellm import acompletion, completion
|
||||
from litellm.llms.custom_llm import CustomLLM
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||
|
||||
|
||||
class CUAAdapter(CustomLLM):
|
||||
def __init__(self, base_url: str | None = None, api_key: str | None = None, **_: Any):
|
||||
super().__init__()
|
||||
self.base_url = base_url or os.environ.get("CUA_BASE_URL") or "https://inference.cua.ai/v1"
|
||||
self.api_key = api_key or os.environ.get("CUA_INFERENCE_API_KEY") or os.environ.get("CUA_API_KEY")
|
||||
|
||||
def _normalize_model(self, model: str) -> str:
|
||||
# Accept either "cua/<model>" or raw "<model>"
|
||||
return model.split("/", 1)[1] if model and model.startswith("cua/") else model
|
||||
|
||||
def completion(self, *args, **kwargs) -> ModelResponse:
|
||||
params = dict(kwargs)
|
||||
inner_model = self._normalize_model(params.get("model", ""))
|
||||
params.update(
|
||||
{
|
||||
"model": f"openai/{inner_model}",
|
||||
"api_base": self.base_url,
|
||||
"api_key": self.api_key,
|
||||
"stream": False,
|
||||
}
|
||||
)
|
||||
return completion(**params) # type: ignore
|
||||
|
||||
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
||||
params = dict(kwargs)
|
||||
inner_model = self._normalize_model(params.get("model", ""))
|
||||
params.update(
|
||||
{
|
||||
"model": f"openai/{inner_model}",
|
||||
"api_base": self.base_url,
|
||||
"api_key": self.api_key,
|
||||
"stream": False,
|
||||
}
|
||||
)
|
||||
return await acompletion(**params) # type: ignore
|
||||
|
||||
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||
params = dict(kwargs)
|
||||
inner_model = self._normalize_model(params.get("model", ""))
|
||||
params.update(
|
||||
{
|
||||
"model": f"openai/{inner_model}",
|
||||
"api_base": self.base_url,
|
||||
"api_key": self.api_key,
|
||||
"stream": True,
|
||||
}
|
||||
)
|
||||
# Yield chunks directly from LiteLLM's streaming generator
|
||||
for chunk in completion(**params): # type: ignore
|
||||
yield chunk # type: ignore
|
||||
|
||||
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
||||
params = dict(kwargs)
|
||||
inner_model = self._normalize_model(params.get("model", ""))
|
||||
params.update(
|
||||
{
|
||||
"model": f"openai/{inner_model}",
|
||||
"api_base": self.base_url,
|
||||
"api_key": self.api_key,
|
||||
"stream": True,
|
||||
}
|
||||
)
|
||||
stream = await acompletion(**params) # type: ignore
|
||||
async for chunk in stream: # type: ignore
|
||||
yield chunk # type: ignore
|
||||
@@ -23,11 +23,7 @@ import litellm
|
||||
import litellm.utils
|
||||
from litellm.responses.utils import Usage
|
||||
|
||||
from .adapters import (
|
||||
HuggingFaceLocalAdapter,
|
||||
HumanAdapter,
|
||||
MLXVLMAdapter,
|
||||
)
|
||||
from .adapters import CUAAdapter, HuggingFaceLocalAdapter, HumanAdapter, MLXVLMAdapter
|
||||
from .callbacks import (
|
||||
BudgetManagerCallback,
|
||||
ImageRetentionCallback,
|
||||
@@ -278,10 +274,12 @@ class ComputerAgent:
|
||||
)
|
||||
human_adapter = HumanAdapter()
|
||||
mlx_adapter = MLXVLMAdapter()
|
||||
cua_adapter = CUAAdapter()
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "huggingface-local", "custom_handler": hf_adapter},
|
||||
{"provider": "human", "custom_handler": human_adapter},
|
||||
{"provider": "mlx", "custom_handler": mlx_adapter},
|
||||
{"provider": "cua", "custom_handler": cua_adapter},
|
||||
]
|
||||
litellm.suppress_debug_info = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user