ggerganov commited on
Commit
e4f586b
·
unverified ·
1 Parent(s): 3ad485f

whisper : token-level timestamp refactoring (#49, #120)

Browse files

This turned out pretty good overall. The algorithm has been moved from
main.cpp to whisper.cpp and can be reused for all subtitles types. This
means that now you can specify the maximum length of the generated
lines. Simply provide the "-ml" argument specifying the max length in
number of characters

Files changed (5) hide show
  1. README.md +2 -1
  2. examples/main/README.md +3 -2
  3. examples/main/main.cpp +70 -348
  4. whisper.cpp +423 -12
  5. whisper.h +18 -5
README.md CHANGED
@@ -101,13 +101,14 @@ options:
101
  -ot N, --offset-t N time offset in milliseconds (default: 0)
102
  -on N, --offset-n N segment index offset (default: 0)
103
  -mc N, --max-context N maximum number of text context tokens to store (default: max)
 
104
  -wt N, --word-thold N word timestamp probability threshold (default: 0.010000)
105
  -v, --verbose verbose output
106
  --translate translate from source language to english
107
  -otxt, --output-txt output result in a text file
108
  -ovtt, --output-vtt output result in a vtt file
109
  -osrt, --output-srt output result in a srt file
110
- -owts, --output-words output word-level timestamps to a text file
111
  -ps, --print_special print special tokens
112
  -pc, --print_colors print colors
113
  -nt, --no_timestamps do not print timestamps
 
101
  -ot N, --offset-t N time offset in milliseconds (default: 0)
102
  -on N, --offset-n N segment index offset (default: 0)
103
  -mc N, --max-context N maximum number of text context tokens to store (default: max)
104
+ -ml N, --max-len N maximum segment length in characters (default: 0)
105
  -wt N, --word-thold N word timestamp probability threshold (default: 0.010000)
106
  -v, --verbose verbose output
107
  --translate translate from source language to english
108
  -otxt, --output-txt output result in a text file
109
  -ovtt, --output-vtt output result in a vtt file
110
  -osrt, --output-srt output result in a srt file
111
+ -owts, --output-words output script for generating karaoke video
112
  -ps, --print_special print special tokens
113
  -pc, --print_colors print colors
114
  -nt, --no_timestamps do not print timestamps
examples/main/README.md CHANGED
@@ -8,7 +8,6 @@ It can be used as a reference for using the `whisper.cpp` library in other proje
8
 
9
  usage: ./bin/main [options] file0.wav file1.wav ...
10
 
11
- options:
12
  -h, --help show this help message and exit
13
  -s SEED, --seed SEED RNG seed (default: -1)
14
  -t N, --threads N number of threads to use during computation (default: 4)
@@ -16,18 +15,20 @@ options:
16
  -ot N, --offset-t N time offset in milliseconds (default: 0)
17
  -on N, --offset-n N segment index offset (default: 0)
18
  -mc N, --max-context N maximum number of text context tokens to store (default: max)
 
19
  -wt N, --word-thold N word timestamp probability threshold (default: 0.010000)
20
  -v, --verbose verbose output
21
  --translate translate from source language to english
22
  -otxt, --output-txt output result in a text file
23
  -ovtt, --output-vtt output result in a vtt file
24
  -osrt, --output-srt output result in a srt file
25
- -owts, --output-words output word-level timestamps to a text file
26
  -ps, --print_special print special tokens
27
  -pc, --print_colors print colors
28
  -nt, --no_timestamps do not print timestamps
29
  -l LANG, --language LANG spoken language (default: en)
30
  -m FNAME, --model FNAME model path (default: models/ggml-base.en.bin)
31
  -f FNAME, --file FNAME input WAV file path
 
32
 
33
  ```
 
8
 
9
  usage: ./bin/main [options] file0.wav file1.wav ...
10
 
 
11
  -h, --help show this help message and exit
12
  -s SEED, --seed SEED RNG seed (default: -1)
13
  -t N, --threads N number of threads to use during computation (default: 4)
 
15
  -ot N, --offset-t N time offset in milliseconds (default: 0)
16
  -on N, --offset-n N segment index offset (default: 0)
17
  -mc N, --max-context N maximum number of text context tokens to store (default: max)
18
+ -ml N, --max-len N maximum segment length in characters (default: 0)
19
  -wt N, --word-thold N word timestamp probability threshold (default: 0.010000)
20
  -v, --verbose verbose output
21
  --translate translate from source language to english
22
  -otxt, --output-txt output result in a text file
23
  -ovtt, --output-vtt output result in a vtt file
24
  -osrt, --output-srt output result in a srt file
25
+ -owts, --output-words output script for generating karaoke video
26
  -ps, --print_special print special tokens
27
  -pc, --print_colors print colors
28
  -nt, --no_timestamps do not print timestamps
29
  -l LANG, --language LANG spoken language (default: en)
30
  -m FNAME, --model FNAME model path (default: models/ggml-base.en.bin)
31
  -f FNAME, --file FNAME input WAV file path
32
+ -h, --help show this help message and exit
33
 
34
  ```
examples/main/main.cpp CHANGED
@@ -36,6 +36,7 @@ std::string to_timestamp(int64_t t, bool comma = false) {
36
  return std::string(buf);
37
  }
38
 
 
39
  void replace_all(std::string & s, const std::string & search, const std::string & replace) {
40
  for (size_t pos = 0; ; pos += replace.length()) {
41
  pos = s.find(search, pos);
@@ -45,31 +46,6 @@ void replace_all(std::string & s, const std::string & search, const std::string
45
  }
46
  }
47
 
48
- // a cost-function that is high for text that takes longer to pronounce
49
- float voice_length(const std::string & text) {
50
- float res = 0.0f;
51
-
52
- for (size_t i = 0; i < text.size(); ++i) {
53
- if (text[i] == ' ') {
54
- res += 0.01f;
55
- } else if (text[i] == ',') {
56
- res += 2.00f;
57
- } else if (text[i] == '.') {
58
- res += 3.00f;
59
- } else if (text[i] == '!') {
60
- res += 3.00f;
61
- } else if (text[i] == '?') {
62
- res += 3.00f;
63
- } else if (text[i] >= '0' && text[i] <= '9') {
64
- res += 3.00f;
65
- } else {
66
- res += 1.00f;
67
- }
68
- }
69
-
70
- return res;
71
- }
72
-
73
  // command-line parameters
74
  struct whisper_params {
75
  int32_t seed = -1; // RNG seed, not used currently
@@ -78,6 +54,7 @@ struct whisper_params {
78
  int32_t offset_t_ms = 0;
79
  int32_t offset_n = 0;
80
  int32_t max_context = -1;
 
81
 
82
  float word_thold = 0.01f;
83
 
@@ -120,6 +97,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
120
  params.offset_n = std::stoi(argv[++i]);
121
  } else if (arg == "-mc" || arg == "--max-context") {
122
  params.max_context = std::stoi(argv[++i]);
 
 
123
  } else if (arg == "-wt" || arg == "--word-thold") {
124
  params.word_thold = std::stof(argv[++i]);
125
  } else if (arg == "-v" || arg == "--verbose") {
@@ -176,13 +155,14 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
176
  fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms);
177
  fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n);
178
  fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n");
 
179
  fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold);
180
  fprintf(stderr, " -v, --verbose verbose output\n");
181
  fprintf(stderr, " --translate translate from source language to english\n");
182
  fprintf(stderr, " -otxt, --output-txt output result in a text file\n");
183
  fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n");
184
  fprintf(stderr, " -osrt, --output-srt output result in a srt file\n");
185
- fprintf(stderr, " -owts, --output-words output word-level timestamps to a text file\n");
186
  fprintf(stderr, " -ps, --print_special print special tokens\n");
187
  fprintf(stderr, " -pc, --print_colors print colors\n");
188
  fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
@@ -192,65 +172,67 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
192
  fprintf(stderr, "\n");
193
  }
