zzzweakman commited on
Commit
4ff91cd
Β·
1 Parent(s): cea0eec

fix: infer bug

Browse files
.gitignore CHANGED
@@ -4,7 +4,8 @@
4
  .vscode/
5
  *.pyc
6
  .ipynb_checkpoints
7
- models
8
  results/
9
- data/audio/*.wav
10
- data/video/*.mp4
 
 
 
4
  .vscode/
5
  *.pyc
6
  .ipynb_checkpoints
 
7
  results/
8
+ ./models
9
+ **/__pycache__/
10
+ *.py[cod]
11
+ *$py.class
README.md CHANGED
@@ -177,7 +177,7 @@ You can download weights manually as follows:
177
 
178
  2. Download the weights of other components:
179
  - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
180
- - [whisper](https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt)
181
  - [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
182
  - [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch)
183
  - [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
@@ -201,7 +201,10 @@ Finally, these weights should be organized in `models` as follows:
201
  β”‚ β”œβ”€β”€ config.json
202
  β”‚ └── diffusion_pytorch_model.bin
203
  └── whisper
204
- └── tiny.pt
 
 
 
205
  ```
206
  ## Quickstart
207
 
@@ -210,7 +213,7 @@ We provide inference scripts for both versions of MuseTalk:
210
 
211
  #### MuseTalk 1.5 (Recommended)
212
  ```bash
213
- python3 -m scripts.inference_alpha --inference_config configs/inference/test.yaml --unet_model_path ./models/musetalkV15/unet.pth
214
  ```
215
  This inference script supports both MuseTalk 1.5 and 1.0 models:
216
  - For MuseTalk 1.5: Use the command above with the V1.5 model path
@@ -221,7 +224,7 @@ The video_path should be either a video file, an image file or a directory of im
221
 
222
  #### MuseTalk 1.0
223
  ```bash
224
- python3 -m scripts.inference --inference_config configs/inference/test.yaml
225
  ```
226
  You are recommended to input video with `25fps`, the same fps used when training the model. If your video is far less than 25fps, you are recommended to apply frame interpolation or directly convert the video to 25fps using ffmpeg.
227
  <details close>
 
177
 
178
  2. Download the weights of other components:
179
  - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
180
+ - [whisper](https://huggingface.co/openai/whisper-tiny/tree/main)
181
  - [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
182
  - [face-parse-bisent](https://github.com/zllrunning/face-parsing.PyTorch)
183
  - [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
 
201
  β”‚ β”œβ”€β”€ config.json
202
  β”‚ └── diffusion_pytorch_model.bin
203
  └── whisper
204
+ β”œβ”€β”€ config.json
205
+ β”œβ”€β”€ pytorch_model.bin
206
+ └── preprocessor_config.json
207
+
208
  ```
209
  ## Quickstart
210
 
 
213
 
214
  #### MuseTalk 1.5 (Recommended)
215
  ```bash
216
+ sh inference.sh v1.5
217
  ```
218
  This inference script supports both MuseTalk 1.5 and 1.0 models:
219
  - For MuseTalk 1.5: Use the command above with the V1.5 model path
 
224
 
225
  #### MuseTalk 1.0
226
  ```bash
227
+ sh inference.sh v1.0
228
  ```
229
  You are recommended to input video with `25fps`, the same fps used when training the model. If your video is far less than 25fps, you are recommended to apply frame interpolation or directly convert the video to 25fps using ffmpeg.
230
  <details close>
inference.sh ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # This script runs inference based on the version specified by the user.
4
+ # Usage:
5
+ # To run v1.0 inference: sh inference.sh v1.0
6
+ # To run v1.5 inference: sh inference.sh v1.5
7
+
8
+ # Check if the correct number of arguments is provided
9
+ if [ "$#" -ne 1 ]; then
10
+ echo "Usage: $0 <version>"
11
+ echo "Example: $0 v1.0 or $0 v1.5"
12
+ exit 1
13
+ fi
14
+
15
+ # Get the version from the user input
16
+ version=$1
17
+ config_path="./configs/inference/test.yaml"
18
+
19
+ # Define the model paths based on the version
20
+ if [ "$version" = "v1.0" ]; then
21
+ model_dir="./models/musetalk"
22
+ unet_model_path="$model_dir/pytorch_model.bin"
23
+ elif [ "$version" = "v1.5" ]; then
24
+ model_dir="./models/musetalkV15"
25
+ unet_model_path="$model_dir/unet.pth"
26
+ else
27
+ echo "Invalid version specified. Please use v1.0 or v1.5."
28
+ exit 1
29
+ fi
30
+
31
+ # Run inference based on the version
32
+ if [ "$version" = "v1.0" ]; then
33
+ python3 -m scripts.inference \
34
+ --inference_config "$config_path" \
35
+ --result_dir "./results/test" \
36
+ --unet_model_path "$unet_model_path"
37
+ elif [ "$version" = "v1.5" ]; then
38
+ python3 -m scripts.inference_alpha \
39
+ --inference_config "$config_path" \
40
+ --result_dir "./results/test" \
41
+ --unet_model_path "$unet_model_path"
42
+ fi
musetalk/utils/audio_processor.py CHANGED
@@ -91,7 +91,7 @@ class AudioProcessor:
91
 
92
  if __name__ == "__main__":
93
  audio_processor = AudioProcessor()
94
- wav_path = "/cfs-workspace/users/gozhong/codes/musetalk_opensource2/data/audio/2.wav"
95
  audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path)
96
  print("Audio Feature shape:", audio_feature.shape)
97
  print("librosa_feature_length:", librosa_feature_length)
 
91
 
92
  if __name__ == "__main__":
93
  audio_processor = AudioProcessor()
94
+ wav_path = "./2.wav"
95
  audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path)
