Fix cua adapter

This commit is contained in:
Dillon DuPont
2025-11-19 10:52:50 -05:00
parent 3f3a1c776d
commit d930032c82

View File

@@ -19,25 +19,98 @@ class CUAAdapter(CustomLLM):
return model.split("/", 1)[1] if model and model.startswith("cua/") else model
def completion(self, *args, **kwargs) -> ModelResponse:
model = kwargs.get("model", "")
api_base = kwargs.get("api_base") or self.base_url
if "anthropic/" in model:
model = f"anthropic/{self._normalize_model(model)}"
api_base = api_base.removesuffix("/v1")
else:
model = f"openai/{self._normalize_model(model)}"
params = {
"model": f"openai/{self._normalize_model(kwargs.get("model", ""))}",
"model": model,
"messages": kwargs.get("messages", []),
"api_base": self.base_url,
"api_key": self.api_key,
"api_base": api_base,
"api_key": kwargs.get("api_key") or self.api_key,
"stream": False,
}
if "optional_params" in kwargs:
params.update(kwargs["optional_params"])
del kwargs["optional_params"]
if "headers" in kwargs:
params["headers"] = kwargs["headers"]
del kwargs["headers"]
# Print dropped parameters
original_keys = set(kwargs.keys())
used_keys = set(params.keys()) # Only these are extracted from kwargs
ignored_keys = {
"litellm_params",
"client",
"print_verbose",
"acompletion",
"timeout",
"logging_obj",
"encoding",
"custom_prompt_dict",
"model_response",
"logger_fn",
}
dropped_keys = original_keys - used_keys - ignored_keys
if dropped_keys:
dropped_keyvals = {k: kwargs[k] for k in dropped_keys}
# print(f"CUAAdapter.completion: Dropped parameters: {dropped_keyvals}")
return completion(**params) # type: ignore
async def acompletion(self, *args, **kwargs) -> ModelResponse:
model = kwargs.get("model", "")
api_base = kwargs.get("api_base") or self.base_url
if "anthropic/" in model:
model = f"anthropic/{self._normalize_model(model)}"
api_base = api_base.removesuffix("/v1")
api_base = "http://127.0.0.1:5001"
else:
model = f"openai/{self._normalize_model(model)}"
params = {
"model": f"openai/{self._normalize_model(kwargs.get("model", ""))}",
"model": model,
"messages": kwargs.get("messages", []),
"api_base": self.base_url,
"api_key": self.api_key,
"api_base": api_base,
"api_key": kwargs.get("api_key") or self.api_key,
"stream": False,
}
if "optional_params" in kwargs:
params.update(kwargs["optional_params"])
del kwargs["optional_params"]
if "headers" in kwargs:
params["headers"] = kwargs["headers"]
del kwargs["headers"]
# Print dropped parameters
original_keys = set(kwargs.keys())
used_keys = set(params.keys()) # Only these are extracted from kwargs
ignored_keys = {
"litellm_params",
"client",
"print_verbose",
"acompletion",
"timeout",
"logging_obj",
"encoding",
"custom_prompt_dict",
"model_response",
"logger_fn",
}
dropped_keys = original_keys - used_keys - ignored_keys
if dropped_keys:
dropped_keyvals = {k: kwargs[k] for k in dropped_keys}
# print(f"CUAAdapter.acompletion: Dropped parameters: {dropped_keyvals}")
response = await acompletion(**params) # type: ignore
return response