mirror of
https://github.com/trycua/computer.git
synced 2026-01-04 04:19:57 -06:00
Merge pull request #76 from Lizzard1123/feature/agent/add-ollama-support
Add Ollama support in Omni parser
This commit is contained in:
@@ -6,10 +6,12 @@ from ..providers.omni.types import LLMProvider
|
||||
DEFAULT_MODELS = {
|
||||
LLMProvider.OPENAI: "gpt-4o",
|
||||
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
|
||||
LLMProvider.OLLAMA: "gemma3:4b-it-q4_K_M",
|
||||
}
|
||||
|
||||
# Map providers to their environment variable names
|
||||
ENV_VARS = {
|
||||
LLMProvider.OPENAI: "OPENAI_API_KEY",
|
||||
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
|
||||
LLMProvider.OLLAMA: "OLLAMA_API_KEY",
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ class AgentLoop(Enum):
|
||||
ANTHROPIC = auto() # Anthropic implementation
|
||||
OMNI = auto() # OmniLoop implementation
|
||||
OPENAI = auto() # OpenAI implementation
|
||||
OLLAMA = auto() # OLLAMA implementation
|
||||
# Add more loop types as needed
|
||||
|
||||
|
||||
|
||||
122
libs/agent/agent/providers/omni/clients/ollama.py
Normal file
122
libs/agent/agent/providers/omni/clients/ollama.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Ollama API client implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
import asyncio
|
||||
from httpx import ConnectError, ReadTimeout
|
||||
|
||||
from ollama import AsyncClient, Options
|
||||
from ollama import Message
|
||||
from .base import BaseOmniClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OllamaClient(BaseOmniClient):
|
||||
"""Client for making calls to Ollama API."""
|
||||
|
||||
def __init__(self, api_key: str, model: str, max_retries: int = 3, retry_delay: float = 1.0):
|
||||
"""Initialize the Ollama client.
|
||||
|
||||
Args:
|
||||
api_key: Not used
|
||||
model: Ollama model name (e.g. "gemma3:4b-it-q4_K_M")
|
||||
max_retries: Maximum number of retries for API calls
|
||||
retry_delay: Base delay between retries in seconds
|
||||
"""
|
||||
if not model:
|
||||
raise ValueError("Model name must be provided")
|
||||
|
||||
self.client = AsyncClient(
|
||||
host="http://localhost:11434",
|
||||
)
|
||||
self.model: str = model # Add explicit type annotation
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
|
||||
def _convert_message_format(self, system: str, messages: List[Dict[str, Any]]) -> List[Any]:
|
||||
"""Convert messages from standard format to Ollama format.
|
||||
|
||||
Args:
|
||||
messages: Messages in standard format
|
||||
|
||||
Returns:
|
||||
Messages in Ollama format
|
||||
"""
|
||||
Ollama_messages = []
|
||||
|
||||
# Add system message
|
||||
Ollama_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": system,
|
||||
}
|
||||
)
|
||||
|
||||
for message in messages:
|
||||
# Skip messages with empty content
|
||||
if not message.get("content"):
|
||||
continue
|
||||
content = message.get("content", [{}])[0]
|
||||
isImage = content.get("type", "") == "image_url"
|
||||
isText = content.get("type", "") == "text"
|
||||
if isText:
|
||||
data = content.get("text", "")
|
||||
Ollama_messages.append({"role": message["role"], "content": data})
|
||||
if isImage:
|
||||
data = content.get("image_url", {}).get("url", "")
|
||||
# remove header
|
||||
data = data.removeprefix("data:image/png;base64,")
|
||||
Ollama_messages.append(
|
||||
{"role": message["role"], "content": "Use this image", "images": [data]}
|
||||
)
|
||||
|
||||
# Cast the list to the correct type expected by Ollama
|
||||
return cast(List[Any], Ollama_messages)
|
||||
|
||||
async def run_interleaved(
|
||||
self, messages: List[Dict[str, Any]], system: str, max_tokens: int
|
||||
) -> Any:
|
||||
"""Run model with interleaved conversation format.
|
||||
|
||||
Args:
|
||||
messages: List of messages to process
|
||||
system: System prompt
|
||||
max_tokens: Not used
|
||||
|
||||
Returns:
|
||||
Model response
|
||||
"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# Convert messages to Ollama format
|
||||
Ollama_messages = self._convert_message_format(system, messages)
|
||||
|
||||
response = await self.client.chat(
|
||||
model=self.model,
|
||||
options=Options(
|
||||
temperature=0,
|
||||
),
|
||||
messages=Ollama_messages,
|
||||
format="json",
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except (ConnectError, ReadTimeout) as e:
|
||||
last_error = e
|
||||
logger.warning(
|
||||
f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
|
||||
)
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in Ollama API call: {str(e)}")
|
||||
raise RuntimeError(f"Ollama API call failed: {str(e)}")
|
||||
|
||||
# If we get here, all retries failed
|
||||
raise RuntimeError(f"Connection error after {self.max_retries} retries: {str(last_error)}")
|
||||
@@ -19,6 +19,7 @@ from computer import Computer
|
||||
from .types import LLMProvider
|
||||
from .clients.openai import OpenAIClient
|
||||
from .clients.anthropic import AnthropicClient
|
||||
from .clients.ollama import OllamaClient
|
||||
from .prompts import SYSTEM_PROMPT
|
||||
from .api_handler import OmniAPIHandler
|
||||
from .tools.manager import ToolManager
|
||||
@@ -135,6 +136,11 @@ class OmniLoop(BaseLoop):
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
)
|
||||
elif self.provider == LLMProvider.OLLAMA:
|
||||
self.client = OllamaClient(
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
@@ -160,6 +166,11 @@ class OmniLoop(BaseLoop):
|
||||
max_retries=self.max_retries,
|
||||
retry_delay=self.retry_delay,
|
||||
)
|
||||
elif self.provider == LLMProvider.OLLAMA:
|
||||
self.client = OllamaClient(
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
@@ -370,6 +381,13 @@ class OmniLoop(BaseLoop):
|
||||
else:
|
||||
logger.warning("Invalid Anthropic response format")
|
||||
return True, action_screenshot_saved
|
||||
elif self.provider == LLMProvider.OLLAMA:
|
||||
try:
|
||||
raw_text = response["message"]["content"]
|
||||
standard_content = [{"type": "text", "text": raw_text}]
|
||||
except (KeyError, TypeError, IndexError) as e:
|
||||
logger.error(f"Invalid response format: {str(e)}")
|
||||
return True, action_screenshot_saved
|
||||
else:
|
||||
# Assume OpenAI or compatible format
|
||||
try:
|
||||
|
||||
@@ -11,6 +11,7 @@ class LLMProvider(StrEnum):
|
||||
ANTHROPIC = "anthropic"
|
||||
OMNI = "omni"
|
||||
OPENAI = "openai"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -35,10 +36,12 @@ Model = LLM
|
||||
PROVIDER_TO_DEFAULT_MODEL: Dict[LLMProvider, str] = {
|
||||
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
|
||||
LLMProvider.OPENAI: "gpt-4o",
|
||||
LLMProvider.OLLAMA: "gemma3:4b-it-q4_K_M",
|
||||
}
|
||||
|
||||
# Environment variable names for each provider
|
||||
PROVIDER_TO_ENV_VAR: Dict[LLMProvider, str] = {
|
||||
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
|
||||
LLMProvider.OPENAI: "OPENAI_API_KEY",
|
||||
LLMProvider.OLLAMA: "none",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user