From 4733adb983b43e24a5848a0dd360bdeab606e1a6 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Fri, 8 Aug 2025 12:40:16 +0200 Subject: [PATCH] chore: add Dia to the model gallery, fix backend (#5998) * fix: correctly call OuteTTS and DiaTTS Signed-off-by: Ettore Di Giacinto * chore(model gallery): add dia Signed-off-by: Ettore Di Giacinto --------- Signed-off-by: Ettore Di Giacinto --- backend/python/transformers/backend.py | 14 ++++++++++---- gallery/index.yaml | 21 +++++++++++++++++++++ 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/backend/python/transformers/backend.py b/backend/python/transformers/backend.py index 3d34132fd..ef8a2fd40 100644 --- a/backend/python/transformers/backend.py +++ b/backend/python/transformers/backend.py @@ -229,8 +229,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): self.model = MusicgenForConditionalGeneration.from_pretrained(model_name) elif request.Type == "DiaForConditionalGeneration": autoTokenizer = False + print("DiaForConditionalGeneration", file=sys.stderr) self.processor = AutoProcessor.from_pretrained(model_name) self.model = DiaForConditionalGeneration.from_pretrained(model_name) + if self.CUDA: + self.model = self.model.to("cuda") + self.processor = self.processor.to("cuda") + print("DiaForConditionalGeneration loaded", file=sys.stderr) self.DiaTTS = True elif request.Type == "OuteTTS": autoTokenizer = False @@ -536,7 +541,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): return backend_pb2.Result(success=True) - def DiaTTS(self, request, context): + def CallDiaTTS(self, request, context): """ Generates dialogue audio using the Dia model. @@ -581,7 +586,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): return backend_pb2.Result(success=True) - def OuteTTS(self, request, context): + def CallOuteTTS(self, request, context): try: print("[OuteTTS] generating TTS", file=sys.stderr) gen_cfg = outetts.GenerationConfig( @@ -603,10 +608,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): # The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons def TTS(self, request, context): if self.OuteTTS: - return self.OuteTTS(request, context) + return self.CallOuteTTS(request, context) if self.DiaTTS: - return self.DiaTTS(request, context) + print("DiaTTS", file=sys.stderr) + return self.CallDiaTTS(request, context) model_name = request.model try: diff --git a/gallery/index.yaml b/gallery/index.yaml index 6ebc096cc..420c20edf 100644 --- a/gallery/index.yaml +++ b/gallery/index.yaml @@ -191,6 +191,27 @@ - filename: OpenAI-20B-NEOPlus-Uncensored-IQ4_NL.gguf sha256: 274ffaaf0783270c071006842ffe60af73600fc63c2b6153c0701b596fc3b122 uri: huggingface://DavidAU/OpenAi-GPT-oss-20b-abliterated-uncensored-NEO-Imatrix-gguf/OpenAI-20B-NEOPlus-Uncensored-IQ4_NL.gguf +- name: "dia" + url: "github:mudler/LocalAI/gallery/virtual.yaml@master" + icon: https://github.com/nari-labs/dia/raw/main/dia/static/images/banner.png + urls: + - https://github.com/nari-labs/dia + - https://huggingface.co/nari-labs/Dia-1.6B-0626 + license: apache-2.0 + tags: + - tts + - dia + - gpu + - text-to-speech + overrides: + backend: "transformers" + name: "dia" + description: "Dia is a 1.6B parameter text to speech model created by Nari Labs." + parameters: + model: nari-labs/Dia-1.6B-0626 + type: DiaForConditionalGeneration + known_usecases: + - tts - &afm name: "arcee-ai_afm-4.5b" url: "github:mudler/LocalAI/gallery/chatml.yaml@master"