mirror of
https://github.com/mudler/LocalAI.git
synced 2026-01-07 02:59:54 -06:00
chore: add Dia to the model gallery, fix backend (#5998)
* fix: correctly call OuteTTS and DiaTTS Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * chore(model gallery): add dia Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
committed by
GitHub
parent
326fda3223
commit
4733adb983
@@ -229,8 +229,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
self.model = MusicgenForConditionalGeneration.from_pretrained(model_name)
|
||||
elif request.Type == "DiaForConditionalGeneration":
|
||||
autoTokenizer = False
|
||||
print("DiaForConditionalGeneration", file=sys.stderr)
|
||||
self.processor = AutoProcessor.from_pretrained(model_name)
|
||||
self.model = DiaForConditionalGeneration.from_pretrained(model_name)
|
||||
if self.CUDA:
|
||||
self.model = self.model.to("cuda")
|
||||
self.processor = self.processor.to("cuda")
|
||||
print("DiaForConditionalGeneration loaded", file=sys.stderr)
|
||||
self.DiaTTS = True
|
||||
elif request.Type == "OuteTTS":
|
||||
autoTokenizer = False
|
||||
@@ -536,7 +541,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
|
||||
def DiaTTS(self, request, context):
|
||||
def CallDiaTTS(self, request, context):
|
||||
"""
|
||||
Generates dialogue audio using the Dia model.
|
||||
|
||||
@@ -581,7 +586,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
return backend_pb2.Result(success=True)
|
||||
|
||||
|
||||
def OuteTTS(self, request, context):
|
||||
def CallOuteTTS(self, request, context):
|
||||
try:
|
||||
print("[OuteTTS] generating TTS", file=sys.stderr)
|
||||
gen_cfg = outetts.GenerationConfig(
|
||||
@@ -603,10 +608,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons
|
||||
def TTS(self, request, context):
|
||||
if self.OuteTTS:
|
||||
return self.OuteTTS(request, context)
|
||||
return self.CallOuteTTS(request, context)
|
||||
|
||||
if self.DiaTTS:
|
||||
return self.DiaTTS(request, context)
|
||||
print("DiaTTS", file=sys.stderr)
|
||||
return self.CallDiaTTS(request, context)
|
||||
|
||||
model_name = request.model
|
||||
try:
|
||||
|
||||
@@ -191,6 +191,27 @@
|
||||
- filename: OpenAI-20B-NEOPlus-Uncensored-IQ4_NL.gguf
|
||||
sha256: 274ffaaf0783270c071006842ffe60af73600fc63c2b6153c0701b596fc3b122
|
||||
uri: huggingface://DavidAU/OpenAi-GPT-oss-20b-abliterated-uncensored-NEO-Imatrix-gguf/OpenAI-20B-NEOPlus-Uncensored-IQ4_NL.gguf
|
||||
- name: "dia"
|
||||
url: "github:mudler/LocalAI/gallery/virtual.yaml@master"
|
||||
icon: https://github.com/nari-labs/dia/raw/main/dia/static/images/banner.png
|
||||
urls:
|
||||
- https://github.com/nari-labs/dia
|
||||
- https://huggingface.co/nari-labs/Dia-1.6B-0626
|
||||
license: apache-2.0
|
||||
tags:
|
||||
- tts
|
||||
- dia
|
||||
- gpu
|
||||
- text-to-speech
|
||||
overrides:
|
||||
backend: "transformers"
|
||||
name: "dia"
|
||||
description: "Dia is a 1.6B parameter text to speech model created by Nari Labs."
|
||||
parameters:
|
||||
model: nari-labs/Dia-1.6B-0626
|
||||
type: DiaForConditionalGeneration
|
||||
known_usecases:
|
||||
- tts
|
||||
- &afm
|
||||
name: "arcee-ai_afm-4.5b"
|
||||
url: "github:mudler/LocalAI/gallery/chatml.yaml@master"
|
||||
|
||||
Reference in New Issue
Block a user