mirror of
https://github.com/mudler/LocalAI.git
synced 2025-12-30 22:20:20 -06:00
fix(chatterbox): chunk long text (#6407)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
aa8965b634
commit
20f1e842b3
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
This is an extra gRPC server of LocalAI for Bark TTS
|
||||
This is an extra gRPC server of LocalAI for Chatterbox TTS
|
||||
"""
|
||||
from concurrent import futures
|
||||
import time
|
||||
@@ -16,6 +16,7 @@ import torchaudio as ta
|
||||
from chatterbox.tts import ChatterboxTTS
|
||||
from chatterbox.mtl_tts import ChatterboxMultilingualTTS
|
||||
import grpc
|
||||
import tempfile
|
||||
|
||||
def is_float(s):
|
||||
"""Check if a string can be converted to float."""
|
||||
@@ -32,11 +33,79 @@ def is_int(s):
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def split_text_at_word_boundary(text, max_length=250):
|
||||
"""
|
||||
Split text at word boundaries without truncating words.
|
||||
Returns a list of text chunks.
|
||||
"""
|
||||
if not text or len(text) <= max_length:
|
||||
return [text]
|
||||
|
||||
chunks = []
|
||||
words = text.split()
|
||||
current_chunk = ""
|
||||
|
||||
for word in words:
|
||||
# Check if adding this word would exceed the limit
|
||||
if len(current_chunk) + len(word) + 1 <= max_length:
|
||||
if current_chunk:
|
||||
current_chunk += " " + word
|
||||
else:
|
||||
current_chunk = word
|
||||
else:
|
||||
# If current chunk is not empty, add it to chunks
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = word
|
||||
else:
|
||||
# If a single word is longer than max_length, we have to include it anyway
|
||||
chunks.append(word)
|
||||
current_chunk = ""
|
||||
|
||||
# Add the last chunk if it's not empty
|
||||
if current_chunk:
|
||||
chunks.append(current_chunk)
|
||||
|
||||
return chunks
|
||||
|
||||
def merge_audio_files(audio_files, output_path, sample_rate):
|
||||
"""
|
||||
Merge multiple audio files into a single audio file.
|
||||
"""
|
||||
if not audio_files:
|
||||
return
|
||||
|
||||
if len(audio_files) == 1:
|
||||
# If only one file, just copy it
|
||||
import shutil
|
||||
shutil.copy2(audio_files[0], output_path)
|
||||
return
|
||||
|
||||
# Load all audio files
|
||||
waveforms = []
|
||||
for audio_file in audio_files:
|
||||
waveform, sr = ta.load(audio_file)
|
||||
if sr != sample_rate:
|
||||
# Resample if necessary
|
||||
resampler = ta.transforms.Resample(sr, sample_rate)
|
||||
waveform = resampler(waveform)
|
||||
waveforms.append(waveform)
|
||||
|
||||
# Concatenate all waveforms
|
||||
merged_waveform = torch.cat(waveforms, dim=1)
|
||||
|
||||
# Save the merged audio
|
||||
ta.save(output_path, merged_waveform, sample_rate)
|
||||
|
||||
# Clean up temporary files
|
||||
for audio_file in audio_files:
|
||||
if os.path.exists(audio_file):
|
||||
os.remove(audio_file)
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
COQUI_LANGUAGE = os.environ.get('COQUI_LANGUAGE', None)
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
@@ -118,10 +187,33 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# add options to kwargs
|
||||
kwargs.update(self.options)
|
||||
|
||||
# Generate audio using ChatterboxTTS
|
||||
wav = self.model.generate(request.text, **kwargs)
|
||||
# Save the generated audio
|
||||
ta.save(request.dst, wav, self.model.sr)
|
||||
# Check if text exceeds 250 characters
|
||||
# (chatterbox does not support long text)
|
||||
# https://github.com/resemble-ai/chatterbox/issues/60
|
||||
# https://github.com/resemble-ai/chatterbox/issues/110
|
||||
if len(request.text) > 250:
|
||||
# Split text at word boundaries
|
||||
text_chunks = split_text_at_word_boundary(request.text, max_length=250)
|
||||
print(f"Splitting text into chunks of 250 characters: {len(text_chunks)}", file=sys.stderr)
|
||||
# Generate audio for each chunk
|
||||
temp_audio_files = []
|
||||
for i, chunk in enumerate(text_chunks):
|
||||
# Generate audio for this chunk
|
||||
wav = self.model.generate(chunk, **kwargs)
|
||||
|
||||
# Create temporary file for this chunk
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
|
||||
temp_file.close()
|
||||
ta.save(temp_file.name, wav, self.model.sr)
|
||||
temp_audio_files.append(temp_file.name)
|
||||
|
||||
# Merge all audio files
|
||||
merge_audio_files(temp_audio_files, request.dst, self.model.sr)
|
||||
else:
|
||||
# Generate audio using ChatterboxTTS for short text
|
||||
wav = self.model.generate(request.text, **kwargs)
|
||||
# Save the generated audio
|
||||
ta.save(request.dst, wav, self.model.sr)
|
||||
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
|
||||
Reference in New Issue
Block a user