mirror of
https://github.com/mudler/LocalAI.git
synced 2026-01-07 11:10:11 -06:00
feat(whisper): Add prompt to condition transcription output (#7624)
* chore(makefile): Add buildargs for sd and cuda when building backend Signed-off-by: Richard Palethorpe <io@richiejp.com> * feat(whisper): Add prompt to condition transcription output Signed-off-by: Richard Palethorpe <io@richiejp.com> --------- Signed-off-by: Richard Palethorpe <io@richiejp.com>
This commit is contained in:
committed by
GitHub
parent
247983265d
commit
716dba94b4
2
Makefile
2
Makefile
@@ -514,7 +514,7 @@ docker-save-diffusers: backend-images
|
||||
docker save local-ai-backend:diffusers -o backend-images/diffusers.tar
|
||||
|
||||
docker-build-whisper:
|
||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) -t local-ai-backend:whisper -f backend/Dockerfile.golang --build-arg BACKEND=whisper .
|
||||
docker build --build-arg BUILD_TYPE=$(BUILD_TYPE) --build-arg BASE_IMAGE=$(BASE_IMAGE) --build-arg CUDA_MAJOR_VERSION=$(CUDA_MAJOR_VERSION) --build-arg CUDA_MINOR_VERSION=$(CUDA_MINOR_VERSION) -t local-ai-backend:whisper -f backend/Dockerfile.golang --build-arg BACKEND=whisper .
|
||||
|
||||
docker-save-whisper: backend-images
|
||||
docker save local-ai-backend:whisper -o backend-images/whisper.tar
|
||||
|
||||
@@ -282,6 +282,7 @@ message TranscriptRequest {
|
||||
uint32 threads = 4;
|
||||
bool translate = 5;
|
||||
bool diarize = 6;
|
||||
string prompt = 7;
|
||||
}
|
||||
|
||||
message TranscriptResult {
|
||||
|
||||
4
backend/go/whisper/.gitignore
vendored
4
backend/go/whisper/.gitignore
vendored
@@ -3,5 +3,5 @@ sources/
|
||||
build/
|
||||
package/
|
||||
whisper
|
||||
libgowhisper.so
|
||||
|
||||
*.so
|
||||
compile_commands.json
|
||||
|
||||
@@ -107,7 +107,7 @@ int vad(float pcmf32[], size_t pcmf32_len, float **segs_out,
|
||||
}
|
||||
|
||||
int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
|
||||
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len) {
|
||||
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len, char *prompt) {
|
||||
whisper_full_params wparams =
|
||||
whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
@@ -122,8 +122,10 @@ int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
|
||||
wparams.debug_mode = true;
|
||||
wparams.print_progress = true;
|
||||
wparams.tdrz_enable = tdrz;
|
||||
wparams.initial_prompt = prompt;
|
||||
|
||||
fprintf(stderr, "info: Enable tdrz: %d\n", tdrz);
|
||||
fprintf(stderr, "info: Initial prompt: \"%s\"\n", prompt);
|
||||
|
||||
if (whisper_full(ctx, wparams, pcmf32, pcmf32_len)) {
|
||||
fprintf(stderr, "error: transcription failed\n");
|
||||
|
||||
@@ -17,7 +17,7 @@ var (
|
||||
CppLoadModel func(modelPath string) int
|
||||
CppLoadModelVAD func(modelPath string) int
|
||||
CppVAD func(pcmf32 []float32, pcmf32Size uintptr, segsOut unsafe.Pointer, segsOutLen unsafe.Pointer) int
|
||||
CppTranscribe func(threads uint32, lang string, translate bool, diarize bool, pcmf32 []float32, pcmf32Len uintptr, segsOutLen unsafe.Pointer) int
|
||||
CppTranscribe func(threads uint32, lang string, translate bool, diarize bool, pcmf32 []float32, pcmf32Len uintptr, segsOutLen unsafe.Pointer, prompt string) int
|
||||
CppGetSegmentText func(i int) string
|
||||
CppGetSegmentStart func(i int) int64
|
||||
CppGetSegmentEnd func(i int) int64
|
||||
@@ -123,7 +123,7 @@ func (w *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (pb.TranscriptR
|
||||
segsLen := uintptr(0xdeadbeef)
|
||||
segsLenPtr := unsafe.Pointer(&segsLen)
|
||||
|
||||
if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr); ret != 0 {
|
||||
if ret := CppTranscribe(opts.Threads, opts.Language, opts.Translate, opts.Diarize, data, uintptr(len(data)), segsLenPtr, opts.Prompt); ret != 0 {
|
||||
return pb.TranscriptResult{}, fmt.Errorf("Failed Transcribe")
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,8 @@ int load_model_vad(const char *const model_path);
|
||||
int vad(float pcmf32[], size_t pcmf32_size, float **segs_out,
|
||||
size_t *segs_out_len);
|
||||
int transcribe(uint32_t threads, char *lang, bool translate, bool tdrz,
|
||||
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len);
|
||||
float pcmf32[], size_t pcmf32_len, size_t *segs_out_len,
|
||||
char *prompt);
|
||||
const char *get_segment_text(int i);
|
||||
int64_t get_segment_t0(int i);
|
||||
int64_t get_segment_t1(int i);
|
||||
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
"github.com/mudler/LocalAI/pkg/model"
|
||||
)
|
||||
|
||||
func ModelTranscription(audio, language string, translate bool, diarize bool, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||
func ModelTranscription(audio, language string, translate bool, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||
|
||||
if modelConfig.Backend == "" {
|
||||
modelConfig.Backend = model.WhisperBackend
|
||||
@@ -35,6 +35,7 @@ func ModelTranscription(audio, language string, translate bool, diarize bool, ml
|
||||
Translate: translate,
|
||||
Diarize: diarize,
|
||||
Threads: uint32(*modelConfig.Threads),
|
||||
Prompt: prompt,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -23,6 +23,7 @@ type TranscriptCMD struct {
|
||||
Diarize bool `short:"d" help:"Mark speaker turns"`
|
||||
Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"`
|
||||
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
|
||||
Prompt string `short:"p" help:"Previous transcribed text or words that hint at what the model should expect"`
|
||||
}
|
||||
|
||||
func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
|
||||
@@ -57,7 +58,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
|
||||
}
|
||||
}()
|
||||
|
||||
tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, t.Diarize, ml, c, opts)
|
||||
tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, t.Diarize, t.Prompt, ml, c, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -593,6 +593,11 @@ func updateTransSession(session *Session, update *types.ClientSession, cl *confi
|
||||
session.ModelInterface = m
|
||||
}
|
||||
|
||||
if trUpd != nil {
|
||||
trCur.Language = trUpd.Language
|
||||
trCur.Prompt = trUpd.Prompt
|
||||
}
|
||||
|
||||
if update.TurnDetection != nil && update.TurnDetection.Type != "" {
|
||||
session.TurnDetection.Type = types.ServerTurnDetectionType(update.TurnDetection.Type)
|
||||
session.TurnDetection.TurnDetectionParams = update.TurnDetection.TurnDetectionParams
|
||||
@@ -790,6 +795,7 @@ func commitUtterance(ctx context.Context, utt []byte, cfg *config.ModelConfig, e
|
||||
Language: session.InputAudioTranscription.Language,
|
||||
Translate: false,
|
||||
Threads: uint32(*cfg.Threads),
|
||||
Prompt: session.InputAudioTranscription.Prompt,
|
||||
})
|
||||
if err != nil {
|
||||
sendError(c, "transcription_failed", err.Error(), "", "event_TODO")
|
||||
|
||||
@@ -37,6 +37,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
}
|
||||
|
||||
diarize := c.FormValue("diarize") != "false"
|
||||
prompt := c.FormValue("prompt")
|
||||
|
||||
// retrieve the file data from the request
|
||||
file, err := c.FormFile("file")
|
||||
@@ -69,7 +70,7 @@ func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, app
|
||||
|
||||
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
||||
|
||||
tr, err := backend.ModelTranscription(dst, input.Language, input.Translate, diarize, ml, *config, appConfig)
|
||||
tr, err := backend.ModelTranscription(dst, input.Language, input.Translate, diarize, prompt, ml, *config, appConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user