Spaces:
Runtime error
Runtime error
zzzweakman
commited on
Commit
Β·
4ff91cd
1
Parent(s):
cea0eec
fix: infer bug
Browse files- .gitignore +4 -3
- README.md +7 -4
- inference.sh +42 -0
- musetalk/utils/audio_processor.py +1 -1
- scripts/inference.py +65 -26
- scripts/inference_alpha.py +4 -5
.gitignore
CHANGED
|
@@ -4,7 +4,8 @@
|
|
| 4 |
.vscode/
|
| 5 |
*.pyc
|
| 6 |
.ipynb_checkpoints
|
| 7 |
-
models
|
| 8 |
results/
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
| 4 |
.vscode/
|
| 5 |
*.pyc
|
| 6 |
.ipynb_checkpoints
|
|
|
|
| 7 |
results/
|
| 8 |
+
./models
|
| 9 |
+
**/__pycache__/
|
| 10 |
+
*.py[cod]
|
| 11 |
+
*$py.class
|
README.md
CHANGED
|
@@ -177,7 +177,7 @@ You can download weights manually as follows:
|
|
| 177 |
|
| 178 |
2. Download the weights of other components:
|
| 179 |
- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
|
| 180 |
-
- [whisper](https://
|
| 181 |
- [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
|
| 182 |
- [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch)
|
| 183 |
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
|
|
@@ -201,7 +201,10 @@ Finally, these weights should be organized in `models` as follows:
|
|
| 201 |
β βββ config.json
|
| 202 |
β βββ diffusion_pytorch_model.bin
|
| 203 |
βββ whisper
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
| 205 |
```
|
| 206 |
## Quickstart
|
| 207 |
|
|
@@ -210,7 +213,7 @@ We provide inference scripts for both versions of MuseTalk:
|
|
| 210 |
|
| 211 |
#### MuseTalk 1.5 (Recommended)
|
| 212 |
```bash
|
| 213 |
-
|
| 214 |
```
|
| 215 |
This inference script supports both MuseTalk 1.5 and 1.0 models:
|
| 216 |
- For MuseTalk 1.5: Use the command above with the V1.5 model path
|
|
@@ -221,7 +224,7 @@ The video_path should be either a video file, an image file or a directory of im
|
|
| 221 |
|
| 222 |
#### MuseTalk 1.0
|
| 223 |
```bash
|
| 224 |
-
|
| 225 |
```
|
| 226 |
You are recommended to input video with `25fps`, the same fps used when training the model. If your video is far less than 25fps, you are recommended to apply frame interpolation or directly convert the video to 25fps using ffmpeg.
|
| 227 |
<details close>
|
|
|
|
| 177 |
|
| 178 |
2. Download the weights of other components:
|
| 179 |
- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
|
| 180 |
+
- [whisper](https://huggingface.co/openai/whisper-tiny/tree/main)
|
| 181 |
- [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
|
| 182 |
- [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch)
|
| 183 |
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
|
|
|
|
| 201 |
β βββ config.json
|
| 202 |
β βββ diffusion_pytorch_model.bin
|
| 203 |
βββ whisper
|
| 204 |
+
βββ config.json
|
| 205 |
+
βββ pytorch_model.bin
|
| 206 |
+
βββ preprocessor_config.json
|
| 207 |
+
|
| 208 |
```
|
| 209 |
## Quickstart
|
| 210 |
|
|
|
|
| 213 |
|
| 214 |
#### MuseTalk 1.5 (Recommended)
|
| 215 |
```bash
|
| 216 |
+
sh inference.sh v1.5
|
| 217 |
```
|
| 218 |
This inference script supports both MuseTalk 1.5 and 1.0 models:
|
| 219 |
- For MuseTalk 1.5: Use the command above with the V1.5 model path
|
|
|
|
| 224 |
|
| 225 |
#### MuseTalk 1.0
|
| 226 |
```bash
|
| 227 |
+
sh inference.sh v1.0
|
| 228 |
```
|
| 229 |
You are recommended to input video with `25fps`, the same fps used when training the model. If your video is far less than 25fps, you are recommended to apply frame interpolation or directly convert the video to 25fps using ffmpeg.
|
| 230 |
<details close>
|
inference.sh
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# This script runs inference based on the version specified by the user.
|
| 4 |
+
# Usage:
|
| 5 |
+
# To run v1.0 inference: sh inference.sh v1.0
|
| 6 |
+
# To run v1.5 inference: sh inference.sh v1.5
|
| 7 |
+
|
| 8 |
+
# Check if the correct number of arguments is provided
|
| 9 |
+
if [ "$#" -ne 1 ]; then
|
| 10 |
+
echo "Usage: $0 <version>"
|
| 11 |
+
echo "Example: $0 v1.0 or $0 v1.5"
|
| 12 |
+
exit 1
|
| 13 |
+
fi
|
| 14 |
+
|
| 15 |
+
# Get the version from the user input
|
| 16 |
+
version=$1
|
| 17 |
+
config_path="./configs/inference/test.yaml"
|
| 18 |
+
|
| 19 |
+
# Define the model paths based on the version
|
| 20 |
+
if [ "$version" = "v1.0" ]; then
|
| 21 |
+
model_dir="./models/musetalk"
|
| 22 |
+
unet_model_path="$model_dir/pytorch_model.bin"
|
| 23 |
+
elif [ "$version" = "v1.5" ]; then
|
| 24 |
+
model_dir="./models/musetalkV15"
|
| 25 |
+
unet_model_path="$model_dir/unet.pth"
|
| 26 |
+
else
|
| 27 |
+
echo "Invalid version specified. Please use v1.0 or v1.5."
|
| 28 |
+
exit 1
|
| 29 |
+
fi
|
| 30 |
+
|
| 31 |
+
# Run inference based on the version
|
| 32 |
+
if [ "$version" = "v1.0" ]; then
|
| 33 |
+
python3 -m scripts.inference \
|
| 34 |
+
--inference_config "$config_path" \
|
| 35 |
+
--result_dir "./results/test" \
|
| 36 |
+
--unet_model_path "$unet_model_path"
|
| 37 |
+
elif [ "$version" = "v1.5" ]; then
|
| 38 |
+
python3 -m scripts.inference_alpha \
|
| 39 |
+
--inference_config "$config_path" \
|
| 40 |
+
--result_dir "./results/test" \
|
| 41 |
+
--unet_model_path "$unet_model_path"
|
| 42 |
+
fi
|
musetalk/utils/audio_processor.py
CHANGED
|
@@ -91,7 +91,7 @@ class AudioProcessor:
|
|
| 91 |
|
| 92 |
if __name__ == "__main__":
|
| 93 |
audio_processor = AudioProcessor()
|
| 94 |
-
wav_path = "
|
| 95 |
audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path)
|
| 96 |
print("Audio Feature shape:", audio_feature.shape)
|
| 97 |
print("librosa_feature_length:", librosa_feature_length)
|
|
|
|
| 91 |
|
| 92 |
if __name__ == "__main__":
|
| 93 |
audio_processor = AudioProcessor()
|
| 94 |
+
wav_path = "./2.wav"
|
| 95 |
audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path)
|
| 96 |
print("Audio Feature shape:", audio_feature.shape)
|
| 97 |
print("librosa_feature_length:", librosa_feature_length)
|
scripts/inference.py
CHANGED
|
@@ -1,32 +1,58 @@
|
|
| 1 |
-
import argparse
|
| 2 |
import os
|
| 3 |
-
from omegaconf import OmegaConf
|
| 4 |
-
import numpy as np
|
| 5 |
import cv2
|
| 6 |
-
import
|
| 7 |
import glob
|
|
|
|
|
|
|
| 8 |
import pickle
|
|
|
|
|
|
|
| 9 |
from tqdm import tqdm
|
| 10 |
-
import
|
|
|
|
| 11 |
|
| 12 |
-
from musetalk.utils.utils import get_file_type,get_video_fps,datagen
|
| 13 |
-
from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
|
| 14 |
from musetalk.utils.blending import get_image
|
| 15 |
-
from musetalk.utils.
|
| 16 |
-
import
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
# load model weights
|
| 19 |
-
audio_processor, vae, unet, pe = load_all_model()
|
| 20 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
-
timesteps = torch.tensor([0], device=device)
|
| 22 |
|
| 23 |
@torch.no_grad()
|
| 24 |
def main(args):
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
if args.use_float16 is True:
|
| 27 |
pe = pe.half()
|
| 28 |
vae.vae = vae.vae.half()
|
| 29 |
unet.model = unet.model.half()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
inference_config = OmegaConf.load(args.inference_config)
|
| 32 |
print(inference_config)
|
|
@@ -64,10 +90,20 @@ def main(args):
|
|
| 64 |
else:
|
| 65 |
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
| 66 |
|
| 67 |
-
#print(input_img_list)
|
| 68 |
############################################## extract audio feature ##############################################
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
############################################## preprocess input image ##############################################
|
| 72 |
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
| 73 |
print("using extracted coordinates")
|
|
@@ -102,10 +138,7 @@ def main(args):
|
|
| 102 |
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
|
| 103 |
res_frame_list = []
|
| 104 |
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
| 105 |
-
audio_feature_batch =
|
| 106 |
-
audio_feature_batch = audio_feature_batch.to(device=unet.device,
|
| 107 |
-
dtype=unet.model.dtype) # torch, B, 5*N,384
|
| 108 |
-
audio_feature_batch = pe(audio_feature_batch)
|
| 109 |
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
| 110 |
|
| 111 |
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
|
@@ -122,10 +155,10 @@ def main(args):
|
|
| 122 |
try:
|
| 123 |
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
| 124 |
except:
|
| 125 |
-
# print(bbox)
|
| 126 |
continue
|
| 127 |
|
| 128 |
-
|
|
|
|
| 129 |
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
|
| 130 |
|
| 131 |
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
|
|
@@ -142,11 +175,11 @@ def main(args):
|
|
| 142 |
|
| 143 |
if __name__ == "__main__":
|
| 144 |
parser = argparse.ArgumentParser()
|
|
|
|
| 145 |
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
|
| 146 |
parser.add_argument("--bbox_shift", type=int, default=0)
|
| 147 |
parser.add_argument("--result_dir", default='./results', help="path to output")
|
| 148 |
-
|
| 149 |
-
parser.add_argument("--fps", type=int, default=25)
|
| 150 |
parser.add_argument("--batch_size", type=int, default=8)
|
| 151 |
parser.add_argument("--output_vid_name", type=str, default=None)
|
| 152 |
parser.add_argument("--use_saved_coord",
|
|
@@ -156,6 +189,12 @@ if __name__ == "__main__":
|
|
| 156 |
action="store_true",
|
| 157 |
help="Whether use float16 to speed up inference",
|
| 158 |
)
|
| 159 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
args = parser.parse_args()
|
| 161 |
main(args)
|
|
|
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
| 2 |
import cv2
|
| 3 |
+
import copy
|
| 4 |
import glob
|
| 5 |
+
import torch
|
| 6 |
+
import shutil
|
| 7 |
import pickle
|
| 8 |
+
import argparse
|
| 9 |
+
import numpy as np
|
| 10 |
from tqdm import tqdm
|
| 11 |
+
from omegaconf import OmegaConf
|
| 12 |
+
from transformers import WhisperModel
|
| 13 |
|
|
|
|
|
|
|
| 14 |
from musetalk.utils.blending import get_image
|
| 15 |
+
from musetalk.utils.face_parsing import FaceParsing
|
| 16 |
+
from musetalk.utils.audio_processor import AudioProcessor
|
| 17 |
+
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
| 18 |
+
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
|
| 19 |
+
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
@torch.no_grad()
|
| 23 |
def main(args):
|
| 24 |
+
# Configure ffmpeg path
|
| 25 |
+
if args.ffmpeg_path not in os.getenv('PATH'):
|
| 26 |
+
print("Adding ffmpeg to PATH")
|
| 27 |
+
os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}"
|
| 28 |
+
|
| 29 |
+
# Set computing device
|
| 30 |
+
device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
|
| 31 |
+
|
| 32 |
+
# Load model weights
|
| 33 |
+
vae, unet, pe = load_all_model(
|
| 34 |
+
unet_model_path=args.unet_model_path,
|
| 35 |
+
vae_type=args.vae_type,
|
| 36 |
+
unet_config=args.unet_config,
|
| 37 |
+
device=device
|
| 38 |
+
)
|
| 39 |
+
timesteps = torch.tensor([0], device=device)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
if args.use_float16 is True:
|
| 43 |
pe = pe.half()
|
| 44 |
vae.vae = vae.vae.half()
|
| 45 |
unet.model = unet.model.half()
|
| 46 |
+
|
| 47 |
+
# Initialize audio processor and Whisper model
|
| 48 |
+
audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
|
| 49 |
+
weight_dtype = unet.model.dtype
|
| 50 |
+
whisper = WhisperModel.from_pretrained(args.whisper_dir)
|
| 51 |
+
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
| 52 |
+
whisper.requires_grad_(False)
|
| 53 |
+
|
| 54 |
+
# Initialize face parser
|
| 55 |
+
fp = FaceParsing()
|
| 56 |
|
| 57 |
inference_config = OmegaConf.load(args.inference_config)
|
| 58 |
print(inference_config)
|
|
|
|
| 90 |
else:
|
| 91 |
raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
|
| 92 |
|
|
|
|
| 93 |
############################################## extract audio feature ##############################################
|
| 94 |
+
# Extract audio features
|
| 95 |
+
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
|
| 96 |
+
whisper_chunks = audio_processor.get_whisper_chunk(
|
| 97 |
+
whisper_input_features,
|
| 98 |
+
device,
|
| 99 |
+
weight_dtype,
|
| 100 |
+
whisper,
|
| 101 |
+
librosa_length,
|
| 102 |
+
fps=fps,
|
| 103 |
+
audio_padding_length_left=args.audio_padding_length_left,
|
| 104 |
+
audio_padding_length_right=args.audio_padding_length_right,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
############################################## preprocess input image ##############################################
|
| 108 |
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
| 109 |
print("using extracted coordinates")
|
|
|
|
| 138 |
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
|
| 139 |
res_frame_list = []
|
| 140 |
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
| 141 |
+
audio_feature_batch = pe(whisper_batch)
|
|
|
|
|
|
|
|
|
|
| 142 |
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
| 143 |
|
| 144 |
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
|
|
|
| 155 |
try:
|
| 156 |
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
| 157 |
except:
|
|
|
|
| 158 |
continue
|
| 159 |
|
| 160 |
+
# Merge results
|
| 161 |
+
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp)
|
| 162 |
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
|
| 163 |
|
| 164 |
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
|
|
|
|
| 175 |
|
| 176 |
if __name__ == "__main__":
|
| 177 |
parser = argparse.ArgumentParser()
|
| 178 |
+
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
|
| 179 |
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
|
| 180 |
parser.add_argument("--bbox_shift", type=int, default=0)
|
| 181 |
parser.add_argument("--result_dir", default='./results', help="path to output")
|
| 182 |
+
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
|
|
|
|
| 183 |
parser.add_argument("--batch_size", type=int, default=8)
|
| 184 |
parser.add_argument("--output_vid_name", type=str, default=None)
|
| 185 |
parser.add_argument("--use_saved_coord",
|
|
|
|
| 189 |
action="store_true",
|
| 190 |
help="Whether use float16 to speed up inference",
|
| 191 |
)
|
| 192 |
+
parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
|
| 193 |
+
parser.add_argument("--unet_model_path", type=str, default="./models/musetalk/pytorch_model.bin", help="Path to UNet model weights")
|
| 194 |
+
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
|
| 195 |
+
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
|
| 196 |
+
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
|
| 197 |
+
parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
|
| 198 |
+
parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
|
| 199 |
args = parser.parse_args()
|
| 200 |
main(args)
|
scripts/inference_alpha.py
CHANGED
|
@@ -72,8 +72,7 @@ def main(args):
|
|
| 72 |
audio_path = inference_config[task_id]["audio_path"]
|
| 73 |
if "result_name" in inference_config[task_id]:
|
| 74 |
args.output_vid_name = inference_config[task_id]["result_name"]
|
| 75 |
-
bbox_shift =
|
| 76 |
-
|
| 77 |
# Set output paths
|
| 78 |
input_basename = os.path.basename(video_path).split('.')[0]
|
| 79 |
audio_basename = os.path.basename(audio_path).split('.')[0]
|
|
@@ -228,12 +227,12 @@ def main(args):
|
|
| 228 |
|
| 229 |
if __name__ == "__main__":
|
| 230 |
parser = argparse.ArgumentParser()
|
| 231 |
-
parser.add_argument("--ffmpeg_path", type=str, default="
|
| 232 |
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
|
| 233 |
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
|
| 234 |
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
|
| 235 |
-
parser.add_argument("--unet_model_path", type=str, default="/
|
| 236 |
-
parser.add_argument("--whisper_dir", type=str, default="/
|
| 237 |
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
|
| 238 |
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
|
| 239 |
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
|
|
|
|
| 72 |
audio_path = inference_config[task_id]["audio_path"]
|
| 73 |
if "result_name" in inference_config[task_id]:
|
| 74 |
args.output_vid_name = inference_config[task_id]["result_name"]
|
| 75 |
+
bbox_shift = args.bbox_shift
|
|
|
|
| 76 |
# Set output paths
|
| 77 |
input_basename = os.path.basename(video_path).split('.')[0]
|
| 78 |
audio_basename = os.path.basename(audio_path).split('.')[0]
|
|
|
|
| 227 |
|
| 228 |
if __name__ == "__main__":
|
| 229 |
parser = argparse.ArgumentParser()
|
| 230 |
+
parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
|
| 231 |
parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
|
| 232 |
parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
|
| 233 |
parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
|
| 234 |
+
parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights")
|
| 235 |
+
parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
|
| 236 |
parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
|
| 237 |
parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
|
| 238 |
parser.add_argument("--result_dir", default='./results', help="Directory for output results")
|