From 654f83686a0404ac74b69573152c5fc5eb61bba6 Mon Sep 17 00:00:00 2001 From: Ethan Gutierrez Date: Mon, 31 Mar 2025 16:39:54 -0400 Subject: [PATCH] Add Ollama support in Omni parser --- libs/agent/agent/core/provider_config.py | 2 + libs/agent/agent/core/types.py | 1 + .../agent/providers/omni/clients/ollama.py | 122 ++++++++++++++++++ libs/agent/agent/providers/omni/loop.py | 18 +++ libs/agent/agent/providers/omni/types.py | 3 + 5 files changed, 146 insertions(+) create mode 100644 libs/agent/agent/providers/omni/clients/ollama.py diff --git a/libs/agent/agent/core/provider_config.py b/libs/agent/agent/core/provider_config.py index f7078f3f..68b69e2e 100644 --- a/libs/agent/agent/core/provider_config.py +++ b/libs/agent/agent/core/provider_config.py @@ -6,10 +6,12 @@ from ..providers.omni.types import LLMProvider DEFAULT_MODELS = { LLMProvider.OPENAI: "gpt-4o", LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219", + LLMProvider.OLLAMA: "gemma3:4b-it-q4_K_M", } # Map providers to their environment variable names ENV_VARS = { LLMProvider.OPENAI: "OPENAI_API_KEY", LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY", + LLMProvider.OLLAMA: "OLLAMA_API_KEY", } diff --git a/libs/agent/agent/core/types.py b/libs/agent/agent/core/types.py index ae2af868..b9dad859 100644 --- a/libs/agent/agent/core/types.py +++ b/libs/agent/agent/core/types.py @@ -10,6 +10,7 @@ class AgentLoop(Enum): ANTHROPIC = auto() # Anthropic implementation OMNI = auto() # OmniLoop implementation OPENAI = auto() # OpenAI implementation + OLLAMA = auto() # OLLAMA implementation # Add more loop types as needed diff --git a/libs/agent/agent/providers/omni/clients/ollama.py b/libs/agent/agent/providers/omni/clients/ollama.py new file mode 100644 index 00000000..f33cc769 --- /dev/null +++ b/libs/agent/agent/providers/omni/clients/ollama.py @@ -0,0 +1,122 @@ +"""Ollama API client implementation.""" + +import logging +from typing import Any, Dict, List, Optional, Tuple, cast +import asyncio +from httpx import ConnectError, ReadTimeout + +from ollama import AsyncClient, Options +from ollama import Message +from .base import BaseOmniClient + +logger = logging.getLogger(__name__) + + +class OllamaClient(BaseOmniClient): + """Client for making calls to Ollama API.""" + + def __init__(self, api_key: str, model: str, max_retries: int = 3, retry_delay: float = 1.0): + """Initialize the Ollama client. + + Args: + api_key: Not used + model: Ollama model name (e.g. "gemma3:4b-it-q4_K_M") + max_retries: Maximum number of retries for API calls + retry_delay: Base delay between retries in seconds + """ + if not model: + raise ValueError("Model name must be provided") + + self.client = AsyncClient( + host="http://localhost:11434", + ) + self.model: str = model # Add explicit type annotation + self.max_retries = max_retries + self.retry_delay = retry_delay + + def _convert_message_format(self, system: str, messages: List[Dict[str, Any]]) -> List[Any]: + """Convert messages from standard format to Ollama format. + + Args: + messages: Messages in standard format + + Returns: + Messages in Ollama format + """ + Ollama_messages = [] + + # Add system message + Ollama_messages.append( + { + "role": "system", + "content": system, + } + ) + + for message in messages: + # Skip messages with empty content + if not message.get("content"): + continue + content = message.get("content", [{}])[0] + isImage = content.get("type", "") == "image_url" + isText = content.get("type", "") == "text" + if isText: + data = content.get("text", "") + Ollama_messages.append({"role": message["role"], "content": data}) + if isImage: + data = content.get("image_url", {}).get("url", "") + # remove header + data = data.removeprefix("data:image/png;base64,") + Ollama_messages.append( + {"role": message["role"], "content": "Use this image", "images": [data]} + ) + + # Cast the list to the correct type expected by Ollama + return cast(List[Any], Ollama_messages) + + async def run_interleaved( + self, messages: List[Dict[str, Any]], system: str, max_tokens: int + ) -> Any: + """Run model with interleaved conversation format. + + Args: + messages: List of messages to process + system: System prompt + max_tokens: Not used + + Returns: + Model response + """ + last_error = None + + for attempt in range(self.max_retries): + try: + # Convert messages to Ollama format + Ollama_messages = self._convert_message_format(system, messages) + + response = await self.client.chat( + model=self.model, + options=Options( + temperature=0, + ), + messages=Ollama_messages, + format="json", + ) + + return response + + except (ConnectError, ReadTimeout) as e: + last_error = e + logger.warning( + f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}" + ) + if attempt < self.max_retries - 1: + await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff + continue + + except Exception as e: + logger.error(f"Unexpected error in Ollama API call: {str(e)}") + raise RuntimeError(f"Ollama API call failed: {str(e)}") + + # If we get here, all retries failed + raise RuntimeError(f"Connection error after {self.max_retries} retries: {str(last_error)}") diff --git a/libs/agent/agent/providers/omni/loop.py b/libs/agent/agent/providers/omni/loop.py index 3223583e..d5fa930f 100644 --- a/libs/agent/agent/providers/omni/loop.py +++ b/libs/agent/agent/providers/omni/loop.py @@ -19,6 +19,7 @@ from computer import Computer from .types import LLMProvider from .clients.openai import OpenAIClient from .clients.anthropic import AnthropicClient +from .clients.ollama import OllamaClient from .prompts import SYSTEM_PROMPT from .api_handler import OmniAPIHandler from .tools.manager import ToolManager @@ -135,6 +136,11 @@ class OmniLoop(BaseLoop): api_key=self.api_key, model=self.model, ) + elif self.provider == LLMProvider.OLLAMA: + self.client = OllamaClient( + api_key=self.api_key, + model=self.model, + ) else: raise ValueError(f"Unsupported provider: {self.provider}") @@ -160,6 +166,11 @@ class OmniLoop(BaseLoop): max_retries=self.max_retries, retry_delay=self.retry_delay, ) + elif self.provider == LLMProvider.OLLAMA: + self.client = OllamaClient( + api_key=self.api_key, + model=self.model, + ) else: raise ValueError(f"Unsupported provider: {self.provider}") @@ -370,6 +381,13 @@ class OmniLoop(BaseLoop): else: logger.warning("Invalid Anthropic response format") return True, action_screenshot_saved + elif self.provider == LLMProvider.OLLAMA: + try: + raw_text = response["message"]["content"] + standard_content = [{"type": "text", "text": raw_text}] + except (KeyError, TypeError, IndexError) as e: + logger.error(f"Invalid response format: {str(e)}") + return True, action_screenshot_saved else: # Assume OpenAI or compatible format try: diff --git a/libs/agent/agent/providers/omni/types.py b/libs/agent/agent/providers/omni/types.py index 1f3aae93..c0d9837b 100644 --- a/libs/agent/agent/providers/omni/types.py +++ b/libs/agent/agent/providers/omni/types.py @@ -11,6 +11,7 @@ class LLMProvider(StrEnum): ANTHROPIC = "anthropic" OMNI = "omni" OPENAI = "openai" + OLLAMA = "ollama" @dataclass @@ -35,10 +36,12 @@ Model = LLM PROVIDER_TO_DEFAULT_MODEL: Dict[LLMProvider, str] = { LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219", LLMProvider.OPENAI: "gpt-4o", + LLMProvider.OLLAMA: "gemma3:4b-it-q4_K_M", } # Environment variable names for each provider PROVIDER_TO_ENV_VAR: Dict[LLMProvider, str] = { LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY", LLMProvider.OPENAI: "OPENAI_API_KEY", + LLMProvider.OLLAMA: "none", }