194
 
195
- void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) {
196
  const whisper_params & params = *(whisper_params *) user_data;
197
 
198
  const int n_segments = whisper_full_n_segments(ctx);
199
 
200
- // print the last segment
201
- const int i = n_segments - 1;
202
- if (i == 0) {
203
  printf("\n");
204
  }
205
 
206
- if (params.no_timestamps) {
207
- if (params.print_colors) {
208
- for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
209
- if (params.print_special_tokens == false) {
210
- const whisper_token id = whisper_full_get_token_id(ctx, i, j);
211
- if (id >= whisper_token_eot(ctx)) {
212
- continue;
 
 
213
  }
214
- }
215
 
216
- const char * text = whisper_full_get_token_text(ctx, i, j);
217
- const float p = whisper_full_get_token_p (ctx, i, j);
218
 
219
- const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
220
 
221
- printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
 
 
 
 
222
  }
 
223
  } else {
224
- const char * text = whisper_full_get_segment_text(ctx, i);
225
- printf("%s", text);
226
- }
227
- fflush(stdout);
228
- } else {
229
- const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
230
- const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
231
-
232
- if (params.print_colors) {
233
- printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
234
- for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
235
- if (params.print_special_tokens == false) {
236
- const whisper_token id = whisper_full_get_token_id(ctx, i, j);
237
- if (id >= whisper_token_eot(ctx)) {
238
- continue;
239
  }
240
- }
241
 
242
- const char * text = whisper_full_get_token_text(ctx, i, j);
243
- const float p = whisper_full_get_token_p (ctx, i, j);
244
 
245
- const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
246
 
247
- printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
248
- }
249
- printf("\n");
250
- } else {
251
- const char * text = whisper_full_get_segment_text(ctx, i);
252
 
253
- printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
 
254
  }
255
  }
256
  }
@@ -320,297 +302,41 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
320
  return true;
321
  }
322
 
323
- // word-level timestamps (experimental)
324
- // TODO: make ffmpeg output optional
325
- // TODO: extra pass to detect unused speech and assign to tokens
326
  // TODO: font parameter adjustments
327
- // TODO: move to whisper.h/whisper.cpp and add parameter to select max line-length of subtitles
328
- bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, const std::vector<float> & pcmf32) {
329
- std::vector<float> pcm_avg(pcmf32.size(), 0);
330
-
331
- // average the fabs of the signal
332
- {
333
- const int hw = 32;
334
-
335
- for (int i = 0; i < pcmf32.size(); i++) {
336
- float sum = 0;
337
- for (int j = -hw; j <= hw; j++) {
338
- if (i + j >= 0 && i + j < pcmf32.size()) {
339
- sum += fabs(pcmf32[i + j]);
340
- }
341
- }
342
- pcm_avg[i] = sum/(2*hw + 1);
343
- }
344
- }
345
-
346
- struct token_info {
347
- int64_t t0 = -1;
348
- int64_t t1 = -1;
349
-
350
- int64_t tt0 = -1;
351
- int64_t tt1 = -1;
352
-
353
- whisper_token id;
354
- whisper_token tid;
355
-
356
- float p = 0.0f;
357
- float pt = 0.0f;
358
- float ptsum = 0.0f;
359
-
360
- std::string text;
361
- float vlen = 0.0f; // voice length of this token
362
- };
363
-
364
- int64_t t_beg = 0;
365
- int64_t t_last = 0;
366
-
367
- whisper_token tid_last = 0;
368
-
369
  std::ofstream fout(fname);
370
 
371
  fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
372
 
 
 
 
373
  fout << "!/bin/bash" << "\n";
374
  fout << "\n";
375
 
376
- fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE << ":rate=25:color=black -vf \"";
377
-
378
- bool is_first = true;
379
 
380
  for (int i = 0; i < whisper_full_n_segments(ctx); i++) {
381
  const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
382
  const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
383
 
384
- const char *text = whisper_full_get_segment_text(ctx, i);
385
-
386
- const int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100));
387
- const int s1 = std::min((int) pcmf32.size(), (int) (t1*WHISPER_SAMPLE_RATE/100));
388
-
389
  const int n = whisper_full_n_tokens(ctx, i);
390
 
391
- std::vector<token_info> tokens(n);
392
-
393
- if (n <= 1) {
394
- continue;
395
- }
396
-
397
  for (int j = 0; j < n; ++j) {
398
- struct whisper_token_data token = whisper_full_get_token_data(ctx, i, j);
399
-
400
- if (j == 0) {
401
- if (token.id == whisper_token_beg(ctx)) {
402
- tokens[j ].t0 = t0;
403
- tokens[j ].t1 = t0;
404
- tokens[j + 1].t0 = t0;
405
-
406
- t_beg = t0;
407
- t_last = t0;
408
- tid_last = whisper_token_beg(ctx);
409
- } else {
410
- tokens[j ].t0 = t_last;
411
- }
412
- }
413
-
414
- const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
415
-
416
- tokens[j].id = token.id;
417
- tokens[j].tid = token.tid;
418
- tokens[j].p = token.p;
419
- tokens[j].pt = token.pt;
420
- tokens[j].ptsum = token.ptsum;
421
-
422
- tokens[j].text = whisper_token_to_str(ctx, token.id);
423
- tokens[j].vlen = voice_length(tokens[j].text);
424
-
425
- if (token.pt > params.word_thold && token.ptsum > 0.01 && token.tid > tid_last && tt <= t1) {
426
- if (j > 0) {
427
- tokens[j - 1].t1 = tt;
428
- }
429
- tokens[j].t0 = tt;
430
- tid_last = token.tid;
431
- }
432
  }
433
 