96
  print("Audio Feature shape:", audio_feature.shape)
97
  print("librosa_feature_length:", librosa_feature_length)
scripts/inference.py CHANGED
@@ -1,32 +1,58 @@
1
- import argparse
2
  import os
3
- from omegaconf import OmegaConf
4
- import numpy as np
5
  import cv2
6
- import torch
7
  import glob
 
 
8
  import pickle
 
 
9
  from tqdm import tqdm
10
- import copy
 
11
 
12
- from musetalk.utils.utils import get_file_type,get_video_fps,datagen
13
- from musetalk.utils.preprocessing import get_landmark_and_bbox,read_imgs,coord_placeholder
14
  from musetalk.utils.blending import get_image
15
- from musetalk.utils.utils import load_all_model
16
- import shutil
 
 
 
17
 
18
- # load model weights
19
- audio_processor, vae, unet, pe = load_all_model()
20
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
- timesteps = torch.tensor([0], device=device)
22
 
23
  @torch.no_grad()
24
  def main(args):
25
- global pe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  if args.use_float16 is True:
27
  pe = pe.half()
28
  vae.vae = vae.vae.half()
29
  unet.model = unet.model.half()
 
 
 
 
 
 
 
 
 
 
30
 
31
  inference_config = OmegaConf.load(args.inference_config)
32
  print(inference_config)
@@ -64,10 +90,20 @@ def main(args):
64
  else:
65
  raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
66
 
67
- #print(input_img_list)
68
  ############################################## extract audio feature ##############################################
69
- whisper_feature = audio_processor.audio2feat(audio_path)
70
- whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
 
 
 
 
 
 
 
 
 
 
 
71
  ############################################## preprocess input image ##############################################
72
  if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
73
  print("using extracted coordinates")
@@ -102,10 +138,7 @@ def main(args):
102
  gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
103
  res_frame_list = []
104
  for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
105
- audio_feature_batch = torch.from_numpy(whisper_batch)
106
- audio_feature_batch = audio_feature_batch.to(device=unet.device,
107
- dtype=unet.model.dtype) # torch, B, 5*N,384
108
- audio_feature_batch = pe(audio_feature_batch)
109
  latent_batch = latent_batch.to(dtype=unet.model.dtype)
110
 
111
  pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
@@ -122,10 +155,10 @@ def main(args):
122
  try:
123
  res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
124
  except:
125
- # print(bbox)
126
  continue
127
 
128
- combine_frame = get_image(ori_frame,res_frame,bbox)
 
