mirror of
https://github.com/mudler/LocalAI.git
synced 2025-12-19 08:20:13 -06:00
feat(mlx): add thread-safe LRU prompt cache and min_p/top_k sampling (#7556)
* feat(mlx): add thread-safe LRU prompt cache Port mlx-lm's LRUPromptCache to fix race condition where concurrent requests corrupt shared KV cache state. The previous implementation used a single prompt_cache instance shared across all requests. Changes: - Add backend/python/common/mlx_cache.py with ThreadSafeLRUPromptCache - Modify backend.py to use per-request cache isolation via fetch/insert - Add prefix matching for cache reuse across similar prompts - Add LRU eviction (default 10 entries, configurable) - Add concurrency and cache unit tests The cache uses a trie-based structure for efficient prefix matching, allowing prompts that share common prefixes to reuse cached KV states. Thread safety is provided via threading.Lock. New configuration options: - max_cache_entries: Maximum LRU cache entries (default: 10) - max_kv_size: Maximum KV cache size per entry (default: None) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: Blightbow <blightbow@users.noreply.github.com> * feat(mlx): add min_p and top_k sampler support Add MinP field to proto (field 52) following the precedent set by other non-OpenAI sampling parameters like TopK, TailFreeSamplingZ, TypicalP, and Mirostat. Changes: - backend.proto: Add float MinP field for min-p sampling - backend.py: Extract and pass min_p and top_k to mlx_lm sampler (top_k was in proto but not being passed) - test.py: Fix test_sampling_params to use valid proto fields and switch to MLX-compatible model (mlx-community/Llama-3.2-1B-Instruct) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: Blightbow <blightbow@users.noreply.github.com> * refactor(mlx): move mlx_cache.py from common to mlx backend The ThreadSafeLRUPromptCache is only used by the mlx backend. After evaluating mlx-vlm, it was determined that the cache cannot be shared because mlx-vlm's generate/stream_generate functions don't support the prompt_cache parameter that mlx_lm provides. - Move mlx_cache.py from backend/python/common/ to backend/python/mlx/ - Remove sys.path manipulation from backend.py and test.py - Fix test assertion to expect "MLX model loaded successfully" 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: Blightbow <blightbow@users.noreply.github.com> * test(mlx): add comprehensive cache tests and document upstream behavior Added comprehensive unit tests (test_mlx_cache.py) covering all cache operation modes: - Exact match - Shorter prefix match - Longer prefix match with trimming - No match scenarios - LRU eviction and access order - Reference counting and deep copy behavior - Multi-model namespacing - Thread safety with data integrity verification Documents upstream mlx_lm/server.py behavior: single-token prefixes are deliberately not matched (uses > 0, not >= 0) to allow longer cached sequences to be preferred for trimming. This is acceptable because real prompts with chat templates are always many tokens. Removed weak unit tests from test.py that only verified "no exception thrown" rather than correctness. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: Blightbow <blightbow@users.noreply.github.com> * chore(mlx): remove unused MinP proto field The MinP field was added to PredictOptions but is not populated by the Go frontend/API. The MLX backend uses getattr with a default value, so it works without the proto field. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: Blightbow <blightbow@users.noreply.github.com> --------- Signed-off-by: Blightbow <blightbow@users.noreply.github.com> Co-authored-by: Blightbow <blightbow@users.noreply.github.com> Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -14,11 +14,13 @@ import backend_pb2_grpc
|
||||
import grpc
|
||||
from mlx_lm import load, generate, stream_generate
|
||||
from mlx_lm.sample_utils import make_sampler
|
||||
from mlx_lm.models.cache import make_prompt_cache
|
||||
from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache
|
||||
import mlx.core as mx
|
||||
import base64
|
||||
import io
|
||||
|
||||
from mlx_cache import ThreadSafeLRUPromptCache
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
@@ -118,10 +120,16 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config)
|
||||
else:
|
||||
self.model, self.tokenizer = load(request.Model)
|
||||
|
||||
# Initialize prompt cache for efficient generation
|
||||
max_kv_size = self.options.get("max_kv_size", None)
|
||||
self.prompt_cache = make_prompt_cache(self.model, max_kv_size)
|
||||
|
||||
# Initialize thread-safe LRU prompt cache for efficient generation
|
||||
max_cache_entries = self.options.get("max_cache_entries", 10)
|
||||
self.max_kv_size = self.options.get("max_kv_size", None)
|
||||
self.model_key = request.Model
|
||||
self.lru_cache = ThreadSafeLRUPromptCache(
|
||||
max_size=max_cache_entries,
|
||||
can_trim_fn=can_trim_prompt_cache,
|
||||
trim_fn=trim_prompt_cache,
|
||||
)
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error loading MLX model {err=}, {type(err)=}", file=sys.stderr)
|
||||
@@ -134,6 +142,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
Generates text based on the given prompt and sampling parameters using MLX.
|
||||
|
||||
Uses thread-safe LRU prompt cache for efficient prefix reuse across requests.
|
||||
|
||||
Args:
|
||||
request: The predict request.
|
||||
context: The gRPC context.
|
||||
@@ -141,31 +151,48 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
Returns:
|
||||
backend_pb2.Reply: The predict result.
|
||||
"""
|
||||
prompt_cache = None
|
||||
cache_key = None
|
||||
|
||||
try:
|
||||
# Prepare the prompt
|
||||
prompt = self._prepare_prompt(request)
|
||||
|
||||
# Prepare the prompt and tokenize for cache key
|
||||
prompt_text = self._prepare_prompt(request)
|
||||
cache_key = self._get_tokens_from_prompt(prompt_text)
|
||||
|
||||
# Fetch nearest cache (exact, shorter prefix, or create new)
|
||||
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(
|
||||
self.model_key, cache_key
|
||||
)
|
||||
if prompt_cache is None:
|
||||
prompt_cache = make_prompt_cache(self.model, self.max_kv_size)
|
||||
remaining_tokens = cache_key
|
||||
|
||||
# Build generation parameters using request attributes and options
|
||||
max_tokens, sampler_params = self._build_generation_params(request)
|
||||
|
||||
print(f"Generating text with MLX - max_tokens: {max_tokens}, sampler_params: {sampler_params}", file=sys.stderr)
|
||||
|
||||
|
||||
print(f"Generating text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr)
|
||||
|
||||
# Create sampler with parameters
|
||||
sampler = make_sampler(**sampler_params)
|
||||
|
||||
# Generate text using MLX with proper parameters
|
||||
response = generate(
|
||||
|
||||
# Use stream_generate to track generated tokens for cache key
|
||||
generated_text = []
|
||||
for response in stream_generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt=prompt,
|
||||
prompt=remaining_tokens if remaining_tokens else cache_key,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
prompt_cache=self.prompt_cache,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
return backend_pb2.Reply(message=bytes(response, encoding='utf-8'))
|
||||
|
||||
prompt_cache=prompt_cache,
|
||||
):
|
||||
generated_text.append(response.text)
|
||||
cache_key.append(response.token)
|
||||
|
||||
# Insert completed cache
|
||||
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
|
||||
|
||||
return backend_pb2.Reply(message=bytes(''.join(generated_text), encoding='utf-8'))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in MLX Predict: {e}", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
@@ -194,6 +221,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
Generates text based on the given prompt and sampling parameters, and streams the results using MLX.
|
||||
|
||||
Uses thread-safe LRU prompt cache for efficient prefix reuse across requests.
|
||||
|
||||
Args:
|
||||
request: The predict stream request.
|
||||
context: The gRPC context.
|
||||
@@ -201,35 +230,56 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
Yields:
|
||||
backend_pb2.Reply: Streaming predict results.
|
||||
"""
|
||||
prompt_cache = None
|
||||
cache_key = None
|
||||
|
||||
try:
|
||||
# Prepare the prompt
|
||||
prompt = self._prepare_prompt(request)
|
||||
|
||||
# Prepare the prompt and tokenize for cache key
|
||||
prompt_text = self._prepare_prompt(request)
|
||||
cache_key = self._get_tokens_from_prompt(prompt_text)
|
||||
|
||||
# Fetch nearest cache (exact, shorter prefix, or create new)
|
||||
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(
|
||||
self.model_key, cache_key
|
||||
)
|
||||
if prompt_cache is None:
|
||||
prompt_cache = make_prompt_cache(self.model, self.max_kv_size)
|
||||
remaining_tokens = cache_key
|
||||
|
||||
# Build generation parameters using request attributes and options
|
||||
max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512)
|
||||
|
||||
print(f"Streaming text with MLX - max_tokens: {max_tokens}, sampler_params: {sampler_params}", file=sys.stderr)
|
||||
|
||||
|
||||
print(f"Streaming text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr)
|
||||
|
||||
# Create sampler with parameters
|
||||
sampler = make_sampler(**sampler_params)
|
||||
|
||||
|
||||
# Stream text generation using MLX with proper parameters
|
||||
for response in stream_generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
prompt=prompt,
|
||||
prompt=remaining_tokens if remaining_tokens else cache_key,
|
||||
max_tokens=max_tokens,
|
||||
sampler=sampler,
|
||||
prompt_cache=self.prompt_cache,
|
||||
prompt_cache=prompt_cache,
|
||||
):
|
||||
cache_key.append(response.token)
|
||||
yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in MLX PredictStream: {e}", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"Streaming generation failed: {str(e)}")
|
||||
yield backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
||||
|
||||
finally:
|
||||
# Always insert cache, even on interruption
|
||||
if prompt_cache is not None and cache_key is not None:
|
||||
try:
|
||||
self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache)
|
||||
except Exception as e:
|
||||
print(f"Error inserting cache: {e}", file=sys.stderr)
|
||||
|
||||
def _prepare_prompt(self, request):
|
||||
"""
|
||||
Prepare the prompt for MLX generation, handling chat templates if needed.
|
||||
@@ -246,16 +296,31 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
messages = []
|
||||
for msg in request.Messages:
|
||||
messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
return prompt
|
||||
else:
|
||||
return request.Prompt
|
||||
|
||||
def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]:
|
||||
"""
|
||||
Tokenize prompt text for cache key generation.
|
||||
|
||||
Args:
|
||||
prompt_text: The prompt string to tokenize.
|
||||
|
||||
Returns:
|
||||
List[int]: List of token IDs.
|
||||
"""
|
||||
tokens = self.tokenizer.encode(prompt_text)
|
||||
if hasattr(tokens, 'tolist'):
|
||||
return tokens.tolist()
|
||||
return list(tokens)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -284,11 +349,19 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
top_p = getattr(request, 'TopP', 0.0)
|
||||
if top_p == 0.0:
|
||||
top_p = 1.0 # Default top_p
|
||||
|
||||
|
||||
min_p = getattr(request, 'MinP', 0.0)
|
||||
# min_p default of 0.0 means disabled (no filtering)
|
||||
|
||||
top_k = getattr(request, 'TopK', 0)
|
||||
# top_k default of 0 means disabled (no filtering)
|
||||
|
||||
# Initialize sampler parameters
|
||||
sampler_params = {
|
||||
'temp': temp,
|
||||
'top_p': top_p,
|
||||
'min_p': min_p,
|
||||
'top_k': top_k,
|
||||
'xtc_threshold': 0.0,
|
||||
'xtc_probability': 0.0,
|
||||
}
|
||||
@@ -308,7 +381,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
sampler_option_mapping = {
|
||||
'temp': 'temp',
|
||||
'temperature': 'temp', # alias
|
||||
'top_p': 'top_p',
|
||||
'top_p': 'top_p',
|
||||
'min_p': 'min_p',
|
||||
'top_k': 'top_k',
|
||||
'xtc_threshold': 'xtc_threshold',
|
||||
'xtc_probability': 'xtc_probability',
|
||||
}
|
||||
|
||||
266
backend/python/mlx/mlx_cache.py
Normal file
266
backend/python/mlx/mlx_cache.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
Thread-safe LRU prompt cache for MLX-based backends.
|
||||
|
||||
Ported from mlx_lm/server.py (MIT License, Copyright 2023-2024 Apple Inc.)
|
||||
with thread-safety additions for LocalAI's gRPC backend.
|
||||
|
||||
Usage:
|
||||
from mlx_cache import ThreadSafeLRUPromptCache
|
||||
|
||||
# In LoadModel:
|
||||
self.lru_cache = ThreadSafeLRUPromptCache(max_size=10)
|
||||
|
||||
# In Predict/PredictStream:
|
||||
prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(model_key, tokens)
|
||||
# ... generate ...
|
||||
self.lru_cache.insert_cache(model_key, tokens, prompt_cache)
|
||||
"""
|
||||
import copy
|
||||
import threading
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""A cache entry with reference counting."""
|
||||
prompt_cache: List[Any]
|
||||
count: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class SearchResult:
|
||||
"""Result of searching the cache trie."""
|
||||
model: Any
|
||||
exact: Optional[List[int]]
|
||||
shorter: Optional[List[int]]
|
||||
longer: Optional[List[int]]
|
||||
common_prefix: int
|
||||
|
||||
|
||||
class ThreadSafeLRUPromptCache:
|
||||
"""
|
||||
Thread-safe LRU cache with prefix matching for prompt KV caches.
|
||||
|
||||
This cache stores KV caches keyed by token sequences and supports:
|
||||
- Exact match: Return the cache for the exact token sequence
|
||||
- Shorter prefix match: Return a cache for a prefix of the tokens
|
||||
- Longer prefix match: If a longer sequence is cached and can be trimmed
|
||||
- LRU eviction: When max_size is exceeded, evict least recently used
|
||||
|
||||
Thread safety is provided via a threading.Lock that protects all
|
||||
cache operations.
|
||||
|
||||
Args:
|
||||
max_size: Maximum number of cache entries (default: 10)
|
||||
can_trim_fn: Optional function to check if a cache can be trimmed
|
||||
trim_fn: Optional function to trim a cache
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_size: int = 10,
|
||||
can_trim_fn: Optional[Any] = None,
|
||||
trim_fn: Optional[Any] = None,
|
||||
):
|
||||
self.max_size = max_size
|
||||
self._cache = {}
|
||||
self._lru = deque()
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Optional trim functions (for longer prefix reuse)
|
||||
self._can_trim_fn = can_trim_fn
|
||||
self._trim_fn = trim_fn
|
||||
|
||||
def _search(self, model, tokens: List[int]) -> SearchResult:
|
||||
"""
|
||||
Search the cache for a prompt cache. Return exact or close match.
|
||||
|
||||
The cache is organized as a trie where each node is keyed by a token.
|
||||
This allows efficient prefix matching.
|
||||
"""
|
||||
if model not in self._cache:
|
||||
return SearchResult(model, None, None, None, 0)
|
||||
|
||||
current = self._cache[model]
|
||||
last_cache_index = -1
|
||||
index = 0
|
||||
|
||||
# Traverse the trie following the token sequence
|
||||
while index < len(tokens) and tokens[index] in current:
|
||||
current = current[tokens[index]]
|
||||
if "cache" in current:
|
||||
last_cache_index = index
|
||||
index += 1
|
||||
|
||||
# Exact match - no need to search for longer or shorter caches
|
||||
if last_cache_index == len(tokens) - 1:
|
||||
return SearchResult(model, tuple(tokens), None, None, 0)
|
||||
|
||||
# Find the shorter cache (a prefix that has a cache)
|
||||
# Note: Uses > 0 (not >= 0) to match upstream mlx_lm/server.py behavior.
|
||||
# Single-token prefixes are not matched, which allows longer cached
|
||||
# sequences to be preferred for trimming. This is acceptable because
|
||||
# real prompts with chat templates are always many tokens.
|
||||
shorter = None
|
||||
if last_cache_index > 0:
|
||||
shorter = tuple(tokens[: last_cache_index + 1])
|
||||
|
||||
# Check for caches that are longer than our token sequence
|
||||
longer = None
|
||||
common_prefix = index
|
||||
if index > 0 and last_cache_index <= 0:
|
||||
best = None
|
||||
stack = [(current, [])]
|
||||
while stack:
|
||||
current, extra = stack.pop()
|
||||
if "cache" in current:
|
||||
if best is None or len(extra) < len(best):
|
||||
best = extra
|
||||
else:
|
||||
for tok in current:
|
||||
stack.append((current[tok], extra + [tok]))
|
||||
if best is not None:
|
||||
longer = tuple(tokens[:index] + best)
|
||||
|
||||
return SearchResult(model, None, shorter, longer, common_prefix)
|
||||
|
||||
def _get(self, model, tokens: Tuple[int, ...]) -> CacheEntry:
|
||||
"""Get a cache entry by traversing the trie."""
|
||||
current = self._cache[model]
|
||||
for tok in tokens:
|
||||
current = current[tok]
|
||||
return current["cache"]
|
||||
|
||||
def _delete(self, model, tokens: Tuple[int, ...]) -> None:
|
||||
"""Delete a cache entry and clean up empty trie nodes."""
|
||||
path = [self._cache[model]]
|
||||
for tok in tokens:
|
||||
path.append(path[-1][tok])
|
||||
del path[-1]["cache"]
|
||||
|
||||
# Clean up empty nodes bottom-up
|
||||
for i in reversed(range(len(tokens))):
|
||||
d_prev, d, t = path[i], path[i + 1], tokens[i]
|
||||
if len(d) > 0:
|
||||
break
|
||||
del d_prev[t]
|
||||
|
||||
def _extract(self, model, tokens: Tuple[int, ...]) -> CacheEntry:
|
||||
"""
|
||||
Extract a cache entry for exclusive use.
|
||||
|
||||
If the entry has count > 1, deep copy and decrement.
|
||||
If count == 1, remove from cache entirely.
|
||||
"""
|
||||
cache_entry = self._get(model, tokens)
|
||||
if cache_entry.count == 1:
|
||||
self._delete(model, tokens)
|
||||
self._lru.remove((model, tokens))
|
||||
return cache_entry
|
||||
|
||||
cache_entry.count -= 1
|
||||
return CacheEntry(
|
||||
copy.deepcopy(cache_entry.prompt_cache),
|
||||
1,
|
||||
)
|
||||
|
||||
def fetch_nearest_cache(
|
||||
self, model, tokens: List[int]
|
||||
) -> Tuple[Optional[List[Any]], List[int]]:
|
||||
"""
|
||||
Fetch the nearest cache for the given token sequence.
|
||||
|
||||
Thread-safe. Returns (cache, remaining_tokens) where:
|
||||
- cache: The KV cache to use (or None if no cache found)
|
||||
- remaining_tokens: Tokens that still need to be processed
|
||||
|
||||
Args:
|
||||
model: Model identifier (used to namespace caches)
|
||||
tokens: The full token sequence for the prompt
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_cache, remaining_tokens)
|
||||
"""
|
||||
with self._lock:
|
||||
tokens_tuple = tuple(tokens)
|
||||
result = self._search(model, tokens)
|
||||
|
||||
# Exact match - extract and return
|
||||
if result.exact is not None:
|
||||
cache_entry = self._extract(result.model, result.exact)
|
||||
return cache_entry.prompt_cache, []
|
||||
|
||||
# Shorter prefix match - extract and return remaining
|
||||
if result.shorter is not None:
|
||||
cache_entry = self._extract(result.model, result.shorter)
|
||||
prefix_len = len(result.shorter)
|
||||
return cache_entry.prompt_cache, list(tokens[prefix_len:])
|
||||
|
||||
# Longer prefix match - try to trim if possible
|
||||
if result.longer is not None and self._can_trim_fn is not None:
|
||||
cache_entry = self._get(result.model, result.longer)
|
||||
if self._can_trim_fn(cache_entry.prompt_cache):
|
||||
# Deep copy and trim
|
||||
trimmed_cache = copy.deepcopy(cache_entry.prompt_cache)
|
||||
prefix = min(len(tokens) - 1, result.common_prefix)
|
||||
num_to_trim = len(result.longer) - prefix
|
||||
if self._trim_fn is not None:
|
||||
self._trim_fn(trimmed_cache, num_to_trim)
|
||||
return trimmed_cache, list(tokens[prefix:])
|
||||
|
||||
# No match found
|
||||
return None, list(tokens)
|
||||
|
||||
def insert_cache(
|
||||
self, model, tokens: List[int], prompt_cache: List[Any]
|
||||
) -> None:
|
||||
"""
|
||||
Insert a cache entry after generation completes.
|
||||
|
||||
Thread-safe. Handles LRU eviction if max_size is exceeded.
|
||||
|
||||
Args:
|
||||
model: Model identifier (used to namespace caches)
|
||||
tokens: The full token sequence (prompt + generated)
|
||||
prompt_cache: The KV cache to store
|
||||
"""
|
||||
with self._lock:
|
||||
tokens_tuple = tuple(tokens)
|
||||
|
||||
if model not in self._cache:
|
||||
self._cache[model] = {}
|
||||
current = self._cache[model]
|
||||
|
||||
# Build trie path
|
||||
for tok in tokens_tuple:
|
||||
if tok not in current:
|
||||
current[tok] = {}
|
||||
current = current[tok]
|
||||
|
||||
# Update or create entry
|
||||
if "cache" in current:
|
||||
current["cache"].count += 1
|
||||
self._lru.remove((model, tokens_tuple))
|
||||
else:
|
||||
current["cache"] = CacheEntry(prompt_cache, 1)
|
||||
|
||||
# Update LRU order
|
||||
self._lru.append((model, tokens_tuple))
|
||||
|
||||
# Evict if over capacity
|
||||
if len(self._lru) > self.max_size:
|
||||
evict_model, evict_tokens = self._lru.popleft()
|
||||
self._delete(evict_model, evict_tokens)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all cache entries. Thread-safe."""
|
||||
with self._lock:
|
||||
self._cache.clear()
|
||||
self._lru.clear()
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of cache entries. Thread-safe."""
|
||||
with self._lock:
|
||||
return len(self._lru)
|
||||
@@ -1,17 +1,10 @@
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import grpc
|
||||
import backend_pb2_grpc
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
@@ -47,9 +40,9 @@ class TestBackendServicer(unittest.TestCase):
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(response.message, "Model loaded successfully")
|
||||
self.assertEqual(response.message, "MLX model loaded successfully")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("LoadModel service failed")
|
||||
@@ -64,7 +57,7 @@ class TestBackendServicer(unittest.TestCase):
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
req = backend_pb2.PredictOptions(Prompt="The capital of France is")
|
||||
resp = stub.Predict(req)
|
||||
@@ -84,7 +77,7 @@ class TestBackendServicer(unittest.TestCase):
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
req = backend_pb2.PredictOptions(
|
||||
@@ -95,26 +88,13 @@ class TestBackendServicer(unittest.TestCase):
|
||||
TopK=40,
|
||||
PresencePenalty=0.1,
|
||||
FrequencyPenalty=0.2,
|
||||
RepetitionPenalty=1.1,
|
||||
MinP=0.05,
|
||||
Seed=42,
|
||||
StopPrompts=["\n"],
|
||||
StopTokenIds=[50256],
|
||||
BadWords=["badword"],
|
||||
IncludeStopStrInOutput=True,
|
||||
IgnoreEOS=True,
|
||||
MinTokens=5,
|
||||
Logprobs=5,
|
||||
PromptLogprobs=5,
|
||||
SkipSpecialTokens=True,
|
||||
SpacesBetweenSpecialTokens=True,
|
||||
TruncatePromptTokens=10,
|
||||
GuidedDecoding=True,
|
||||
N=2,
|
||||
)
|
||||
resp = stub.Predict(req)
|
||||
self.assertIsNotNone(resp.message)
|
||||
self.assertIsNotNone(resp.logprobs)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("sampling params service failed")
|
||||
@@ -143,4 +123,112 @@ class TestBackendServicer(unittest.TestCase):
|
||||
print(err)
|
||||
self.fail("Embedding service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
self.tearDown()
|
||||
|
||||
def test_concurrent_requests(self):
|
||||
"""
|
||||
This method tests that concurrent requests don't corrupt each other's cache state.
|
||||
This is a regression test for the race condition in the original implementation.
|
||||
"""
|
||||
import concurrent.futures
|
||||
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
def make_request(prompt):
|
||||
req = backend_pb2.PredictOptions(Prompt=prompt, Tokens=20)
|
||||
return stub.Predict(req)
|
||||
|
||||
# Run 5 concurrent requests with different prompts
|
||||
prompts = [
|
||||
"The capital of France is",
|
||||
"The capital of Germany is",
|
||||
"The capital of Italy is",
|
||||
"The capital of Spain is",
|
||||
"The capital of Portugal is",
|
||||
]
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
|
||||
futures = [executor.submit(make_request, p) for p in prompts]
|
||||
results = [f.result() for f in concurrent.futures.as_completed(futures)]
|
||||
|
||||
# All results should be non-empty
|
||||
messages = [r.message for r in results]
|
||||
self.assertTrue(all(len(m) > 0 for m in messages), "All requests should return non-empty responses")
|
||||
print(f"Concurrent test passed: {len(messages)} responses received")
|
||||
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Concurrent requests test failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_cache_reuse(self):
|
||||
"""
|
||||
This method tests that repeated prompts reuse cached KV states.
|
||||
The second request should benefit from the cached prompt processing.
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
prompt = "The quick brown fox jumps over the lazy dog. "
|
||||
|
||||
# First request - populates cache
|
||||
req1 = backend_pb2.PredictOptions(Prompt=prompt, Tokens=10)
|
||||
resp1 = stub.Predict(req1)
|
||||
self.assertIsNotNone(resp1.message)
|
||||
|
||||
# Second request with same prompt - should reuse cache
|
||||
req2 = backend_pb2.PredictOptions(Prompt=prompt, Tokens=10)
|
||||
resp2 = stub.Predict(req2)
|
||||
self.assertIsNotNone(resp2.message)
|
||||
|
||||
print(f"Cache reuse test passed: first={len(resp1.message)} bytes, second={len(resp2.message)} bytes")
|
||||
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Cache reuse test failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_prefix_cache_reuse(self):
|
||||
"""
|
||||
This method tests that prompts sharing a common prefix benefit from cached KV states.
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
# First request with base prompt
|
||||
prompt_base = "Once upon a time in a land far away, "
|
||||
req1 = backend_pb2.PredictOptions(Prompt=prompt_base, Tokens=10)
|
||||
resp1 = stub.Predict(req1)
|
||||
self.assertIsNotNone(resp1.message)
|
||||
|
||||
# Second request with extended prompt (same prefix)
|
||||
prompt_extended = prompt_base + "there lived a brave knight who "
|
||||
req2 = backend_pb2.PredictOptions(Prompt=prompt_extended, Tokens=10)
|
||||
resp2 = stub.Predict(req2)
|
||||
self.assertIsNotNone(resp2.message)
|
||||
|
||||
print(f"Prefix cache test passed: base={len(resp1.message)} bytes, extended={len(resp2.message)} bytes")
|
||||
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Prefix cache reuse test failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
|
||||
# Unit tests for ThreadSafeLRUPromptCache are in test_mlx_cache.py
|
||||
480
backend/python/mlx/test_mlx_cache.py
Normal file
480
backend/python/mlx/test_mlx_cache.py
Normal file
@@ -0,0 +1,480 @@
|
||||
"""
|
||||
Comprehensive unit tests for ThreadSafeLRUPromptCache.
|
||||
|
||||
Tests all cache operation modes:
|
||||
- Exact match
|
||||
- Shorter prefix match
|
||||
- Longer prefix match (with trimming)
|
||||
- No match
|
||||
- LRU eviction
|
||||
- Reference counting
|
||||
- Multi-model namespacing
|
||||
- Thread safety with data integrity verification
|
||||
"""
|
||||
import unittest
|
||||
import concurrent.futures
|
||||
import threading
|
||||
import copy
|
||||
from mlx_cache import ThreadSafeLRUPromptCache
|
||||
|
||||
|
||||
class TestCacheExactMatch(unittest.TestCase):
|
||||
"""Tests for exact match cache behavior."""
|
||||
|
||||
def setUp(self):
|
||||
self.cache = ThreadSafeLRUPromptCache(max_size=10)
|
||||
|
||||
def test_exact_match_returns_cache_and_empty_remaining(self):
|
||||
"""Exact match should return the cache with no remaining tokens."""
|
||||
tokens = [1, 2, 3, 4, 5]
|
||||
mock_cache = ["kv_cache_data"]
|
||||
|
||||
self.cache.insert_cache("model1", tokens, mock_cache)
|
||||
result_cache, remaining = self.cache.fetch_nearest_cache("model1", tokens)
|
||||
|
||||
self.assertEqual(result_cache, mock_cache)
|
||||
self.assertEqual(remaining, [])
|
||||
|
||||
def test_exact_match_extracts_and_removes_from_cache(self):
|
||||
"""Fetching exact match with count=1 should remove entry from cache."""
|
||||
tokens = [1, 2, 3]
|
||||
self.cache.insert_cache("model1", tokens, ["cache"])
|
||||
|
||||
self.assertEqual(len(self.cache), 1)
|
||||
|
||||
# First fetch extracts the entry
|
||||
self.cache.fetch_nearest_cache("model1", tokens)
|
||||
|
||||
# Cache should now be empty
|
||||
self.assertEqual(len(self.cache), 0)
|
||||
|
||||
# Second fetch should return None (no match)
|
||||
result_cache, remaining = self.cache.fetch_nearest_cache("model1", tokens)
|
||||
self.assertIsNone(result_cache)
|
||||
self.assertEqual(remaining, tokens)
|
||||
|
||||
|
||||
class TestCacheShorterPrefix(unittest.TestCase):
|
||||
"""Tests for shorter prefix match behavior."""
|
||||
|
||||
def setUp(self):
|
||||
self.cache = ThreadSafeLRUPromptCache(max_size=10)
|
||||
|
||||
def test_shorter_prefix_returns_cache_with_remaining_tokens(self):
|
||||
"""When cached prefix is shorter, return cache and remaining suffix."""
|
||||
short_tokens = [1, 2, 3]
|
||||
long_tokens = [1, 2, 3, 4, 5, 6]
|
||||
mock_cache = ["prefix_cache"]
|
||||
|
||||
self.cache.insert_cache("model1", short_tokens, mock_cache)
|
||||
result_cache, remaining = self.cache.fetch_nearest_cache("model1", long_tokens)
|
||||
|
||||
self.assertEqual(result_cache, mock_cache)
|
||||
self.assertEqual(remaining, [4, 5, 6])
|
||||
|
||||
def test_shorter_prefix_correct_remaining_calculation(self):
|
||||
"""Verify remaining tokens are calculated correctly for various prefix lengths."""
|
||||
# Note: Single-token prefixes ([1] -> [1,2,3]) are deliberately not matched
|
||||
# to allow longer cached sequences to be preferred for trimming.
|
||||
# This matches upstream mlx_lm/server.py behavior.
|
||||
test_cases = [
|
||||
# (cached_tokens, requested_tokens, expected_remaining)
|
||||
([1, 2], [1, 2, 3, 4, 5], [3, 4, 5]),
|
||||
([10, 20, 30, 40], [10, 20, 30, 40, 50], [50]),
|
||||
]
|
||||
|
||||
for cached, requested, expected_remaining in test_cases:
|
||||
with self.subTest(cached=cached, requested=requested):
|
||||
cache = ThreadSafeLRUPromptCache(max_size=10)
|
||||
cache.insert_cache("model", cached, ["cache"])
|
||||
result_cache, remaining = cache.fetch_nearest_cache("model", requested)
|
||||
|
||||
self.assertIsNotNone(result_cache)
|
||||
self.assertEqual(remaining, expected_remaining)
|
||||
|
||||
def test_single_token_prefix_not_matched(self):
|
||||
"""Single-token prefixes are not matched (by design, matches upstream).
|
||||
|
||||
This allows longer cached sequences to be preferred for trimming,
|
||||
which provides better KV cache reuse. Single-token caches are rare
|
||||
in practice since real prompts with chat templates are many tokens.
|
||||
"""
|
||||
cache = ThreadSafeLRUPromptCache(max_size=10)
|
||||
cache.insert_cache("model", [1], ["cache"])
|
||||
|
||||
result_cache, remaining = cache.fetch_nearest_cache("model", [1, 2, 3])
|
||||
|
||||
# Single-token prefix is NOT matched
|
||||
self.assertIsNone(result_cache)
|
||||
self.assertEqual(remaining, [1, 2, 3])
|
||||
|
||||
|
||||
class TestCacheLongerPrefix(unittest.TestCase):
|
||||
"""Tests for longer prefix match behavior (trimming)."""
|
||||
|
||||
def setUp(self):
|
||||
# Track trim calls for verification
|
||||
self.trim_calls = []
|
||||
|
||||
def mock_can_trim(cache):
|
||||
return True
|
||||
|
||||
def mock_trim(cache, num_to_trim):
|
||||
self.trim_calls.append(num_to_trim)
|
||||
# Simulate trimming by modifying the cache
|
||||
cache.append(f"trimmed_{num_to_trim}")
|
||||
|
||||
self.cache = ThreadSafeLRUPromptCache(
|
||||
max_size=10,
|
||||
can_trim_fn=mock_can_trim,
|
||||
trim_fn=mock_trim,
|
||||
)
|
||||
|
||||
def test_longer_prefix_triggers_trim(self):
|
||||
"""When cached sequence is longer, should trim to match requested prefix."""
|
||||
long_tokens = [1, 2, 3, 4, 5]
|
||||
short_tokens = [1, 2, 3]
|
||||
|
||||
self.cache.insert_cache("model1", long_tokens, ["original_cache"])
|
||||
result_cache, remaining = self.cache.fetch_nearest_cache("model1", short_tokens)
|
||||
|
||||
# Should have called trim
|
||||
self.assertTrue(len(self.trim_calls) > 0, "trim_fn should have been called")
|
||||
# Result should be a trimmed copy, not the original
|
||||
self.assertIn("trimmed_", str(result_cache))
|
||||
|
||||
def test_longer_prefix_without_trim_fn_returns_no_match(self):
|
||||
"""Without trim functions, longer prefix should not match."""
|
||||
cache_no_trim = ThreadSafeLRUPromptCache(max_size=10)
|
||||
|
||||
long_tokens = [1, 2, 3, 4, 5]
|
||||
short_tokens = [1, 2, 3]
|
||||
|
||||
cache_no_trim.insert_cache("model1", long_tokens, ["cache"])
|
||||
result_cache, remaining = cache_no_trim.fetch_nearest_cache("model1", short_tokens)
|
||||
|
||||
# Without trim_fn, should return no match
|
||||
self.assertIsNone(result_cache)
|
||||
self.assertEqual(remaining, short_tokens)
|
||||
|
||||
def test_longer_prefix_can_trim_false_returns_no_match(self):
|
||||
"""When can_trim_fn returns False, should not attempt trim."""
|
||||
cache = ThreadSafeLRUPromptCache(
|
||||
max_size=10,
|
||||
can_trim_fn=lambda c: False,
|
||||
trim_fn=lambda c, n: None,
|
||||
)
|
||||
|
||||
cache.insert_cache("model1", [1, 2, 3, 4, 5], ["cache"])
|
||||
result_cache, remaining = cache.fetch_nearest_cache("model1", [1, 2, 3])
|
||||
|
||||
self.assertIsNone(result_cache)
|
||||
self.assertEqual(remaining, [1, 2, 3])
|
||||
|
||||
|
||||
class TestCacheNoMatch(unittest.TestCase):
|
||||
"""Tests for no match behavior."""
|
||||
|
||||
def setUp(self):
|
||||
self.cache = ThreadSafeLRUPromptCache(max_size=10)
|
||||
|
||||
def test_empty_cache_returns_none(self):
|
||||
"""Empty cache should return None and all tokens as remaining."""
|
||||
tokens = [1, 2, 3]
|
||||
result_cache, remaining = self.cache.fetch_nearest_cache("model1", tokens)
|
||||
|
||||
self.assertIsNone(result_cache)
|
||||
self.assertEqual(remaining, tokens)
|
||||
|
||||
def test_different_prefix_returns_none(self):
|
||||
"""Tokens with different prefix should not match."""
|
||||
self.cache.insert_cache("model1", [1, 2, 3], ["cache"])
|
||||
|
||||
# Completely different tokens
|
||||
result_cache, remaining = self.cache.fetch_nearest_cache("model1", [4, 5, 6])
|
||||
|
||||
self.assertIsNone(result_cache)
|
||||
self.assertEqual(remaining, [4, 5, 6])
|
||||
|
||||
def test_partial_prefix_mismatch_returns_none(self):
|
||||
"""Tokens that diverge mid-sequence should not match."""
|
||||
self.cache.insert_cache("model1", [1, 2, 3], ["cache"])
|
||||
|
||||
# Same start but diverges
|
||||
result_cache, remaining = self.cache.fetch_nearest_cache("model1", [1, 2, 99])
|
||||
|
||||
self.assertIsNone(result_cache)
|
||||
self.assertEqual(remaining, [1, 2, 99])
|
||||
|
||||
def test_wrong_model_returns_none(self):
|
||||
"""Different model key should not match."""
|
||||
self.cache.insert_cache("model1", [1, 2, 3], ["cache"])
|
||||
|
||||
result_cache, remaining = self.cache.fetch_nearest_cache("model2", [1, 2, 3])
|
||||
|
||||
self.assertIsNone(result_cache)
|
||||
self.assertEqual(remaining, [1, 2, 3])
|
||||
|
||||
|
||||
class TestCacheLRUEviction(unittest.TestCase):
|
||||
"""Tests for LRU eviction behavior."""
|
||||
|
||||
def setUp(self):
|
||||
self.cache = ThreadSafeLRUPromptCache(max_size=3)
|
||||
|
||||
def test_evicts_oldest_when_full(self):
|
||||
"""Should evict least recently used entry when capacity exceeded."""
|
||||
self.cache.insert_cache("model", [1], ["cache1"])
|
||||
self.cache.insert_cache("model", [2], ["cache2"])
|
||||
self.cache.insert_cache("model", [3], ["cache3"])
|
||||
|
||||
self.assertEqual(len(self.cache), 3)
|
||||
|
||||
# Insert 4th entry - should evict [1]
|
||||
self.cache.insert_cache("model", [4], ["cache4"])
|
||||
|
||||
self.assertEqual(len(self.cache), 3)
|
||||
|
||||
# [1] should be evicted
|
||||
result, _ = self.cache.fetch_nearest_cache("model", [1])
|
||||
self.assertIsNone(result)
|
||||
|
||||
# [2], [3], [4] should still exist
|
||||
for tokens in [[2], [3], [4]]:
|
||||
# Re-insert since fetch extracts
|
||||
self.cache.insert_cache("model", tokens, [f"cache{tokens[0]}"])
|
||||
|
||||
result2, _ = self.cache.fetch_nearest_cache("model", [2])
|
||||
self.assertIsNotNone(result2)
|
||||
|
||||
def test_access_updates_lru_order(self):
|
||||
"""Accessing an entry should move it to most recently used."""
|
||||
self.cache.insert_cache("model", [1], ["cache1"])
|
||||
self.cache.insert_cache("model", [2], ["cache2"])
|
||||
self.cache.insert_cache("model", [3], ["cache3"])
|
||||
|
||||
# Access [1] to make it most recently used
|
||||
cache1, _ = self.cache.fetch_nearest_cache("model", [1])
|
||||
# Re-insert it (simulating normal usage pattern)
|
||||
self.cache.insert_cache("model", [1], cache1)
|
||||
|
||||
# Now insert two more entries - should evict [2] then [3], not [1]
|
||||
self.cache.insert_cache("model", [4], ["cache4"])
|
||||
self.cache.insert_cache("model", [5], ["cache5"])
|
||||
|
||||
# [1] should still exist (was accessed, so not evicted)
|
||||
result1, _ = self.cache.fetch_nearest_cache("model", [1])
|
||||
self.assertIsNotNone(result1)
|
||||
|
||||
# [2] should be evicted (was oldest after [1] was accessed)
|
||||
result2, _ = self.cache.fetch_nearest_cache("model", [2])
|
||||
self.assertIsNone(result2)
|
||||
|
||||
|
||||
class TestCacheReferenceCount(unittest.TestCase):
|
||||
"""Tests for reference counting behavior."""
|
||||
|
||||
def setUp(self):
|
||||
self.cache = ThreadSafeLRUPromptCache(max_size=10)
|
||||
|
||||
def test_multiple_inserts_increment_count(self):
|
||||
"""Inserting same tokens multiple times should increment count."""
|
||||
tokens = [1, 2, 3]
|
||||
|
||||
self.cache.insert_cache("model", tokens, ["cache"])
|
||||
self.cache.insert_cache("model", tokens, ["cache"])
|
||||
self.cache.insert_cache("model", tokens, ["cache"])
|
||||
|
||||
# Should still be one entry (with count=3 internally)
|
||||
self.assertEqual(len(self.cache), 1)
|
||||
|
||||
# First two fetches should return copies (count decremented)
|
||||
result1, _ = self.cache.fetch_nearest_cache("model", tokens)
|
||||
self.assertIsNotNone(result1)
|
||||
|
||||
result2, _ = self.cache.fetch_nearest_cache("model", tokens)
|
||||
self.assertIsNotNone(result2)
|
||||
|
||||
# Third fetch extracts the last reference
|
||||
result3, _ = self.cache.fetch_nearest_cache("model", tokens)
|
||||
self.assertIsNotNone(result3)
|
||||
|
||||
# Fourth fetch should return None (entry fully extracted)
|
||||
result4, _ = self.cache.fetch_nearest_cache("model", tokens)
|
||||
self.assertIsNone(result4)
|
||||
|
||||
def test_extract_with_high_count_returns_deep_copy(self):
|
||||
"""When count > 1, extract should return a deep copy."""
|
||||
tokens = [1, 2, 3]
|
||||
original_cache = [{"nested": "data"}]
|
||||
|
||||
self.cache.insert_cache("model", tokens, original_cache)
|
||||
self.cache.insert_cache("model", tokens, original_cache) # count=2
|
||||
|
||||
result1, _ = self.cache.fetch_nearest_cache("model", tokens)
|
||||
|
||||
# Modify the returned cache
|
||||
result1[0]["nested"] = "modified"
|
||||
|
||||
# Second fetch should get unmodified copy
|
||||
result2, _ = self.cache.fetch_nearest_cache("model", tokens)
|
||||
self.assertEqual(result2[0]["nested"], "data")
|
||||
|
||||
|
||||
class TestCacheMultiModel(unittest.TestCase):
|
||||
"""Tests for multi-model namespacing."""
|
||||
|
||||
def setUp(self):
|
||||
self.cache = ThreadSafeLRUPromptCache(max_size=10)
|
||||
|
||||
def test_same_tokens_different_models_are_separate(self):
|
||||
"""Same token sequence under different models should be independent."""
|
||||
tokens = [1, 2, 3]
|
||||
|
||||
self.cache.insert_cache("model_a", tokens, ["cache_a"])
|
||||
self.cache.insert_cache("model_b", tokens, ["cache_b"])
|
||||
|
||||
self.assertEqual(len(self.cache), 2)
|
||||
|
||||
result_a, _ = self.cache.fetch_nearest_cache("model_a", tokens)
|
||||
result_b, _ = self.cache.fetch_nearest_cache("model_b", tokens)
|
||||
|
||||
self.assertEqual(result_a, ["cache_a"])
|
||||
self.assertEqual(result_b, ["cache_b"])
|
||||
|
||||
def test_eviction_across_models(self):
|
||||
"""LRU eviction should work across different models."""
|
||||
cache = ThreadSafeLRUPromptCache(max_size=3)
|
||||
|
||||
cache.insert_cache("model_a", [1], ["a1"])
|
||||
cache.insert_cache("model_b", [1], ["b1"])
|
||||
cache.insert_cache("model_a", [2], ["a2"])
|
||||
|
||||
self.assertEqual(len(cache), 3)
|
||||
|
||||
# Insert 4th - should evict model_a:[1] (oldest)
|
||||
cache.insert_cache("model_b", [2], ["b2"])
|
||||
|
||||
result, _ = cache.fetch_nearest_cache("model_a", [1])
|
||||
self.assertIsNone(result)
|
||||
|
||||
|
||||
class TestCacheThreadSafety(unittest.TestCase):
|
||||
"""Tests for thread safety with data integrity verification."""
|
||||
|
||||
def test_concurrent_inserts_no_data_loss(self):
|
||||
"""Concurrent inserts should not lose data."""
|
||||
cache = ThreadSafeLRUPromptCache(max_size=100)
|
||||
num_threads = 10
|
||||
inserts_per_thread = 20
|
||||
|
||||
def insert_entries(thread_id):
|
||||
for i in range(inserts_per_thread):
|
||||
tokens = [thread_id, i]
|
||||
cache.insert_cache("model", tokens, [f"cache_{thread_id}_{i}"])
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(insert_entries, tid) for tid in range(num_threads)]
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
# Verify expected number of entries (may be less due to LRU eviction with max_size=100)
|
||||
# But should be exactly 100 since we inserted exactly 200 and max_size is 100
|
||||
self.assertEqual(len(cache), 100)
|
||||
|
||||
def test_concurrent_fetch_and_insert_no_corruption(self):
|
||||
"""Concurrent fetches and inserts should not corrupt data."""
|
||||
cache = ThreadSafeLRUPromptCache(max_size=50)
|
||||
errors = []
|
||||
lock = threading.Lock()
|
||||
|
||||
# Pre-populate with known data
|
||||
for i in range(20):
|
||||
cache.insert_cache("model", [i], [f"original_{i}"])
|
||||
|
||||
def fetch_and_verify(thread_id):
|
||||
try:
|
||||
for _ in range(50):
|
||||
token_id = thread_id % 20
|
||||
result, remaining = cache.fetch_nearest_cache("model", [token_id])
|
||||
|
||||
if result is not None:
|
||||
# Verify data integrity
|
||||
expected_prefix = f"original_{token_id}"
|
||||
if not str(result[0]).startswith("original_"):
|
||||
with lock:
|
||||
errors.append(f"Corrupted data: {result}")
|
||||
|
||||
# Re-insert to keep cache populated
|
||||
cache.insert_cache("model", [token_id], result)
|
||||
|
||||
except Exception as e:
|
||||
with lock:
|
||||
errors.append(str(e))
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [executor.submit(fetch_and_verify, tid) for tid in range(10)]
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
self.assertEqual(errors, [], f"Thread safety errors: {errors}")
|
||||
|
||||
def test_concurrent_operations_maintain_cache_bounds(self):
|
||||
"""Cache size should never exceed max_size under concurrent operations."""
|
||||
max_size = 10
|
||||
cache = ThreadSafeLRUPromptCache(max_size=max_size)
|
||||
size_violations = []
|
||||
lock = threading.Lock()
|
||||
|
||||
def random_operations(thread_id):
|
||||
import random
|
||||
for i in range(100):
|
||||
tokens = [random.randint(0, 50)]
|
||||
if random.random() < 0.7:
|
||||
cache.insert_cache("model", tokens, [f"cache_{thread_id}_{i}"])
|
||||
else:
|
||||
cache.fetch_nearest_cache("model", tokens)
|
||||
|
||||
current_size = len(cache)
|
||||
if current_size > max_size:
|
||||
with lock:
|
||||
size_violations.append(current_size)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = [executor.submit(random_operations, tid) for tid in range(10)]
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
self.assertEqual(size_violations, [], f"Size exceeded max: {size_violations}")
|
||||
self.assertLessEqual(len(cache), max_size)
|
||||
|
||||
|
||||
class TestCacheClear(unittest.TestCase):
|
||||
"""Tests for cache clear operation."""
|
||||
|
||||
def setUp(self):
|
||||
self.cache = ThreadSafeLRUPromptCache(max_size=10)
|
||||
|
||||
def test_clear_removes_all_entries(self):
|
||||
"""Clear should remove all entries."""
|
||||
self.cache.insert_cache("model1", [1, 2], ["cache1"])
|
||||
self.cache.insert_cache("model2", [3, 4], ["cache2"])
|
||||
self.cache.insert_cache("model1", [5, 6], ["cache3"])
|
||||
|
||||
self.assertEqual(len(self.cache), 3)
|
||||
|
||||
self.cache.clear()
|
||||
|
||||
self.assertEqual(len(self.cache), 0)
|
||||
|
||||
def test_clear_allows_new_inserts(self):
|
||||
"""After clear, new inserts should work normally."""
|
||||
self.cache.insert_cache("model", [1], ["cache1"])
|
||||
self.cache.clear()
|
||||
self.cache.insert_cache("model", [2], ["cache2"])
|
||||
|
||||
self.assertEqual(len(self.cache), 1)
|
||||
|
||||
result, _ = self.cache.fetch_nearest_cache("model", [2])
|
||||
self.assertEqual(result, ["cache2"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user