434
- tokens[n - 2].t1 = t1;
435
- tokens[n - 1].t0 = t1;
436
- tokens[n - 1].t1 = t1;
437
-
438
- t_last = t1;
439
-
440
- // find intervals of tokens with unknown timestamps
441
- // fill the timestamps by proportionally splitting the interval based on the token voice lengths
442
- {
443
- int p0 = 0;
444
- int p1 = 0;
445
- while (true) {
446
- while (p1 < n && tokens[p1].t1 < 0) {
447
- p1++;
448
- }
449
-
450
- if (p1 >= n) {
451
- p1--;
452
- }
453
-
454
- if (p1 > p0) {
455
- double psum = 0.0;
456
- for (int j = p0; j <= p1; j++) {
457
- psum += tokens[j].vlen;
458
- }
459
-
460
- //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
461
-
462
- const double dt = tokens[p1].t1 - tokens[p0].t0;
463
-
464
- // split the time proportionally to the voice length
465
- for (int j = p0 + 1; j <= p1; j++) {
466
- const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
467
-
468
- tokens[j - 1].t1 = ct;
469
- tokens[j ].t0 = ct;
470
- }
471
- }
472
-
473
- p1++;
474
- p0 = p1;
475
- if (p1 >= n) {
476
- break;
477
- }
478
- }
479
- }
480
-
481
- // fix up (just in case)
482
- for (int j = 0; j < n - 1; j++) {
483
- if (tokens[j].t1 < 0) {
484
- tokens[j + 1].t0 = tokens[j].t1;
485
- }
486
-
487
- if (j > 0) {
488
- if (tokens[j - 1].t1 > tokens[j].t0) {
489
- tokens[j].t0 = tokens[j - 1].t1;
490
- tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
491
- }
492
- }
493
-
494
- tokens[j].tt0 = tokens[j].t0;
495
- tokens[j].tt1 = tokens[j].t1;
496
- }
497
-
498
- // VAD
499
- // expand or contract tokens based on voice activity
500
- {
501
- const int hw = WHISPER_SAMPLE_RATE/8;
502
-
503
- for (int j = 0; j < n; j++) {
504
- if (tokens[j].id >= whisper_token_eot(ctx)) {
505
- continue;
506
- }
507
-
508
- const int64_t t0 = tokens[j].t0;
509
- const int64_t t1 = tokens[j].t1;
510
-
511
- int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100));
512
- int s1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100));
513
-
514
- const int ss0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100) - hw);
515
- const int ss1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100) + hw);
516
-
517
- const int n = ss1 - ss0;
518
-
519
- float sum = 0.0f;
520
-
521
- for (int k = ss0; k < ss1; k++) {
522
- sum += pcm_avg[k];
523
- }
524
-
525
- const float thold = 0.5*sum/n;
526
-
527
- {
528
- int k = s0;
529
- if (pcm_avg[k] > thold && j > 0) {
530
- while (k > 0 && pcm_avg[k] > thold) {
531
- k--;
532
- }
533
- tokens[j].t0 = (int64_t) (100*k/WHISPER_SAMPLE_RATE);
534
- if (tokens[j].t0 < tokens[j - 1].t1) {
535
- tokens[j].t0 = tokens[j - 1].t1;
536
- } else {
537
- s0 = k;
538
- }
539
- } else {
540
- while (pcm_avg[k] < thold && k < s1) {
541
- k++;
542
- }
543
- s0 = k;
544
- tokens[j].t0 = 100*k/WHISPER_SAMPLE_RATE;
545
- }
546
- }
547
-
548
- {
549
- int k = s1;
550
- if (pcm_avg[k] > thold) {
551
- while (k < (int) pcmf32.size() - 1 && pcm_avg[k] > thold) {
552
- k++;
553
- }
554
- tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE;
555
- if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
556
- tokens[j].t1 = tokens[j + 1].t0;
557
- } else {
558
- s1 = k;
559
- }
560
- } else {
561
- while (pcm_avg[k] < thold && k > s0) {
562
- k--;
563
- }
564
- s1 = k;
565
- tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE;
566
- }
567
- }
568
- }
569
- }
570
-
571
- // fixed token expand (optional)
572
- {
573
- const int t_expand = 0;
574
-
575
- for (int j = 0; j < n; j++) {
576
- if (j > 0) {
577
- tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
578
- }
579
- if (j < n - 1) {
580
- tokens[j].t1 = tokens[j].t1 + t_expand;
581
- }
582
- }
583
- }
584
-
585
- // debug info
586
- // TODO: toggle via parameter
587
- for (int j = 0; j < n; ++j) {
588
- const auto & token = tokens[j];
589
- const auto tt = token.pt > params.word_thold && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
590
- printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
591
- tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, token.text.c_str());
592
-
593
- if (tokens[j].id >= whisper_token_eot(ctx)) {
594
- continue;
595
- }
596
-
597
- //printf("[%s --> %s] %s\n", to_timestamp(token.t0).c_str(), to_timestamp(token.t1).c_str(), whisper_token_to_str(ctx, token.id));
598
-
599
- //fout << "# " << to_timestamp(token.t0) << " --> " << to_timestamp(token.t1) << " " << whisper_token_to_str(ctx, token.id) << "\n";
600
- }
601
-
602
- // TODO: become parameters
603
- static const int line_wrap = 60;
604
- static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
605
-
606
- if (!is_first) {
607
  fout << ",";
608
  }
609
 
610
  // background text
611
  fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
612
 
613
- is_first = false;
614
 
615
  for (int j = 0; j < n; ++j) {
616
  const auto & token = tokens[j];
@@ -654,17 +380,6 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
654
  }
655
 
656
  ncnt += txt.size();
657
-
658
- if (ncnt > line_wrap) {
659
- if (k < j) {
660
- txt_bg = "> ";
661
- txt_fg = "> ";
662
- txt_ul = "\\ \\ ";
663
- ncnt = 0;
664
- } else {
665
- break;
666
- }
667
- }
668
  }
669
 
670
  ::replace_all(txt_bg, "'", "’");
@@ -673,8 +388,11 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
673
  ::replace_all(txt_fg, "\"", "\\\"");
674
  }
675
 
676
- // background text
677
- fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << token.tt0/100.0 << "," << token.tt1/100.0 << ")'";
 
 
 
678
 
679
  // foreground text
680
  fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
@@ -815,6 +533,10 @@ int main(int argc, char ** argv) {
815
  wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
816
  wparams.offset_ms = params.offset_t_ms;
817
 
 
 
 
 
818
  // this callback is called on each new segment
819
  if (!wparams.print_realtime) {
820
  wparams.new_segment_callback = whisper_print_segment_callback;
@@ -852,7 +574,7 @@ int main(int argc, char ** argv) {
852
  // output to WTS file
853
  if (params.output_wts) {
854
  const auto fname_wts = fname_inp + ".wts";
855
- output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, pcmf32);
856
  }
857
  }
858
  }
 
