mirror of
https://github.com/trycua/computer.git
synced 2026-01-04 04:19:57 -06:00
Limited pytorch inference to 1 thread
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import warnings
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Iterator, AsyncIterator, Dict, List, Any, Optional
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||
from litellm.llms.custom_llm import CustomLLM
|
||||
@@ -28,6 +29,7 @@ class HuggingFaceLocalAdapter(CustomLLM):
|
||||
self.device = device
|
||||
self.models = {} # Cache for loaded models
|
||||
self.processors = {} # Cache for loaded processors
|
||||
self._executor = ThreadPoolExecutor(max_workers=1) # Single thread pool
|
||||
|
||||
def _load_model_and_processor(self, model_name: str):
|
||||
"""Load model and processor if not already cached.
|
||||
@@ -51,7 +53,8 @@ class HuggingFaceLocalAdapter(CustomLLM):
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_name,
|
||||
min_pixels=3136,
|
||||
max_pixels=4096 * 2160
|
||||
max_pixels=4096 * 2160,
|
||||
device_map=self.device
|
||||
)
|
||||
|
||||
# Cache them
|
||||
@@ -185,7 +188,8 @@ class HuggingFaceLocalAdapter(CustomLLM):
|
||||
ModelResponse with generated text
|
||||
"""
|
||||
# Run _generate in thread pool to avoid blocking
|
||||
generated_text = await asyncio.to_thread(self._generate, **kwargs)
|
||||
loop = asyncio.get_event_loop()
|
||||
generated_text = await loop.run_in_executor(self._executor, self._generate, **kwargs)
|
||||
|
||||
return await acompletion(
|
||||
model=f"huggingface-local/{kwargs['model']}",
|
||||
@@ -218,7 +222,8 @@ class HuggingFaceLocalAdapter(CustomLLM):
|
||||
AsyncIterator of GenericStreamingChunk
|
||||
"""
|
||||
# Run _generate in thread pool to avoid blocking
|
||||
generated_text = await asyncio.to_thread(self._generate, **kwargs)
|
||||
loop = asyncio.get_event_loop()
|
||||
generated_text = await loop.run_in_executor(self._executor, self._generate, **kwargs)
|
||||
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
|
||||
Reference in New Issue
Block a user