chore: add Dia to the model gallery, fix backend (#5998)

* fix: correctly call OuteTTS and DiaTTS

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* chore(model gallery): add dia

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto
2025-08-08 12:40:16 +02:00
committed by GitHub
parent 326fda3223
commit 4733adb983
2 changed files with 31 additions and 4 deletions

View File

@@ -229,8 +229,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
elif request.Type == "DiaForConditionalGeneration":
autoTokenizer = False
print("DiaForConditionalGeneration", file=sys.stderr)
self.processor = AutoProcessor.from_pretrained(model_name)
self.model = DiaForConditionalGeneration.from_pretrained(model_name)
if self.CUDA:
self.model = self.model.to("cuda")
self.processor = self.processor.to("cuda")
print("DiaForConditionalGeneration loaded", file=sys.stderr)
self.DiaTTS = True
elif request.Type == "OuteTTS":
autoTokenizer = False
@@ -536,7 +541,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.Result(success=True)
def DiaTTS(self, request, context):
def CallDiaTTS(self, request, context):
"""
Generates dialogue audio using the Dia model.
@@ -581,7 +586,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
return backend_pb2.Result(success=True)
def OuteTTS(self, request, context):
def CallOuteTTS(self, request, context):
try:
print("[OuteTTS] generating TTS", file=sys.stderr)
gen_cfg = outetts.GenerationConfig(
@@ -603,10 +608,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
def TTS(self, request, context):
if self.OuteTTS:
return self.OuteTTS(request, context)
return self.CallOuteTTS(request, context)
if self.DiaTTS:
return self.DiaTTS(request, context)
print("DiaTTS", file=sys.stderr)
return self.CallDiaTTS(request, context)
model_name = request.model
try:

View File

@@ -191,6 +191,27 @@
- filename: OpenAI-20B-NEOPlus-Uncensored-IQ4_NL.gguf
sha256: 274ffaaf0783270c071006842ffe60af73600fc63c2b6153c0701b596fc3b122
uri: huggingface://DavidAU/OpenAi-GPT-oss-20b-abliterated-uncensored-NEO-Imatrix-gguf/OpenAI-20B-NEOPlus-Uncensored-IQ4_NL.gguf
- name: "dia"
url: "github:mudler/LocalAI/gallery/virtual.yaml@master"
icon: https://github.com/nari-labs/dia/raw/main/dia/static/images/banner.png
urls:
- https://github.com/nari-labs/dia
- https://huggingface.co/nari-labs/Dia-1.6B-0626
license: apache-2.0
tags:
- tts
- dia
- gpu
- text-to-speech
overrides:
backend: "transformers"
name: "dia"
description: "Dia is a 1.6B parameter text to speech model created by Nari Labs."
parameters:
model: nari-labs/Dia-1.6B-0626
type: DiaForConditionalGeneration
known_usecases:
- tts
- &afm
name: "arcee-ai_afm-4.5b"
url: "github:mudler/LocalAI/gallery/chatml.yaml@master"