36
  return std::string(buf);
37
  }
38
 
39
+ // helper function to replace substrings
40
  void replace_all(std::string & s, const std::string & search, const std::string & replace) {
41
  for (size_t pos = 0; ; pos += replace.length()) {
42
  pos = s.find(search, pos);
 
46
  }
47
  }
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  // command-line parameters
50
  struct whisper_params {
51
  int32_t seed = -1; // RNG seed, not used currently
 
54
  int32_t offset_t_ms = 0;
55
  int32_t offset_n = 0;
56
  int32_t max_context = -1;
57
+ int32_t max_len = 0;
58
 
59
  float word_thold = 0.01f;
60
 
 
97
  params.offset_n = std::stoi(argv[++i]);
98
  } else if (arg == "-mc" || arg == "--max-context") {
99
  params.max_context = std::stoi(argv[++i]);
100
+ } else if (arg == "-ml" || arg == "--max-len") {
101
+ params.max_len = std::stoi(argv[++i]);
102
  } else if (arg == "-wt" || arg == "--word-thold") {
103
  params.word_thold = std::stof(argv[++i]);
104
  } else if (arg == "-v" || arg == "--verbose") {
 
155
  fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms);
156
  fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n);
157
  fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n");
158
+ fprintf(stderr, " -ml N, --max-len N maximum segment length in characters (default: %d)\n", params.max_len);
159
  fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold);
160
  fprintf(stderr, " -v, --verbose verbose output\n");
161
  fprintf(stderr, " --translate translate from source language to english\n");
162
  fprintf(stderr, " -otxt, --output-txt output result in a text file\n");
163
  fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n");
164
  fprintf(stderr, " -osrt, --output-srt output result in a srt file\n");
165
+ fprintf(stderr, " -owts, --output-words output script for generating karaoke video\n");
166
  fprintf(stderr, " -ps, --print_special print special tokens\n");
167
  fprintf(stderr, " -pc, --print_colors print colors\n");
168
  fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
 
172
  fprintf(stderr, "\n");
173
  }
174
 
175
+ void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
176
  const whisper_params & params = *(whisper_params *) user_data;
177
 
178
  const int n_segments = whisper_full_n_segments(ctx);
179
 
180
+ // print the last n_new segments
181
+ const int s0 = n_segments - n_new;
182
+ if (s0 == 0) {
183
  printf("\n");
184
  }
185
 
186
+ for (int i = s0; i < n_segments; i++) {
187
+ if (params.no_timestamps) {
188
+ if (params.print_colors) {
189
+ for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
190
+ if (params.print_special_tokens == false) {
191
+ const whisper_token id = whisper_full_get_token_id(ctx, i, j);
192
+ if (id >= whisper_token_eot(ctx)) {
193
+ continue;
194
+ }
195
  }
 
196
 
197
+ const char * text = whisper_full_get_token_text(ctx, i, j);
198
+ const float p = whisper_full_get_token_p (ctx, i, j);
199
 
200
+ const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
201
 
202
+ printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
203
+ }
204
+ } else {
205
+ const char * text = whisper_full_get_segment_text(ctx, i);
206
+ printf("%s", text);
207
  }
208
+ fflush(stdout);
209
  } else {
210
+ const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
211
+ const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
212
+
213
+ if (params.print_colors) {
214
+ printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
215
+ for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
216
+ if (params.print_special_tokens == false) {
217
+ const whisper_token id = whisper_full_get_token_id(ctx, i, j);
218
+ if (id >= whisper_token_eot(ctx)) {
219
+ continue;
220
+ }
 
 
 
 
221
  }
 
222
 
223
+ const char * text = whisper_full_get_token_text(ctx, i, j);
224
+ const float p = whisper_full_get_token_p (ctx, i, j);
225
 
226
+ const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
227
 
228
+ printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
229
+ }
230
+ printf("\n");
231
+ } else {
232
+ const char * text = whisper_full_get_segment_text(ctx, i);
233
 
234
+ printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
235
+ }
236
  }
237
  }
238
  }
 
302
  return true;
303
  }
304
 
305
+ // karaoke video generation
306
+ // outputs a bash script that uses ffmpeg to generate a video with the subtitles
 
307
  // TODO: font parameter adjustments
308
+ bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  std::ofstream fout(fname);
310
 
311
  fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
312
 
313
+ // TODO: become parameter
314
+ static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
315
+
316
  fout << "!/bin/bash" << "\n";
317
  fout << "\n";
318
 
319
+ fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << t_sec << ":rate=25:color=black -vf \"";
 
 
320
 
321
  for (int i = 0; i < whisper_full_n_segments(ctx); i++) {
322
  const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
323
  const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
324
 
 
 
 
 
 
325
  const int n = whisper_full_n_tokens(ctx, i);
326
 
327
+ std::vector<whisper_token_data> tokens(n);
 
 
 
 
 
328
  for (int j = 0; j < n; ++j) {
329
+ tokens[j] = whisper_full_get_token_data(ctx, i, j);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  }
331
 
332
+ if (i > 0) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
  fout << ",";
334
  }
335
 
336
  // background text
337
  fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
338
 
339
+ bool is_first = true;
340
 
341
  for (int j = 0; j < n; ++j) {
342
  const auto & token = tokens[j];
 
380
  }
381
 
382
  ncnt += txt.size();
 
 
 
 
 
 
 
 
 
 
 
383
  }
384
 
385
  ::replace_all(txt_bg, "'", "’");
 
388
  ::replace_all(txt_fg, "\"", "\\\"");
389
  }
390
 
391
+ if (is_first) {
392
+ // background text
393
+ fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << t0/100.0 << "," << t1/100.0 << ")'";
394
+ is_first = false;
395
+ }
396
 
397
  // foreground text
398
  fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
 
533
  wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
534
  wparams.offset_ms = params.offset_t_ms;
535
 
536
+ wparams.token_timestamps = params.output_wts || params.max_len > 0;
537
+ wparams.thold_pt = params.word_thold;
538
+ wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
539
+
540
  // this callback is called on each new segment
541
  if (!wparams.print_realtime) {
542
  wparams.new_segment_callback = whisper_print_segment_callback;
 
574
  // output to WTS file
575
  if (params.output_wts) {
576
  const auto fname_wts = fname_inp + ".wts";
577
+ output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
578
  }
579
  }
580
  }
whisper.cpp CHANGED
@@ -418,6 +418,12 @@ struct whisper_context {
418
  std::vector<whisper_segment> result_all;
419
 
420
  std::vector<whisper_token> prompt_past;
 
 
 
 
 
 
421
  };
422
 
423
  // load the model from a ggml file
