mirror of
https://github.com/trycua/computer.git
synced 2026-01-01 19:10:30 -06:00
Fix cua adapter
This commit is contained in:
@@ -19,25 +19,98 @@ class CUAAdapter(CustomLLM):
|
||||
return model.split("/", 1)[1] if model and model.startswith("cua/") else model
|
||||
|
||||
def completion(self, *args, **kwargs) -> ModelResponse:
|
||||
model = kwargs.get("model", "")
|
||||
api_base = kwargs.get("api_base") or self.base_url
|
||||
if "anthropic/" in model:
|
||||
model = f"anthropic/{self._normalize_model(model)}"
|
||||
api_base = api_base.removesuffix("/v1")
|
||||
else:
|
||||
model = f"openai/{self._normalize_model(model)}"
|
||||
|
||||
params = {
|
||||
"model": f"openai/{self._normalize_model(kwargs.get("model", ""))}",
|
||||
"model": model,
|
||||
"messages": kwargs.get("messages", []),
|
||||
"api_base": self.base_url,
|
||||
"api_key": self.api_key,
|
||||
"api_base": api_base,
|
||||
"api_key": kwargs.get("api_key") or self.api_key,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
|
||||
if "optional_params" in kwargs:
|
||||
params.update(kwargs["optional_params"])
|
||||
del kwargs["optional_params"]
|
||||
|
||||
if "headers" in kwargs:
|
||||
params["headers"] = kwargs["headers"]
|
||||
del kwargs["headers"]
|
||||
|
||||
# Print dropped parameters
|
||||
original_keys = set(kwargs.keys())
|
||||
used_keys = set(params.keys()) # Only these are extracted from kwargs
|
||||
ignored_keys = {
|
||||
"litellm_params",
|
||||
"client",
|
||||
"print_verbose",
|
||||
"acompletion",
|
||||
"timeout",
|
||||
"logging_obj",
|
||||
"encoding",
|
||||
"custom_prompt_dict",
|
||||
"model_response",
|
||||
"logger_fn",
|
||||
}
|
||||
dropped_keys = original_keys - used_keys - ignored_keys
|
||||
if dropped_keys:
|
||||
dropped_keyvals = {k: kwargs[k] for k in dropped_keys}
|
||||
# print(f"CUAAdapter.completion: Dropped parameters: {dropped_keyvals}")
|
||||
|
||||
return completion(**params) # type: ignore
|
||||
|
||||
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
||||
model = kwargs.get("model", "")
|
||||
api_base = kwargs.get("api_base") or self.base_url
|
||||
if "anthropic/" in model:
|
||||
model = f"anthropic/{self._normalize_model(model)}"
|
||||
api_base = api_base.removesuffix("/v1")
|
||||
api_base = "http://127.0.0.1:5001"
|
||||
else:
|
||||
model = f"openai/{self._normalize_model(model)}"
|
||||
|
||||
params = {
|
||||
"model": f"openai/{self._normalize_model(kwargs.get("model", ""))}",
|
||||
"model": model,
|
||||
"messages": kwargs.get("messages", []),
|
||||
"api_base": self.base_url,
|
||||
"api_key": self.api_key,
|
||||
"api_base": api_base,
|
||||
"api_key": kwargs.get("api_key") or self.api_key,
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
|
||||
if "optional_params" in kwargs:
|
||||
params.update(kwargs["optional_params"])
|
||||
del kwargs["optional_params"]
|
||||
|
||||
if "headers" in kwargs:
|
||||
params["headers"] = kwargs["headers"]
|
||||
del kwargs["headers"]
|
||||
|
||||
# Print dropped parameters
|
||||
original_keys = set(kwargs.keys())
|
||||
used_keys = set(params.keys()) # Only these are extracted from kwargs
|
||||
ignored_keys = {
|
||||
"litellm_params",
|
||||
"client",
|
||||
"print_verbose",
|
||||
"acompletion",
|
||||
"timeout",
|
||||
"logging_obj",
|
||||
"encoding",
|
||||
"custom_prompt_dict",
|
||||
"model_response",
|
||||
"logger_fn",
|
||||
}
|
||||
dropped_keys = original_keys - used_keys - ignored_keys
|
||||
if dropped_keys:
|
||||
dropped_keyvals = {k: kwargs[k] for k in dropped_keys}
|
||||
# print(f"CUAAdapter.acompletion: Dropped parameters: {dropped_keyvals}")
|
||||
|
||||
response = await acompletion(**params) # type: ignore
|
||||
|
||||
return response
|
||||
|
||||
Reference in New Issue
Block a user