mirror of
https://github.com/mudler/LocalAI.git
synced 2026-02-12 21:19:23 -06:00
Add mlx-vlm (#6119)
* Add mlx-vlm Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add to CI workflows Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Add requirements-mps.txt Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Simplify Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
1ba66d00f5
commit
6971f71a6c
13
.github/workflows/backend.yml
vendored
13
.github/workflows/backend.yml
vendored
@@ -972,6 +972,19 @@ jobs:
|
||||
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
||||
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
||||
mlx-vlm-darwin:
|
||||
uses: ./.github/workflows/backend_build_darwin.yml
|
||||
with:
|
||||
backend: "mlx-vlm"
|
||||
build-type: "mps"
|
||||
go-version: "1.24.x"
|
||||
tag-suffix: "-metal-darwin-arm64-mlx-vlm"
|
||||
runs-on: "macOS-14"
|
||||
secrets:
|
||||
dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }}
|
||||
quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }}
|
||||
quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }}
|
||||
llama-cpp-darwin:
|
||||
runs-on: macOS-14
|
||||
strategy:
|
||||
|
||||
8
Makefile
8
Makefile
@@ -366,13 +366,17 @@ build-darwin-python-backend: build
|
||||
bash ./scripts/build/python-darwin.sh
|
||||
|
||||
backends/mlx:
|
||||
BACKEND=mlx BUILD_TYPE=mps $(MAKE) build-darwin-python-backend
|
||||
BACKEND=mlx $(MAKE) build-darwin-python-backend
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/mlx.tar)"
|
||||
|
||||
backends/diffuser-darwin:
|
||||
USE_PIP=true BACKEND=diffusers BUILD_TYPE=mps $(MAKE) build-darwin-python-backend
|
||||
BACKEND=diffusers $(MAKE) build-darwin-python-backend
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/diffusers.tar)"
|
||||
|
||||
backends/mlx-vlm:
|
||||
BACKEND=mlx-vlm $(MAKE) build-darwin-python-backend
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-vlm.tar)"
|
||||
|
||||
backend-images:
|
||||
mkdir -p backend-images
|
||||
|
||||
|
||||
@@ -142,6 +142,23 @@
|
||||
- text-to-text
|
||||
- LLM
|
||||
- MLX
|
||||
- &mlx-vlm
|
||||
name: "mlx-vlm"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx-vlm"
|
||||
icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4
|
||||
urls:
|
||||
- https://github.com/ml-explore/mlx-vlm
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-mlx-vlm
|
||||
license: MIT
|
||||
description: |
|
||||
Run Vision-Language Models with MLX
|
||||
tags:
|
||||
- text-to-text
|
||||
- multimodal
|
||||
- vision-language
|
||||
- LLM
|
||||
- MLX
|
||||
- &rerankers
|
||||
name: "rerankers"
|
||||
alias: "rerankers"
|
||||
@@ -392,6 +409,11 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-mlx
|
||||
- !!merge <<: *mlx-vlm
|
||||
name: "mlx-vlm-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-vlm"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-mlx-vlm
|
||||
- !!merge <<: *kitten-tts
|
||||
name: "kitten-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-kitten-tts"
|
||||
|
||||
23
backend/python/mlx-vlm/Makefile
Normal file
23
backend/python/mlx-vlm/Makefile
Normal file
@@ -0,0 +1,23 @@
|
||||
.PHONY: mlx-vlm
|
||||
mlx-vlm:
|
||||
bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: mlx-vlm
|
||||
@echo "Running mlx-vlm..."
|
||||
bash run.sh
|
||||
@echo "mlx run."
|
||||
|
||||
.PHONY: test
|
||||
test: mlx-vlm
|
||||
@echo "Testing mlx-vlm..."
|
||||
bash test.sh
|
||||
@echo "mlx tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
477
backend/python/mlx-vlm/backend.py
Normal file
477
backend/python/mlx-vlm/backend.py
Normal file
@@ -0,0 +1,477 @@
|
||||
#!/usr/bin/env python3
|
||||
import asyncio
|
||||
from concurrent import futures
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
from typing import List
|
||||
import time
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
from mlx_vlm import load, generate, stream_generate
|
||||
from mlx_vlm.prompt_utils import apply_chat_template
|
||||
from mlx_vlm.utils import load_config, load_image
|
||||
import mlx.core as mx
|
||||
import base64
|
||||
import io
|
||||
from PIL import Image
|
||||
import tempfile
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
A gRPC servicer that implements the Backend service defined in backend.proto.
|
||||
"""
|
||||
|
||||
def _is_float(self, s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _is_int(self, s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def Health(self, request, context):
|
||||
"""
|
||||
Returns a health check message.
|
||||
|
||||
Args:
|
||||
request: The health check request.
|
||||
context: The gRPC context.
|
||||
|
||||
Returns:
|
||||
backend_pb2.Reply: The health check reply.
|
||||
"""
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
async def LoadModel(self, request, context):
|
||||
"""
|
||||
Loads a multimodal vision-language model using MLX-VLM.
|
||||
|
||||
Args:
|
||||
request: The load model request.
|
||||
context: The gRPC context.
|
||||
|
||||
Returns:
|
||||
backend_pb2.Result: The load model result.
|
||||
"""
|
||||
try:
|
||||
print(f"Loading MLX-VLM model: {request.Model}", file=sys.stderr)
|
||||
print(f"Request: {request}", file=sys.stderr)
|
||||
|
||||
# Parse options like in the diffusers backend
|
||||
options = request.Options
|
||||
self.options = {}
|
||||
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
# We store all the options in a dict for later use
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
|
||||
|
||||
# Convert numeric values to appropriate types
|
||||
if self._is_float(value):
|
||||
value = float(value)
|
||||
elif self._is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
|
||||
self.options[key] = value
|
||||
|
||||
print(f"Options: {self.options}", file=sys.stderr)
|
||||
|
||||
# Load model and processor using MLX-VLM
|
||||
# mlx-vlm load function returns (model, processor) instead of (model, tokenizer)
|
||||
self.model, self.processor = load(request.Model)
|
||||
|
||||
# Load model config for chat template support
|
||||
self.config = load_config(request.Model)
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error loading MLX-VLM model {err=}, {type(err)=}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Error loading MLX-VLM model: {err}")
|
||||
|
||||
print("MLX-VLM model loaded successfully", file=sys.stderr)
|
||||
return backend_pb2.Result(message="MLX-VLM model loaded successfully", success=True)
|
||||
|
||||
async def Predict(self, request, context):
|
||||
"""
|
||||
Generates text based on the given prompt and sampling parameters using MLX-VLM with multimodal support.
|
||||
|
||||
Args:
|
||||
request: The predict request.
|
||||
context: The gRPC context.
|
||||
|
||||
Returns:
|
||||
backend_pb2.Reply: The predict result.
|
||||
"""
|
||||
temp_files = []
|
||||
try:
|
||||
# Process images and audios from request
|
||||
image_paths = []
|
||||
audio_paths = []
|
||||
|
||||
# Process images
|
||||
if request.Images:
|
||||
for img_data in request.Images:
|
||||
img_path = self.load_image_from_base64(img_data)
|
||||
if img_path:
|
||||
image_paths.append(img_path)
|
||||
temp_files.append(img_path)
|
||||
|
||||
# Process audios
|
||||
if request.Audios:
|
||||
for audio_data in request.Audios:
|
||||
audio_path = self.load_audio_from_base64(audio_data)
|
||||
if audio_path:
|
||||
audio_paths.append(audio_path)
|
||||
temp_files.append(audio_path)
|
||||
|
||||
# Prepare the prompt with multimodal information
|
||||
prompt = self._prepare_prompt(request, num_images=len(image_paths), num_audios=len(audio_paths))
|
||||
|
||||
# Build generation parameters using request attributes and options
|
||||
max_tokens, generation_params = self._build_generation_params(request)
|
||||
|
||||
print(f"Generating text with MLX-VLM - max_tokens: {max_tokens}, params: {generation_params}", file=sys.stderr)
|
||||
print(f"Images: {len(image_paths)}, Audios: {len(audio_paths)}", file=sys.stderr)
|
||||
|
||||
# Generate text using MLX-VLM with multimodal inputs
|
||||
response = generate(
|
||||
model=self.model,
|
||||
processor=self.processor,
|
||||
prompt=prompt,
|
||||
image=image_paths if image_paths else None,
|
||||
audio=audio_paths if audio_paths else None,
|
||||
max_tokens=max_tokens,
|
||||
temperature=generation_params.get('temp', 0.6),
|
||||
top_p=generation_params.get('top_p', 1.0),
|
||||
verbose=False
|
||||
)
|
||||
|
||||
return backend_pb2.Reply(message=bytes(response, encoding='utf-8'))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in MLX-VLM Predict: {e}", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"Generation failed: {str(e)}")
|
||||
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
||||
finally:
|
||||
# Clean up temporary files
|
||||
self.cleanup_temp_files(temp_files)
|
||||
|
||||
def Embedding(self, request, context):
|
||||
"""
|
||||
A gRPC method that calculates embeddings for a given sentence.
|
||||
|
||||
Note: MLX-VLM doesn't support embeddings directly. This method returns an error.
|
||||
|
||||
Args:
|
||||
request: An EmbeddingRequest object that contains the request parameters.
|
||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||
|
||||
Returns:
|
||||
An EmbeddingResult object that contains the calculated embeddings.
|
||||
"""
|
||||
print("Embeddings not supported in MLX-VLM backend", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
||||
context.set_details("Embeddings are not supported in the MLX-VLM backend.")
|
||||
return backend_pb2.EmbeddingResult()
|
||||
|
||||
async def PredictStream(self, request, context):
|
||||
"""
|
||||
Generates text based on the given prompt and sampling parameters, and streams the results using MLX-VLM with multimodal support.
|
||||
|
||||
Args:
|
||||
request: The predict stream request.
|
||||
context: The gRPC context.
|
||||
|
||||
Yields:
|
||||
backend_pb2.Reply: Streaming predict results.
|
||||
"""
|
||||
temp_files = []
|
||||
try:
|
||||
# Process images and audios from request
|
||||
image_paths = []
|
||||
audio_paths = []
|
||||
|
||||
# Process images
|
||||
if request.Images:
|
||||
for img_data in request.Images:
|
||||
img_path = self.load_image_from_base64(img_data)
|
||||
if img_path:
|
||||
image_paths.append(img_path)
|
||||
temp_files.append(img_path)
|
||||
|
||||
# Process audios
|
||||
if request.Audios:
|
||||
for audio_data in request.Audios:
|
||||
audio_path = self.load_audio_from_base64(audio_data)
|
||||
if audio_path:
|
||||
audio_paths.append(audio_path)
|
||||
temp_files.append(audio_path)
|
||||
|
||||
# Prepare the prompt with multimodal information
|
||||
prompt = self._prepare_prompt(request, num_images=len(image_paths), num_audios=len(audio_paths))
|
||||
|
||||
# Build generation parameters using request attributes and options
|
||||
max_tokens, generation_params = self._build_generation_params(request, default_max_tokens=512)
|
||||
|
||||
print(f"Streaming text with MLX-VLM - max_tokens: {max_tokens}, params: {generation_params}", file=sys.stderr)
|
||||
print(f"Images: {len(image_paths)}, Audios: {len(audio_paths)}", file=sys.stderr)
|
||||
|
||||
# Stream text generation using MLX-VLM with multimodal inputs
|
||||
for response in stream_generate(
|
||||
model=self.model,
|
||||
processor=self.processor,
|
||||
prompt=prompt,
|
||||
image=image_paths if image_paths else None,
|
||||
audio=audio_paths if audio_paths else None,
|
||||
max_tokens=max_tokens,
|
||||
temperature=generation_params.get('temp', 0.6),
|
||||
top_p=generation_params.get('top_p', 1.0),
|
||||
):
|
||||
yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8'))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in MLX-VLM PredictStream: {e}", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"Streaming generation failed: {str(e)}")
|
||||
yield backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
||||
finally:
|
||||
# Clean up temporary files
|
||||
self.cleanup_temp_files(temp_files)
|
||||
|
||||
def _prepare_prompt(self, request, num_images=0, num_audios=0):
|
||||
"""
|
||||
Prepare the prompt for MLX-VLM generation, handling chat templates and multimodal inputs.
|
||||
|
||||
Args:
|
||||
request: The gRPC request containing prompt and message information.
|
||||
num_images: Number of images in the request.
|
||||
num_audios: Number of audio files in the request.
|
||||
|
||||
Returns:
|
||||
str: The prepared prompt.
|
||||
"""
|
||||
# If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template
|
||||
if not request.Prompt and request.UseTokenizerTemplate and request.Messages:
|
||||
# Convert gRPC messages to the format expected by apply_chat_template
|
||||
messages = []
|
||||
for msg in request.Messages:
|
||||
messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
# Use mlx-vlm's apply_chat_template which handles multimodal inputs
|
||||
prompt = apply_chat_template(
|
||||
self.processor,
|
||||
self.config,
|
||||
messages,
|
||||
num_images=num_images,
|
||||
num_audios=num_audios
|
||||
)
|
||||
return prompt
|
||||
elif request.Prompt:
|
||||
# If we have a direct prompt but also have images/audio, we need to format it properly
|
||||
if num_images > 0 or num_audios > 0:
|
||||
# Create a simple message structure for multimodal prompt
|
||||
messages = [{"role": "user", "content": request.Prompt}]
|
||||
prompt = apply_chat_template(
|
||||
self.processor,
|
||||
self.config,
|
||||
messages,
|
||||
num_images=num_images,
|
||||
num_audios=num_audios
|
||||
)
|
||||
return prompt
|
||||
else:
|
||||
return request.Prompt
|
||||
else:
|
||||
# Fallback to empty prompt with multimodal template if we have media
|
||||
if num_images > 0 or num_audios > 0:
|
||||
messages = [{"role": "user", "content": ""}]
|
||||
prompt = apply_chat_template(
|
||||
self.processor,
|
||||
self.config,
|
||||
messages,
|
||||
num_images=num_images,
|
||||
num_audios=num_audios
|
||||
)
|
||||
return prompt
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _build_generation_params(self, request, default_max_tokens=200):
|
||||
"""
|
||||
Build generation parameters from request attributes and options for MLX-VLM.
|
||||
|
||||
Args:
|
||||
request: The gRPC request.
|
||||
default_max_tokens: Default max_tokens if not specified.
|
||||
|
||||
Returns:
|
||||
tuple: (max_tokens, generation_params dict)
|
||||
"""
|
||||
# Extract max_tokens
|
||||
max_tokens = getattr(request, 'Tokens', default_max_tokens)
|
||||
if max_tokens == 0:
|
||||
max_tokens = default_max_tokens
|
||||
|
||||
# Extract generation parameters from request attributes
|
||||
temp = getattr(request, 'Temperature', 0.0)
|
||||
if temp == 0.0:
|
||||
temp = 0.6 # Default temperature
|
||||
|
||||
top_p = getattr(request, 'TopP', 0.0)
|
||||
if top_p == 0.0:
|
||||
top_p = 1.0 # Default top_p
|
||||
|
||||
# Initialize generation parameters for MLX-VLM
|
||||
generation_params = {
|
||||
'temp': temp,
|
||||
'top_p': top_p,
|
||||
}
|
||||
|
||||
# Add seed if specified
|
||||
seed = getattr(request, 'Seed', 0)
|
||||
if seed != 0:
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Override with options if available
|
||||
if hasattr(self, 'options'):
|
||||
# Max tokens from options
|
||||
if 'max_tokens' in self.options:
|
||||
max_tokens = self.options['max_tokens']
|
||||
|
||||
# Generation parameters from options
|
||||
param_option_mapping = {
|
||||
'temp': 'temp',
|
||||
'temperature': 'temp', # alias
|
||||
'top_p': 'top_p',
|
||||
}
|
||||
|
||||
for option_key, param_key in param_option_mapping.items():
|
||||
if option_key in self.options:
|
||||
generation_params[param_key] = self.options[option_key]
|
||||
|
||||
# Handle seed from options
|
||||
if 'seed' in self.options:
|
||||
mx.random.seed(self.options['seed'])
|
||||
|
||||
return max_tokens, generation_params
|
||||
|
||||
def load_image_from_base64(self, image_data: str):
|
||||
"""
|
||||
Load an image from base64 encoded data.
|
||||
|
||||
Args:
|
||||
image_data (str): Base64 encoded image data.
|
||||
|
||||
Returns:
|
||||
PIL.Image or str: The loaded image or path to the image.
|
||||
"""
|
||||
try:
|
||||
decoded_data = base64.b64decode(image_data)
|
||||
image = Image.open(io.BytesIO(decoded_data))
|
||||
|
||||
# Save to temporary file for mlx-vlm
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file:
|
||||
image.save(tmp_file.name, format='JPEG')
|
||||
return tmp_file.name
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading image from base64: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
def load_audio_from_base64(self, audio_data: str):
|
||||
"""
|
||||
Load audio from base64 encoded data.
|
||||
|
||||
Args:
|
||||
audio_data (str): Base64 encoded audio data.
|
||||
|
||||
Returns:
|
||||
str: Path to the loaded audio file.
|
||||
"""
|
||||
try:
|
||||
decoded_data = base64.b64decode(audio_data)
|
||||
|
||||
# Save to temporary file for mlx-vlm
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file:
|
||||
tmp_file.write(decoded_data)
|
||||
return tmp_file.name
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error loading audio from base64: {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
def cleanup_temp_files(self, file_paths: List[str]):
|
||||
"""
|
||||
Clean up temporary files.
|
||||
|
||||
Args:
|
||||
file_paths (List[str]): List of file paths to clean up.
|
||||
"""
|
||||
for file_path in file_paths:
|
||||
try:
|
||||
if file_path and os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
except Exception as e:
|
||||
print(f"Error removing temporary file {file_path}: {e}", file=sys.stderr)
|
||||
|
||||
async def serve(address):
|
||||
# Start asyncio gRPC server
|
||||
server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
# Add the servicer to the server
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
# Bind the server to the address
|
||||
server.add_insecure_port(address)
|
||||
|
||||
# Gracefully shutdown the server on SIGTERM or SIGINT
|
||||
loop = asyncio.get_event_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
loop.add_signal_handler(
|
||||
sig, lambda: asyncio.ensure_future(server.stop(5))
|
||||
)
|
||||
|
||||
# Start the server
|
||||
await server.start()
|
||||
print("Server started. Listening on: " + address, file=sys.stderr)
|
||||
# Wait for the server to be terminated
|
||||
await server.wait_for_termination()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||
parser.add_argument(
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(serve(args.addr))
|
||||
14
backend/python/mlx-vlm/install.sh
Executable file
14
backend/python/mlx-vlm/install.sh
Executable file
@@ -0,0 +1,14 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
USE_PIP=true
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
1
backend/python/mlx-vlm/requirements-mps.txt
Normal file
1
backend/python/mlx-vlm/requirements-mps.txt
Normal file
@@ -0,0 +1 @@
|
||||
git+https://github.com/Blaizzy/mlx-vlm
|
||||
4
backend/python/mlx-vlm/requirements.txt
Normal file
4
backend/python/mlx-vlm/requirements.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
grpcio==1.71.0
|
||||
protobuf
|
||||
certifi
|
||||
setuptools
|
||||
11
backend/python/mlx-vlm/run.sh
Executable file
11
backend/python/mlx-vlm/run.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
146
backend/python/mlx-vlm/test.py
Normal file
146
backend/python/mlx-vlm/test.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import grpc
|
||||
import backend_pb2_grpc
|
||||
import backend_pb2
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
TestBackendServicer is the class that tests the gRPC service.
|
||||
|
||||
This class contains methods to test the startup and shutdown of the gRPC service.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"])
|
||||
time.sleep(10)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.service.terminate()
|
||||
self.service.wait()
|
||||
|
||||
def test_server_startup(self):
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b'OK')
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Server failed to start")
|
||||
finally:
|
||||
self.tearDown()
|
||||
def test_load_model(self):
|
||||
"""
|
||||
This method tests if the model is loaded successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(response.message, "Model loaded successfully")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("LoadModel service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_text(self):
|
||||
"""
|
||||
This method tests if the embeddings are generated successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
|
||||
self.assertTrue(response.success)
|
||||
req = backend_pb2.PredictOptions(Prompt="The capital of France is")
|
||||
resp = stub.Predict(req)
|
||||
self.assertIsNotNone(resp.message)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("text service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_sampling_params(self):
|
||||
"""
|
||||
This method tests if all sampling parameters are correctly processed
|
||||
NOTE: this does NOT test for correctness, just that we received a compatible response
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
req = backend_pb2.PredictOptions(
|
||||
Prompt="The capital of France is",
|
||||
TopP=0.8,
|
||||
Tokens=50,
|
||||
Temperature=0.7,
|
||||
TopK=40,
|
||||
PresencePenalty=0.1,
|
||||
FrequencyPenalty=0.2,
|
||||
RepetitionPenalty=1.1,
|
||||
MinP=0.05,
|
||||
Seed=42,
|
||||
StopPrompts=["\n"],
|
||||
StopTokenIds=[50256],
|
||||
BadWords=["badword"],
|
||||
IncludeStopStrInOutput=True,
|
||||
IgnoreEOS=True,
|
||||
MinTokens=5,
|
||||
Logprobs=5,
|
||||
PromptLogprobs=5,
|
||||
SkipSpecialTokens=True,
|
||||
SpacesBetweenSpecialTokens=True,
|
||||
TruncatePromptTokens=10,
|
||||
GuidedDecoding=True,
|
||||
N=2,
|
||||
)
|
||||
resp = stub.Predict(req)
|
||||
self.assertIsNotNone(resp.message)
|
||||
self.assertIsNotNone(resp.logprobs)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("sampling params service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
|
||||
def test_embedding(self):
|
||||
"""
|
||||
This method tests if the embeddings are generated successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="intfloat/e5-mistral-7b-instruct"))
|
||||
self.assertTrue(response.success)
|
||||
embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.")
|
||||
embedding_response = stub.Embedding(embedding_request)
|
||||
self.assertIsNotNone(embedding_response.embeddings)
|
||||
# assert that is a list of floats
|
||||
self.assertIsInstance(embedding_response.embeddings, list)
|
||||
# assert that the list is not empty
|
||||
self.assertTrue(len(embedding_response.embeddings) > 0)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Embedding service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
12
backend/python/mlx-vlm/test.sh
Executable file
12
backend/python/mlx-vlm/test.sh
Executable file
@@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
@@ -3,6 +3,8 @@
|
||||
set -ex
|
||||
|
||||
export PORTABLE_PYTHON=true
|
||||
export BUILD_TYPE=mps
|
||||
export USE_PIP=true
|
||||
IMAGE_NAME="${IMAGE_NAME:-localai/llama-cpp-darwin}"
|
||||
mkdir -p backend-images
|
||||
make -C backend/python/${BACKEND}
|
||||
|
||||
Reference in New Issue
Block a user