Limited pytorch inference to 1 thread

This commit is contained in:
Dillon DuPont
2025-08-11 10:31:49 -04:00
parent 8bbcbec54b
commit be1ce76849

View File

@@ -1,5 +1,6 @@
import asyncio
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import Iterator, AsyncIterator, Dict, List, Any, Optional
from litellm.types.utils import GenericStreamingChunk, ModelResponse
from litellm.llms.custom_llm import CustomLLM
@@ -28,6 +29,7 @@ class HuggingFaceLocalAdapter(CustomLLM):
self.device = device
self.models = {} # Cache for loaded models
self.processors = {} # Cache for loaded processors
self._executor = ThreadPoolExecutor(max_workers=1) # Single thread pool
def _load_model_and_processor(self, model_name: str):
"""Load model and processor if not already cached.
@@ -51,7 +53,8 @@ class HuggingFaceLocalAdapter(CustomLLM):
processor = AutoProcessor.from_pretrained(
model_name,
min_pixels=3136,
max_pixels=4096 * 2160
max_pixels=4096 * 2160,
device_map=self.device
)
# Cache them
@@ -185,7 +188,8 @@ class HuggingFaceLocalAdapter(CustomLLM):
ModelResponse with generated text
"""
# Run _generate in thread pool to avoid blocking
generated_text = await asyncio.to_thread(self._generate, **kwargs)
loop = asyncio.get_event_loop()
generated_text = await loop.run_in_executor(self._executor, self._generate, **kwargs)
return await acompletion(
model=f"huggingface-local/{kwargs['model']}",
@@ -218,7 +222,8 @@ class HuggingFaceLocalAdapter(CustomLLM):
AsyncIterator of GenericStreamingChunk
"""
# Run _generate in thread pool to avoid blocking
generated_text = await asyncio.to_thread(self._generate, **kwargs)
loop = asyncio.get_event_loop()
generated_text = await loop.run_in_executor(self._executor, self._generate, **kwargs)
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",