Spaces:
Sleeping
Sleeping
whisper : token-level timestamp refactoring (#49, #120)
Browse filesThis turned out pretty good overall. The algorithm has been moved from
main.cpp to whisper.cpp and can be reused for all subtitles types. This
means that now you can specify the maximum length of the generated
lines. Simply provide the "-ml" argument specifying the max length in
number of characters
- README.md +2 -1
- examples/main/README.md +3 -2
- examples/main/main.cpp +70 -348
- whisper.cpp +423 -12
- whisper.h +18 -5
README.md
CHANGED
|
@@ -101,13 +101,14 @@ options:
|
|
| 101 |
-ot N, --offset-t N time offset in milliseconds (default: 0)
|
| 102 |
-on N, --offset-n N segment index offset (default: 0)
|
| 103 |
-mc N, --max-context N maximum number of text context tokens to store (default: max)
|
|
|
|
| 104 |
-wt N, --word-thold N word timestamp probability threshold (default: 0.010000)
|
| 105 |
-v, --verbose verbose output
|
| 106 |
--translate translate from source language to english
|
| 107 |
-otxt, --output-txt output result in a text file
|
| 108 |
-ovtt, --output-vtt output result in a vtt file
|
| 109 |
-osrt, --output-srt output result in a srt file
|
| 110 |
-
-owts, --output-words output
|
| 111 |
-ps, --print_special print special tokens
|
| 112 |
-pc, --print_colors print colors
|
| 113 |
-nt, --no_timestamps do not print timestamps
|
|
|
|
| 101 |
-ot N, --offset-t N time offset in milliseconds (default: 0)
|
| 102 |
-on N, --offset-n N segment index offset (default: 0)
|
| 103 |
-mc N, --max-context N maximum number of text context tokens to store (default: max)
|
| 104 |
+
-ml N, --max-len N maximum segment length in characters (default: 0)
|
| 105 |
-wt N, --word-thold N word timestamp probability threshold (default: 0.010000)
|
| 106 |
-v, --verbose verbose output
|
| 107 |
--translate translate from source language to english
|
| 108 |
-otxt, --output-txt output result in a text file
|
| 109 |
-ovtt, --output-vtt output result in a vtt file
|
| 110 |
-osrt, --output-srt output result in a srt file
|
| 111 |
+
-owts, --output-words output script for generating karaoke video
|
| 112 |
-ps, --print_special print special tokens
|
| 113 |
-pc, --print_colors print colors
|
| 114 |
-nt, --no_timestamps do not print timestamps
|
examples/main/README.md
CHANGED
|
@@ -8,7 +8,6 @@ It can be used as a reference for using the `whisper.cpp` library in other proje
|
|
| 8 |
|
| 9 |
usage: ./bin/main [options] file0.wav file1.wav ...
|
| 10 |
|
| 11 |
-
options:
|
| 12 |
-h, --help show this help message and exit
|
| 13 |
-s SEED, --seed SEED RNG seed (default: -1)
|
| 14 |
-t N, --threads N number of threads to use during computation (default: 4)
|
|
@@ -16,18 +15,20 @@ options:
|
|
| 16 |
-ot N, --offset-t N time offset in milliseconds (default: 0)
|
| 17 |
-on N, --offset-n N segment index offset (default: 0)
|
| 18 |
-mc N, --max-context N maximum number of text context tokens to store (default: max)
|
|
|
|
| 19 |
-wt N, --word-thold N word timestamp probability threshold (default: 0.010000)
|
| 20 |
-v, --verbose verbose output
|
| 21 |
--translate translate from source language to english
|
| 22 |
-otxt, --output-txt output result in a text file
|
| 23 |
-ovtt, --output-vtt output result in a vtt file
|
| 24 |
-osrt, --output-srt output result in a srt file
|
| 25 |
-
-owts, --output-words output
|
| 26 |
-ps, --print_special print special tokens
|
| 27 |
-pc, --print_colors print colors
|
| 28 |
-nt, --no_timestamps do not print timestamps
|
| 29 |
-l LANG, --language LANG spoken language (default: en)
|
| 30 |
-m FNAME, --model FNAME model path (default: models/ggml-base.en.bin)
|
| 31 |
-f FNAME, --file FNAME input WAV file path
|
|
|
|
| 32 |
|
| 33 |
```
|
|
|
|
| 8 |
|
| 9 |
usage: ./bin/main [options] file0.wav file1.wav ...
|
| 10 |
|
|
|
|
| 11 |
-h, --help show this help message and exit
|
| 12 |
-s SEED, --seed SEED RNG seed (default: -1)
|
| 13 |
-t N, --threads N number of threads to use during computation (default: 4)
|
|
|
|
| 15 |
-ot N, --offset-t N time offset in milliseconds (default: 0)
|
| 16 |
-on N, --offset-n N segment index offset (default: 0)
|
| 17 |
-mc N, --max-context N maximum number of text context tokens to store (default: max)
|
| 18 |
+
-ml N, --max-len N maximum segment length in characters (default: 0)
|
| 19 |
-wt N, --word-thold N word timestamp probability threshold (default: 0.010000)
|
| 20 |
-v, --verbose verbose output
|
| 21 |
--translate translate from source language to english
|
| 22 |
-otxt, --output-txt output result in a text file
|
| 23 |
-ovtt, --output-vtt output result in a vtt file
|
| 24 |
-osrt, --output-srt output result in a srt file
|
| 25 |
+
-owts, --output-words output script for generating karaoke video
|
| 26 |
-ps, --print_special print special tokens
|
| 27 |
-pc, --print_colors print colors
|
| 28 |
-nt, --no_timestamps do not print timestamps
|
| 29 |
-l LANG, --language LANG spoken language (default: en)
|
| 30 |
-m FNAME, --model FNAME model path (default: models/ggml-base.en.bin)
|
| 31 |
-f FNAME, --file FNAME input WAV file path
|
| 32 |
+
-h, --help show this help message and exit
|
| 33 |
|
| 34 |
```
|
examples/main/main.cpp
CHANGED
|
@@ -36,6 +36,7 @@ std::string to_timestamp(int64_t t, bool comma = false) {
|
|
| 36 |
return std::string(buf);
|
| 37 |
}
|
| 38 |
|
|
|
|
| 39 |
void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
| 40 |
for (size_t pos = 0; ; pos += replace.length()) {
|
| 41 |
pos = s.find(search, pos);
|
|
@@ -45,31 +46,6 @@ void replace_all(std::string & s, const std::string & search, const std::string
|
|
| 45 |
}
|
| 46 |
}
|
| 47 |
|
| 48 |
-
// a cost-function that is high for text that takes longer to pronounce
|
| 49 |
-
float voice_length(const std::string & text) {
|
| 50 |
-
float res = 0.0f;
|
| 51 |
-
|
| 52 |
-
for (size_t i = 0; i < text.size(); ++i) {
|
| 53 |
-
if (text[i] == ' ') {
|
| 54 |
-
res += 0.01f;
|
| 55 |
-
} else if (text[i] == ',') {
|
| 56 |
-
res += 2.00f;
|
| 57 |
-
} else if (text[i] == '.') {
|
| 58 |
-
res += 3.00f;
|
| 59 |
-
} else if (text[i] == '!') {
|
| 60 |
-
res += 3.00f;
|
| 61 |
-
} else if (text[i] == '?') {
|
| 62 |
-
res += 3.00f;
|
| 63 |
-
} else if (text[i] >= '0' && text[i] <= '9') {
|
| 64 |
-
res += 3.00f;
|
| 65 |
-
} else {
|
| 66 |
-
res += 1.00f;
|
| 67 |
-
}
|
| 68 |
-
}
|
| 69 |
-
|
| 70 |
-
return res;
|
| 71 |
-
}
|
| 72 |
-
|
| 73 |
// command-line parameters
|
| 74 |
struct whisper_params {
|
| 75 |
int32_t seed = -1; // RNG seed, not used currently
|
|
@@ -78,6 +54,7 @@ struct whisper_params {
|
|
| 78 |
int32_t offset_t_ms = 0;
|
| 79 |
int32_t offset_n = 0;
|
| 80 |
int32_t max_context = -1;
|
|
|
|
| 81 |
|
| 82 |
float word_thold = 0.01f;
|
| 83 |
|
|
@@ -120,6 +97,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|
| 120 |
params.offset_n = std::stoi(argv[++i]);
|
| 121 |
} else if (arg == "-mc" || arg == "--max-context") {
|
| 122 |
params.max_context = std::stoi(argv[++i]);
|
|
|
|
|
|
|
| 123 |
} else if (arg == "-wt" || arg == "--word-thold") {
|
| 124 |
params.word_thold = std::stof(argv[++i]);
|
| 125 |
} else if (arg == "-v" || arg == "--verbose") {
|
|
@@ -176,13 +155,14 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
|
|
| 176 |
fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms);
|
| 177 |
fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n);
|
| 178 |
fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n");
|
|
|
|
| 179 |
fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold);
|
| 180 |
fprintf(stderr, " -v, --verbose verbose output\n");
|
| 181 |
fprintf(stderr, " --translate translate from source language to english\n");
|
| 182 |
fprintf(stderr, " -otxt, --output-txt output result in a text file\n");
|
| 183 |
fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n");
|
| 184 |
fprintf(stderr, " -osrt, --output-srt output result in a srt file\n");
|
| 185 |
-
fprintf(stderr, " -owts, --output-words output
|
| 186 |
fprintf(stderr, " -ps, --print_special print special tokens\n");
|
| 187 |
fprintf(stderr, " -pc, --print_colors print colors\n");
|
| 188 |
fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
|
|
@@ -192,65 +172,67 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
|
|
| 192 |
fprintf(stderr, "\n");
|
| 193 |
}
|
| 194 |
|
| 195 |
-
void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) {
|
| 196 |
const whisper_params & params = *(whisper_params *) user_data;
|
| 197 |
|
| 198 |
const int n_segments = whisper_full_n_segments(ctx);
|
| 199 |
|
| 200 |
-
// print the last
|
| 201 |
-
const int
|
| 202 |
-
if (
|
| 203 |
printf("\n");
|
| 204 |
}
|
| 205 |
|
| 206 |
-
|
| 207 |
-
if (params.
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
|
|
|
|
|
|
| 213 |
}
|
| 214 |
-
}
|
| 215 |
|
| 216 |
-
|
| 217 |
-
|
| 218 |
|
| 219 |
-
|
| 220 |
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
}
|
|
|
|
| 223 |
} else {
|
| 224 |
-
const
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
if (params.print_special_tokens == false) {
|
| 236 |
-
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
| 237 |
-
if (id >= whisper_token_eot(ctx)) {
|
| 238 |
-
continue;
|
| 239 |
}
|
| 240 |
-
}
|
| 241 |
|
| 242 |
-
|
| 243 |
-
|
| 244 |
|
| 245 |
-
|
| 246 |
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
| 251 |
-
|
| 252 |
|
| 253 |
-
|
|
|
|
| 254 |
}
|
| 255 |
}
|
| 256 |
}
|
|
@@ -320,297 +302,41 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
|
|
| 320 |
return true;
|
| 321 |
}
|
| 322 |
|
| 323 |
-
//
|
| 324 |
-
//
|
| 325 |
-
// TODO: extra pass to detect unused speech and assign to tokens
|
| 326 |
// TODO: font parameter adjustments
|
| 327 |
-
|
| 328 |
-
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, const std::vector<float> & pcmf32) {
|
| 329 |
-
std::vector<float> pcm_avg(pcmf32.size(), 0);
|
| 330 |
-
|
| 331 |
-
// average the fabs of the signal
|
| 332 |
-
{
|
| 333 |
-
const int hw = 32;
|
| 334 |
-
|
| 335 |
-
for (int i = 0; i < pcmf32.size(); i++) {
|
| 336 |
-
float sum = 0;
|
| 337 |
-
for (int j = -hw; j <= hw; j++) {
|
| 338 |
-
if (i + j >= 0 && i + j < pcmf32.size()) {
|
| 339 |
-
sum += fabs(pcmf32[i + j]);
|
| 340 |
-
}
|
| 341 |
-
}
|
| 342 |
-
pcm_avg[i] = sum/(2*hw + 1);
|
| 343 |
-
}
|
| 344 |
-
}
|
| 345 |
-
|
| 346 |
-
struct token_info {
|
| 347 |
-
int64_t t0 = -1;
|
| 348 |
-
int64_t t1 = -1;
|
| 349 |
-
|
| 350 |
-
int64_t tt0 = -1;
|
| 351 |
-
int64_t tt1 = -1;
|
| 352 |
-
|
| 353 |
-
whisper_token id;
|
| 354 |
-
whisper_token tid;
|
| 355 |
-
|
| 356 |
-
float p = 0.0f;
|
| 357 |
-
float pt = 0.0f;
|
| 358 |
-
float ptsum = 0.0f;
|
| 359 |
-
|
| 360 |
-
std::string text;
|
| 361 |
-
float vlen = 0.0f; // voice length of this token
|
| 362 |
-
};
|
| 363 |
-
|
| 364 |
-
int64_t t_beg = 0;
|
| 365 |
-
int64_t t_last = 0;
|
| 366 |
-
|
| 367 |
-
whisper_token tid_last = 0;
|
| 368 |
-
|
| 369 |
std::ofstream fout(fname);
|
| 370 |
|
| 371 |
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
| 372 |
|
|
|
|
|
|
|
|
|
|
| 373 |
fout << "!/bin/bash" << "\n";
|
| 374 |
fout << "\n";
|
| 375 |
|
| 376 |
-
fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" <<
|
| 377 |
-
|
| 378 |
-
bool is_first = true;
|
| 379 |
|
| 380 |
for (int i = 0; i < whisper_full_n_segments(ctx); i++) {
|
| 381 |
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
| 382 |
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
| 383 |
|
| 384 |
-
const char *text = whisper_full_get_segment_text(ctx, i);
|
| 385 |
-
|
| 386 |
-
const int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100));
|
| 387 |
-
const int s1 = std::min((int) pcmf32.size(), (int) (t1*WHISPER_SAMPLE_RATE/100));
|
| 388 |
-
|
| 389 |
const int n = whisper_full_n_tokens(ctx, i);
|
| 390 |
|
| 391 |
-
std::vector<
|
| 392 |
-
|
| 393 |
-
if (n <= 1) {
|
| 394 |
-
continue;
|
| 395 |
-
}
|
| 396 |
-
|
| 397 |
for (int j = 0; j < n; ++j) {
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
if (j == 0) {
|
| 401 |
-
if (token.id == whisper_token_beg(ctx)) {
|
| 402 |
-
tokens[j ].t0 = t0;
|
| 403 |
-
tokens[j ].t1 = t0;
|
| 404 |
-
tokens[j + 1].t0 = t0;
|
| 405 |
-
|
| 406 |
-
t_beg = t0;
|
| 407 |
-
t_last = t0;
|
| 408 |
-
tid_last = whisper_token_beg(ctx);
|
| 409 |
-
} else {
|
| 410 |
-
tokens[j ].t0 = t_last;
|
| 411 |
-
}
|
| 412 |
-
}
|
| 413 |
-
|
| 414 |
-
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
|
| 415 |
-
|
| 416 |
-
tokens[j].id = token.id;
|
| 417 |
-
tokens[j].tid = token.tid;
|
| 418 |
-
tokens[j].p = token.p;
|
| 419 |
-
tokens[j].pt = token.pt;
|
| 420 |
-
tokens[j].ptsum = token.ptsum;
|
| 421 |
-
|
| 422 |
-
tokens[j].text = whisper_token_to_str(ctx, token.id);
|
| 423 |
-
tokens[j].vlen = voice_length(tokens[j].text);
|
| 424 |
-
|
| 425 |
-
if (token.pt > params.word_thold && token.ptsum > 0.01 && token.tid > tid_last && tt <= t1) {
|
| 426 |
-
if (j > 0) {
|
| 427 |
-
tokens[j - 1].t1 = tt;
|
| 428 |
-
}
|
| 429 |
-
tokens[j].t0 = tt;
|
| 430 |
-
tid_last = token.tid;
|
| 431 |
-
}
|
| 432 |
}
|
| 433 |
|
| 434 |
-
|
| 435 |
-
tokens[n - 1].t0 = t1;
|
| 436 |
-
tokens[n - 1].t1 = t1;
|
| 437 |
-
|
| 438 |
-
t_last = t1;
|
| 439 |
-
|
| 440 |
-
// find intervals of tokens with unknown timestamps
|
| 441 |
-
// fill the timestamps by proportionally splitting the interval based on the token voice lengths
|
| 442 |
-
{
|
| 443 |
-
int p0 = 0;
|
| 444 |
-
int p1 = 0;
|
| 445 |
-
while (true) {
|
| 446 |
-
while (p1 < n && tokens[p1].t1 < 0) {
|
| 447 |
-
p1++;
|
| 448 |
-
}
|
| 449 |
-
|
| 450 |
-
if (p1 >= n) {
|
| 451 |
-
p1--;
|
| 452 |
-
}
|
| 453 |
-
|
| 454 |
-
if (p1 > p0) {
|
| 455 |
-
double psum = 0.0;
|
| 456 |
-
for (int j = p0; j <= p1; j++) {
|
| 457 |
-
psum += tokens[j].vlen;
|
| 458 |
-
}
|
| 459 |
-
|
| 460 |
-
//printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
|
| 461 |
-
|
| 462 |
-
const double dt = tokens[p1].t1 - tokens[p0].t0;
|
| 463 |
-
|
| 464 |
-
// split the time proportionally to the voice length
|
| 465 |
-
for (int j = p0 + 1; j <= p1; j++) {
|
| 466 |
-
const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
|
| 467 |
-
|
| 468 |
-
tokens[j - 1].t1 = ct;
|
| 469 |
-
tokens[j ].t0 = ct;
|
| 470 |
-
}
|
| 471 |
-
}
|
| 472 |
-
|
| 473 |
-
p1++;
|
| 474 |
-
p0 = p1;
|
| 475 |
-
if (p1 >= n) {
|
| 476 |
-
break;
|
| 477 |
-
}
|
| 478 |
-
}
|
| 479 |
-
}
|
| 480 |
-
|
| 481 |
-
// fix up (just in case)
|
| 482 |
-
for (int j = 0; j < n - 1; j++) {
|
| 483 |
-
if (tokens[j].t1 < 0) {
|
| 484 |
-
tokens[j + 1].t0 = tokens[j].t1;
|
| 485 |
-
}
|
| 486 |
-
|
| 487 |
-
if (j > 0) {
|
| 488 |
-
if (tokens[j - 1].t1 > tokens[j].t0) {
|
| 489 |
-
tokens[j].t0 = tokens[j - 1].t1;
|
| 490 |
-
tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
|
| 491 |
-
}
|
| 492 |
-
}
|
| 493 |
-
|
| 494 |
-
tokens[j].tt0 = tokens[j].t0;
|
| 495 |
-
tokens[j].tt1 = tokens[j].t1;
|
| 496 |
-
}
|
| 497 |
-
|
| 498 |
-
// VAD
|
| 499 |
-
// expand or contract tokens based on voice activity
|
| 500 |
-
{
|
| 501 |
-
const int hw = WHISPER_SAMPLE_RATE/8;
|
| 502 |
-
|
| 503 |
-
for (int j = 0; j < n; j++) {
|
| 504 |
-
if (tokens[j].id >= whisper_token_eot(ctx)) {
|
| 505 |
-
continue;
|
| 506 |
-
}
|
| 507 |
-
|
| 508 |
-
const int64_t t0 = tokens[j].t0;
|
| 509 |
-
const int64_t t1 = tokens[j].t1;
|
| 510 |
-
|
| 511 |
-
int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100));
|
| 512 |
-
int s1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100));
|
| 513 |
-
|
| 514 |
-
const int ss0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100) - hw);
|
| 515 |
-
const int ss1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100) + hw);
|
| 516 |
-
|
| 517 |
-
const int n = ss1 - ss0;
|
| 518 |
-
|
| 519 |
-
float sum = 0.0f;
|
| 520 |
-
|
| 521 |
-
for (int k = ss0; k < ss1; k++) {
|
| 522 |
-
sum += pcm_avg[k];
|
| 523 |
-
}
|
| 524 |
-
|
| 525 |
-
const float thold = 0.5*sum/n;
|
| 526 |
-
|
| 527 |
-
{
|
| 528 |
-
int k = s0;
|
| 529 |
-
if (pcm_avg[k] > thold && j > 0) {
|
| 530 |
-
while (k > 0 && pcm_avg[k] > thold) {
|
| 531 |
-
k--;
|
| 532 |
-
}
|
| 533 |
-
tokens[j].t0 = (int64_t) (100*k/WHISPER_SAMPLE_RATE);
|
| 534 |
-
if (tokens[j].t0 < tokens[j - 1].t1) {
|
| 535 |
-
tokens[j].t0 = tokens[j - 1].t1;
|
| 536 |
-
} else {
|
| 537 |
-
s0 = k;
|
| 538 |
-
}
|
| 539 |
-
} else {
|
| 540 |
-
while (pcm_avg[k] < thold && k < s1) {
|
| 541 |
-
k++;
|
| 542 |
-
}
|
| 543 |
-
s0 = k;
|
| 544 |
-
tokens[j].t0 = 100*k/WHISPER_SAMPLE_RATE;
|
| 545 |
-
}
|
| 546 |
-
}
|
| 547 |
-
|
| 548 |
-
{
|
| 549 |
-
int k = s1;
|
| 550 |
-
if (pcm_avg[k] > thold) {
|
| 551 |
-
while (k < (int) pcmf32.size() - 1 && pcm_avg[k] > thold) {
|
| 552 |
-
k++;
|
| 553 |
-
}
|
| 554 |
-
tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE;
|
| 555 |
-
if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) {
|
| 556 |
-
tokens[j].t1 = tokens[j + 1].t0;
|
| 557 |
-
} else {
|
| 558 |
-
s1 = k;
|
| 559 |
-
}
|
| 560 |
-
} else {
|
| 561 |
-
while (pcm_avg[k] < thold && k > s0) {
|
| 562 |
-
k--;
|
| 563 |
-
}
|
| 564 |
-
s1 = k;
|
| 565 |
-
tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE;
|
| 566 |
-
}
|
| 567 |
-
}
|
| 568 |
-
}
|
| 569 |
-
}
|
| 570 |
-
|
| 571 |
-
// fixed token expand (optional)
|
| 572 |
-
{
|
| 573 |
-
const int t_expand = 0;
|
| 574 |
-
|
| 575 |
-
for (int j = 0; j < n; j++) {
|
| 576 |
-
if (j > 0) {
|
| 577 |
-
tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
|
| 578 |
-
}
|
| 579 |
-
if (j < n - 1) {
|
| 580 |
-
tokens[j].t1 = tokens[j].t1 + t_expand;
|
| 581 |
-
}
|
| 582 |
-
}
|
| 583 |
-
}
|
| 584 |
-
|
| 585 |
-
// debug info
|
| 586 |
-
// TODO: toggle via parameter
|
| 587 |
-
for (int j = 0; j < n; ++j) {
|
| 588 |
-
const auto & token = tokens[j];
|
| 589 |
-
const auto tt = token.pt > params.word_thold && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
|
| 590 |
-
printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
|
| 591 |
-
tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, token.text.c_str());
|
| 592 |
-
|
| 593 |
-
if (tokens[j].id >= whisper_token_eot(ctx)) {
|
| 594 |
-
continue;
|
| 595 |
-
}
|
| 596 |
-
|
| 597 |
-
//printf("[%s --> %s] %s\n", to_timestamp(token.t0).c_str(), to_timestamp(token.t1).c_str(), whisper_token_to_str(ctx, token.id));
|
| 598 |
-
|
| 599 |
-
//fout << "# " << to_timestamp(token.t0) << " --> " << to_timestamp(token.t1) << " " << whisper_token_to_str(ctx, token.id) << "\n";
|
| 600 |
-
}
|
| 601 |
-
|
| 602 |
-
// TODO: become parameters
|
| 603 |
-
static const int line_wrap = 60;
|
| 604 |
-
static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
|
| 605 |
-
|
| 606 |
-
if (!is_first) {
|
| 607 |
fout << ",";
|
| 608 |
}
|
| 609 |
|
| 610 |
// background text
|
| 611 |
fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
|
| 612 |
|
| 613 |
-
is_first =
|
| 614 |
|
| 615 |
for (int j = 0; j < n; ++j) {
|
| 616 |
const auto & token = tokens[j];
|
|
@@ -654,17 +380,6 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
|
|
| 654 |
}
|
| 655 |
|
| 656 |
ncnt += txt.size();
|
| 657 |
-
|
| 658 |
-
if (ncnt > line_wrap) {
|
| 659 |
-
if (k < j) {
|
| 660 |
-
txt_bg = "> ";
|
| 661 |
-
txt_fg = "> ";
|
| 662 |
-
txt_ul = "\\ \\ ";
|
| 663 |
-
ncnt = 0;
|
| 664 |
-
} else {
|
| 665 |
-
break;
|
| 666 |
-
}
|
| 667 |
-
}
|
| 668 |
}
|
| 669 |
|
| 670 |
::replace_all(txt_bg, "'", "’");
|
|
@@ -673,8 +388,11 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
|
|
| 673 |
::replace_all(txt_fg, "\"", "\\\"");
|
| 674 |
}
|
| 675 |
|
| 676 |
-
|
| 677 |
-
|
|
|
|
|
|
|
|
|
|
| 678 |
|
| 679 |
// foreground text
|
| 680 |
fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
|
|
@@ -815,6 +533,10 @@ int main(int argc, char ** argv) {
|
|
| 815 |
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
|
| 816 |
wparams.offset_ms = params.offset_t_ms;
|
| 817 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 818 |
// this callback is called on each new segment
|
| 819 |
if (!wparams.print_realtime) {
|
| 820 |
wparams.new_segment_callback = whisper_print_segment_callback;
|
|
@@ -852,7 +574,7 @@ int main(int argc, char ** argv) {
|
|
| 852 |
// output to WTS file
|
| 853 |
if (params.output_wts) {
|
| 854 |
const auto fname_wts = fname_inp + ".wts";
|
| 855 |
-
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, pcmf32);
|
| 856 |
}
|
| 857 |
}
|
| 858 |
}
|
|
|
|
| 36 |
return std::string(buf);
|
| 37 |
}
|
| 38 |
|
| 39 |
+
// helper function to replace substrings
|
| 40 |
void replace_all(std::string & s, const std::string & search, const std::string & replace) {
|
| 41 |
for (size_t pos = 0; ; pos += replace.length()) {
|
| 42 |
pos = s.find(search, pos);
|
|
|
|
| 46 |
}
|
| 47 |
}
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
// command-line parameters
|
| 50 |
struct whisper_params {
|
| 51 |
int32_t seed = -1; // RNG seed, not used currently
|
|
|
|
| 54 |
int32_t offset_t_ms = 0;
|
| 55 |
int32_t offset_n = 0;
|
| 56 |
int32_t max_context = -1;
|
| 57 |
+
int32_t max_len = 0;
|
| 58 |
|
| 59 |
float word_thold = 0.01f;
|
| 60 |
|
|
|
|
| 97 |
params.offset_n = std::stoi(argv[++i]);
|
| 98 |
} else if (arg == "-mc" || arg == "--max-context") {
|
| 99 |
params.max_context = std::stoi(argv[++i]);
|
| 100 |
+
} else if (arg == "-ml" || arg == "--max-len") {
|
| 101 |
+
params.max_len = std::stoi(argv[++i]);
|
| 102 |
} else if (arg == "-wt" || arg == "--word-thold") {
|
| 103 |
params.word_thold = std::stof(argv[++i]);
|
| 104 |
} else if (arg == "-v" || arg == "--verbose") {
|
|
|
|
| 155 |
fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms);
|
| 156 |
fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n);
|
| 157 |
fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n");
|
| 158 |
+
fprintf(stderr, " -ml N, --max-len N maximum segment length in characters (default: %d)\n", params.max_len);
|
| 159 |
fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold);
|
| 160 |
fprintf(stderr, " -v, --verbose verbose output\n");
|
| 161 |
fprintf(stderr, " --translate translate from source language to english\n");
|
| 162 |
fprintf(stderr, " -otxt, --output-txt output result in a text file\n");
|
| 163 |
fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n");
|
| 164 |
fprintf(stderr, " -osrt, --output-srt output result in a srt file\n");
|
| 165 |
+
fprintf(stderr, " -owts, --output-words output script for generating karaoke video\n");
|
| 166 |
fprintf(stderr, " -ps, --print_special print special tokens\n");
|
| 167 |
fprintf(stderr, " -pc, --print_colors print colors\n");
|
| 168 |
fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
|
|
|
|
| 172 |
fprintf(stderr, "\n");
|
| 173 |
}
|
| 174 |
|
| 175 |
+
void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) {
|
| 176 |
const whisper_params & params = *(whisper_params *) user_data;
|
| 177 |
|
| 178 |
const int n_segments = whisper_full_n_segments(ctx);
|
| 179 |
|
| 180 |
+
// print the last n_new segments
|
| 181 |
+
const int s0 = n_segments - n_new;
|
| 182 |
+
if (s0 == 0) {
|
| 183 |
printf("\n");
|
| 184 |
}
|
| 185 |
|
| 186 |
+
for (int i = s0; i < n_segments; i++) {
|
| 187 |
+
if (params.no_timestamps) {
|
| 188 |
+
if (params.print_colors) {
|
| 189 |
+
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
| 190 |
+
if (params.print_special_tokens == false) {
|
| 191 |
+
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
| 192 |
+
if (id >= whisper_token_eot(ctx)) {
|
| 193 |
+
continue;
|
| 194 |
+
}
|
| 195 |
}
|
|
|
|
| 196 |
|
| 197 |
+
const char * text = whisper_full_get_token_text(ctx, i, j);
|
| 198 |
+
const float p = whisper_full_get_token_p (ctx, i, j);
|
| 199 |
|
| 200 |
+
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
| 201 |
|
| 202 |
+
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
|
| 203 |
+
}
|
| 204 |
+
} else {
|
| 205 |
+
const char * text = whisper_full_get_segment_text(ctx, i);
|
| 206 |
+
printf("%s", text);
|
| 207 |
}
|
| 208 |
+
fflush(stdout);
|
| 209 |
} else {
|
| 210 |
+
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
| 211 |
+
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
| 212 |
+
|
| 213 |
+
if (params.print_colors) {
|
| 214 |
+
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
| 215 |
+
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
| 216 |
+
if (params.print_special_tokens == false) {
|
| 217 |
+
const whisper_token id = whisper_full_get_token_id(ctx, i, j);
|
| 218 |
+
if (id >= whisper_token_eot(ctx)) {
|
| 219 |
+
continue;
|
| 220 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
}
|
|
|
|
| 222 |
|
| 223 |
+
const char * text = whisper_full_get_token_text(ctx, i, j);
|
| 224 |
+
const float p = whisper_full_get_token_p (ctx, i, j);
|
| 225 |
|
| 226 |
+
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
| 227 |
|
| 228 |
+
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
|
| 229 |
+
}
|
| 230 |
+
printf("\n");
|
| 231 |
+
} else {
|
| 232 |
+
const char * text = whisper_full_get_segment_text(ctx, i);
|
| 233 |
|
| 234 |
+
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
|
| 235 |
+
}
|
| 236 |
}
|
| 237 |
}
|
| 238 |
}
|
|
|
|
| 302 |
return true;
|
| 303 |
}
|
| 304 |
|
| 305 |
+
// karaoke video generation
|
| 306 |
+
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
|
|
|
|
| 307 |
// TODO: font parameter adjustments
|
| 308 |
+
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
std::ofstream fout(fname);
|
| 310 |
|
| 311 |
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
| 312 |
|
| 313 |
+
// TODO: become parameter
|
| 314 |
+
static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
|
| 315 |
+
|
| 316 |
fout << "!/bin/bash" << "\n";
|
| 317 |
fout << "\n";
|
| 318 |
|
| 319 |
+
fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << t_sec << ":rate=25:color=black -vf \"";
|
|
|
|
|
|
|
| 320 |
|
| 321 |
for (int i = 0; i < whisper_full_n_segments(ctx); i++) {
|
| 322 |
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
| 323 |
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
| 324 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
const int n = whisper_full_n_tokens(ctx, i);
|
| 326 |
|
| 327 |
+
std::vector<whisper_token_data> tokens(n);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
for (int j = 0; j < n; ++j) {
|
| 329 |
+
tokens[j] = whisper_full_get_token_data(ctx, i, j);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
}
|
| 331 |
|
| 332 |
+
if (i > 0) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 333 |
fout << ",";
|
| 334 |
}
|
| 335 |
|
| 336 |
// background text
|
| 337 |
fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
|
| 338 |
|
| 339 |
+
bool is_first = true;
|
| 340 |
|
| 341 |
for (int j = 0; j < n; ++j) {
|
| 342 |
const auto & token = tokens[j];
|
|
|
|
| 380 |
}
|
| 381 |
|
| 382 |
ncnt += txt.size();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
}
|
| 384 |
|
| 385 |
::replace_all(txt_bg, "'", "’");
|
|
|
|
| 388 |
::replace_all(txt_fg, "\"", "\\\"");
|
| 389 |
}
|
| 390 |
|
| 391 |
+
if (is_first) {
|
| 392 |
+
// background text
|
| 393 |
+
fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << t0/100.0 << "," << t1/100.0 << ")'";
|
| 394 |
+
is_first = false;
|
| 395 |
+
}
|
| 396 |
|
| 397 |
// foreground text
|
| 398 |
fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'";
|
|
|
|
| 533 |
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
|
| 534 |
wparams.offset_ms = params.offset_t_ms;
|
| 535 |
|
| 536 |
+
wparams.token_timestamps = params.output_wts || params.max_len > 0;
|
| 537 |
+
wparams.thold_pt = params.word_thold;
|
| 538 |
+
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
|
| 539 |
+
|
| 540 |
// this callback is called on each new segment
|
| 541 |
if (!wparams.print_realtime) {
|
| 542 |
wparams.new_segment_callback = whisper_print_segment_callback;
|
|
|
|
| 574 |
// output to WTS file
|
| 575 |
if (params.output_wts) {
|
| 576 |
const auto fname_wts = fname_inp + ".wts";
|
| 577 |
+
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
|
| 578 |
}
|
| 579 |
}
|
| 580 |
}
|
whisper.cpp
CHANGED
|
@@ -418,6 +418,12 @@ struct whisper_context {
|
|
| 418 |
std::vector<whisper_segment> result_all;
|
| 419 |
|
| 420 |
std::vector<whisper_token> prompt_past;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 421 |
};
|
| 422 |
|
| 423 |
// load the model from a ggml file
|
|
@@ -431,7 +437,7 @@ struct whisper_context {
|
|
| 431 |
//
|
| 432 |
// see the convert-pt-to-ggml.py script for details
|
| 433 |
//
|
| 434 |
-
bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
|
| 435 |
fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
|
| 436 |
|
| 437 |
auto & model = wctx.model;
|
|
@@ -1062,7 +1068,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
|
|
| 1062 |
// - n_threads: number of threads to use
|
| 1063 |
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
|
| 1064 |
//
|
| 1065 |
-
bool whisper_encode(
|
| 1066 |
whisper_context & wctx,
|
| 1067 |
const int n_threads,
|
| 1068 |
const int mel_offset) {
|
|
@@ -1448,7 +1454,7 @@ bool whisper_encode(
|
|
| 1448 |
// - n_tokens: number of tokens in the prompt
|
| 1449 |
// - n_past: number of past tokens to prefix the prompt with
|
| 1450 |
//
|
| 1451 |
-
bool whisper_decode(
|
| 1452 |
whisper_context & wctx,
|
| 1453 |
const int n_threads,
|
| 1454 |
const whisper_token * tokens,
|
|
@@ -1811,10 +1817,12 @@ bool whisper_decode(
|
|
| 1811 |
}
|
| 1812 |
|
| 1813 |
// the most basic sampling scheme - select the top token
|
| 1814 |
-
whisper_token_data whisper_sample_best(
|
| 1815 |
const whisper_vocab & vocab,
|
| 1816 |
const float * probs) {
|
| 1817 |
-
whisper_token_data result
|
|
|
|
|
|
|
| 1818 |
|
| 1819 |
int n_logits = vocab.id_to_token.size();
|
| 1820 |
|
|
@@ -1887,7 +1895,7 @@ whisper_token_data whisper_sample_best(
|
|
| 1887 |
}
|
| 1888 |
|
| 1889 |
// samples only from the timestamps tokens
|
| 1890 |
-
whisper_vocab::id whisper_sample_timestamp(
|
| 1891 |
const whisper_vocab & vocab,
|
| 1892 |
const float * probs) {
|
| 1893 |
int n_logits = vocab.id_to_token.size();
|
|
@@ -1939,7 +1947,7 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
|
|
| 1939 |
// naive Discrete Fourier Transform
|
| 1940 |
// input is real-valued
|
| 1941 |
// output is complex-valued
|
| 1942 |
-
void dft(const std::vector<float> & in, std::vector<float> & out) {
|
| 1943 |
int N = in.size();
|
| 1944 |
|
| 1945 |
out.resize(N*2);
|
|
@@ -1963,7 +1971,7 @@ void dft(const std::vector<float> & in, std::vector<float> & out) {
|
|
| 1963 |
// poor man's implementation - use something better
|
| 1964 |
// input is real-valued
|
| 1965 |
// output is complex-valued
|
| 1966 |
-
void fft(const std::vector<float> & in, std::vector<float> & out) {
|
| 1967 |
out.resize(in.size()*2);
|
| 1968 |
|
| 1969 |
int N = in.size();
|
|
@@ -2014,7 +2022,7 @@ void fft(const std::vector<float> & in, std::vector<float> & out) {
|
|
| 2014 |
}
|
| 2015 |
|
| 2016 |
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
|
| 2017 |
-
bool log_mel_spectrogram(
|
| 2018 |
const float * samples,
|
| 2019 |
const int n_samples,
|
| 2020 |
const int sample_rate,
|
|
@@ -2339,6 +2347,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
| 2339 |
/*.print_realtime =*/ false,
|
| 2340 |
/*.print_timestamps =*/ true,
|
| 2341 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2342 |
/*.language =*/ "en",
|
| 2343 |
|
| 2344 |
/*.greedy =*/ {
|
|
@@ -2371,6 +2384,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
| 2371 |
/*.print_realtime =*/ false,
|
| 2372 |
/*.print_timestamps =*/ true,
|
| 2373 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2374 |
/*.language =*/ "en",
|
| 2375 |
|
| 2376 |
/*.greedy =*/ {
|
|
@@ -2392,6 +2410,68 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|
| 2392 |
return result;
|
| 2393 |
}
|
| 2394 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2395 |
int whisper_full(
|
| 2396 |
struct whisper_context * ctx,
|
| 2397 |
struct whisper_full_params params,
|
|
@@ -2408,6 +2488,13 @@ int whisper_full(
|
|
| 2408 |
return -1;
|
| 2409 |
}
|
| 2410 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2411 |
const int seek_start = params.offset_ms/10;
|
| 2412 |
|
| 2413 |
// if length of spectrogram is less than 1s (100 samples), then return
|
|
@@ -2557,6 +2644,7 @@ int whisper_full(
|
|
| 2557 |
}
|
| 2558 |
}
|
| 2559 |
|
|
|
|
| 2560 |
tokens_cur.resize(result_len);
|
| 2561 |
|
| 2562 |
for (const auto & r : tokens_cur) {
|
|
@@ -2595,8 +2683,19 @@ int whisper_full(
|
|
| 2595 |
for (int j = i0; j <= i; j++) {
|
| 2596 |
result_all.back().tokens.push_back(tokens_cur[j]);
|
| 2597 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2598 |
if (params.new_segment_callback) {
|
| 2599 |
-
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
|
| 2600 |
}
|
| 2601 |
}
|
| 2602 |
text = "";
|
|
@@ -2625,8 +2724,19 @@ int whisper_full(
|
|
| 2625 |
for (int j = i0; j < (int) tokens_cur.size(); j++) {
|
| 2626 |
result_all.back().tokens.push_back(tokens_cur[j]);
|
| 2627 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2628 |
if (params.new_segment_callback) {
|
| 2629 |
-
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
|
| 2630 |
}
|
| 2631 |
}
|
| 2632 |
}
|
|
@@ -2760,7 +2870,7 @@ int whisper_full_parallel(
|
|
| 2760 |
|
| 2761 |
// call the new_segment_callback for each segment
|
| 2762 |
if (params.new_segment_callback) {
|
| 2763 |
-
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
|
| 2764 |
}
|
| 2765 |
}
|
| 2766 |
|
|
@@ -2836,3 +2946,304 @@ const char * whisper_print_system_info() {
|
|
| 2836 |
|
| 2837 |
return s.c_str();
|
| 2838 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
std::vector<whisper_segment> result_all;
|
| 419 |
|
| 420 |
std::vector<whisper_token> prompt_past;
|
| 421 |
+
|
| 422 |
+
// [EXPERIMENTAL] token-level timestamps data
|
| 423 |
+
int64_t t_beg;
|
| 424 |
+
int64_t t_last;
|
| 425 |
+
whisper_token tid_last;
|
| 426 |
+
std::vector<float> energy; // PCM signal energy
|
| 427 |
};
|
| 428 |
|
| 429 |
// load the model from a ggml file
|
|
|
|
| 437 |
//
|
| 438 |
// see the convert-pt-to-ggml.py script for details
|
| 439 |
//
|
| 440 |
+
static bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
|
| 441 |
fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
|
| 442 |
|
| 443 |
auto & model = wctx.model;
|
|
|
|
| 1068 |
// - n_threads: number of threads to use
|
| 1069 |
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
|
| 1070 |
//
|
| 1071 |
+
static bool whisper_encode(
|
| 1072 |
whisper_context & wctx,
|
| 1073 |
const int n_threads,
|
| 1074 |
const int mel_offset) {
|
|
|
|
| 1454 |
// - n_tokens: number of tokens in the prompt
|
| 1455 |
// - n_past: number of past tokens to prefix the prompt with
|
| 1456 |
//
|
| 1457 |
+
static bool whisper_decode(
|
| 1458 |
whisper_context & wctx,
|
| 1459 |
const int n_threads,
|
| 1460 |
const whisper_token * tokens,
|
|
|
|
| 1817 |
}
|
| 1818 |
|
| 1819 |
// the most basic sampling scheme - select the top token
|
| 1820 |
+
static whisper_token_data whisper_sample_best(
|
| 1821 |
const whisper_vocab & vocab,
|
| 1822 |
const float * probs) {
|
| 1823 |
+
whisper_token_data result = {
|
| 1824 |
+
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
|
| 1825 |
+
};
|
| 1826 |
|
| 1827 |
int n_logits = vocab.id_to_token.size();
|
| 1828 |
|
|
|
|
| 1895 |
}
|
| 1896 |
|
| 1897 |
// samples only from the timestamps tokens
|
| 1898 |
+
static whisper_vocab::id whisper_sample_timestamp(
|
| 1899 |
const whisper_vocab & vocab,
|
| 1900 |
const float * probs) {
|
| 1901 |
int n_logits = vocab.id_to_token.size();
|
|
|
|
| 1947 |
// naive Discrete Fourier Transform
|
| 1948 |
// input is real-valued
|
| 1949 |
// output is complex-valued
|
| 1950 |
+
static void dft(const std::vector<float> & in, std::vector<float> & out) {
|
| 1951 |
int N = in.size();
|
| 1952 |
|
| 1953 |
out.resize(N*2);
|
|
|
|
| 1971 |
// poor man's implementation - use something better
|
| 1972 |
// input is real-valued
|
| 1973 |
// output is complex-valued
|
| 1974 |
+
static void fft(const std::vector<float> & in, std::vector<float> & out) {
|
| 1975 |
out.resize(in.size()*2);
|
| 1976 |
|
| 1977 |
int N = in.size();
|
|
|
|
| 2022 |
}
|
| 2023 |
|
| 2024 |
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
|
| 2025 |
+
static bool log_mel_spectrogram(
|
| 2026 |
const float * samples,
|
| 2027 |
const int n_samples,
|
| 2028 |
const int sample_rate,
|
|
|
|
| 2347 |
/*.print_realtime =*/ false,
|
| 2348 |
/*.print_timestamps =*/ true,
|
| 2349 |
|
| 2350 |
+
/*.token_timestamps =*/ false,
|
| 2351 |
+
/*.thold_pt =*/ 0.01f,
|
| 2352 |
+
/*.thold_ptsum =*/ 0.01f,
|
| 2353 |
+
/*.max_len =*/ 0,
|
| 2354 |
+
|
| 2355 |
/*.language =*/ "en",
|
| 2356 |
|
| 2357 |
/*.greedy =*/ {
|
|
|
|
| 2384 |
/*.print_realtime =*/ false,
|
| 2385 |
/*.print_timestamps =*/ true,
|
| 2386 |
|
| 2387 |
+
/*.token_timestamps =*/ false,
|
| 2388 |
+
/*.thold_pt =*/ 0.01f,
|
| 2389 |
+
/*.thold_ptsum =*/ 0.01f,
|
| 2390 |
+
/*.max_len =*/ 0,
|
| 2391 |
+
|
| 2392 |
/*.language =*/ "en",
|
| 2393 |
|
| 2394 |
/*.greedy =*/ {
|
|
|
|
| 2410 |
return result;
|
| 2411 |
}
|
| 2412 |
|
| 2413 |
+
// forward declarations
|
| 2414 |
+
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
|
| 2415 |
+
static void whisper_exp_compute_token_level_timestamps(
|
| 2416 |
+
struct whisper_context * ctx,
|
| 2417 |
+
int i_segment,
|
| 2418 |
+
float thold_pt,
|
| 2419 |
+
float thold_ptsum);
|
| 2420 |
+
|
| 2421 |
+
// wrap the last segment to max_len characters
|
| 2422 |
+
// returns the number of new segments
|
| 2423 |
+
static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
|
| 2424 |
+
auto segment = ctx->result_all.back();
|
| 2425 |
+
|
| 2426 |
+
int res = 1;
|
| 2427 |
+
int acc = 0;
|
| 2428 |
+
|
| 2429 |
+
std::string text;
|
| 2430 |
+
|
| 2431 |
+
for (int i = 0; i < (int) segment.tokens.size(); i++) {
|
| 2432 |
+
const auto & token = segment.tokens[i];
|
| 2433 |
+
if (token.id >= whisper_token_eot(ctx)) {
|
| 2434 |
+
continue;
|
| 2435 |
+
}
|
| 2436 |
+
|
| 2437 |
+
const auto txt = whisper_token_to_str(ctx, token.id);
|
| 2438 |
+
|
| 2439 |
+
const int cur = strlen(txt);
|
| 2440 |
+
|
| 2441 |
+
if (acc + cur > max_len && i > 0) {
|
| 2442 |
+
// split here
|
| 2443 |
+
ctx->result_all.back().text = std::move(text);
|
| 2444 |
+
ctx->result_all.back().t1 = token.t0;
|
| 2445 |
+
ctx->result_all.back().tokens.resize(i);
|
| 2446 |
+
|
| 2447 |
+
ctx->result_all.push_back({});
|
| 2448 |
+
ctx->result_all.back().t0 = token.t0;
|
| 2449 |
+
ctx->result_all.back().t1 = segment.t1;
|
| 2450 |
+
|
| 2451 |
+
// add tokens [i, end] to the new segment
|
| 2452 |
+
ctx->result_all.back().tokens.insert(
|
| 2453 |
+
ctx->result_all.back().tokens.end(),
|
| 2454 |
+
segment.tokens.begin() + i,
|
| 2455 |
+
segment.tokens.end());
|
| 2456 |
+
|
| 2457 |
+
acc = 0;
|
| 2458 |
+
text = "";
|
| 2459 |
+
|
| 2460 |
+
segment = ctx->result_all.back();
|
| 2461 |
+
i = -1;
|
| 2462 |
+
|
| 2463 |
+
res++;
|
| 2464 |
+
} else {
|
| 2465 |
+
acc += cur;
|
| 2466 |
+
text += txt;
|
| 2467 |
+
}
|
| 2468 |
+
}
|
| 2469 |
+
|
| 2470 |
+
ctx->result_all.back().text = std::move(text);
|
| 2471 |
+
|
| 2472 |
+
return res;
|
| 2473 |
+
}
|
| 2474 |
+
|
| 2475 |
int whisper_full(
|
| 2476 |
struct whisper_context * ctx,
|
| 2477 |
struct whisper_full_params params,
|
|
|
|
| 2488 |
return -1;
|
| 2489 |
}
|
| 2490 |
|
| 2491 |
+
if (params.token_timestamps) {
|
| 2492 |
+
ctx->t_beg = 0;
|
| 2493 |
+
ctx->t_last = 0;
|
| 2494 |
+
ctx->tid_last = 0;
|
| 2495 |
+
ctx->energy = get_signal_energy(samples, n_samples, 32);
|
| 2496 |
+
}
|
| 2497 |
+
|
| 2498 |
const int seek_start = params.offset_ms/10;
|
| 2499 |
|
| 2500 |
// if length of spectrogram is less than 1s (100 samples), then return
|
|
|
|
| 2644 |
}
|
| 2645 |
}
|
| 2646 |
|
| 2647 |
+
// shrink down to result_len
|
| 2648 |
tokens_cur.resize(result_len);
|
| 2649 |
|
| 2650 |
for (const auto & r : tokens_cur) {
|
|
|
|
| 2683 |
for (int j = i0; j <= i; j++) {
|
| 2684 |
result_all.back().tokens.push_back(tokens_cur[j]);
|
| 2685 |
}
|
| 2686 |
+
|
| 2687 |
+
int n_new = 1;
|
| 2688 |
+
|
| 2689 |
+
if (params.token_timestamps) {
|
| 2690 |
+
whisper_exp_compute_token_level_timestamps(
|
| 2691 |
+
ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
| 2692 |
+
|
| 2693 |
+
if (params.max_len > 0) {
|
| 2694 |
+
n_new = whisper_wrap_segment(ctx, params.max_len);
|
| 2695 |
+
}
|
| 2696 |
+
}
|
| 2697 |
if (params.new_segment_callback) {
|
| 2698 |
+
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
|
| 2699 |
}
|
| 2700 |
}
|
| 2701 |
text = "";
|
|
|
|
| 2724 |
for (int j = i0; j < (int) tokens_cur.size(); j++) {
|
| 2725 |
result_all.back().tokens.push_back(tokens_cur[j]);
|
| 2726 |
}
|
| 2727 |
+
|
| 2728 |
+
int n_new = 1;
|
| 2729 |
+
|
| 2730 |
+
if (params.token_timestamps) {
|
| 2731 |
+
whisper_exp_compute_token_level_timestamps(
|
| 2732 |
+
ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
|
| 2733 |
+
|
| 2734 |
+
if (params.max_len > 0) {
|
| 2735 |
+
n_new = whisper_wrap_segment(ctx, params.max_len);
|
| 2736 |
+
}
|
| 2737 |
+
}
|
| 2738 |
if (params.new_segment_callback) {
|
| 2739 |
+
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
|
| 2740 |
}
|
| 2741 |
}
|
| 2742 |
}
|
|
|
|
| 2870 |
|
| 2871 |
// call the new_segment_callback for each segment
|
| 2872 |
if (params.new_segment_callback) {
|
| 2873 |
+
params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
|
| 2874 |
}
|
| 2875 |
}
|
| 2876 |
|
|
|
|
| 2946 |
|
| 2947 |
return s.c_str();
|
| 2948 |
}
|
| 2949 |
+
|
| 2950 |
+
// =================================================================================================
|
| 2951 |
+
|
| 2952 |
+
//
|
| 2953 |
+
// Experimental stuff below
|
| 2954 |
+
//
|
| 2955 |
+
// Not sure if these should be part of the library at all, because the quality of the results is not
|
| 2956 |
+
// guaranteed. Might get removed at some point unless a robust algorithm implementation is found
|
| 2957 |
+
//
|
| 2958 |
+
|
| 2959 |
+
// =================================================================================================
|
| 2960 |
+
|
| 2961 |
+
//
|
| 2962 |
+
// token-level timestamps
|
| 2963 |
+
//
|
| 2964 |
+
|
| 2965 |
+
static int timestamp_to_sample(int64_t t, int n_samples) {
|
| 2966 |
+
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
|
| 2967 |
+
}
|
| 2968 |
+
|
| 2969 |
+
static int64_t sample_to_timestamp(int i_sample) {
|
| 2970 |
+
return (100*i_sample)/WHISPER_SAMPLE_RATE;
|
| 2971 |
+
}
|
| 2972 |
+
|
| 2973 |
+
// a cost-function / heuristic that is high for text that takes longer to pronounce
|
| 2974 |
+
// obviously, can be improved
|
| 2975 |
+
static float voice_length(const std::string & text) {
|
| 2976 |
+
float res = 0.0f;
|
| 2977 |
+
|
| 2978 |
+
for (size_t i = 0; i < text.size(); ++i) {
|
| 2979 |
+
if (text[i] == ' ') {
|
| 2980 |
+
res += 0.01f;
|
| 2981 |
+
} else if (text[i] == ',') {
|
| 2982 |
+
res += 2.00f;
|
| 2983 |
+
} else if (text[i] == '.') {
|
| 2984 |
+
res += 3.00f;
|
| 2985 |
+
} else if (text[i] == '!') {
|
| 2986 |
+
res += 3.00f;
|
| 2987 |
+
} else if (text[i] == '?') {
|
| 2988 |
+
res += 3.00f;
|
| 2989 |
+
} else if (text[i] >= '0' && text[i] <= '9') {
|
| 2990 |
+
res += 3.00f;
|
| 2991 |
+
} else {
|
| 2992 |
+
res += 1.00f;
|
| 2993 |
+
}
|
| 2994 |
+
}
|
| 2995 |
+
|
| 2996 |
+
return res;
|
| 2997 |
+
}
|
| 2998 |
+
|
| 2999 |
+
// average the fabs of the signal
|
| 3000 |
+
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) {
|
| 3001 |
+
const int hw = n_samples_per_half_window;
|
| 3002 |
+
|
| 3003 |
+
std::vector<float> result(n_samples);
|
| 3004 |
+
|
| 3005 |
+
for (int i = 0; i < n_samples; i++) {
|
| 3006 |
+
float sum = 0;
|
| 3007 |
+
for (int j = -hw; j <= hw; j++) {
|
| 3008 |
+
if (i + j >= 0 && i + j < n_samples) {
|
| 3009 |
+
sum += fabs(signal[i + j]);
|
| 3010 |
+
}
|
| 3011 |
+
}
|
| 3012 |
+
result[i] = sum/(2*hw + 1);
|
| 3013 |
+
}
|
| 3014 |
+
|
| 3015 |
+
return result;
|
| 3016 |
+
}
|
| 3017 |
+
|
| 3018 |
+
static void whisper_exp_compute_token_level_timestamps(
|
| 3019 |
+
struct whisper_context * ctx,
|
| 3020 |
+
int i_segment,
|
| 3021 |
+
float thold_pt,
|
| 3022 |
+
float thold_ptsum) {
|
| 3023 |
+
auto & segment = ctx->result_all[i_segment];
|
| 3024 |
+
auto & tokens = segment.tokens;
|
| 3025 |
+
|
| 3026 |
+
const int n_samples = ctx->energy.size();
|
| 3027 |
+
|
| 3028 |
+
if (n_samples == 0) {
|
| 3029 |
+
fprintf(stderr, "%s: no signal data available\n", __func__);
|
| 3030 |
+
return;
|
| 3031 |
+
}
|
| 3032 |
+
|
| 3033 |
+
const int64_t t0 = segment.t0;
|
| 3034 |
+
const int64_t t1 = segment.t1;
|
| 3035 |
+
|
| 3036 |
+
const int s0 = timestamp_to_sample(t0, n_samples);
|
| 3037 |
+
const int s1 = timestamp_to_sample(t1, n_samples);
|
| 3038 |
+
|
| 3039 |
+
const int n = tokens.size();
|
| 3040 |
+
|
| 3041 |
+
if (n == 0) {
|
| 3042 |
+
return;
|
| 3043 |
+
}
|
| 3044 |
+
|
| 3045 |
+
if (n == 1) {
|
| 3046 |
+
tokens[0].t0 = t0;
|
| 3047 |
+
tokens[0].t1 = t1;
|
| 3048 |
+
|
| 3049 |
+
return;
|
| 3050 |
+
}
|
| 3051 |
+
|
| 3052 |
+
auto & t_beg = ctx->t_beg;
|
| 3053 |
+
auto & t_last = ctx->t_last;
|
| 3054 |
+
auto & tid_last = ctx->tid_last;
|
| 3055 |
+
|
| 3056 |
+
for (int j = 0; j < n; ++j) {
|
| 3057 |
+
auto & token = tokens[j];
|
| 3058 |
+
|
| 3059 |
+
if (j == 0) {
|
| 3060 |
+
if (token.id == whisper_token_beg(ctx)) {
|
| 3061 |
+
tokens[j ].t0 = t0;
|
| 3062 |
+
tokens[j ].t1 = t0;
|
| 3063 |
+
tokens[j + 1].t0 = t0;
|
| 3064 |
+
|
| 3065 |
+
t_beg = t0;
|
| 3066 |
+
t_last = t0;
|
| 3067 |
+
tid_last = whisper_token_beg(ctx);
|
| 3068 |
+
} else {
|
| 3069 |
+
tokens[j ].t0 = t_last;
|
| 3070 |
+
}
|
| 3071 |
+
}
|
| 3072 |
+
|
| 3073 |
+
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
|
| 3074 |
+
|
| 3075 |
+
tokens[j].id = token.id;
|
| 3076 |
+
tokens[j].tid = token.tid;
|
| 3077 |
+
tokens[j].p = token.p;
|
| 3078 |
+
tokens[j].pt = token.pt;
|
| 3079 |
+
tokens[j].ptsum = token.ptsum;
|
| 3080 |
+
|
| 3081 |
+
tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id));
|
| 3082 |
+
|
| 3083 |
+
if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
|
| 3084 |
+
if (j > 0) {
|
| 3085 |
+
tokens[j - 1].t1 = tt;
|
| 3086 |
+
}
|
| 3087 |
+
tokens[j].t0 = tt;
|
| 3088 |
+
tid_last = token.tid;
|
| 3089 |
+
}
|
| 3090 |
+
}
|
| 3091 |
+
|
| 3092 |
+
tokens[n - 2].t1 = t1;
|
| 3093 |
+
tokens[n - 1].t0 = t1;
|
| 3094 |
+
tokens[n - 1].t1 = t1;
|
| 3095 |
+
|
| 3096 |
+
t_last = t1;
|
| 3097 |
+
|
| 3098 |
+
// find intervals of tokens with unknown timestamps
|
| 3099 |
+
// fill the timestamps by proportionally splitting the interval based on the token voice lengths
|
| 3100 |
+
{
|
| 3101 |
+
int p0 = 0;
|
| 3102 |
+
int p1 = 0;
|
| 3103 |
+
|
| 3104 |
+
while (true) {
|
| 3105 |
+
while (p1 < n && tokens[p1].t1 < 0) {
|
| 3106 |
+
p1++;
|
| 3107 |
+
}
|
| 3108 |
+
|
| 3109 |
+
if (p1 >= n) {
|
| 3110 |
+
p1--;
|
| 3111 |
+
}
|
| 3112 |
+
|
| 3113 |
+
if (p1 > p0) {
|
| 3114 |
+
double psum = 0.0;
|
| 3115 |
+
for (int j = p0; j <= p1; j++) {
|
| 3116 |
+
psum += tokens[j].vlen;
|
| 3117 |
+
}
|
| 3118 |
+
|
| 3119 |
+
//printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
|
| 3120 |
+
|
| 3121 |
+
const double dt = tokens[p1].t1 - tokens[p0].t0;
|
| 3122 |
+
|
| 3123 |
+
// split the time proportionally to the voice length
|
| 3124 |
+
for (int j = p0 + 1; j <= p1; j++) {
|
| 3125 |
+
const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
|
| 3126 |
+
|
| 3127 |
+
tokens[j - 1].t1 = ct;
|
| 3128 |
+
tokens[j ].t0 = ct;
|
| 3129 |
+
}
|
| 3130 |
+
}
|
| 3131 |
+
|
| 3132 |
+
p1++;
|
| 3133 |
+
p0 = p1;
|
| 3134 |
+
if (p1 >= n) {
|
| 3135 |
+
break;
|
| 3136 |
+
}
|
| 3137 |
+
}
|
| 3138 |
+
}
|
| 3139 |
+
|
| 3140 |
+
// fix up (just in case)
|
| 3141 |
+
for (int j = 0; j < n - 1; j++) {
|
| 3142 |
+
if (tokens[j].t1 < 0) {
|
| 3143 |
+
tokens[j + 1].t0 = tokens[j].t1;
|
| 3144 |
+
}
|
| 3145 |
+
|
| 3146 |
+
if (j > 0) {
|
| 3147 |
+
if (tokens[j - 1].t1 > tokens[j].t0) {
|
| 3148 |
+
tokens[j].t0 = tokens[j - 1].t1;
|
| 3149 |
+
tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
|
| 3150 |
+
}
|
| 3151 |
+
}
|
| 3152 |
+
}
|
| 3153 |
+
|
| 3154 |
+
// VAD
|
| 3155 |
+
// expand or contract tokens based on voice activity
|
| 3156 |
+
{
|
| 3157 |
+
const int hw = WHISPER_SAMPLE_RATE/8;
|
| 3158 |
+
|
| 3159 |
+
for (int j = 0; j < n; j++) {
|
| 3160 |
+
if (tokens[j].id >= whisper_token_eot(ctx)) {
|
| 3161 |
+
continue;
|
| 3162 |
+
}
|
| 3163 |
+
|
| 3164 |
+
int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
|
| 3165 |
+
int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
|
| 3166 |
+
|
| 3167 |
+
const int ss0 = std::max(s0 - hw, 0);
|
| 3168 |
+
const int ss1 = std::min(s1 + hw, n_samples);
|
| 3169 |
+
|
| 3170 |
+
const int ns = ss1 - ss0;
|
| 3171 |
+
|
| 3172 |
+
float sum = 0.0f;
|
| 3173 |
+
|
| 3174 |
+
for (int k = ss0; k < ss1; k++) {
|
| 3175 |
+
sum += ctx->energy[k];
|
| 3176 |
+
}
|
| 3177 |
+
|
| 3178 |
+
const float thold = 0.5*sum/ns;
|
| 3179 |
+
|
| 3180 |
+
{
|
| 3181 |
+
int k = s0;
|
| 3182 |
+
if (ctx->energy[k] > thold && j > 0) {
|
| 3183 |
+
while (k > 0 && ctx->energy[k] > thold) {
|
| 3184 |
+
k--;
|
| 3185 |
+
}
|
| 3186 |
+
tokens[j].t0 = sample_to_timestamp(k);
|
| 3187 |
+
if (tokens[j].t0 < tokens[j - 1].t1) {
|
| 3188 |
+
tokens[j].t0 = tokens[j - 1].t1;
|
| 3189 |
+
} else {
|
| 3190 |
+
s0 = k;
|
| 3191 |
+
}
|
| 3192 |
+
} else {
|
| 3193 |
+
while (ctx->energy[k] < thold && k < s1) {
|
| 3194 |
+
k++;
|
| 3195 |
+
}
|
| 3196 |
+
s0 = k;
|
| 3197 |
+
tokens[j].t0 = sample_to_timestamp(k);
|
| 3198 |
+
}
|
| 3199 |
+
}
|
| 3200 |
+
|
| 3201 |
+
{
|
| 3202 |
+
int k = s1;
|
| 3203 |
+
if (ctx->energy[k] > thold) {
|
| 3204 |
+
while (k < n_samples - 1 && ctx->energy[k] > thold) {
|
| 3205 |
+
k++;
|
| 3206 |
+
}
|
| 3207 |
+
tokens[j].t1 = sample_to_timestamp(k);
|
| 3208 |
+
if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
|
| 3209 |
+
tokens[j].t1 = tokens[j + 1].t0;
|
| 3210 |
+
} else {
|
| 3211 |
+
s1 = k;
|
| 3212 |
+
}
|
| 3213 |
+
} else {
|
| 3214 |
+
while (ctx->energy[k] < thold && k > s0) {
|
| 3215 |
+
k--;
|
| 3216 |
+
}
|
| 3217 |
+
s1 = k;
|
| 3218 |
+
tokens[j].t1 = sample_to_timestamp(k);
|
| 3219 |
+
}
|
| 3220 |
+
}
|
| 3221 |
+
}
|
| 3222 |
+
}
|
| 3223 |
+
|
| 3224 |
+
// fixed token expand (optional)
|
| 3225 |
+
//{
|
| 3226 |
+
// const int t_expand = 0;
|
| 3227 |
+
|
| 3228 |
+
// for (int j = 0; j < n; j++) {
|
| 3229 |
+
// if (j > 0) {
|
| 3230 |
+
// tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
|
| 3231 |
+
// }
|
| 3232 |
+
// if (j < n - 1) {
|
| 3233 |
+
// tokens[j].t1 = tokens[j].t1 + t_expand;
|
| 3234 |
+
// }
|
| 3235 |
+
// }
|
| 3236 |
+
//}
|
| 3237 |
+
|
| 3238 |
+
// debug info
|
| 3239 |
+
//for (int j = 0; j < n; ++j) {
|
| 3240 |
+
// const auto & token = tokens[j];
|
| 3241 |
+
// const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
|
| 3242 |
+
// printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
|
| 3243 |
+
// tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id));
|
| 3244 |
+
|
| 3245 |
+
// if (tokens[j].id >= whisper_token_eot(ctx)) {
|
| 3246 |
+
// continue;
|
| 3247 |
+
// }
|
| 3248 |
+
//}
|
| 3249 |
+
}
|
whisper.h
CHANGED
|
@@ -68,14 +68,21 @@ extern "C" {
|
|
| 68 |
|
| 69 |
typedef int whisper_token;
|
| 70 |
|
| 71 |
-
struct whisper_token_data {
|
| 72 |
whisper_token id; // token id
|
| 73 |
whisper_token tid; // forced timestamp token id
|
| 74 |
|
| 75 |
float p; // probability of the token
|
| 76 |
float pt; // probability of the timestamp token
|
| 77 |
float ptsum; // sum of probabilities of all timestamp tokens
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
// Allocates all memory needed for the model and loads the model from the given file.
|
| 81 |
// Returns NULL on failure.
|
|
@@ -129,7 +136,7 @@ extern "C" {
|
|
| 129 |
// You can also implement your own sampling method using the whisper_get_probs() function.
|
| 130 |
// whisper_sample_best() returns the token with the highest probability
|
| 131 |
// whisper_sample_timestamp() returns the most probable timestamp token
|
| 132 |
-
WHISPER_API
|
| 133 |
WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
|
| 134 |
|
| 135 |
// Return the id of the specified language, returns -1 if not found
|
|
@@ -172,7 +179,7 @@ extern "C" {
|
|
| 172 |
// Text segment callback
|
| 173 |
// Called on every newly generated text segment
|
| 174 |
// Use the whisper_full_...() functions to obtain the text segments
|
| 175 |
-
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data);
|
| 176 |
|
| 177 |
struct whisper_full_params {
|
| 178 |
enum whisper_sampling_strategy strategy;
|
|
@@ -188,6 +195,12 @@ extern "C" {
|
|
| 188 |
bool print_realtime;
|
| 189 |
bool print_timestamps;
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
const char * language;
|
| 192 |
|
| 193 |
struct {
|
|
@@ -244,7 +257,7 @@ extern "C" {
|
|
| 244 |
|
| 245 |
// Get token data for the specified token in the specified segment.
|
| 246 |
// This contains probabilities, timestamps, etc.
|
| 247 |
-
WHISPER_API
|
| 248 |
|
| 249 |
// Get the probability of the specified token in the specified segment.
|
| 250 |
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
|
|
|
|
| 68 |
|
| 69 |
typedef int whisper_token;
|
| 70 |
|
| 71 |
+
typedef struct whisper_token_data {
|
| 72 |
whisper_token id; // token id
|
| 73 |
whisper_token tid; // forced timestamp token id
|
| 74 |
|
| 75 |
float p; // probability of the token
|
| 76 |
float pt; // probability of the timestamp token
|
| 77 |
float ptsum; // sum of probabilities of all timestamp tokens
|
| 78 |
+
|
| 79 |
+
// token-level timestamp data
|
| 80 |
+
// do not use if you haven't computed token-level timestamps
|
| 81 |
+
int64_t t0; // start time of the token
|
| 82 |
+
int64_t t1; // end time of the token
|
| 83 |
+
|
| 84 |
+
float vlen; // voice length of the token
|
| 85 |
+
} whisper_token_data;
|
| 86 |
|
| 87 |
// Allocates all memory needed for the model and loads the model from the given file.
|
| 88 |
// Returns NULL on failure.
|
|
|
|
| 136 |
// You can also implement your own sampling method using the whisper_get_probs() function.
|
| 137 |
// whisper_sample_best() returns the token with the highest probability
|
| 138 |
// whisper_sample_timestamp() returns the most probable timestamp token
|
| 139 |
+
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
|
| 140 |
WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx);
|
| 141 |
|
| 142 |
// Return the id of the specified language, returns -1 if not found
|
|
|
|
| 179 |
// Text segment callback
|
| 180 |
// Called on every newly generated text segment
|
| 181 |
// Use the whisper_full_...() functions to obtain the text segments
|
| 182 |
+
typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data);
|
| 183 |
|
| 184 |
struct whisper_full_params {
|
| 185 |
enum whisper_sampling_strategy strategy;
|
|
|
|
| 195 |
bool print_realtime;
|
| 196 |
bool print_timestamps;
|
| 197 |
|
| 198 |
+
// [EXPERIMENTAL] token-level timestamps
|
| 199 |
+
bool token_timestamps; // enable token-level timestamps
|
| 200 |
+
float thold_pt; // timestamp token probability threshold (~0.01)
|
| 201 |
+
float thold_ptsum; // timestamp token sum probability threshold (~0.01)
|
| 202 |
+
int max_len; // max segment length in characters
|
| 203 |
+
|
| 204 |
const char * language;
|
| 205 |
|
| 206 |
struct {
|
|
|
|
| 257 |
|
| 258 |
// Get token data for the specified token in the specified segment.
|
| 259 |
// This contains probabilities, timestamps, etc.
|
| 260 |
+
WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token);
|
| 261 |
|
| 262 |
// Get the probability of the specified token in the specified segment.
|
| 263 |
WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);
|