Pinned transformers version, first working grounding version

This commit is contained in:
Dillon DuPont
2025-08-21 11:54:13 -04:00
parent dad6634ffd
commit b20d2a0a93
4 changed files with 6 additions and 6 deletions

View File

@@ -22,7 +22,7 @@ def load_model(model_name: str, device: str = "auto"):
)
cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
cls = cfg.__class__.__name__
print(f"cls: {cls}")
# print(f"cls: {cls}")
if "OpenCUA" in cls:
return OpenCUAModel(model_name=model_name, device=device)
return GenericHFModel(model_name=model_name, device=device)

View File

@@ -37,6 +37,7 @@ class OpenCUAModel:
torch_dtype="auto",
device_map=self.device,
trust_remote_code=True,
attn_implementation="sdpa",
)
self.image_processor = AutoImageProcessor.from_pretrained(
self.model_name, trust_remote_code=True

View File

@@ -97,7 +97,7 @@ class OpenCUAConfig(AsyncAgentConfig):
},
{
"type": "text",
"text": instruction
"text": f"Click on {instruction}"
}
]
}
@@ -116,8 +116,7 @@ class OpenCUAConfig(AsyncAgentConfig):
# Extract response text
output_text = response.choices[0].message.content
print(output_text)
# print(output_text)
# Extract coordinates from pyautogui format
coordinates = extract_coordinates_from_pyautogui(output_text)

View File

@@ -50,7 +50,7 @@ glm45v-hf = [
opencua-hf = [
"accelerate",
"torch",
"transformers>=4.54.0",
"transformers==4.53.0",
"tiktoken>=0.11.0",
"blobfile>=3.0.0"
]
@@ -75,7 +75,7 @@ all = [
"transformers>=4.54.0",
# opencua requirements
"tiktoken>=0.11.0",
"blobfile>=3.0.0"
"blobfile>=3.0.0",
# ui requirements
"gradio>=5.23.3",
"python-dotenv>=1.0.1",