mirror of
https://github.com/mudler/LocalAI.git
synced 2025-12-30 22:20:20 -06:00
feat(mlx-audio): Add mlx-audio backend (#6138)
* feat(mlx-audio): Add mlx-audio backend Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * improve loading Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * CI tests Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * fix: set C_INCLUDE_PATH to point to python install Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
f8a8cf3e95
commit
3c3f477854
2310
.github/workflows/backend.yml
vendored
2310
.github/workflows/backend.yml
vendored
File diff suppressed because it is too large
Load Diff
4
Makefile
4
Makefile
@@ -388,6 +388,10 @@ backends/mlx-vlm:
|
||||
BACKEND=mlx-vlm $(MAKE) build-darwin-python-backend
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-vlm.tar)"
|
||||
|
||||
backends/mlx-audio:
|
||||
BACKEND=mlx-audio $(MAKE) build-darwin-python-backend
|
||||
./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-audio.tar)"
|
||||
|
||||
backend-images:
|
||||
mkdir -p backend-images
|
||||
|
||||
|
||||
@@ -159,6 +159,23 @@
|
||||
- vision-language
|
||||
- LLM
|
||||
- MLX
|
||||
- &mlx-audio
|
||||
name: "mlx-audio"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx-audio"
|
||||
icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4
|
||||
urls:
|
||||
- https://github.com/Blaizzy/mlx-audio
|
||||
mirrors:
|
||||
- localai/localai-backends:latest-metal-darwin-arm64-mlx-audio
|
||||
license: MIT
|
||||
description: |
|
||||
Run Audio Models with MLX
|
||||
tags:
|
||||
- audio-to-text
|
||||
- audio-generation
|
||||
- text-to-audio
|
||||
- LLM
|
||||
- MLX
|
||||
- &rerankers
|
||||
name: "rerankers"
|
||||
alias: "rerankers"
|
||||
@@ -415,6 +432,11 @@
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-vlm"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-mlx-vlm
|
||||
- !!merge <<: *mlx-audio
|
||||
name: "mlx-audio-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-audio"
|
||||
mirrors:
|
||||
- localai/localai-backends:master-metal-darwin-arm64-mlx-audio
|
||||
- !!merge <<: *kitten-tts
|
||||
name: "kitten-tts-development"
|
||||
uri: "quay.io/go-skynet/local-ai-backends:master-kitten-tts"
|
||||
|
||||
@@ -384,6 +384,11 @@ function installRequirements() {
|
||||
requirementFiles+=("${EDIR}/requirements-${BUILD_PROFILE}-after.txt")
|
||||
fi
|
||||
|
||||
# This is needed to build wheels that e.g. depends on Python.h
|
||||
if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then
|
||||
export C_INCLUDE_PATH="${C_INCLUDE_PATH:-}:$(_portable_dir)/include/python${PYTHON_VERSION}"
|
||||
fi
|
||||
|
||||
for reqFile in ${requirementFiles[@]}; do
|
||||
if [ -f "${reqFile}" ]; then
|
||||
echo "starting requirements install for ${reqFile}"
|
||||
|
||||
23
backend/python/mlx-audio/Makefile
Normal file
23
backend/python/mlx-audio/Makefile
Normal file
@@ -0,0 +1,23 @@
|
||||
.PHONY: mlx-audio
|
||||
mlx-audio:
|
||||
bash install.sh
|
||||
|
||||
.PHONY: run
|
||||
run: mlx-audio
|
||||
@echo "Running mlx-audio..."
|
||||
bash run.sh
|
||||
@echo "mlx run."
|
||||
|
||||
.PHONY: test
|
||||
test: mlx-audio
|
||||
@echo "Testing mlx-audio..."
|
||||
bash test.sh
|
||||
@echo "mlx tested."
|
||||
|
||||
.PHONY: protogen-clean
|
||||
protogen-clean:
|
||||
$(RM) backend_pb2_grpc.py backend_pb2.py
|
||||
|
||||
.PHONY: clean
|
||||
clean: protogen-clean
|
||||
rm -rf venv __pycache__
|
||||
466
backend/python/mlx-audio/backend.py
Normal file
466
backend/python/mlx-audio/backend.py
Normal file
@@ -0,0 +1,466 @@
|
||||
#!/usr/bin/env python3
|
||||
import asyncio
|
||||
from concurrent import futures
|
||||
import argparse
|
||||
import signal
|
||||
import sys
|
||||
import os
|
||||
import shutil
|
||||
import glob
|
||||
from typing import List
|
||||
import time
|
||||
import tempfile
|
||||
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
from mlx_audio.tts.utils import load_model
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
import uuid
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"""
|
||||
A gRPC servicer that implements the Backend service defined in backend.proto.
|
||||
This backend provides TTS (Text-to-Speech) functionality using MLX-Audio.
|
||||
"""
|
||||
|
||||
def _is_float(self, s):
|
||||
"""Check if a string can be converted to float."""
|
||||
try:
|
||||
float(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _is_int(self, s):
|
||||
"""Check if a string can be converted to int."""
|
||||
try:
|
||||
int(s)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def Health(self, request, context):
|
||||
"""
|
||||
Returns a health check message.
|
||||
|
||||
Args:
|
||||
request: The health check request.
|
||||
context: The gRPC context.
|
||||
|
||||
Returns:
|
||||
backend_pb2.Reply: The health check reply.
|
||||
"""
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
async def LoadModel(self, request, context):
|
||||
"""
|
||||
Loads a TTS model using MLX-Audio.
|
||||
|
||||
Args:
|
||||
request: The load model request.
|
||||
context: The gRPC context.
|
||||
|
||||
Returns:
|
||||
backend_pb2.Result: The load model result.
|
||||
"""
|
||||
try:
|
||||
print(f"Loading MLX-Audio TTS model: {request.Model}", file=sys.stderr)
|
||||
print(f"Request: {request}", file=sys.stderr)
|
||||
|
||||
# Parse options like in the kokoro backend
|
||||
options = request.Options
|
||||
self.options = {}
|
||||
|
||||
# The options are a list of strings in this form optname:optvalue
|
||||
# We store all the options in a dict for later use
|
||||
for opt in options:
|
||||
if ":" not in opt:
|
||||
continue
|
||||
key, value = opt.split(":", 1) # Split only on first colon to handle values with colons
|
||||
|
||||
# Convert numeric values to appropriate types
|
||||
if self._is_float(value):
|
||||
value = float(value)
|
||||
elif self._is_int(value):
|
||||
value = int(value)
|
||||
elif value.lower() in ["true", "false"]:
|
||||
value = value.lower() == "true"
|
||||
|
||||
self.options[key] = value
|
||||
|
||||
print(f"Options: {self.options}", file=sys.stderr)
|
||||
|
||||
# Load the model using MLX-Audio's load_model function
|
||||
try:
|
||||
self.tts_model = load_model(request.Model)
|
||||
self.model_path = request.Model
|
||||
print(f"TTS model loaded successfully from {request.Model}", file=sys.stderr)
|
||||
except Exception as model_err:
|
||||
print(f"Error loading TTS model: {model_err}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Failed to load model: {model_err}")
|
||||
|
||||
except Exception as err:
|
||||
print(f"Error loading MLX-Audio TTS model {err=}, {type(err)=}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Error loading MLX-Audio TTS model: {err}")
|
||||
|
||||
print("MLX-Audio TTS model loaded successfully", file=sys.stderr)
|
||||
return backend_pb2.Result(message="MLX-Audio TTS model loaded successfully", success=True)
|
||||
|
||||
def TTS(self, request, context):
|
||||
"""
|
||||
Generates TTS audio from text using MLX-Audio.
|
||||
|
||||
Args:
|
||||
request: A TTSRequest object containing text, model, destination, voice, and language.
|
||||
context: A grpc.ServicerContext object that provides information about the RPC.
|
||||
|
||||
Returns:
|
||||
A Result object indicating success or failure.
|
||||
"""
|
||||
try:
|
||||
# Check if model is loaded
|
||||
if not hasattr(self, 'tts_model') or self.tts_model is None:
|
||||
return backend_pb2.Result(success=False, message="TTS model not loaded. Please call LoadModel first.")
|
||||
|
||||
print(f"Generating TTS with MLX-Audio - text: {request.text[:50]}..., voice: {request.voice}, language: {request.language}", file=sys.stderr)
|
||||
|
||||
# Handle speed parameter based on model type
|
||||
speed_value = self._handle_speed_parameter(request, self.model_path)
|
||||
|
||||
# Map language names to codes if needed
|
||||
lang_code = self._map_language_code(request.language, request.voice)
|
||||
|
||||
# Prepare generation parameters
|
||||
gen_params = {
|
||||
"text": request.text,
|
||||
"speed": speed_value,
|
||||
"verbose": False,
|
||||
}
|
||||
|
||||
# Add model-specific parameters
|
||||
if request.voice and request.voice.strip():
|
||||
gen_params["voice"] = request.voice
|
||||
|
||||
# Check if model supports language codes (primarily Kokoro)
|
||||
if "kokoro" in self.model_path.lower():
|
||||
gen_params["lang_code"] = lang_code
|
||||
|
||||
# Add pitch and gender for Spark models
|
||||
if "spark" in self.model_path.lower():
|
||||
gen_params["pitch"] = 1.0 # Default to moderate
|
||||
gen_params["gender"] = "female" # Default to female
|
||||
|
||||
print(f"Generation parameters: {gen_params}", file=sys.stderr)
|
||||
|
||||
# Generate audio using the loaded model
|
||||
try:
|
||||
results = self.tts_model.generate(**gen_params)
|
||||
except Exception as gen_err:
|
||||
print(f"Error during TTS generation: {gen_err}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"TTS generation failed: {gen_err}")
|
||||
|
||||
# Process the generated audio segments
|
||||
audio_arrays = []
|
||||
for segment in results:
|
||||
audio_arrays.append(segment.audio)
|
||||
|
||||
# If no segments, return error
|
||||
if not audio_arrays:
|
||||
print("No audio segments generated", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message="No audio generated")
|
||||
|
||||
# Concatenate all segments
|
||||
cat_audio = np.concatenate(audio_arrays, axis=0)
|
||||
|
||||
# Generate output filename and path
|
||||
if request.dst:
|
||||
output_path = request.dst
|
||||
else:
|
||||
unique_id = str(uuid.uuid4())
|
||||
filename = f"tts_{unique_id}.wav"
|
||||
output_path = filename
|
||||
|
||||
# Write the audio as a WAV
|
||||
try:
|
||||
sf.write(output_path, cat_audio, 24000)
|
||||
print(f"Successfully wrote audio file to {output_path}", file=sys.stderr)
|
||||
|
||||
# Verify the file exists and has content
|
||||
if not os.path.exists(output_path):
|
||||
print(f"File was not created at {output_path}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message="Failed to create audio file")
|
||||
|
||||
file_size = os.path.getsize(output_path)
|
||||
if file_size == 0:
|
||||
print("File was created but is empty", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message="Generated audio file is empty")
|
||||
|
||||
print(f"Audio file size: {file_size} bytes", file=sys.stderr)
|
||||
|
||||
except Exception as write_err:
|
||||
print(f"Error writing audio file: {write_err}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"Failed to save audio: {write_err}")
|
||||
|
||||
return backend_pb2.Result(success=True, message=f"TTS audio generated successfully: {output_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in MLX-Audio TTS: {e}", file=sys.stderr)
|
||||
return backend_pb2.Result(success=False, message=f"TTS generation failed: {str(e)}")
|
||||
|
||||
async def Predict(self, request, context):
|
||||
"""
|
||||
Generates TTS audio based on the given prompt using MLX-Audio TTS.
|
||||
This is a fallback method for compatibility with the Predict endpoint.
|
||||
|
||||
Args:
|
||||
request: The predict request.
|
||||
context: The gRPC context.
|
||||
|
||||
Returns:
|
||||
backend_pb2.Reply: The predict result.
|
||||
"""
|
||||
try:
|
||||
# Check if model is loaded
|
||||
if not hasattr(self, 'tts_model') or self.tts_model is None:
|
||||
context.set_code(grpc.StatusCode.FAILED_PRECONDITION)
|
||||
context.set_details("TTS model not loaded. Please call LoadModel first.")
|
||||
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
||||
|
||||
# For TTS, we expect the prompt to contain the text to synthesize
|
||||
if not request.Prompt:
|
||||
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
|
||||
context.set_details("Prompt is required for TTS generation")
|
||||
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
||||
|
||||
# Handle speed parameter based on model type
|
||||
speed_value = self._handle_speed_parameter(request, self.model_path)
|
||||
|
||||
# Map language names to codes if needed
|
||||
lang_code = self._map_language_code(None, None) # Use defaults for Predict
|
||||
|
||||
# Prepare generation parameters
|
||||
gen_params = {
|
||||
"text": request.Prompt,
|
||||
"speed": speed_value,
|
||||
"verbose": False,
|
||||
}
|
||||
|
||||
# Add model-specific parameters
|
||||
if hasattr(self, 'options') and 'voice' in self.options:
|
||||
gen_params["voice"] = self.options['voice']
|
||||
|
||||
# Check if model supports language codes (primarily Kokoro)
|
||||
if "kokoro" in self.model_path.lower():
|
||||
gen_params["lang_code"] = lang_code
|
||||
|
||||
print(f"Generating TTS with MLX-Audio - text: {request.Prompt[:50]}..., params: {gen_params}", file=sys.stderr)
|
||||
|
||||
# Generate audio using the loaded model
|
||||
try:
|
||||
results = self.tts_model.generate(**gen_params)
|
||||
except Exception as gen_err:
|
||||
print(f"Error during TTS generation: {gen_err}", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"TTS generation failed: {gen_err}")
|
||||
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
||||
|
||||
# Process the generated audio segments
|
||||
audio_arrays = []
|
||||
for segment in results:
|
||||
audio_arrays.append(segment.audio)
|
||||
|
||||
# If no segments, return error
|
||||
if not audio_arrays:
|
||||
print("No audio segments generated", file=sys.stderr)
|
||||
return backend_pb2.Reply(message=bytes("No audio generated", encoding='utf-8'))
|
||||
|
||||
# Concatenate all segments
|
||||
cat_audio = np.concatenate(audio_arrays, axis=0)
|
||||
duration = len(cat_audio) / 24000 # Assuming 24kHz sample rate
|
||||
|
||||
# Return success message with audio information
|
||||
response = f"TTS audio generated successfully. Duration: {duration:.2f}s, Sample rate: 24000Hz"
|
||||
return backend_pb2.Reply(message=bytes(response, encoding='utf-8'))
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in MLX-Audio TTS Predict: {e}", file=sys.stderr)
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
context.set_details(f"TTS generation failed: {str(e)}")
|
||||
return backend_pb2.Reply(message=bytes("", encoding='utf-8'))
|
||||
|
||||
def _handle_speed_parameter(self, request, model_path):
|
||||
"""
|
||||
Handle speed parameter based on model type.
|
||||
|
||||
Args:
|
||||
request: The TTSRequest object.
|
||||
model_path: The model path to determine model type.
|
||||
|
||||
Returns:
|
||||
float: The processed speed value.
|
||||
"""
|
||||
# Get speed from options if available
|
||||
speed = 1.0
|
||||
if hasattr(self, 'options') and 'speed' in self.options:
|
||||
speed = self.options['speed']
|
||||
|
||||
# Handle speed parameter based on model type
|
||||
if "spark" in model_path.lower():
|
||||
# Spark actually expects float values that map to speed descriptions
|
||||
speed_map = {
|
||||
"very_low": 0.0,
|
||||
"low": 0.5,
|
||||
"moderate": 1.0,
|
||||
"high": 1.5,
|
||||
"very_high": 2.0,
|
||||
}
|
||||
if isinstance(speed, str) and speed in speed_map:
|
||||
speed_value = speed_map[speed]
|
||||
else:
|
||||
# Try to use as float, default to 1.0 (moderate) if invalid
|
||||
try:
|
||||
speed_value = float(speed)
|
||||
if speed_value not in [0.0, 0.5, 1.0, 1.5, 2.0]:
|
||||
speed_value = 1.0 # Default to moderate
|
||||
except:
|
||||
speed_value = 1.0 # Default to moderate
|
||||
else:
|
||||
# Other models use float speed values
|
||||
try:
|
||||
speed_value = float(speed)
|
||||
if speed_value < 0.5 or speed_value > 2.0:
|
||||
speed_value = 1.0 # Default to 1.0 if out of range
|
||||
except ValueError:
|
||||
speed_value = 1.0 # Default to 1.0 if invalid
|
||||
|
||||
return speed_value
|
||||
|
||||
def _map_language_code(self, language, voice):
|
||||
"""
|
||||
Map language names to codes if needed.
|
||||
|
||||
Args:
|
||||
language: The language parameter from the request.
|
||||
voice: The voice parameter from the request.
|
||||
|
||||
Returns:
|
||||
str: The language code.
|
||||
"""
|
||||
if not language:
|
||||
# Default to voice[0] if not found
|
||||
return voice[0] if voice else "a"
|
||||
|
||||
# Map language names to codes if needed
|
||||
language_map = {
|
||||
"american_english": "a",
|
||||
"british_english": "b",
|
||||
"spanish": "e",
|
||||
"french": "f",
|
||||
"hindi": "h",
|
||||
"italian": "i",
|
||||
"portuguese": "p",
|
||||
"japanese": "j",
|
||||
"mandarin_chinese": "z",
|
||||
# Also accept direct language codes
|
||||
"a": "a", "b": "b", "e": "e", "f": "f", "h": "h", "i": "i", "p": "p", "j": "j", "z": "z",
|
||||
}
|
||||
|
||||
return language_map.get(language.lower(), language)
|
||||
|
||||
def _build_generation_params(self, request, default_speed=1.0):
|
||||
"""
|
||||
Build generation parameters from request attributes and options for MLX-Audio TTS.
|
||||
|
||||
Args:
|
||||
request: The gRPC request.
|
||||
default_speed: Default speed if not specified.
|
||||
|
||||
Returns:
|
||||
dict: Generation parameters for MLX-Audio
|
||||
"""
|
||||
# Initialize generation parameters for MLX-Audio TTS
|
||||
generation_params = {
|
||||
'speed': default_speed,
|
||||
'voice': 'af_heart', # Default voice
|
||||
'lang_code': 'a', # Default language code
|
||||
}
|
||||
|
||||
# Extract parameters from request attributes
|
||||
if hasattr(request, 'Temperature') and request.Temperature > 0:
|
||||
# Temperature could be mapped to speed variation
|
||||
generation_params['speed'] = 1.0 + (request.Temperature - 0.5) * 0.5
|
||||
|
||||
# Override with options if available
|
||||
if hasattr(self, 'options'):
|
||||
# Speed from options
|
||||
if 'speed' in self.options:
|
||||
generation_params['speed'] = self.options['speed']
|
||||
|
||||
# Voice from options
|
||||
if 'voice' in self.options:
|
||||
generation_params['voice'] = self.options['voice']
|
||||
|
||||
# Language code from options
|
||||
if 'lang_code' in self.options:
|
||||
generation_params['lang_code'] = self.options['lang_code']
|
||||
|
||||
# Model-specific parameters
|
||||
param_option_mapping = {
|
||||
'temp': 'speed',
|
||||
'temperature': 'speed',
|
||||
'top_p': 'speed', # Map top_p to speed variation
|
||||
}
|
||||
|
||||
for option_key, param_key in param_option_mapping.items():
|
||||
if option_key in self.options:
|
||||
if param_key == 'speed':
|
||||
# Ensure speed is within reasonable bounds
|
||||
speed_val = float(self.options[option_key])
|
||||
if 0.5 <= speed_val <= 2.0:
|
||||
generation_params[param_key] = speed_val
|
||||
|
||||
return generation_params
|
||||
|
||||
async def serve(address):
|
||||
# Start asyncio gRPC server
|
||||
server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
|
||||
options=[
|
||||
('grpc.max_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB
|
||||
('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB
|
||||
])
|
||||
# Add the servicer to the server
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
# Bind the server to the address
|
||||
server.add_insecure_port(address)
|
||||
|
||||
# Gracefully shutdown the server on SIGTERM or SIGINT
|
||||
loop = asyncio.get_event_loop()
|
||||
for sig in (signal.SIGINT, signal.SIGTERM):
|
||||
loop.add_signal_handler(
|
||||
sig, lambda: asyncio.ensure_future(server.stop(5))
|
||||
)
|
||||
|
||||
# Start the server
|
||||
await server.start()
|
||||
print("MLX-Audio TTS Server started. Listening on: " + address, file=sys.stderr)
|
||||
# Wait for the server to be terminated
|
||||
await server.wait_for_termination()
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the MLX-Audio TTS gRPC server.")
|
||||
parser.add_argument(
|
||||
"--addr", default="localhost:50051", help="The address to bind the server to."
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(serve(args.addr))
|
||||
14
backend/python/mlx-audio/install.sh
Executable file
14
backend/python/mlx-audio/install.sh
Executable file
@@ -0,0 +1,14 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
USE_PIP=true
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
installRequirements
|
||||
1
backend/python/mlx-audio/requirements-mps.txt
Normal file
1
backend/python/mlx-audio/requirements-mps.txt
Normal file
@@ -0,0 +1 @@
|
||||
git+https://github.com/Blaizzy/mlx-audio
|
||||
7
backend/python/mlx-audio/requirements.txt
Normal file
7
backend/python/mlx-audio/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
grpcio==1.71.0
|
||||
protobuf
|
||||
certifi
|
||||
setuptools
|
||||
mlx-audio
|
||||
soundfile
|
||||
numpy
|
||||
11
backend/python/mlx-audio/run.sh
Executable file
11
backend/python/mlx-audio/run.sh
Executable file
@@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
startBackend $@
|
||||
142
backend/python/mlx-audio/test.py
Normal file
142
backend/python/mlx-audio/test.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import backend_pb2
|
||||
import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
|
||||
import unittest
|
||||
import subprocess
|
||||
import time
|
||||
import grpc
|
||||
import backend_pb2_grpc
|
||||
import backend_pb2
|
||||
|
||||
class TestBackendServicer(unittest.TestCase):
|
||||
"""
|
||||
TestBackendServicer is the class that tests the gRPC service.
|
||||
|
||||
This class contains methods to test the startup and shutdown of the gRPC service.
|
||||
"""
|
||||
def setUp(self):
|
||||
self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"])
|
||||
time.sleep(10)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.service.terminate()
|
||||
self.service.wait()
|
||||
|
||||
def test_server_startup(self):
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.Health(backend_pb2.HealthMessage())
|
||||
self.assertEqual(response.message, b'OK')
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Server failed to start")
|
||||
finally:
|
||||
self.tearDown()
|
||||
def test_load_model(self):
|
||||
"""
|
||||
This method tests if the TTS model is loaded successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(response.message, "MLX-Audio TTS model loaded successfully")
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("LoadModel service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_tts_generation(self):
|
||||
"""
|
||||
This method tests if TTS audio is generated successfully
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
# Test TTS generation
|
||||
tts_req = backend_pb2.TTSRequest(
|
||||
text="Hello, this is a test of the MLX-Audio TTS system.",
|
||||
model="mlx-community/Kokoro-82M-4bit",
|
||||
voice="af_heart",
|
||||
language="a"
|
||||
)
|
||||
tts_resp = stub.TTS(tts_req)
|
||||
self.assertTrue(tts_resp.success)
|
||||
self.assertIn("TTS audio generated successfully", tts_resp.message)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("TTS service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
def test_tts_with_options(self):
|
||||
"""
|
||||
This method tests if TTS works with various options and parameters
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(
|
||||
Model="mlx-community/Kokoro-82M-4bit",
|
||||
Options=["voice:af_soft", "speed:1.2", "lang_code:b"]
|
||||
))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
# Test TTS generation with different voice and language
|
||||
tts_req = backend_pb2.TTSRequest(
|
||||
text="Hello, this is a test with British English accent.",
|
||||
model="mlx-community/Kokoro-82M-4bit",
|
||||
voice="af_soft",
|
||||
language="b"
|
||||
)
|
||||
tts_resp = stub.TTS(tts_req)
|
||||
self.assertTrue(tts_resp.success)
|
||||
self.assertIn("TTS audio generated successfully", tts_resp.message)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("TTS with options service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
|
||||
|
||||
def test_tts_multilingual(self):
|
||||
"""
|
||||
This method tests if TTS works with different languages
|
||||
"""
|
||||
try:
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit"))
|
||||
self.assertTrue(response.success)
|
||||
|
||||
# Test Spanish TTS
|
||||
tts_req = backend_pb2.TTSRequest(
|
||||
text="Hola, esto es una prueba del sistema TTS MLX-Audio.",
|
||||
model="mlx-community/Kokoro-82M-4bit",
|
||||
voice="af_heart",
|
||||
language="e"
|
||||
)
|
||||
tts_resp = stub.TTS(tts_req)
|
||||
self.assertTrue(tts_resp.success)
|
||||
self.assertIn("TTS audio generated successfully", tts_resp.message)
|
||||
except Exception as err:
|
||||
print(err)
|
||||
self.fail("Multilingual TTS service failed")
|
||||
finally:
|
||||
self.tearDown()
|
||||
12
backend/python/mlx-audio/test.sh
Executable file
12
backend/python/mlx-audio/test.sh
Executable file
@@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
backend_dir=$(dirname $0)
|
||||
|
||||
if [ -d $backend_dir/common ]; then
|
||||
source $backend_dir/common/libbackend.sh
|
||||
else
|
||||
source $backend_dir/../common/libbackend.sh
|
||||
fi
|
||||
|
||||
runUnittests
|
||||
Reference in New Issue
Block a user