From be1ce76849191bf6e702ed99c9a6fb2a8423d895 Mon Sep 17 00:00:00 2001 From: Dillon DuPont Date: Mon, 11 Aug 2025 10:31:49 -0400 Subject: [PATCH] Limited pytorch inference to 1 thread --- .../agent/agent/adapters/huggingfacelocal_adapter.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/libs/python/agent/agent/adapters/huggingfacelocal_adapter.py b/libs/python/agent/agent/adapters/huggingfacelocal_adapter.py index e8281114..11f03c0f 100644 --- a/libs/python/agent/agent/adapters/huggingfacelocal_adapter.py +++ b/libs/python/agent/agent/adapters/huggingfacelocal_adapter.py @@ -1,5 +1,6 @@ import asyncio import warnings +from concurrent.futures import ThreadPoolExecutor from typing import Iterator, AsyncIterator, Dict, List, Any, Optional from litellm.types.utils import GenericStreamingChunk, ModelResponse from litellm.llms.custom_llm import CustomLLM @@ -28,6 +29,7 @@ class HuggingFaceLocalAdapter(CustomLLM): self.device = device self.models = {} # Cache for loaded models self.processors = {} # Cache for loaded processors + self._executor = ThreadPoolExecutor(max_workers=1) # Single thread pool def _load_model_and_processor(self, model_name: str): """Load model and processor if not already cached. @@ -51,7 +53,8 @@ class HuggingFaceLocalAdapter(CustomLLM): processor = AutoProcessor.from_pretrained( model_name, min_pixels=3136, - max_pixels=4096 * 2160 + max_pixels=4096 * 2160, + device_map=self.device ) # Cache them @@ -185,7 +188,8 @@ class HuggingFaceLocalAdapter(CustomLLM): ModelResponse with generated text """ # Run _generate in thread pool to avoid blocking - generated_text = await asyncio.to_thread(self._generate, **kwargs) + loop = asyncio.get_event_loop() + generated_text = await loop.run_in_executor(self._executor, self._generate, **kwargs) return await acompletion( model=f"huggingface-local/{kwargs['model']}", @@ -218,7 +222,8 @@ class HuggingFaceLocalAdapter(CustomLLM): AsyncIterator of GenericStreamingChunk """ # Run _generate in thread pool to avoid blocking - generated_text = await asyncio.to_thread(self._generate, **kwargs) + loop = asyncio.get_event_loop() + generated_text = await loop.run_in_executor(self._executor, self._generate, **kwargs) generic_streaming_chunk: GenericStreamingChunk = { "finish_reason": "stop",