129
  cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
130
 
131
  cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
@@ -142,11 +175,11 @@ def main(args):
142
 
143
  if __name__ == "__main__":
144
  parser = argparse.ArgumentParser()
 
145
  parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
146
  parser.add_argument("--bbox_shift", type=int, default=0)
147
  parser.add_argument("--result_dir", default='./results', help="path to output")
148
-
149
- parser.add_argument("--fps", type=int, default=25)
150
  parser.add_argument("--batch_size", type=int, default=8)
151
  parser.add_argument("--output_vid_name", type=str, default=None)
152
  parser.add_argument("--use_saved_coord",
@@ -156,6 +189,12 @@ if __name__ == "__main__":
156
  action="store_true",
157
  help="Whether use float16 to speed up inference",
158
  )
159
-
 
 
 
 
 
 
160
  args = parser.parse_args()
161
  main(args)
 
 
1
  import os
 
 
2
  import cv2
3
+ import copy
4
  import glob
5
+ import torch
6
+ import shutil
7
  import pickle
8
+ import argparse
9
+ import numpy as np
10
  from tqdm import tqdm
11
+ from omegaconf import OmegaConf
12
+ from transformers import WhisperModel
13
 
 
 
14
  from musetalk.utils.blending import get_image
15
+ from musetalk.utils.face_parsing import FaceParsing
16
+ from musetalk.utils.audio_processor import AudioProcessor
17
+ from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
18
+ from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder
19
+
20
 
 
 
 
 
21
 
22
  @torch.no_grad()
23
  def main(args):
24
+ # Configure ffmpeg path
25
+ if args.ffmpeg_path not in os.getenv('PATH'):
26
+ print("Adding ffmpeg to PATH")
27
+ os.environ["PATH"] = f"{args.ffmpeg_path}:{os.environ['PATH']}"
28
+
29
+ # Set computing device
30
+ device = torch.device(f"cuda:{args.gpu_id}" if torch.cuda.is_available() else "cpu")
31
+
32
+ # Load model weights
33
+ vae, unet, pe = load_all_model(
34
+ unet_model_path=args.unet_model_path,
35
+ vae_type=args.vae_type,
36
+ unet_config=args.unet_config,
37
+ device=device
38
+ )
39
+ timesteps = torch.tensor([0], device=device)
40
+
41
+
42
  if args.use_float16 is True:
43
  pe = pe.half()
44
  vae.vae = vae.vae.half()
45
  unet.model = unet.model.half()
46
+
47
+ # Initialize audio processor and Whisper model
48
+ audio_processor = AudioProcessor(feature_extractor_path=args.whisper_dir)
49
+ weight_dtype = unet.model.dtype
50
+ whisper = WhisperModel.from_pretrained(args.whisper_dir)
51
+ whisper = whisper.to(device=device, dtype=weight_dtype).eval()
52
+ whisper.requires_grad_(False)
53
+
54
+ # Initialize face parser
55
+ fp = FaceParsing()
56
 
57
  inference_config = OmegaConf.load(args.inference_config)
58
  print(inference_config)
 
90
  else:
91
  raise ValueError(f"{video_path} should be a video file, an image file or a directory of images")
92
 
 
93
  ############################################## extract audio feature ##############################################
94
+ # Extract audio features
95
+ whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
96
+ whisper_chunks = audio_processor.get_whisper_chunk(
97
+ whisper_input_features,
98
+ device,
99
+ weight_dtype,
100
+ whisper,
101
+ librosa_length,
102
+ fps=fps,
103
+ audio_padding_length_left=args.audio_padding_length_left,
104
+ audio_padding_length_right=args.audio_padding_length_right,
105
+ )
106
+
107
  ############################################## preprocess input image ##############################################
108
  if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
109
  print("using extracted coordinates")
 
138
  gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
139
  res_frame_list = []
140
  for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
141
+ audio_feature_batch = pe(whisper_batch)
 
 
 
142
  latent_batch = latent_batch.to(dtype=unet.model.dtype)
143
 
144
  pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
 
155
  try:
156
  res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
157
  except:
 
158
  continue
159
 
160
+ # Merge results
161
+ combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], fp=fp)
162
  cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
