feat(whisper): Add prompt to condition transcription output (#7624)

* chore(makefile): Add buildargs for sd and cuda when building backend

Signed-off-by: Richard Palethorpe <io@richiejp.com>

* feat(whisper): Add prompt to condition transcription output

Signed-off-by: Richard Palethorpe <io@richiejp.com>

---------

Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
Richard Palethorpe
2025-12-18 13:40:45 +00:00
committed by GitHub
parent 247983265d
commit 716dba94b4
10 changed files with 23 additions and 10 deletions

View File

@@ -514,7 +514,7 @@ docker-save-diffusers: backend-images
docker save local-ai-backend:diffusers -o backend-images/diffusers.tar
docker-build-whisper:
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:whisper -f backend/Dockerfile.golang --build-arg BACKEND=whisper .
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) --build-arg CUDA_MAJOR_VERSION=$(CUDA_MAJOR_VERSION) --build-arg CUDA_MINOR_VERSION=$(CUDA_MINOR_VERSION) -t local-ai-backend:whisper -f backend/Dockerfile.golang --build-arg BACKEND=whisper .
docker-save-whisper: backend-images
docker save local-ai-backend:whisper -o backend-images/whisper.tar

View File

@@ -282,6 +282,7 @@ message TranscriptRequest {
uint32 threads = 4;
bool translate = 5;
bool diarize = 6;
string prompt = 7;
}
message TranscriptResult {

View File

@@ -3,5 +3,5 @@ sources/
build/
package/
whisper
libgowhisper.so
*.so
compile_commands.json

View File

@@ -107,7 +107,7 @@ int vad(float pcmf32[], size_t pcmf32_len, float **segs_out,
}
int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len) {
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len, char *prompt) {
whisper_full_params wparams =
whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
@@ -122,8 +122,10 @@ int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
wparams.debug_mode = true;
wparams.print_progress = true;
wparams.tdrz_enable = tdrz;
wparams.initial_prompt = prompt;
fprintf(stderr, "info: Enable tdrz: %d\n", tdrz);
fprintf(stderr, "info: Initial prompt: \"%s\"\n", prompt);
if (whisper_full(ctx, wparams, pcmf32, pcmf32_len)) {
fprintf(stderr, "error: transcription failed\n");

View File

@@ -17,7 +17,7 @@ var (
CppLoadModel func(modelPath string) int
CppLoadModelVAD func(modelPath string) int
CppVAD func(pcmf32 []float32, pcmf32Size uintptr, segsOut unsafe.Pointer, segsOutLen unsafe.Pointer) int
CppTranscribe func(threads uint32, lang string, translate bool, diarize bool, pcmf32 []float32, pcmf32Len uintptr, segsOutLen unsafe.Pointer) int
CppTranscribe func(threads uint32, lang string, translate bool, diarize bool, pcmf32 []float32, pcmf32Len uintptr, segsOutLen unsafe.Pointer, prompt string) int
CppGetSegmentText func(i int) string
CppGetSegmentStart func(i int) int64
CppGetSegmentEnd func(i int) int64
@@ -123,7 +123,7 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
segsLen := uintptr(0xdeadbeef)
segsLenPtr := unsafe.Pointer(&segsLen)
if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr); ret != 0 {
if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt); ret != 0 {
return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
}

View File

@@ -7,7 +7,8 @@ int load_model_vad(const char *const model_path);
int vad(float pcmf32[], size_t pcmf32_size, float **segs_out,
size_t *segs_out_len);
int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len);
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len,
char *prompt);
const char *get_segment_text(int i);
int64_t get_segment_t0(int i);
int64_t get_segment_t1(int i);

View File

@@ -12,7 +12,7 @@ import (
"github.com/mudler/LocalAI/pkg/model"
)
func ModelTranscription(audio, language string, translate bool, diarize bool, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
func ModelTranscription(audio, language string, translate bool, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
if modelConfig.Backend == "" {
modelConfig.Backend = model.WhisperBackend
@@ -35,6 +35,7 @@ func ModelTranscription(audio, language string, translate bool, diarize bool, ml
Translate: translate,
Diarize: diarize,
Threads: uint32(*modelConfig.Threads),
Prompt: prompt,
})
if err != nil {
return nil, err

View File

@@ -23,6 +23,7 @@ type TranscriptCMD struct {
Diarize bool `short:"d" help:"Mark speaker turns"`
Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"`
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
Prompt string `short:"p" help:"Previous transcribed text or words that hint at what the model should expect"`
}
func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
@@ -57,7 +58,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
}
}()
tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, t.Diarize, ml, c, opts)
tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, t.Diarize, t.Prompt, ml, c, opts)
if err != nil {
return err
}

View File

@@ -593,6 +593,11 @@ func updateTransSession(session *Session, update *types.ClientSession, cl *confi
session.ModelInterface = m
}
if trUpd != nil {
trCur.Language = trUpd.Language
trCur.Prompt = trUpd.Prompt
}
if update.TurnDetection != nil && update.TurnDetection.Type != "" {
session.TurnDetection.Type = types.ServerTurnDetectionType(update.TurnDetection.Type)
session.TurnDetection.TurnDetectionParams = update.TurnDetection.TurnDetectionParams
@@ -790,6 +795,7 @@ func commitUtterance(ctx context.Context, utt []byte, cfg *config.ModelConfig, e
Language: session.InputAudioTranscription.Language,
Translate: false,
Threads: uint32(*cfg.Threads),
Prompt: session.InputAudioTranscription.Prompt,
})
if err != nil {
sendError(c, "transcription_failed", err.Error(), "", "event_TODO")

View File

@@ -37,6 +37,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
}
diarize := c.FormValue("diarize") != "false"
prompt := c.FormValue("prompt")
// retrieve the file data from the request
file, err := c.FormFile("file")
@@ -69,7 +70,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
log.Debug().Msgf("Audio file copied to: %+v", dst)
tr, err := backend.ModelTranscription(dst, input.Language, input.Translate, diarize, ml, *config, appConfig)
tr, err := backend.ModelTranscription(dst, input.Language, input.Translate, diarize, prompt, ml, *config, appConfig)
if err != nil {
return err
}