Merge pull request #76 from Lizzard1123/feature/agent/add-ollama-support

Add Ollama support in Omni parser
This commit is contained in:
f-trycua
2025-04-04 11:47:14 -07:00
committed by GitHub
5 changed files with 146 additions and 0 deletions

View File

@@ -6,10 +6,12 @@ from ..providers.omni.types import LLMProvider
DEFAULT_MODELS = {
LLMProvider.OPENAI: "gpt-4o",
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
LLMProvider.OLLAMA: "gemma3:4b-it-q4_K_M",
}
# Map providers to their environment variable names
ENV_VARS = {
LLMProvider.OPENAI: "OPENAI_API_KEY",
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
LLMProvider.OLLAMA: "OLLAMA_API_KEY",
}

View File

@@ -10,6 +10,7 @@ class AgentLoop(Enum):
ANTHROPIC = auto() # Anthropic implementation
OMNI = auto() # OmniLoop implementation
OPENAI = auto() # OpenAI implementation
OLLAMA = auto() # OLLAMA implementation
# Add more loop types as needed

View File

@@ -0,0 +1,122 @@
"""Ollama API client implementation."""
import logging
from typing import Any, Dict, List, Optional, Tuple, cast
import asyncio
from httpx import ConnectError, ReadTimeout
from ollama import AsyncClient, Options
from ollama import Message
from .base import BaseOmniClient
logger = logging.getLogger(__name__)
class OllamaClient(BaseOmniClient):
"""Client for making calls to Ollama API."""
def __init__(self, api_key: str, model: str, max_retries: int = 3, retry_delay: float = 1.0):
"""Initialize the Ollama client.
Args:
api_key: Not used
model: Ollama model name (e.g. "gemma3:4b-it-q4_K_M")
max_retries: Maximum number of retries for API calls
retry_delay: Base delay between retries in seconds
"""
if not model:
raise ValueError("Model name must be provided")
self.client = AsyncClient(
host="http://localhost:11434",
)
self.model: str = model # Add explicit type annotation
self.max_retries = max_retries
self.retry_delay = retry_delay
def _convert_message_format(self, system: str, messages: List[Dict[str, Any]]) -> List[Any]:
"""Convert messages from standard format to Ollama format.
Args:
messages: Messages in standard format
Returns:
Messages in Ollama format
"""
Ollama_messages = []
# Add system message
Ollama_messages.append(
{
"role": "system",
"content": system,
}
)
for message in messages:
# Skip messages with empty content
if not message.get("content"):
continue
content = message.get("content", [{}])[0]
isImage = content.get("type", "") == "image_url"
isText = content.get("type", "") == "text"
if isText:
data = content.get("text", "")
Ollama_messages.append({"role": message["role"], "content": data})
if isImage:
data = content.get("image_url", {}).get("url", "")
# remove header
data = data.removeprefix("data:image/png;base64,")
Ollama_messages.append(
{"role": message["role"], "content": "Use this image", "images": [data]}
)
# Cast the list to the correct type expected by Ollama
return cast(List[Any], Ollama_messages)
async def run_interleaved(
self, messages: List[Dict[str, Any]], system: str, max_tokens: int
) -> Any:
"""Run model with interleaved conversation format.
Args:
messages: List of messages to process
system: System prompt
max_tokens: Not used
Returns:
Model response
"""
last_error = None
for attempt in range(self.max_retries):
try:
# Convert messages to Ollama format
Ollama_messages = self._convert_message_format(system, messages)
response = await self.client.chat(
model=self.model,
options=Options(
temperature=0,
),
messages=Ollama_messages,
format="json",
)
return response
except (ConnectError, ReadTimeout) as e:
last_error = e
logger.warning(
f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
)
if attempt < self.max_retries - 1:
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
continue
except Exception as e:
logger.error(f"Unexpected error in Ollama API call: {str(e)}")
raise RuntimeError(f"Ollama API call failed: {str(e)}")
# If we get here, all retries failed
raise RuntimeError(f"Connection error after {self.max_retries} retries: {str(last_error)}")

View File

@@ -19,6 +19,7 @@ from computer import Computer
from .types import LLMProvider
from .clients.openai import OpenAIClient
from .clients.anthropic import AnthropicClient
from .clients.ollama import OllamaClient
from .prompts import SYSTEM_PROMPT
from .api_handler import OmniAPIHandler
from .tools.manager import ToolManager
@@ -135,6 +136,11 @@ class OmniLoop(BaseLoop):
api_key=self.api_key,
model=self.model,
)
elif self.provider == LLMProvider.OLLAMA:
self.client = OllamaClient(
api_key=self.api_key,
model=self.model,
)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
@@ -160,6 +166,11 @@ class OmniLoop(BaseLoop):
max_retries=self.max_retries,
retry_delay=self.retry_delay,
)
elif self.provider == LLMProvider.OLLAMA:
self.client = OllamaClient(
api_key=self.api_key,
model=self.model,
)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
@@ -370,6 +381,13 @@ class OmniLoop(BaseLoop):
else:
logger.warning("Invalid Anthropic response format")
return True, action_screenshot_saved
elif self.provider == LLMProvider.OLLAMA:
try:
raw_text = response["message"]["content"]
standard_content = [{"type": "text", "text": raw_text}]
except (KeyError, TypeError, IndexError) as e:
logger.error(f"Invalid response format: {str(e)}")
return True, action_screenshot_saved
else:
# Assume OpenAI or compatible format
try:

View File

@@ -11,6 +11,7 @@ class LLMProvider(StrEnum):
ANTHROPIC = "anthropic"
OMNI = "omni"
OPENAI = "openai"
OLLAMA = "ollama"
@dataclass
@@ -35,10 +36,12 @@ Model = LLM
PROVIDER_TO_DEFAULT_MODEL: Dict[LLMProvider, str] = {
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
LLMProvider.OPENAI: "gpt-4o",
LLMProvider.OLLAMA: "gemma3:4b-it-q4_K_M",
}
# Environment variable names for each provider
PROVIDER_TO_ENV_VAR: Dict[LLMProvider, str] = {
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
LLMProvider.OPENAI: "OPENAI_API_KEY",
LLMProvider.OLLAMA: "none",
}