163
 
164
  cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {result_img_save_path}/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 temp.mp4"
 
175
 
176
  if __name__ == "__main__":
177
  parser = argparse.ArgumentParser()
178
+ parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
179
  parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml")
180
  parser.add_argument("--bbox_shift", type=int, default=0)
181
  parser.add_argument("--result_dir", default='./results', help="path to output")
182
+ parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
 
183
  parser.add_argument("--batch_size", type=int, default=8)
184
  parser.add_argument("--output_vid_name", type=str, default=None)
185
  parser.add_argument("--use_saved_coord",
 
189
  action="store_true",
190
  help="Whether use float16 to speed up inference",
191
  )
192
+ parser.add_argument("--fps", type=int, default=25, help="Video frames per second")
193
+ parser.add_argument("--unet_model_path", type=str, default="./models/musetalk/pytorch_model.bin", help="Path to UNet model weights")
194
+ parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
195
+ parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
196
+ parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
197
+ parser.add_argument("--audio_padding_length_left", type=int, default=2, help="Left padding length for audio")
198
+ parser.add_argument("--audio_padding_length_right", type=int, default=2, help="Right padding length for audio")
199
  args = parser.parse_args()
200
  main(args)
scripts/inference_alpha.py CHANGED
@@ -72,8 +72,7 @@ def main(args):
72
  audio_path = inference_config[task_id]["audio_path"]
73
  if "result_name" in inference_config[task_id]:
74
  args.output_vid_name = inference_config[task_id]["result_name"]
75
- bbox_shift = inference_config[task_id].get("bbox_shift", args.bbox_shift)
76
-
77
  # Set output paths
78
  input_basename = os.path.basename(video_path).split('.')[0]
79
  audio_basename = os.path.basename(audio_path).split('.')[0]
@@ -228,12 +227,12 @@ def main(args):
228
 
229
  if __name__ == "__main__":
230
  parser = argparse.ArgumentParser()
231
- parser.add_argument("--ffmpeg_path", type=str, default="/cfs-workspace/users/gozhong/ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
232
  parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
233
  parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
234
  parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
235
- parser.add_argument("--unet_model_path", type=str, default="/cfs-datasets/users/gozhong/codes/musetalk_exp/exp_out/stage1_bs40/unet-20000.pth", help="Path to UNet model weights")
236
- parser.add_argument("--whisper_dir", type=str, default="/cfs-datasets/public_models/whisper-tiny", help="Directory containing Whisper model")
237
  parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
238
  parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
239
  parser.add_argument("--result_dir", default='./results', help="Directory for output results")
 
72
  audio_path = inference_config[task_id]["audio_path"]
73
  if "result_name" in inference_config[task_id]:
74
  args.output_vid_name = inference_config[task_id]["result_name"]
75
+ bbox_shift = args.bbox_shift
 
76
  # Set output paths
77
  input_basename = os.path.basename(video_path).split('.')[0]
78
  audio_basename = os.path.basename(audio_path).split('.')[0]
 
227
 
228
  if __name__ == "__main__":
229
  parser = argparse.ArgumentParser()
230
+ parser.add_argument("--ffmpeg_path", type=str, default="./ffmpeg-4.4-amd64-static/", help="Path to ffmpeg executable")
231
  parser.add_argument("--gpu_id", type=int, default=0, help="GPU ID to use")
232
  parser.add_argument("--vae_type", type=str, default="sd-vae", help="Type of VAE model")
233
  parser.add_argument("--unet_config", type=str, default="./models/musetalk/config.json", help="Path to UNet configuration file")
234
+ parser.add_argument("--unet_model_path", type=str, default="./models/musetalkV15/unet.pth", help="Path to UNet model weights")
235
+ parser.add_argument("--whisper_dir", type=str, default="./models/whisper", help="Directory containing Whisper model")
236
  parser.add_argument("--inference_config", type=str, default="configs/inference/test_img.yaml", help="Path to inference configuration file")
237
  parser.add_argument("--bbox_shift", type=int, default=0, help="Bounding box shift value")
238
  parser.add_argument("--result_dir", default='./results', help="Directory for output results")