Skip to content

Commit 114c54d

Browse files
committed
#14 Add option to pass on float entropy_thold, logprob_thold, beam_size, best_of, split_on_word when doing the prediction
1 parent 4992a73 commit 114c54d

7 files changed

+45
-11
lines changed

DESCRIPTION

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Package: audio.whisper
22
Type: Package
33
Title: Transcribe Audio Files using the "Whisper" Automatic Speech Recognition Model
4-
Version: 0.2.1-1
4+
Version: 0.2.2
55
Maintainer: Jan Wijffels <jwijffels@bnosac.be>
66
Authors@R: c(
77
person('Jan', 'Wijffels', role = c('aut', 'cre', 'cph'), email = 'jwijffels@bnosac.be', comment = "R wrapper"),

NEWS.md

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## CHANGES IN audio.whisper VERSION 0.2.2
2+
3+
- Add option to pass on float entropy_thold, logprob_thold, beam_size, best_of, split_on_word when doing the prediction
4+
15
## CHANGES IN audio.whisper VERSION 0.2.1-1
26

37
- whisper_download_model now Deprecates downloading from https://ggml.ggerganov.com and changed the URL's to download models from huggingface (Issue #18)

R/RcppExports.R

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ whisper_load_model <- function(model) {
55
.Call('_audio_whisper_whisper_load_model', PACKAGE = 'audio.whisper', model)
66
}
77

8-
whisper_encode <- function(model, path, language, token_timestamps = FALSE, translate = FALSE, print_special = FALSE, duration = 0L, offset = 0L, trace = FALSE, n_threads = 1L, n_processors = 1L) {
9-
.Call('_audio_whisper_whisper_encode', PACKAGE = 'audio.whisper', model, path, language, token_timestamps, translate, print_special, duration, offset, trace, n_threads, n_processors)
8+
whisper_encode <- function(model, path, language, token_timestamps = FALSE, translate = FALSE, print_special = FALSE, duration = 0L, offset = 0L, trace = FALSE, n_threads = 1L, n_processors = 1L, entropy_thold = 2.40, logprob_thold = -1.00, beam_size = -1L, best_of = 5L, split_on_word = FALSE) {
9+
.Call('_audio_whisper_whisper_encode', PACKAGE = 'audio.whisper', model, path, language, token_timestamps, translate, print_special, duration, offset, trace, n_threads, n_processors, entropy_thold, logprob_thold, beam_size, best_of, split_on_word)
1010
}
1111

1212
whisper_print_benchmark <- function(model, n_threads = 1L) {

R/whisper.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#' \itemize{
1212
#' \item{n_segments: the number of audio segments}
1313
#' \item{data: a data.frame with the transcription with columns segment, text, from and to}
14-
#' \item{tokens: a data.frame with the transcription tokens with columns segment, token, token_prob indicating the token probability given the context}
14+
#' \item{tokens: a data.frame with the transcription tokens with columns segment, token_id, token, token_prob indicating the token probability given the context}
1515
#' \item{params: a list with parameters used for inference}
1616
#' }
1717
#' @export

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ This repository contains an R package which is an Rcpp wrapper around the [whisp
1919

2020
### Installation
2121

22-
For the *stable* version of this package: `remotes::install_github("bnosac/audio.whisper", ref = "0.2.1-1")` <br>
22+
For the *stable* version of this package: `remotes::install_github("bnosac/audio.whisper", ref = "0.2.2")` <br>
2323
Look to the documentation of the functions: `help(package = "audio.whisper")`
2424

2525
- For the *development* version of this package: `remotes::install_github("bnosac/audio.whisper")`

src/RcppExports.cpp

+9-4
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ BEGIN_RCPP
1717
END_RCPP
1818
}
1919
// whisper_encode
20-
Rcpp::List whisper_encode(SEXP model, std::string path, std::string language, bool token_timestamps, bool translate, bool print_special, int duration, int offset, bool trace, int n_threads, int n_processors);
21-
RcppExport SEXP _audio_whisper_whisper_encode(SEXP modelSEXP, SEXP pathSEXP, SEXP languageSEXP, SEXP token_timestampsSEXP, SEXP translateSEXP, SEXP print_specialSEXP, SEXP durationSEXP, SEXP offsetSEXP, SEXP traceSEXP, SEXP n_threadsSEXP, SEXP n_processorsSEXP) {
20+
Rcpp::List whisper_encode(SEXP model, std::string path, std::string language, bool token_timestamps, bool translate, bool print_special, int duration, int offset, bool trace, int n_threads, int n_processors, float entropy_thold, float logprob_thold, int beam_size, int best_of, bool split_on_word);
21+
RcppExport SEXP _audio_whisper_whisper_encode(SEXP modelSEXP, SEXP pathSEXP, SEXP languageSEXP, SEXP token_timestampsSEXP, SEXP translateSEXP, SEXP print_specialSEXP, SEXP durationSEXP, SEXP offsetSEXP, SEXP traceSEXP, SEXP n_threadsSEXP, SEXP n_processorsSEXP, SEXP entropy_tholdSEXP, SEXP logprob_tholdSEXP, SEXP beam_sizeSEXP, SEXP best_ofSEXP, SEXP split_on_wordSEXP) {
2222
BEGIN_RCPP
2323
Rcpp::RObject rcpp_result_gen;
2424
Rcpp::RNGScope rcpp_rngScope_gen;
@@ -33,7 +33,12 @@ BEGIN_RCPP
3333
Rcpp::traits::input_parameter< bool >::type trace(traceSEXP);
3434
Rcpp::traits::input_parameter< int >::type n_threads(n_threadsSEXP);
3535
Rcpp::traits::input_parameter< int >::type n_processors(n_processorsSEXP);
36-
rcpp_result_gen = Rcpp::wrap(whisper_encode(model, path, language, token_timestamps, translate, print_special, duration, offset, trace, n_threads, n_processors));
36+
Rcpp::traits::input_parameter< float >::type entropy_thold(entropy_tholdSEXP);
37+
Rcpp::traits::input_parameter< float >::type logprob_thold(logprob_tholdSEXP);
38+
Rcpp::traits::input_parameter< int >::type beam_size(beam_sizeSEXP);
39+
Rcpp::traits::input_parameter< int >::type best_of(best_ofSEXP);
40+
Rcpp::traits::input_parameter< bool >::type split_on_word(split_on_wordSEXP);
41+
rcpp_result_gen = Rcpp::wrap(whisper_encode(model, path, language, token_timestamps, translate, print_special, duration, offset, trace, n_threads, n_processors, entropy_thold, logprob_thold, beam_size, best_of, split_on_word));
3742
return rcpp_result_gen;
3843
END_RCPP
3944
}
@@ -51,7 +56,7 @@ END_RCPP
5156

5257
static const R_CallMethodDef CallEntries[] = {
5358
{"_audio_whisper_whisper_load_model", (DL_FUNC) &_audio_whisper_whisper_load_model, 1},
54-
{"_audio_whisper_whisper_encode", (DL_FUNC) &_audio_whisper_whisper_encode, 11},
59+
{"_audio_whisper_whisper_encode", (DL_FUNC) &_audio_whisper_whisper_encode, 16},
5560
{"_audio_whisper_whisper_print_benchmark", (DL_FUNC) &_audio_whisper_whisper_print_benchmark, 2},
5661
{NULL, NULL, 0}
5762
};

src/rcpp_whisper.cpp

+27-2
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,12 @@ SEXP whisper_load_model(std::string model) {
220220
// [[Rcpp::export]]
221221
Rcpp::List whisper_encode(SEXP model, std::string path, std::string language,
222222
bool token_timestamps = false, bool translate = false, bool print_special = false, int duration = 0, int offset = 0, bool trace = false,
223-
int n_threads = 1, int n_processors = 1) {
223+
int n_threads = 1, int n_processors = 1,
224+
float entropy_thold = 2.40,
225+
float logprob_thold = -1.00,
226+
int beam_size = -1,
227+
int best_of = 5,
228+
bool split_on_word = false) {
224229
whisper_params params;
225230
params.language = language;
226231
//params.model = model;
@@ -232,6 +237,13 @@ Rcpp::List whisper_encode(SEXP model, std::string path, std::string language,
232237
params.n_threads = n_threads;
233238
params.n_processors = n_processors;
234239

240+
params.entropy_thold = entropy_thold;
241+
params.logprob_thold = logprob_thold;
242+
params.beam_size = beam_size;
243+
params.best_of = best_of;
244+
params.split_on_word = split_on_word;
245+
//params.no_speech_thold = no_speech_thold;
246+
//params.temperature_inc = temperature_inc;
235247

236248
//std::string language = "en";
237249
//std::string model = "models/ggml-base.en.bin";
@@ -417,6 +429,7 @@ Rcpp::List whisper_encode(SEXP model, std::string path, std::string language,
417429
Rcpp::StringVector transcriptions_from(n_segments);
418430
Rcpp::StringVector transcriptions_to(n_segments);
419431
std::vector<int> token_segment_nr;
432+
std::vector<int> token_segment_id;
420433
std::vector<std::string> token_segment_text;
421434
std::vector<float> token_segment_probability;
422435
std::vector<std::string> token_segment_from;
@@ -439,7 +452,9 @@ Rcpp::List whisper_encode(SEXP model, std::string path, std::string language,
439452
}
440453
const char * text = whisper_full_get_token_text(ctx, i, j);
441454
const float p = whisper_full_get_token_p (ctx, i, j);
455+
const int tokenid = whisper_full_get_token_id (ctx, i, j);
442456
token_segment_nr.push_back(i + 1);
457+
token_segment_id.push_back(tokenid);
443458
std::string str(text);
444459
token_segment_text.push_back(str);
445460
token_segment_probability.push_back(p);
@@ -456,6 +471,7 @@ Rcpp::List whisper_encode(SEXP model, std::string path, std::string language,
456471
if(token_timestamps){
457472
tokens = Rcpp::DataFrame::create(
458473
Rcpp::Named("segment") = token_segment_nr,
474+
Rcpp::Named("token_id") = token_segment_id,
459475
Rcpp::Named("token") = token_segment_text,
460476
Rcpp::Named("token_prob") = token_segment_probability,
461477
Rcpp::Named("token_from") = token_segment_from,
@@ -464,6 +480,7 @@ Rcpp::List whisper_encode(SEXP model, std::string path, std::string language,
464480
}else{
465481
tokens = Rcpp::DataFrame::create(
466482
Rcpp::Named("segment") = token_segment_nr,
483+
Rcpp::Named("token_id") = token_segment_id,
467484
Rcpp::Named("token") = token_segment_text,
468485
Rcpp::Named("token_prob") = token_segment_probability,
469486
Rcpp::Named("stringsAsFactors") = false);
@@ -485,7 +502,15 @@ Rcpp::List whisper_encode(SEXP model, std::string path, std::string language,
485502
Rcpp::Named("duration") = duration,
486503
Rcpp::Named("translate") = params.translate,
487504
Rcpp::Named("token_timestamps") = token_timestamps,
488-
Rcpp::Named("word_threshold") = params.word_thold));
505+
Rcpp::Named("word_threshold") = params.word_thold,
506+
Rcpp::Named("entropy_thold") = params.entropy_thold,
507+
Rcpp::Named("logprob_thold") = params.logprob_thold,
508+
Rcpp::Named("beam_size") = params.beam_size,
509+
Rcpp::Named("best_of") = params.best_of,
510+
Rcpp::Named("split_on_word") = params.split_on_word));
511+
512+
513+
489514
return output;
490515
}
491516

0 commit comments

Comments
 (0)