@@ -232,7 +232,9 @@ class WhisperModel {
232
232
public:
233
233
struct whisper_context * ctx;
234
234
WhisperModel (std::string model){
235
- ctx = whisper_init_from_file (model.c_str ());
235
+ struct whisper_context_params cparams;
236
+ cparams.use_gpu = false ;
237
+ ctx = whisper_init_from_file_with_params (model.c_str (), cparams);
236
238
}
237
239
~WhisperModel (){
238
240
whisper_free (ctx);
@@ -262,7 +264,6 @@ Rcpp::List whisper_encode(SEXP model, std::string path, std::string language,
262
264
int max_context = -1 ) {
263
265
whisper_params params;
264
266
params.language = language;
265
- // params.model = model;
266
267
params.translate = translate;
267
268
params.print_special = print_special;
268
269
params.duration_ms = duration;
@@ -277,13 +278,10 @@ Rcpp::List whisper_encode(SEXP model, std::string path, std::string language,
277
278
params.best_of = best_of;
278
279
params.split_on_word = split_on_word;
279
280
params.max_context = max_context;
280
-
281
- // std::string language = "en";
282
- // std::string model = "models/ggml-base.en.bin";
283
281
if (params.fname_inp .empty ()) {
284
282
Rcpp::stop (" error: no input files specified" );
285
283
}
286
-
284
+
287
285
if (params.language != " auto" && whisper_lang_id (params.language .c_str ()) == -1 ) {
288
286
Rcpp::stop (" Unknown language" );
289
287
}
@@ -293,83 +291,12 @@ Rcpp::List whisper_encode(SEXP model, std::string path, std::string language,
293
291
struct whisper_context * ctx = whispermodel->ctx ;
294
292
// Rcpp::XPtr<whisper_context> ctx(model);
295
293
// struct whisper_context * ctx = whisper_init(params.model.c_str());
296
-
297
- // initial prompt
298
- std::vector<whisper_token> prompt_tokens;
299
-
300
- if (!params.prompt .empty ()) {
301
- prompt_tokens.resize (1024 );
302
- prompt_tokens.resize (whisper_tokenize (ctx, params.prompt .c_str (), prompt_tokens.data (), prompt_tokens.size ()));
303
-
304
- Rprintf (" \n " );
305
- Rprintf (" initial prompt: '%s'\n " , params.prompt .c_str ());
306
- Rprintf (" initial tokens: [ " );
307
- for (int i = 0 ; i < (int ) prompt_tokens.size (); ++i) {
308
- Rprintf (" %d " , prompt_tokens[i]);
309
- }
310
- Rprintf (" ]\n " );
311
- }
312
-
294
+
313
295
for (int f = 0 ; f < (int ) params.fname_inp .size (); ++f) {
314
296
const auto fname_inp = params.fname_inp [f];
315
297
std::vector<float > pcmf32; // mono-channel F32 PCM
316
298
std::vector<std::vector<float >> pcmf32s; // stereo-channel F32 PCM
317
- // // WAV input
318
- // {
319
- // drwav wav;
320
- // std::vector<uint8_t> wav_data; // used for pipe input from stdin
321
- //
322
- // if (drwav_init_file(&wav, fname_inp.c_str(), NULL) == false) {
323
- // Rcpp::stop("Failed to open the file as WAV file: ", fname_inp);
324
- // }
325
- //
326
- // if (wav.channels != 1 && wav.channels != 2) {
327
- // Rcpp::stop("WAV file must be mono or stereo: ", fname_inp);
328
- // }
329
- //
330
- // if (params.diarize && wav.channels != 2 && params.no_timestamps == false) {
331
- // Rcpp::stop("WAV file must be stereo for diarization and timestamps have to be enabled: ", fname_inp);
332
- // }
333
- //
334
- // if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
335
- // Rcpp::stop("WAV file must be 16 kHz: ", fname_inp);
336
- // }
337
- //
338
- // if (wav.bitsPerSample != 16) {
339
- // Rcpp::stop("WAV file must be 16 bit: ", fname_inp);
340
- // }
341
- //
342
- // const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
343
- //
344
- // std::vector<int16_t> pcm16;
345
- // pcm16.resize(n*wav.channels);
346
- // drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
347
- // drwav_uninit(&wav);
348
- //
349
- // // convert to mono, float
350
- // pcmf32.resize(n);
351
- // if (wav.channels == 1) {
352
- // for (uint64_t i = 0; i < n; i++) {
353
- // pcmf32[i] = float(pcm16[i])/32768.0f;
354
- // }
355
- // } else {
356
- // for (uint64_t i = 0; i < n; i++) {
357
- // pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
358
- // }
359
- // }
360
- //
361
- // if (params.diarize) {
362
- // // convert to stereo, float
363
- // pcmf32s.resize(2);
364
- //
365
- // pcmf32s[0].resize(n);
366
- // pcmf32s[1].resize(n);
367
- // for (uint64_t i = 0; i < n; i++) {
368
- // pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
369
- // pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
370
- // }
371
- // }
372
- // }
299
+
373
300
if (!::read_wav (fname_inp, pcmf32, pcmf32s, params.diarize )) {
374
301
Rprintf (" error: failed to read WAV file '%s'\n " , fname_inp.c_str ());
375
302
Rcpp::stop (" The input audio needs to be a 16-bit .wav file." );
@@ -391,63 +318,78 @@ Rcpp::List whisper_encode(SEXP model, std::string path, std::string language,
391
318
Rcpp::warning (" WARNING: model is not multilingual, ignoring language and translation options" );
392
319
}
393
320
}
394
- Rcpp::Rcout << " Processing " << fname_inp << " (" << int (pcmf32.size ()) << " samples, " << float (pcmf32.size ())/WHISPER_SAMPLE_RATE << " sec)" << " , lang = " << params.language << " , translate = " << params.translate << " , timestamps = " << token_timestamps << " \n " ;
321
+ Rcpp::Rcout << " Processing " << fname_inp << " (" << int (pcmf32.size ()) << " samples, " << float (pcmf32.size ())/WHISPER_SAMPLE_RATE << " sec)" << " , lang = " << params.language << " , translate = " << params.translate << " , timestamps = " << token_timestamps << " , beam_size = " << params. beam_size << " , best_of = " << params. best_of << " \n " ;
395
322
}
396
323
397
324
// run the inference
398
325
{
399
326
whisper_full_params wparams = whisper_full_default_params (WHISPER_SAMPLING_GREEDY);
400
327
401
328
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
402
- wparams.print_realtime = trace;
403
- wparams.print_progress = false ;
329
+
330
+ wparams.print_realtime = false ;
331
+ wparams.print_progress = params.print_progress ;
404
332
wparams.print_timestamps = !params.no_timestamps ;
405
333
wparams.print_special = params.print_special ;
406
334
wparams.translate = params.translate ;
407
335
wparams.language = params.language .c_str ();
336
+ wparams.detect_language = params.detect_language ;
408
337
wparams.n_threads = params.n_threads ;
409
338
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx ;
410
339
wparams.offset_ms = params.offset_t_ms ;
411
340
wparams.duration_ms = params.duration_ms ;
412
-
413
- wparams.token_timestamps = params.output_wts || params.max_len > 0 ;
414
- wparams.token_timestamps = token_timestamps;
341
+
342
+ wparams.token_timestamps = params.output_wts || params.output_jsn_full || params.max_len > 0 ;
415
343
wparams.thold_pt = params.word_thold ;
416
344
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len ;
417
345
wparams.split_on_word = params.split_on_word ;
418
-
346
+
419
347
wparams.speed_up = params.speed_up ;
420
-
421
- wparams.prompt_tokens = prompt_tokens.empty () ? nullptr : prompt_tokens.data ();
422
- wparams.prompt_n_tokens = prompt_tokens.empty () ? 0 : prompt_tokens.size ();
423
-
348
+ wparams.debug_mode = params.debug_mode ;
349
+
350
+ wparams.tdrz_enable = params.tinydiarize ; // [TDRZ]
351
+
352
+ wparams.initial_prompt = params.prompt .c_str ();
353
+
424
354
wparams.greedy .best_of = params.best_of ;
425
355
wparams.beam_search .beam_size = params.beam_size ;
426
-
356
+
427
357
wparams.temperature_inc = params.no_fallback ? 0 .0f : wparams.temperature_inc ;
428
358
wparams.entropy_thold = params.entropy_thold ;
429
359
wparams.logprob_thold = params.logprob_thold ;
430
360
431
- whisper_print_user_data user_data = { ¶ms, &pcmf32s };
361
+ whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 };
432
362
433
363
// this callback is called on each new segment
434
364
if (!wparams.print_realtime ) {
435
365
wparams.new_segment_callback = whisper_print_segment_callback;
436
366
wparams.new_segment_callback_user_data = &user_data;
437
367
}
438
368
439
- // example for abort mechanism
440
- // in this example, we do not abort the processing, but we could if the flag is set to true
369
+ // examples for abort mechanism
370
+ // in examples below, we do not abort the processing, but we could if the flag is set to true
371
+
441
372
// the callback is called before every encoder run - if it returns false, the processing is aborted
442
373
{
443
374
static bool is_aborted = false ; // NOTE: this should be atomic to avoid data race
444
-
445
- wparams.encoder_begin_callback = [](struct whisper_context * ctx, void * user_data) {
375
+
376
+ wparams.encoder_begin_callback = [](struct whisper_context * /* ctx*/ , struct whisper_state * /* state */ , void * user_data) {
446
377
bool is_aborted = *(bool *)user_data;
447
378
return !is_aborted;
448
379
};
449
380
wparams.encoder_begin_callback_user_data = &is_aborted;
450
381
}
382
+
383
+ // the callback is called before every computation - if it returns true, the computation is aborted
384
+ {
385
+ static bool is_aborted = false ; // NOTE: this should be atomic to avoid data race
386
+
387
+ wparams.abort_callback = [](void * user_data) {
388
+ bool is_aborted = *(bool *)user_data;
389
+ return is_aborted;
390
+ };
391
+ wparams.abort_callback_user_data = &is_aborted;
392
+ }
451
393
452
394
if (whisper_full_parallel (ctx, wparams, pcmf32.data (), pcmf32.size (), params.n_processors ) != 0 ) {
453
395
Rcpp::stop (" failed to process audio" );
@@ -549,7 +491,7 @@ Rcpp::List whisper_encode(SEXP model, std::string path, std::string language,
549
491
550
492
551
493
552
-
494
+ /*
553
495
// [[Rcpp::export]]
554
496
void whisper_print_benchmark(SEXP model, int n_threads = 1) {
555
497
whisper_params params;
@@ -566,4 +508,5 @@ void whisper_print_benchmark(SEXP model, int n_threads = 1) {
566
508
Rprintf("error: failed to encode model: %d\n", ret);
567
509
}
568
510
whisper_print_timings(ctx);
569
- }
511
+ }
512
+ */
0 commit comments