@@ -431,7 +437,7 @@ struct whisper_context {
431
  //
432
  // see the convert-pt-to-ggml.py script for details
433
  //
434
- bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
435
  fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
436
 
437
  auto & model = wctx.model;
@@ -1062,7 +1068,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
1062
  // - n_threads: number of threads to use
1063
  // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
1064
  //
1065
- bool whisper_encode(
1066
  whisper_context & wctx,
1067
  const int n_threads,
1068
  const int mel_offset) {
@@ -1448,7 +1454,7 @@ bool whisper_encode(
1448
  // - n_tokens: number of tokens in the prompt
1449
  // - n_past: number of past tokens to prefix the prompt with
1450
  //
1451
- bool whisper_decode(
1452
  whisper_context & wctx,
1453
  const int n_threads,
1454
  const whisper_token * tokens,
@@ -1811,10 +1817,12 @@ bool whisper_decode(
1811
  }
1812
 
1813
  // the most basic sampling scheme - select the top token
1814
- whisper_token_data whisper_sample_best(
1815
  const whisper_vocab & vocab,
1816
  const float * probs) {
1817
- whisper_token_data result;
 
 
1818
 
1819
  int n_logits = vocab.id_to_token.size();
1820
 
@@ -1887,7 +1895,7 @@ whisper_token_data whisper_sample_best(
1887
  }
1888
 
1889
  // samples only from the timestamps tokens
1890
- whisper_vocab::id whisper_sample_timestamp(
1891
  const whisper_vocab & vocab,
1892
  const float * probs) {
1893
  int n_logits = vocab.id_to_token.size();
@@ -1939,7 +1947,7 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
1939
  // naive Discrete Fourier Transform
1940
  // input is real-valued
1941
  // output is complex-valued
1942
- void dft(const std::vector<float> & in, std::vector<float> & out) {
1943
  int N = in.size();
1944
 
1945
  out.resize(N*2);
@@ -1963,7 +1971,7 @@ void dft(const std::vector<float> & in, std::vector<float> & out) {
1963
  // poor man's implementation - use something better
1964
  // input is real-valued
1965
  // output is complex-valued
1966
- void fft(const std::vector<float> & in, std::vector<float> & out) {
1967
  out.resize(in.size()*2);
1968
 
1969
  int N = in.size();
@@ -2014,7 +2022,7 @@ void fft(const std::vector<float> & in, std::vector<float> & out) {
2014
  }
2015
 
2016
  // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
2017
- bool log_mel_spectrogram(
2018
  const float * samples,
2019
  const int n_samples,
2020
  const int sample_rate,
@@ -2339,6 +2347,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2339
  /*.print_realtime =*/ false,
2340
  /*.print_timestamps =*/ true,
2341
 
 
 
 
 
 
2342
  /*.language =*/ "en",
2343
 
2344
  /*.greedy =*/ {
@@ -2371,6 +2384,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2371
  /*.print_realtime =*/ false,
2372
  /*.print_timestamps =*/ true,
2373
 
 
 
 
 
 
2374
  /*.language =*/ "en",
2375
 
2376
  /*.greedy =*/ {
@@ -2392,6 +2410,68 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
2392
  return result;
2393
  }
2394
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2395
  int whisper_full(
2396
  struct whisper_context * ctx,
2397
  struct whisper_full_params params,
@@ -2408,6 +2488,13 @@ int whisper_full(
2408
  return -1;
2409
  }
2410
 
 
 
 
 
 
 
 
2411
  const int seek_start = params.offset_ms/10;
2412
 
2413
  // if length of spectrogram is less than 1s (100 samples), then return
@@ -2557,6 +2644,7 @@ int whisper_full(
2557
  }
2558
  }
2559
 
 
2560
  tokens_cur.resize(result_len);
2561
 
2562
  for (const auto & r : tokens_cur) {
@@ -2595,8 +2683,19 @@ int whisper_full(
2595
  for (int j = i0; j <= i; j++) {
2596
  result_all.back().tokens.push_back(tokens_cur[j]);
2597
  }
 
 
 
 
 
 
 
 
 
 
 
2598
  if (params.new_segment_callback) {
2599
- params.new_segment_callback(ctx, params.new_segment_callback_user_data);
2600
  }
2601
  }
2602
  text = "";
@@ -2625,8 +2724,19 @@ int whisper_full(
2625
  for (int j = i0; j < (int) tokens_cur.size(); j++) {
2626
  result_all.back().tokens.push_back(tokens_cur[j]);
2627
  }
 
 
 
 
 
 
 
 
 
 
 
2628
  if (params.new_segment_callback) {
2629
- params.new_segment_callback(ctx, params.new_segment_callback_user_data);
2630
  }
2631
  }
2632
  }
@@ -2760,7 +2870,7 @@ int whisper_full_parallel(
2760
 
2761
  // call the new_segment_callback for each segment
2762
  if (params.new_segment_callback) {
2763
- params.new_segment_callback(ctx, params.new_segment_callback_user_data);
2764
  }
2765
  }
2766
 
@@ -2836,3 +2946,304 @@ const char * whisper_print_system_info() {
2836
 
2837
  return s.c_str();
2838
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  std::vector<whisper_segment> result_all;
419
 
420
  std::vector<whisper_token> prompt_past;
421
+
422
+ // [EXPERIMENTAL] token-level timestamps data
423
+ int64_t t_beg;
424
+ int64_t t_last;
425
+ whisper_token tid_last;
426
+ std::vector<float> energy; // PCM signal energy
427
  };
428
 
429
  // load the model from a ggml file
 
437
  //
438
  // see the convert-pt-to-ggml.py script for details
439
  //
440
+ static bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
441
  fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
442
 
443
  auto & model = wctx.model;
 
1068
  // - n_threads: number of threads to use
1069
  // - mel_offset: offset in the mel spectrogram (i.e. audio offset)
1070
  //
1071
+ static bool whisper_encode(
1072
  whisper_context & wctx,
1073
  const int n_threads,
1074
  const int mel_offset) {
 
1454
  // - n_tokens: number of tokens in the prompt
1455
  // - n_past: number of past tokens to prefix the prompt with
1456
  //
1457
+ static bool whisper_decode(
1458
  whisper_context & wctx,
1459
  const int n_threads,
1460
  const whisper_token * tokens,
 
1817
  }
1818
 
1819
  // the most basic sampling scheme - select the top token
1820
+ static whisper_token_data whisper_sample_best(
1821
  const whisper_vocab & vocab,
1822
  const float * probs) {
1823
+ whisper_token_data result = {
1824
+ 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
1825
+ };
1826
 
1827
  int n_logits = vocab.id_to_token.size();
1828
 
 
1895
  }
1896
 
1897
  // samples only from the timestamps tokens
1898
+ static whisper_vocab::id whisper_sample_timestamp(
1899
  const whisper_vocab & vocab,
1900
  const float * probs) {
1901
  int n_logits = vocab.id_to_token.size();
 
1947
  // naive Discrete Fourier Transform
1948
  // input is real-valued
1949
  // output is complex-valued
1950
+ static void dft(const std::vector<float> & in, std::vector<float> & out) {
1951
  int N = in.size();
1952
 
1953
  out.resize(N*2);
 
1971
  // poor man's implementation - use something better
1972
  // input is real-valued
1973
  // output is complex-valued
1974
+ static void fft(const std::vector<float> & in, std::vector<float> & out) {
1975
  out.resize(in.size()*2);
1976
 
1977
  int N = in.size();
 
2022
  }
2023
 
2024
  // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
2025
+ static bool log_mel_spectrogram(
2026
  const float * samples,
2027
  const int n_samples,
2028
  const int sample_rate,
 
2347
  /*.print_realtime =*/ false,
2348
  /*.print_timestamps =*/ true,
2349
 
2350
+ /*.token_timestamps =*/ false,
2351
+ /*.thold_pt =*/ 0.01f,
2352
+ /*.thold_ptsum =*/ 0.01f,
2353
+ /*.max_len =*/ 0,
2354
+
2355
  /*.language =*/ "en",
2356
 
2357
  /*.greedy =*/ {
 
2384
  /*.print_realtime =*/ false,
2385
  /*.print_timestamps =*/ true,
2386
 
2387
+ /*.token_timestamps =*/ false,
2388
+ /*.thold_pt =*/ 0.01f,
2389
+ /*.thold_ptsum =*/ 0.01f,
2390
+ /*.max_len =*/ 0,
2391
+
2392
  /*.language =*/ "en",
2393
 
2394
  /*.greedy =*/ {
 
2410
  return result;
2411
  }
2412
 
2413
+ // forward declarations
2414
+ static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
2415
+ static void whisper_exp_compute_token_level_timestamps(
2416
+ struct whisper_context * ctx,
2417
+ int i_segment,
2418
+ float thold_pt,
2419
+ float thold_ptsum);
2420
+
2421
+ // wrap the last segment to max_len characters
2422
+ // returns the number of new segments
2423
+ static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
2424
+ auto segment = ctx->result_all.back();
2425
+
2426
+ int res = 1;
2427
+ int acc = 0;
2428
+
2429
+ std::string text;
2430
+
2431
+ for (int i = 0; i < (int) segment.tokens.size(); i++) {
2432
+ const auto & token = segment.tokens[i];
2433
+ if (token.id >= whisper_token_eot(ctx)) {
2434
+ continue;
2435
+ }
2436
+
2437
+ const auto txt = whisper_token_to_str(ctx, token.id);
2438
+
2439
+ const int cur = strlen(txt);
2440
+
2441
+ if (acc + cur > max_len && i > 0) {
2442
+ // split here
2443
+ ctx->result_all.back().text = std::move(text);
2444
+ ctx->result_all.back().t1 = token.t0;
2445
+ ctx->result_all.back().tokens.resize(i);
2446
+
2447
+ ctx->result_all.push_back({});
2448
+ ctx->result_all.back().t0 = token.t0;
2449
+ ctx->result_all.back().t1 = segment.t1;
2450
+
2451
+ // add tokens [i, end] to the new segment
2452
+ ctx->result_all.back().tokens.insert(
2453
+ ctx->result_all.back().tokens.end(),
2454
+ segment.tokens.begin() + i,
2455
+ segment.tokens.end());
2456
+
2457
+ acc = 0;
2458
+ text = "";
2459
+
2460
+ segment = ctx->result_all.back();
2461
+ i = -1;
2462
+
2463
+ res++;
2464
+ } else {
2465
+ acc += cur;
2466
+ text += txt;
2467
+ }
2468
+ }
2469
+
2470
+ ctx->result_all.back().text = std::move(text);
2471
+
2472
+ return res;
2473
+ }
2474
+
2475
  int whisper_full(
2476
  struct whisper_context * ctx,
2477
  struct whisper_full_params params,
 
2488
  return -1;
2489
  }
2490
 
2491
+ if (params.token_timestamps) {
2492
+ ctx->t_beg = 0;
2493
+ ctx->t_last = 0;
2494
+ ctx->tid_last = 0;
2495
+ ctx->energy = get_signal_energy(samples, n_samples, 32);
2496
+ }
2497
+
2498
  const int seek_start = params.offset_ms/10;
2499
 
2500
  // if length of spectrogram is less than 1s (100 samples), then return
 
2644
  }
2645
  }
2646
 
2647
+ // shrink down to result_len
2648
  tokens_cur.resize(result_len);
2649
 
2650
  for (const auto & r : tokens_cur) {
 
2683
  for (int j = i0; j <= i; j++) {
2684
  result_all.back().tokens.push_back(tokens_cur[j]);
2685
  }
2686
+
2687
+ int n_new = 1;
2688
+
2689
+ if (params.token_timestamps) {
2690
+ whisper_exp_compute_token_level_timestamps(
2691
+ ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
2692
+
2693
+ if (params.max_len > 0) {
2694
+ n_new = whisper_wrap_segment(ctx, params.max_len);
2695
+ }
2696
+ }
2697
  if (params.new_segment_callback) {
2698
+ params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
2699
  }
2700
  }
2701
  text = "";
 
2724
  for (int j = i0; j < (int) tokens_cur.size(); j++) {
2725
  result_all.back().tokens.push_back(tokens_cur[j]);
2726
  }
2727
+
2728
+ int n_new = 1;
2729
+
2730
+ if (params.token_timestamps) {
2731
+ whisper_exp_compute_token_level_timestamps(
2732
+ ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
2733
+
2734
+ if (params.max_len > 0) {
2735
+ n_new = whisper_wrap_segment(ctx, params.max_len);
2736
+ }
2737
+ }
2738
  if (params.new_segment_callback) {
2739
+ params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
2740
  }
2741
  }
2742
  }
 
2870
 
2871
  // call the new_segment_callback for each segment
2872
  if (params.new_segment_callback) {
2873
+ params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
2874
  }
2875
  }
2876
 
 
2946
 
2947
  return s.c_str();
2948
  }
2949
+
2950
+ // =================================================================================================
2951
+
2952
+ //
2953
+ // Experimental stuff below
2954
+ //
2955
+ // Not sure if these should be part of the library at all, because the quality of the results is not
2956
+ // guaranteed. Might get removed at some point unless a robust algorithm implementation is found
2957
+ //
2958
+
2959
+ // =================================================================================================
2960
+
2961
+ //
2962
+ // token-level timestamps
2963
+ //
2964
+
2965
+ static int timestamp_to_sample(int64_t t, int n_samples) {
2966
+ return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
2967
+ }
2968
+
2969
+ static int64_t sample_to_timestamp(int i_sample) {
2970
+ return (100*i_sample)/WHISPER_SAMPLE_RATE;
2971
+ }
2972
+
2973
+ // a cost-function / heuristic that is high for text that takes longer to pronounce
2974
+ // obviously, can be improved
2975
+ static float voice_length(const std::string & text) {
2976
+ float res = 0.0f;
2977
+
2978
+ for (size_t i = 0; i < text.size(); ++i) {
2979
+ if (text[i] == ' ') {
2980
+ res += 0.01f;
2981
+ } else if (text[i] == ',') {
2982
+ res += 2.00f;
2983
+ } else if (text[i] == '.') {
2984
+ res += 3.00f;
2985
+ } else if (text[i] == '!') {
2986
+ res += 3.00f;
2987
+ } else if (text[i] == '?') {
2988
+ res += 3.00f;
2989
+ } else if (text[i] >= '0' && text[i] <= '9') {
2990
+ res += 3.00f;
2991
+ } else {
2992
+ res += 1.00f;
2993
+ }
2994
+ }
2995
+
2996
+ return res;
2997
+ }
2998
+
2999
+ // average the fabs of the signal
3000
+ static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) {
3001
+ const int hw = n_samples_per_half_window;
3002
+
3003
+ std::vector<float> result(n_samples);
3004
+
3005
+ for (int i = 0; i < n_samples; i++) {
3006
+ float sum = 0;
3007
+ for (int j = -hw; j <= hw; j++) {
3008
+ if (i + j >= 0 && i + j < n_samples) {
3009
+ sum += fabs(signal[i + j]);
3010
+ }
3011
+ }
3012
+ result[i] = sum/(2*hw + 1);
3013
+ }
3014
+
3015
+ return result;
3016
+ }
3017
+
3018
+ static void whisper_exp_compute_token_level_timestamps(
3019
+ struct whisper_context * ctx,
3020
+ int i_segment,
3021
+ float thold_pt,
3022
+ float thold_ptsum) {
3023
+ auto & segment = ctx->result_all[i_segment];
3024
+ auto & tokens = segment.tokens;
3025
+
3026
+ const int n_samples = ctx->energy.size();
3027
+
3028
+ if (n_samples == 0) {
3029
+ fprintf(stderr, "%s: no signal data available\n", __func__);
3030
+ return;
3031
+ }
3032
+
3033
+ const int64_t t0 = segment.t0;
3034
+ const int64_t t1 = segment.t1;
3035
+
3036
+ const int s0 = timestamp_to_sample(t0, n_samples);
3037
+ const int s1 = timestamp_to_sample(t1, n_samples);
3038
+
3039
+ const int n = tokens.size();
3040
+
3041
+ if (n == 0) {
3042
+ return;
3043
+ }
3044
+
3045
+ if (n == 1) {
3046
+ tokens[0].t0 = t0;
3047
+ tokens[0].t1 = t1;
3048
+
3049
+ return;
3050
+ }
3051
+
3052
+ auto & t_beg = ctx->t_beg;
3053
+ auto & t_last = ctx->t_last;
3054
+ auto & tid_last = ctx->tid_last;
3055
+
3056
+ for (int j = 0; j < n; ++j) {
3057
+ auto & token = tokens[j];
3058
+
3059
+ if (j == 0) {
3060
+ if (token.id == whisper_token_beg(ctx)) {
3061
+ tokens[j ].t0 = t0;
3062
+ tokens[j ].t1 = t0;
3063
+ tokens[j + 1].t0 = t0;
3064
+
3065
+ t_beg = t0;
3066
+ t_last = t0;
3067
+ tid_last = whisper_token_beg(ctx);
3068
+ } else {
3069
+ tokens[j ].t0 = t_last;
3070
+ }
3071
+ }
3072
+
3073
+ const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
3074
+
3075
+ tokens[j].id = token.id;
3076
+ tokens[j].tid = token.tid;
3077
+ tokens[j].p = token.p;
3078
+ tokens[j].pt = token.pt;
3079
+ tokens[j].ptsum = token.ptsum;
3080
+
3081
+ tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id));
3082
+
3083
+ if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
3084
+ if (j > 0) {
3085
+ tokens[j - 1].t1 = tt;
3086
+ }
3087
+ tokens[j].t0 = tt;
3088
+ tid_last = token.tid;
3089
+ }
3090
+ }
3091
+
3092
+ tokens[n - 2].t1 = t1;
3093
+ tokens[n - 1].t0 = t1;
3094
+ tokens[n - 1].t1 = t1;
3095
+
3096
+ t_last = t1;
3097
+
3098
+ // find intervals of tokens with unknown timestamps
3099
+ // fill the timestamps by proportionally splitting the interval based on the token voice lengths
3100
+ {
3101
+ int p0 = 0;
3102
+ int p1 = 0;
3103
+
3104
+ while (true) {
3105
+ while (p1 < n && tokens[p1].t1 < 0) {
3106
+ p1++;
3107
+ }
3108
+
3109
+ if (p1 >= n) {
3110
+ p1--;
3111
+ }
3112
+
3113
+ if (p1 > p0) {
3114
+ double psum = 0.0;
3115
+ for (int j = p0; j <= p1; j++) {
3116
+ psum += tokens[j].vlen;
3117
+ }
3118
+
3119
+ //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
3120
+
3121
+ const double dt = tokens[p1].t1 - tokens[p0].t0;
3122
+
3123
+ // split the time proportionally to the voice length
3124
+ for (int j = p0 + 1; j <= p1; j++) {
3125
+ const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
3126
+
3127
+ tokens[j - 1].t1 = ct;
3128
+ tokens[j ].t0 = ct;
3129
+ }
3130
+ }
3131
+
3132
+ p1++;
3133
+ p0 = p1;
3134
+ if (p1 >= n) {
3135
+ break;
3136
+ }
3137
+ }
3138
+ }
3139
+
3140
+ // fix up (just in case)
3141
+ for (int j = 0; j < n - 1; j++) {
3142
+ if (tokens[j].t1 < 0) {
3143
+ tokens[j + 1].t0 = tokens[j].t1;
3144
+ }
3145
+
3146
+ if (j > 0) {
3147
+ if (tokens[j - 1].t1 > tokens[j].t0) {
3148
+ tokens[j].t0 = tokens[j - 1].t1;
3149
+ tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
3150
+ }
3151
+ }
3152
+ }
3153
+
3154
+ // VAD
3155
+ // expand or contract tokens based on voice activity
3156
+ {
3157
+ const int hw = WHISPER_SAMPLE_RATE/8;
3158
+
3159
+ for (int j = 0; j < n; j++) {
3160
+ if (tokens[j].id >= whisper_token_eot(ctx)) {
3161
+ continue;
3162
+ }
3163
+
3164
+ int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
3165
+ int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
3166
+
3167
+ const int ss0 = std::max(s0 - hw, 0);
3168
+ const int ss1 = std::min(s1 + hw, n_samples);
3169
+
3170
+ const int ns = ss1 - ss0;
3171
+
3172
+ float sum = 0.0f;
3173
+
3174
+ for (int k = ss0; k < ss1; k++) {
3175
+ sum += ctx->energy[k];
3176
+ }
3177
+
3178
+ const float thold = 0.5*sum/ns;
3179
+
3180
+ {
3181
+ int k = s0;
3182
+ if (ctx->energy[k] > thold && j > 0) {
3183
+ while (k > 0 && ctx->energy[k] > thold) {
3184
+ k--;
3185
+ }
3186
+ tokens[j].t0 = sample_to_timestamp(k);
3187
+ if (tokens[j].t0 < tokens[j - 1].t1) {
3188
+ tokens[j].t0 = tokens[j - 1].t1;
3189
+ } else {
3190
+ s0 = k;
3191
+ }
3192
+ } else {
3193
+ while (ctx->energy[k] < thold && k < s1) {
3194
+ k++;
3195
+ }
3196
+ s0 = k;
3197
+ tokens[j].t0 = sample_to_timestamp(k);
3198
+ }
3199
+ }
3200
+
3201
+ {
3202
+ int k = s1;
3203
+ if (ctx->energy[k] > thold) {
3204
+ while (k < n_samples - 1 && ctx->energy[k] > thold) {
3205
+ k++;
3206
+ }
3207
+ tokens[j].t1 = sample_to_timestamp(k);
3208
+ if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
3209
+ tokens[j].t1 = tokens[j + 1].t0;
3210
+ } else {
3211
+ s1 = k;
3212
+ }
3213
+ } else {
3214
+ while (ctx->energy[k] < thold && k > s0) {
3215
+ k--;
3216
+ }
3217
+ s1 = k;
3218
+ tokens[j].t1 = sample_to_timestamp(k);
3219
+ }
3220
+ }
3221
+ }
3222
+ }
3223
+
3224
+ // fixed token expand (optional)
3225
+ //{
3226
+ // const int t_expand = 0;
3227
+
3228
+ // for (int j = 0; j < n; j++) {
3229
+ // if (j > 0) {
3230
+ // tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
3231
+ // }
3232
+ // if (j < n - 1) {
3233
+ // tokens[j].t1 = tokens[j].t1 + t_expand;
3234
+ // }
3235
+ // }
3236
+ //}
3237
+
3238
+ // debug info
3239
+ //for (int j = 0; j < n; ++j) {
3240
+ // const auto & token = tokens[j];
3241
+ // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
3242
+ // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
3243
+ // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id));
3244
+
3245
+ // if (tokens[j].id >= whisper_token_eot(ctx)) {
3246
+ // continue;
3247
+ // }
3248
+ //}
3249
+ }
whisper.h CHANGED
@@ -68,14 +68,21 @@ extern "C" {
68
 
69
  typedef int whisper_token;
70
 
71
- struct whisper_token_data {
72
  whisper_token id; // token id
73
  whisper_token tid; // forced timestamp token id
74
 
75
  float p; // probability of the token
76
  float pt; // probability of the timestamp token
77
  float ptsum; // sum of probabilities of all timestamp tokens
78
- };
 
 
 
 
 
 
 
79
 
80
  // Allocates all memory needed for the model and loads the model from the given file.
81
  // Returns NULL on failure.
@@ -129,7 +136,7 @@ extern "C" {
129
  // You can also implement your own sampling method using the whisper_get_probs() function.
130
  // whisper_sample_best() returns the token with the highest probability
131
  // whisper_sample_timestamp() returns the most probable timestamp token
132
- WHISPER_API struct whisper_token_data whisper_sample_best(struct whisper_context * ctx);
133
  WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
134
 
135
  // Return the id of the specified language, returns -1 if not found
@@ -172,7 +179,7 @@ extern "C" {
172
  // Text segment callback
173
  // Called on every newly generated text segment
174
  // Use the whisper_full_...() functions to obtain the text segments
175
- typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data);
176
 
177
  struct whisper_full_params {
178
  enum whisper_sampling_strategy strategy;
@@ -188,6 +195,12 @@ extern "C" {
188
  bool print_realtime;
189
  bool print_timestamps;
190
 
 
 
 
 
 
 
191
  const char * language;
192
 
193
  struct {
@@ -244,7 +257,7 @@ extern "C" {
244
 
245
  // Get token data for the specified token in the specified segment.
246
  // This contains probabilities, timestamps, etc.
247
- WHISPER_API struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
248
 
249
  // Get the probability of the specified token in the specified segment.
250
  WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
 
68
 
69
  typedef int whisper_token;
70
 
71
+ typedef struct whisper_token_data {
72
  whisper_token id; // token id
73
  whisper_token tid; // forced timestamp token id
74
 
75
  float p; // probability of the token
76
  float pt; // probability of the timestamp token
77
  float ptsum; // sum of probabilities of all timestamp tokens
78
+
79
+ // token-level timestamp data
80
+ // do not use if you haven't computed token-level timestamps
81
+ int64_t t0; // start time of the token
82
+ int64_t t1; // end time of the token
83
+
84
+ float vlen; // voice length of the token
85
+ } whisper_token_data;
86
 
87
  // Allocates all memory needed for the model and loads the model from the given file.
88
  // Returns NULL on failure.
 
136
  // You can also implement your own sampling method using the whisper_get_probs() function.
137
  // whisper_sample_best() returns the token with the highest probability
138
  // whisper_sample_timestamp() returns the most probable timestamp token
139
+ WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
140
  WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
141
 
142
  // Return the id of the specified language, returns -1 if not found
 
179
  // Text segment callback
180
  // Called on every newly generated text segment
181
  // Use the whisper_full_...() functions to obtain the text segments
182
+ typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
183
 
184
  struct whisper_full_params {
185
  enum whisper_sampling_strategy strategy;
 
195
  bool print_realtime;
196
  bool print_timestamps;
197
 
198
+ // [EXPERIMENTAL] token-level timestamps
199
+ bool token_timestamps; // enable token-level timestamps
200
+ float thold_pt; // timestamp token probability threshold (~0.01)
201
+ float thold_ptsum; // timestamp token sum probability threshold (~0.01)
202
+ int max_len; // max segment length in characters
203
+
204
  const char * language;
205
 
206
  struct {
 
257
 
258
  // Get token data for the specified token in the specified segment.
259
  // This contains probabilities, timestamps, etc.
260
+ WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
261
 
262
  // Get the probability of the specified token in the specified segment.
263
  WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);