Skip to content

Commit f7d5792

Browse files
committed
disable benchmark, change to whisper_init_from_file_with_params
1 parent 8880e83 commit f7d5792

File tree

4 files changed

+42
-115
lines changed

4 files changed

+42
-115
lines changed

R/RcppExports.R

-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,3 @@ whisper_encode <- function(model, path, language, token_timestamps = FALSE, tran
99
.Call('_audio_whisper_whisper_encode', PACKAGE = 'audio.whisper', model, path, language, token_timestamps, translate, print_special, duration, offset, trace, n_threads, n_processors, entropy_thold, logprob_thold, beam_size, best_of, split_on_word, max_context)
1010
}
1111

12-
whisper_print_benchmark <- function(model, n_threads = 1L) {
13-
invisible(.Call('_audio_whisper_whisper_print_benchmark', PACKAGE = 'audio.whisper', model, n_threads))
14-
}
15-

R/whisper.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,6 @@ whisper_download_model <- function(x = c("tiny", "tiny.en", "base", "base.en", "
221221
whisper_benchmark <- function(object = whisper(system.file(package = "audio.whisper", "models", "for-tests-ggml-tiny.bin")),
222222
threads = 1){
223223
stopifnot(inherits(object, "whisper"))
224-
whisper_print_benchmark(object$model, threads)
224+
#whisper_print_benchmark(object$model, threads)
225225
invisible()
226226
}

src/RcppExports.cpp

-12
Original file line numberDiff line numberDiff line change
@@ -43,22 +43,10 @@ BEGIN_RCPP
4343
return rcpp_result_gen;
4444
END_RCPP
4545
}
46-
// whisper_print_benchmark
47-
void whisper_print_benchmark(SEXP model, int n_threads);
48-
RcppExport SEXP _audio_whisper_whisper_print_benchmark(SEXP modelSEXP, SEXP n_threadsSEXP) {
49-
BEGIN_RCPP
50-
Rcpp::RNGScope rcpp_rngScope_gen;
51-
Rcpp::traits::input_parameter< SEXP >::type model(modelSEXP);
52-
Rcpp::traits::input_parameter< int >::type n_threads(n_threadsSEXP);
53-
whisper_print_benchmark(model, n_threads);
54-
return R_NilValue;
55-
END_RCPP
56-
}
5746

5847
static const R_CallMethodDef CallEntries[] = {
5948
{"_audio_whisper_whisper_load_model", (DL_FUNC) &_audio_whisper_whisper_load_model, 1},
6049
{"_audio_whisper_whisper_encode", (DL_FUNC) &_audio_whisper_whisper_encode, 17},
61-
{"_audio_whisper_whisper_print_benchmark", (DL_FUNC) &_audio_whisper_whisper_print_benchmark, 2},
6250
{NULL, NULL, 0}
6351
};
6452

src/rcpp_whisper.cpp

+41-98
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,9 @@ class WhisperModel {
232232
public:
233233
struct whisper_context * ctx;
234234
WhisperModel(std::string model){
235-
ctx = whisper_init_from_file(model.c_str());
235+
struct whisper_context_params cparams;
236+
cparams.use_gpu = false;
237+
ctx = whisper_init_from_file_with_params(model.c_str(), cparams);
236238
}
237239
~WhisperModel(){
238240
whisper_free(ctx);
@@ -262,7 +264,6 @@ Rcpp::List whisper_encode(SEXP model, std::string path, std::string language,
262264
int max_context = -1) {
263265
whisper_params params;
264266
params.language = language;
265-
//params.model = model;
266267
params.translate = translate;
267268
params.print_special = print_special;
268269
params.duration_ms = duration;
@@ -277,13 +278,10 @@ Rcpp::List whisper_encode(SEXP model, std::string path, std::string language,
277278
params.best_of = best_of;
278279
params.split_on_word = split_on_word;
279280
params.max_context = max_context;
280-
281-
//std::string language = "en";
282-
//std::string model = "models/ggml-base.en.bin";
283281
if (params.fname_inp.empty()) {
284282
Rcpp::stop("error: no input files specified");
285283
}
286-
284+
287285
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
288286
Rcpp::stop("Unknown language");
289287
}
@@ -293,83 +291,12 @@ Rcpp::List whisper_encode(SEXP model, std::string path, std::string language,
293291
struct whisper_context * ctx = whispermodel->ctx;
294292
//Rcpp::XPtr<whisper_context> ctx(model);
295293
//struct whisper_context * ctx = whisper_init(params.model.c_str());
296-
297-
// initial prompt
298-
std::vector<whisper_token> prompt_tokens;
299-
300-
if (!params.prompt.empty()) {
301-
prompt_tokens.resize(1024);
302-
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
303-
304-
Rprintf("\n");
305-
Rprintf("initial prompt: '%s'\n", params.prompt.c_str());
306-
Rprintf("initial tokens: [ ");
307-
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
308-
Rprintf("%d ", prompt_tokens[i]);
309-
}
310-
Rprintf("]\n");
311-
}
312-
294+
313295
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
314296
const auto fname_inp = params.fname_inp[f];
315297
std::vector<float> pcmf32; // mono-channel F32 PCM
316298
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
317-
// // WAV input
318-
// {
319-
// drwav wav;
320-
// std::vector<uint8_t> wav_data; // used for pipe input from stdin
321-
//
322-
// if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
323-
// Rcpp::stop("Failed to open the file as WAV file: ", fname_inp);
324-
// }
325-
//
326-
// if (wav.channels != 1 && wav.channels != 2) {
327-
// Rcpp::stop("WAV file must be mono or stereo: ", fname_inp);
328-
// }
329-
//
330-
// if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
331-
// Rcpp::stop("WAV file must be stereo for diarization and timestamps have to be enabled: ", fname_inp);
332-
// }
333-
//
334-
// if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
335-
// Rcpp::stop("WAV file must be 16 kHz: ", fname_inp);
336-
// }
337-
//
338-
// if (wav.bitsPerSample != 16) {
339-
// Rcpp::stop("WAV file must be 16 bit: ", fname_inp);
340-
// }
341-
//
342-
// const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
343-
//
344-
// std::vector<int16_t> pcm16;
345-
// pcm16.resize(n*wav.channels);
346-
// drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
347-
// drwav_uninit(&wav);
348-
//
349-
// // convert to mono, float
350-
// pcmf32.resize(n);
351-
// if (wav.channels == 1) {
352-
// for (uint64_t i = 0; i < n; i++) {
353-
// pcmf32[i] = float(pcm16[i])/32768.0f;
354-
// }
355-
// } else {
356-
// for (uint64_t i = 0; i < n; i++) {
357-
// pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
358-
// }
359-
// }
360-
//
361-
// if (params.diarize) {
362-
// // convert to stereo, float
363-
// pcmf32s.resize(2);
364-
//
365-
// pcmf32s[0].resize(n);
366-
// pcmf32s[1].resize(n);
367-
// for (uint64_t i = 0; i < n; i++) {
368-
// pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
369-
// pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
370-
// }
371-
// }
372-
// }
299+
373300
if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) {
374301
Rprintf("error: failed to read WAV file '%s'\n", fname_inp.c_str());
375302
Rcpp::stop("The input audio needs to be a 16-bit .wav file.");
@@ -391,63 +318,78 @@ Rcpp::List whisper_encode(SEXP model, std::string path, std::string language,
391318
Rcpp::warning("WARNING: model is not multilingual, ignoring language and translation options");
392319
}
393320
}
394-
Rcpp::Rcout << "Processing " << fname_inp << " (" << int(pcmf32.size()) << " samples, " << float(pcmf32.size())/WHISPER_SAMPLE_RATE << " sec)" << ", lang = " << params.language << ", translate = " << params.translate << ", timestamps = " << token_timestamps << "\n";
321+
Rcpp::Rcout << "Processing " << fname_inp << " (" << int(pcmf32.size()) << " samples, " << float(pcmf32.size())/WHISPER_SAMPLE_RATE << " sec)" << ", lang = " << params.language << ", translate = " << params.translate << ", timestamps = " << token_timestamps << ", beam_size = " << params.beam_size << ", best_of = " << params.best_of << "\n";
395322
}
396323

397324
// run the inference
398325
{
399326
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
400327

401328
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
402-
wparams.print_realtime = trace;
403-
wparams.print_progress = false;
329+
330+
wparams.print_realtime = false;
331+
wparams.print_progress = params.print_progress;
404332
wparams.print_timestamps = !params.no_timestamps;
405333
wparams.print_special = params.print_special;
406334
wparams.translate = params.translate;
407335
wparams.language = params.language.c_str();
336+
wparams.detect_language = params.detect_language;
408337
wparams.n_threads = params.n_threads;
409338
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
410339
wparams.offset_ms = params.offset_t_ms;
411340
wparams.duration_ms = params.duration_ms;
412-
413-
wparams.token_timestamps = params.output_wts || params.max_len > 0;
414-
wparams.token_timestamps = token_timestamps;
341+
342+
wparams.token_timestamps = params.output_wts || params.output_jsn_full || params.max_len > 0;
415343
wparams.thold_pt = params.word_thold;
416344
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
417345
wparams.split_on_word = params.split_on_word;
418-
346+
419347
wparams.speed_up = params.speed_up;
420-
421-
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
422-
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
423-
348+
wparams.debug_mode = params.debug_mode;
349+
350+
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
351+
352+
wparams.initial_prompt = params.prompt.c_str();
353+
424354
wparams.greedy.best_of = params.best_of;
425355
wparams.beam_search.beam_size = params.beam_size;
426-
356+
427357
wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
428358
wparams.entropy_thold = params.entropy_thold;
429359
wparams.logprob_thold = params.logprob_thold;
430360

431-
whisper_print_user_data user_data = { &params, &pcmf32s };
361+
whisper_print_user_data user_data = { &params, &pcmf32s, 0 };
432362

433363
// this callback is called on each new segment
434364
if (!wparams.print_realtime) {
435365
wparams.new_segment_callback = whisper_print_segment_callback;
436366
wparams.new_segment_callback_user_data = &user_data;
437367
}
438368

439-
// example for abort mechanism
440-
// in this example, we do not abort the processing, but we could if the flag is set to true
369+
// examples for abort mechanism
370+
// in examples below, we do not abort the processing, but we could if the flag is set to true
371+
441372
// the callback is called before every encoder run - if it returns false, the processing is aborted
442373
{
443374
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
444-
445-
wparams.encoder_begin_callback = [](struct whisper_context * ctx, void * user_data) {
375+
376+
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
446377
bool is_aborted = *(bool*)user_data;
447378
return !is_aborted;
448379
};
449380
wparams.encoder_begin_callback_user_data = &is_aborted;
450381
}
382+
383+
// the callback is called before every computation - if it returns true, the computation is aborted
384+
{
385+
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
386+
387+
wparams.abort_callback = [](void * user_data) {
388+
bool is_aborted = *(bool*)user_data;
389+
return is_aborted;
390+
};
391+
wparams.abort_callback_user_data = &is_aborted;
392+
}
451393

452394
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
453395
Rcpp::stop("failed to process audio");
@@ -549,7 +491,7 @@ Rcpp::List whisper_encode(SEXP model, std::string path, std::string language,
549491

550492

551493

552-
494+
/*
553495
// [[Rcpp::export]]
554496
void whisper_print_benchmark(SEXP model, int n_threads = 1) {
555497
whisper_params params;
@@ -566,4 +508,5 @@ void whisper_print_benchmark(SEXP model, int n_threads = 1) {
566508
Rprintf("error: failed to encode model: %d\n", ret);
567509
}
568510
whisper_print_timings(ctx);
569-
}
511+
}
512+
*/

0 commit comments

Comments
 (0)