mirror of
https://github.com/mudler/LocalAI.git
synced 2025-12-19 16:30:20 -06:00
* feat(mlx): add thread-safe LRU prompt cache Port mlx-lm's LRUPromptCache to fix race condition where concurrent requests corrupt shared KV cache state. The previous implementation used a single prompt_cache instance shared across all requests. Changes: - Add backend/python/common/mlx_cache.py with ThreadSafeLRUPromptCache - Modify backend.py to use per-request cache isolation via fetch/insert - Add prefix matching for cache reuse across similar prompts - Add LRU eviction (default 10 entries, configurable) - Add concurrency and cache unit tests The cache uses a trie-based structure for efficient prefix matching, allowing prompts that share common prefixes to reuse cached KV states. Thread safety is provided via threading.Lock. New configuration options: - max_cache_entries: Maximum LRU cache entries (default: 10) - max_kv_size: Maximum KV cache size per entry (default: None) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: Blightbow <blightbow@users.noreply.github.com> * feat(mlx): add min_p and top_k sampler support Add MinP field to proto (field 52) following the precedent set by other non-OpenAI sampling parameters like TopK, TailFreeSamplingZ, TypicalP, and Mirostat. Changes: - backend.proto: Add float MinP field for min-p sampling - backend.py: Extract and pass min_p and top_k to mlx_lm sampler (top_k was in proto but not being passed) - test.py: Fix test_sampling_params to use valid proto fields and switch to MLX-compatible model (mlx-community/Llama-3.2-1B-Instruct) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: Blightbow <blightbow@users.noreply.github.com> * refactor(mlx): move mlx_cache.py from common to mlx backend The ThreadSafeLRUPromptCache is only used by the mlx backend. After evaluating mlx-vlm, it was determined that the cache cannot be shared because mlx-vlm's generate/stream_generate functions don't support the prompt_cache parameter that mlx_lm provides. - Move mlx_cache.py from backend/python/common/ to backend/python/mlx/ - Remove sys.path manipulation from backend.py and test.py - Fix test assertion to expect "MLX model loaded successfully" 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: Blightbow <blightbow@users.noreply.github.com> * test(mlx): add comprehensive cache tests and document upstream behavior Added comprehensive unit tests (test_mlx_cache.py) covering all cache operation modes: - Exact match - Shorter prefix match - Longer prefix match with trimming - No match scenarios - LRU eviction and access order - Reference counting and deep copy behavior - Multi-model namespacing - Thread safety with data integrity verification Documents upstream mlx_lm/server.py behavior: single-token prefixes are deliberately not matched (uses > 0, not >= 0) to allow longer cached sequences to be preferred for trimming. This is acceptable because real prompts with chat templates are always many tokens. Removed weak unit tests from test.py that only verified "no exception thrown" rather than correctness. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: Blightbow <blightbow@users.noreply.github.com> * chore(mlx): remove unused MinP proto field The MinP field was added to PredictOptions but is not populated by the Go frontend/API. The MLX backend uses getattr with a default value, so it works without the proto field. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: Blightbow <blightbow@users.noreply.github.com> --------- Signed-off-by: Blightbow <blightbow@users.noreply.github.com> Co-authored-by: Blightbow <blightbow@users.noreply.github.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
267 lines
9.1 KiB
Python
267 lines
9.1 KiB
Python
"""
|
|
Thread-safe LRU prompt cache for MLX-based backends.
|
|
|
|
Ported from mlx_lm/server.py (MIT License, Copyright 2023-2024 Apple Inc.)
|
|
with thread-safety additions for LocalAI's gRPC backend.
|
|
|
|
Usage:
|
|
from mlx_cache import ThreadSafeLRUPromptCache
|
|
|
|
# In LoadModel:
|
|
self.lru_cache = ThreadSafeLRUPromptCache(max_size=10)
|
|
|
|
# In Predict/PredictStream:
|
|
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(model_key, tokens)
|
|
# ... generate ...
|
|
self.lru_cache.insert_cache(model_key, tokens, prompt_cache)
|
|
"""
|
|
import copy
|
|
import threading
|
|
from collections import deque
|
|
from dataclasses import dataclass
|
|
from typing import Any, List, Optional, Tuple
|
|
|
|
|
|
@dataclass
|
|
class CacheEntry:
|
|
"""A cache entry with reference counting."""
|
|
prompt_cache: List[Any]
|
|
count: int
|
|
|
|
|
|
@dataclass
|
|
class SearchResult:
|
|
"""Result of searching the cache trie."""
|
|
model: Any
|
|
exact: Optional[List[int]]
|
|
shorter: Optional[List[int]]
|
|
longer: Optional[List[int]]
|
|
common_prefix: int
|
|
|
|
|
|
class ThreadSafeLRUPromptCache:
|
|
"""
|
|
Thread-safe LRU cache with prefix matching for prompt KV caches.
|
|
|
|
This cache stores KV caches keyed by token sequences and supports:
|
|
- Exact match: Return the cache for the exact token sequence
|
|
- Shorter prefix match: Return a cache for a prefix of the tokens
|
|
- Longer prefix match: If a longer sequence is cached and can be trimmed
|
|
- LRU eviction: When max_size is exceeded, evict least recently used
|
|
|
|
Thread safety is provided via a threading.Lock that protects all
|
|
cache operations.
|
|
|
|
Args:
|
|
max_size: Maximum number of cache entries (default: 10)
|
|
can_trim_fn: Optional function to check if a cache can be trimmed
|
|
trim_fn: Optional function to trim a cache
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
max_size: int = 10,
|
|
can_trim_fn: Optional[Any] = None,
|
|
trim_fn: Optional[Any] = None,
|
|
):
|
|
self.max_size = max_size
|
|
self._cache = {}
|
|
self._lru = deque()
|
|
self._lock = threading.Lock()
|
|
|
|
# Optional trim functions (for longer prefix reuse)
|
|
self._can_trim_fn = can_trim_fn
|
|
self._trim_fn = trim_fn
|
|
|
|
def _search(self, model, tokens: List[int]) -> SearchResult:
|
|
"""
|
|
Search the cache for a prompt cache. Return exact or close match.
|
|
|
|
The cache is organized as a trie where each node is keyed by a token.
|
|
This allows efficient prefix matching.
|
|
"""
|
|
if model not in self._cache:
|
|
return SearchResult(model, None, None, None, 0)
|
|
|
|
current = self._cache[model]
|
|
last_cache_index = -1
|
|
index = 0
|
|
|
|
# Traverse the trie following the token sequence
|
|
while index < len(tokens) and tokens[index] in current:
|
|
current = current[tokens[index]]
|
|
if "cache" in current:
|
|
last_cache_index = index
|
|
index += 1
|
|
|
|
# Exact match - no need to search for longer or shorter caches
|
|
if last_cache_index == len(tokens) - 1:
|
|
return SearchResult(model, tuple(tokens), None, None, 0)
|
|
|
|
# Find the shorter cache (a prefix that has a cache)
|
|
# Note: Uses > 0 (not >= 0) to match upstream mlx_lm/server.py behavior.
|
|
# Single-token prefixes are not matched, which allows longer cached
|
|
# sequences to be preferred for trimming. This is acceptable because
|
|
# real prompts with chat templates are always many tokens.
|
|
shorter = None
|
|
if last_cache_index > 0:
|
|
shorter = tuple(tokens[: last_cache_index + 1])
|
|
|
|
# Check for caches that are longer than our token sequence
|
|
longer = None
|
|
common_prefix = index
|
|
if index > 0 and last_cache_index <= 0:
|
|
best = None
|
|
stack = [(current, [])]
|
|
while stack:
|
|
current, extra = stack.pop()
|
|
if "cache" in current:
|
|
if best is None or len(extra) < len(best):
|
|
best = extra
|
|
else:
|
|
for tok in current:
|
|
stack.append((current[tok], extra + [tok]))
|
|
if best is not None:
|
|
longer = tuple(tokens[:index] + best)
|
|
|
|
return SearchResult(model, None, shorter, longer, common_prefix)
|
|
|
|
def _get(self, model, tokens: Tuple[int, ...]) -> CacheEntry:
|
|
"""Get a cache entry by traversing the trie."""
|
|
current = self._cache[model]
|
|
for tok in tokens:
|
|
current = current[tok]
|
|
return current["cache"]
|
|
|
|
def _delete(self, model, tokens: Tuple[int, ...]) -> None:
|
|
"""Delete a cache entry and clean up empty trie nodes."""
|
|
path = [self._cache[model]]
|
|
for tok in tokens:
|
|
path.append(path[-1][tok])
|
|
del path[-1]["cache"]
|
|
|
|
# Clean up empty nodes bottom-up
|
|
for i in reversed(range(len(tokens))):
|
|
d_prev, d, t = path[i], path[i + 1], tokens[i]
|
|
if len(d) > 0:
|
|
break
|
|
del d_prev[t]
|
|
|
|
def _extract(self, model, tokens: Tuple[int, ...]) -> CacheEntry:
|
|
"""
|
|
Extract a cache entry for exclusive use.
|
|
|
|
If the entry has count > 1, deep copy and decrement.
|
|
If count == 1, remove from cache entirely.
|
|
"""
|
|
cache_entry = self._get(model, tokens)
|
|
if cache_entry.count == 1:
|
|
self._delete(model, tokens)
|
|
self._lru.remove((model, tokens))
|
|
return cache_entry
|
|
|
|
cache_entry.count -= 1
|
|
return CacheEntry(
|
|
copy.deepcopy(cache_entry.prompt_cache),
|
|
1,
|
|
)
|
|
|
|
def fetch_nearest_cache(
|
|
self, model, tokens: List[int]
|
|
) -> Tuple[Optional[List[Any]], List[int]]:
|
|
"""
|
|
Fetch the nearest cache for the given token sequence.
|
|
|
|
Thread-safe. Returns (cache, remaining_tokens) where:
|
|
- cache: The KV cache to use (or None if no cache found)
|
|
- remaining_tokens: Tokens that still need to be processed
|
|
|
|
Args:
|
|
model: Model identifier (used to namespace caches)
|
|
tokens: The full token sequence for the prompt
|
|
|
|
Returns:
|
|
Tuple of (prompt_cache, remaining_tokens)
|
|
"""
|
|
with self._lock:
|
|
tokens_tuple = tuple(tokens)
|
|
result = self._search(model, tokens)
|
|
|
|
# Exact match - extract and return
|
|
if result.exact is not None:
|
|
cache_entry = self._extract(result.model, result.exact)
|
|
return cache_entry.prompt_cache, []
|
|
|
|
# Shorter prefix match - extract and return remaining
|
|
if result.shorter is not None:
|
|
cache_entry = self._extract(result.model, result.shorter)
|
|
prefix_len = len(result.shorter)
|
|
return cache_entry.prompt_cache, list(tokens[prefix_len:])
|
|
|
|
# Longer prefix match - try to trim if possible
|
|
if result.longer is not None and self._can_trim_fn is not None:
|
|
cache_entry = self._get(result.model, result.longer)
|
|
if self._can_trim_fn(cache_entry.prompt_cache):
|
|
# Deep copy and trim
|
|
trimmed_cache = copy.deepcopy(cache_entry.prompt_cache)
|
|
prefix = min(len(tokens) - 1, result.common_prefix)
|
|
num_to_trim = len(result.longer) - prefix
|
|
if self._trim_fn is not None:
|
|
self._trim_fn(trimmed_cache, num_to_trim)
|
|
return trimmed_cache, list(tokens[prefix:])
|
|
|
|
# No match found
|
|
return None, list(tokens)
|
|
|
|
def insert_cache(
|
|
self, model, tokens: List[int], prompt_cache: List[Any]
|
|
) -> None:
|
|
"""
|
|
Insert a cache entry after generation completes.
|
|
|
|
Thread-safe. Handles LRU eviction if max_size is exceeded.
|
|
|
|
Args:
|
|
model: Model identifier (used to namespace caches)
|
|
tokens: The full token sequence (prompt + generated)
|
|
prompt_cache: The KV cache to store
|
|
"""
|
|
with self._lock:
|
|
tokens_tuple = tuple(tokens)
|
|
|
|
if model not in self._cache:
|
|
self._cache[model] = {}
|
|
current = self._cache[model]
|
|
|
|
# Build trie path
|
|
for tok in tokens_tuple:
|
|
if tok not in current:
|
|
current[tok] = {}
|
|
current = current[tok]
|
|
|
|
# Update or create entry
|
|
if "cache" in current:
|
|
current["cache"].count += 1
|
|
self._lru.remove((model, tokens_tuple))
|
|
else:
|
|
current["cache"] = CacheEntry(prompt_cache, 1)
|
|
|
|
# Update LRU order
|
|
self._lru.append((model, tokens_tuple))
|
|
|
|
# Evict if over capacity
|
|
if len(self._lru) > self.max_size:
|
|
evict_model, evict_tokens = self._lru.popleft()
|
|
self._delete(evict_model, evict_tokens)
|
|
|
|
def clear(self) -> None:
|
|
"""Clear all cache entries. Thread-safe."""
|
|
with self._lock:
|
|
self._cache.clear()
|
|
self._lru.clear()
|
|
|
|
def __len__(self) -> int:
|
|
"""Return the number of cache entries. Thread-safe."""
|
|
with self._lock:
|
|
return len(self._lru)
|