Update model.py
Browse files
model.py
CHANGED
|
@@ -206,6 +206,10 @@ def get_pretrained_model(
|
|
| 206 |
return cantonese_models[repo_id](
|
| 207 |
repo_id, decoding_method=decoding_method, num_active_paths=num_active_paths
|
| 208 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
elif repo_id in tibetan_models:
|
| 210 |
return tibetan_models[repo_id](
|
| 211 |
repo_id, decoding_method=decoding_method, num_active_paths=num_active_paths
|
|
@@ -473,6 +477,116 @@ def _get_yifan_thai_pretrained_model(
|
|
| 473 |
|
| 474 |
return recognizer
|
| 475 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
|
| 477 |
@lru_cache(maxsize=10)
|
| 478 |
def _get_zrjin_cantonese_pre_trained_model(
|
|
@@ -2293,6 +2407,14 @@ cantonese_models = {
|
|
| 2293 |
"zrjin/icefall-asr-mdcc-zipformer-2024-03-11": _get_zrjin_cantonese_pre_trained_model,
|
| 2294 |
}
|
| 2295 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2296 |
korean_models = {
|
| 2297 |
"k2-fsa/sherpa-onnx-zipformer-korean-2024-06-24": _get_offline_pre_trained_model,
|
| 2298 |
"k2-fsa/sherpa-onnx-streaming-zipformer-korean-2024-06-16": _get_streaming_zipformer_pre_trained_model,
|
|
@@ -2325,6 +2447,7 @@ all_models = {
|
|
| 2325 |
**chinese_cantonese_english_models,
|
| 2326 |
**chinese_cantonese_english_japanese_korean_models,
|
| 2327 |
**cantonese_models,
|
|
|
|
| 2328 |
**japanese_models,
|
| 2329 |
**tibetan_models,
|
| 2330 |
**arabic_models,
|
|
@@ -2351,6 +2474,7 @@ language_to_models = {
|
|
| 2351 |
),
|
| 2352 |
"Arabic": list(arabic_models.keys()),
|
| 2353 |
"Cantonese": list(cantonese_models.keys()),
|
|
|
|
| 2354 |
"French": list(french_models.keys()),
|
| 2355 |
"German": list(german_models.keys()),
|
| 2356 |
"Japanese": list(japanese_models.keys()),
|
|
|
|
| 206 |
return cantonese_models[repo_id](
|
| 207 |
repo_id, decoding_method=decoding_method, num_active_paths=num_active_paths
|
| 208 |
)
|
| 209 |
+
elif repo_id in revolab_models:
|
| 210 |
+
return revolab_models[repo_id](
|
| 211 |
+
repo_id, decoding_method=decoding_method, num_active_paths=num_active_paths
|
| 212 |
+
)
|
| 213 |
elif repo_id in tibetan_models:
|
| 214 |
return tibetan_models[repo_id](
|
| 215 |
repo_id, decoding_method=decoding_method, num_active_paths=num_active_paths
|
|
|
|
| 477 |
|
| 478 |
return recognizer
|
| 479 |
|
| 480 |
+
@lru_cache(maxsize=10)
|
| 481 |
+
def _get_revolab_pretrained_model(
|
| 482 |
+
repo_id: str, decoding_method: str, num_active_paths: int
|
| 483 |
+
) -> sherpa_onnx.OfflineRecognizer:
|
| 484 |
+
assert 'Revolab' in repo_id
|
| 485 |
+
|
| 486 |
+
if repo_id == "Revolab/zipformer-large-145M":
|
| 487 |
+
real_repo = 'Revolab/malaysian-pruned_transducer_stateless7'
|
| 488 |
+
encoder_model = _get_nn_model_filename(
|
| 489 |
+
repo_id=real_repo,
|
| 490 |
+
filename="encoder-epoch-19-avg-1.onnx",
|
| 491 |
+
subfolder="zipformer-large-20k/export",
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
decoder_model = _get_nn_model_filename(
|
| 495 |
+
repo_id=real_repo,
|
| 496 |
+
filename="decoder-epoch-19-avg-1.onnx",
|
| 497 |
+
subfolder="zipformer-large-20k/export",
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
joiner_model = _get_nn_model_filename(
|
| 501 |
+
repo_id=real_repo,
|
| 502 |
+
filename="joiner-epoch-19-avg-1.onnx",
|
| 503 |
+
subfolder="zipformer-large-20k/export",
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
tokens = _get_token_filename(repo_id=real_repo, subfolder="zipformer-large-20k/exp/lang_bpe_500")
|
| 507 |
+
|
| 508 |
+
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
|
| 509 |
+
tokens=tokens,
|
| 510 |
+
encoder=encoder_model,
|
| 511 |
+
decoder=decoder_model,
|
| 512 |
+
joiner=joiner_model,
|
| 513 |
+
num_threads=2,
|
| 514 |
+
sample_rate=16000,
|
| 515 |
+
feature_dim=80,
|
| 516 |
+
decoding_method=decoding_method,
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
return recognizer
|
| 521 |
+
elif repo_id == "Revolab/zipformer-large-finetuned-145M":
|
| 522 |
+
real_repo = 'Revolab/malaysian-pruned_transducer_stateless7'
|
| 523 |
+
|
| 524 |
+
encoder_model = _get_nn_model_filename(
|
| 525 |
+
repo_id=real_repo,
|
| 526 |
+
filename="encoder-epoch-17-avg-1.onnx",
|
| 527 |
+
subfolder="zipformer-large-finetune-SFO/export",
|
| 528 |
+
)
|
| 529 |
+
decoder_model = _get_nn_model_filename(
|
| 530 |
+
repo_id=real_repo,
|
| 531 |
+
filename="decoder-epoch-19-avg-3.onnx",
|
| 532 |
+
subfolder="zipformer-large-finetune-SFO/export",
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
joiner_model = _get_nn_model_filename(
|
| 536 |
+
repo_id=real_repo,
|
| 537 |
+
filename="joiner-epoch-19-avg-3.onnx",
|
| 538 |
+
subfolder="zipformer-large-finetune-SFO/export",
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
tokens = _get_token_filename(repo_id=real_repo, subfolder="zipformer-large-20k/exp/lang_bpe_500")
|
| 542 |
+
|
| 543 |
+
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
|
| 544 |
+
tokens=tokens,
|
| 545 |
+
encoder=encoder_model,
|
| 546 |
+
decoder=decoder_model,
|
| 547 |
+
joiner=joiner_model,
|
| 548 |
+
num_threads=2,
|
| 549 |
+
sample_rate=16000,
|
| 550 |
+
feature_dim=80,
|
| 551 |
+
decoding_method=decoding_method,
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
return recognizer
|
| 555 |
+
|
| 556 |
+
elif repo_id == "Revolab/pruned-transducer-65M":
|
| 557 |
+
real_repo = 'Revolab/malaysian-pruned_transducer_stateless7'
|
| 558 |
+
|
| 559 |
+
encoder_model = _get_nn_model_filename(
|
| 560 |
+
repo_id=real_repo,
|
| 561 |
+
filename="encoder-epoch-19-avg-3.onnx",
|
| 562 |
+
subfolder="PT7-stage1/export",
|
| 563 |
+
)
|
| 564 |
+
decoder_model = _get_nn_model_filename(
|
| 565 |
+
repo_id=real_repo,
|
| 566 |
+
filename="decoder-epoch-19-avg-3.onnx",
|
| 567 |
+
subfolder="PT7-stage1/export",
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
joiner_model = _get_nn_model_filename(
|
| 571 |
+
repo_id=real_repo,
|
| 572 |
+
filename="joiner-epoch-19-avg-3.onnx",
|
| 573 |
+
subfolder="PT7-stage1/export",
|
| 574 |
+
)
|
| 575 |
+
|
| 576 |
+
tokens = _get_token_filename(repo_id=real_repo, subfolder="pts-16k-all/exp/lang_bpe_500")
|
| 577 |
+
|
| 578 |
+
recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
|
| 579 |
+
tokens=tokens,
|
| 580 |
+
encoder=encoder_model,
|
| 581 |
+
decoder=decoder_model,
|
| 582 |
+
joiner=joiner_model,
|
| 583 |
+
num_threads=2,
|
| 584 |
+
sample_rate=16000,
|
| 585 |
+
feature_dim=80,
|
| 586 |
+
decoding_method=decoding_method,
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
return recognizer
|
| 590 |
|
| 591 |
@lru_cache(maxsize=10)
|
| 592 |
def _get_zrjin_cantonese_pre_trained_model(
|
|
|
|
| 2407 |
"zrjin/icefall-asr-mdcc-zipformer-2024-03-11": _get_zrjin_cantonese_pre_trained_model,
|
| 2408 |
}
|
| 2409 |
|
| 2410 |
+
revolab_models = {
|
| 2411 |
+
"Revolab/zipformer-large-145M": _get_revolab_pretrained_model,
|
| 2412 |
+
"Revolab/pruned-transducer-65M": _get_revolab_pretrained_model,
|
| 2413 |
+
"Revolab/zipformer-large-finetuned-145M":_get_revolab_pretrained_model,
|
| 2414 |
+
|
| 2415 |
+
}
|
| 2416 |
+
|
| 2417 |
+
|
| 2418 |
korean_models = {
|
| 2419 |
"k2-fsa/sherpa-onnx-zipformer-korean-2024-06-24": _get_offline_pre_trained_model,
|
| 2420 |
"k2-fsa/sherpa-onnx-streaming-zipformer-korean-2024-06-16": _get_streaming_zipformer_pre_trained_model,
|
|
|
|
| 2447 |
**chinese_cantonese_english_models,
|
| 2448 |
**chinese_cantonese_english_japanese_korean_models,
|
| 2449 |
**cantonese_models,
|
| 2450 |
+
**revolab_models,
|
| 2451 |
**japanese_models,
|
| 2452 |
**tibetan_models,
|
| 2453 |
**arabic_models,
|
|
|
|
| 2474 |
),
|
| 2475 |
"Arabic": list(arabic_models.keys()),
|
| 2476 |
"Cantonese": list(cantonese_models.keys()),
|
| 2477 |
+
"Malay": list(revolab_models.keys()),
|
| 2478 |
"French": list(french_models.keys()),
|
| 2479 |
"German": list(german_models.keys()),
|
| 2480 |
"Japanese": list(japanese_models.keys()),
|