czk32611 commited on
Commit
d6f1e39
Β·
1 Parent(s): 916813d

<enhance>: support using float16 in inference to speed up

Browse files
README.md CHANGED
@@ -267,10 +267,8 @@ As a complete solution to virtual human generation, you are suggested to first a
267
 
268
  Here, we provide the inference script. This script first applies necessary pre-processing such as face detection, face parsing and VAE encode in advance. During inference, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
269
 
270
- Note that in this script, the generation time is also limited by I/O (e.g. saving images).
271
-
272
  ```
273
- python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml
274
  ```
275
  configs/inference/realtime.yaml is the path to the real-time inference configuration file, including `preparation`, `video_path` , `bbox_shift` and `audio_clips`.
276
 
@@ -280,17 +278,14 @@ configs/inference/realtime.yaml is the path to the real-time inference configura
280
  Inferring using: data/audio/yongen.wav
281
  ```
282
  1. While MuseTalk is inferring, sub-threads can simultaneously stream the results to the users. The generation process can achieve 30fps+ on an NVIDIA Tesla V100.
283
- ```
284
- 2%|β–ˆβ–ˆβ– | 3/141 [00:00<00:32, 4.30it/s] # inference process
285
- Displaying the 6-th frame with FPS: 48.58 # display process
286
- Displaying the 7-th frame with FPS: 48.74
287
- Displaying the 8-th frame with FPS: 49.17
288
- 3%|β–ˆβ–ˆβ–ˆβ–Ž | 4/141 [00:00<00:32, 4.21it/s]
289
- ```
290
  1. Set `preparation` to `False` and run this script if you want to genrate more videos using the same avatar.
291
 
292
- If you want to generate multiple videos using the same avatar/video, you can also use this script to **SIGNIFICANTLY** expedite the generation process.
293
-
 
 
 
 
294
 
295
  # Acknowledgement
296
  1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).
 
267
 
268
  Here, we provide the inference script. This script first applies necessary pre-processing such as face detection, face parsing and VAE encode in advance. During inference, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
269
 
 
 
270
  ```
271
+ python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --batch_size 4
272
  ```
273
  configs/inference/realtime.yaml is the path to the real-time inference configuration file, including `preparation`, `video_path` , `bbox_shift` and `audio_clips`.
274
 
 
278
  Inferring using: data/audio/yongen.wav
279
  ```
280
  1. While MuseTalk is inferring, sub-threads can simultaneously stream the results to the users. The generation process can achieve 30fps+ on an NVIDIA Tesla V100.
 
 
 
 
 
 
 
281
  1. Set `preparation` to `False` and run this script if you want to genrate more videos using the same avatar.
282
 
283
+ ##### Note for Real-time inference
284
+ 1. If you want to generate multiple videos using the same avatar/video, you can also use this script to **SIGNIFICANTLY** expedite the generation process.
285
+ 1. In the previous script, the generation time is also limited by I/O (e.g. saving images). If you just want to test the generation speed without saving the images, you can run
286
+ ```
287
+ python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
288
+ ```
289
 
290
  # Acknowledgement
291
  1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).
musetalk/models/unet.py CHANGED
@@ -37,11 +37,11 @@ class UNet():
37
  self.model = UNet2DConditionModel(**unet_config)
38
  self.pe = PositionalEncoding(d_model=384)
39
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
- self.weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
41
- self.model.load_state_dict(self.weights)
42
  if use_float16:
43
  self.model = self.model.half()
44
  self.model.to(self.device)
45
 
46
  if __name__ == "__main__":
47
- unet = UNet()
 
37
  self.model = UNet2DConditionModel(**unet_config)
38
  self.pe = PositionalEncoding(d_model=384)
39
  self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
41
+ self.model.load_state_dict(weights)
42
  if use_float16:
43
  self.model = self.model.half()
44
  self.model.to(self.device)
45
 
46
  if __name__ == "__main__":
47
+ unet = UNet()
musetalk/utils/utils.py CHANGED
@@ -39,7 +39,10 @@ def get_video_fps(video_path):
39
  video.release()
40
  return fps
41
 
42
- def datagen(whisper_chunks,vae_encode_latents,batch_size=8,delay_frame = 0):
 
 
 
43
  whisper_batch, latent_batch = [], []
44
  for i, w in enumerate(whisper_chunks):
45
  idx = (i+delay_frame)%len(vae_encode_latents)
@@ -48,14 +51,14 @@ def datagen(whisper_chunks,vae_encode_latents,batch_size=8,delay_frame = 0):
48
  latent_batch.append(latent)
49
 
50
  if len(latent_batch) >= batch_size:
51
- whisper_batch = np.asarray(whisper_batch)
52
  latent_batch = torch.cat(latent_batch, dim=0)
53
  yield whisper_batch, latent_batch
54
  whisper_batch, latent_batch = [], []
55
 
56
  # the last batch may smaller than batch size
57
  if len(latent_batch) > 0:
58
- whisper_batch = np.asarray(whisper_batch)
59
  latent_batch = torch.cat(latent_batch, dim=0)
60
 
61
- yield whisper_batch, latent_batch
 
39
  video.release()
40
  return fps
41
 
42
+ def datagen(whisper_chunks,
43
+ vae_encode_latents,
44
+ batch_size=8,
45
+ delay_frame=0):
46
  whisper_batch, latent_batch = [], []
47
  for i, w in enumerate(whisper_chunks):
48
  idx = (i+delay_frame)%len(vae_encode_latents)
 
51
  latent_batch.append(latent)
52
 
53
  if len(latent_batch) >= batch_size:
54
+ whisper_batch = np.stack(whisper_batch)
55
  latent_batch = torch.cat(latent_batch, dim=0)
56
  yield whisper_batch, latent_batch
57
  whisper_batch, latent_batch = [], []
58
 
59
  # the last batch may smaller than batch size
60
  if len(latent_batch) > 0:
61
+ whisper_batch = np.stack(whisper_batch)
62
  latent_batch = torch.cat(latent_batch, dim=0)
63
 
64
+ yield whisper_batch, latent_batch
musetalk/whisper/audio2feature.py CHANGED
@@ -13,7 +13,11 @@ class Audio2Feature():
13
  self.whisper_model_type = whisper_model_type
14
  self.model = load_model(model_path) #
15
 
16
- def get_sliced_feature(self,feature_array, vid_idx, audio_feat_length= [2,2],fps = 25):
 
 
 
 
17
  """
18
  Get sliced features based on a given index
19
  :param feature_array:
 
13
  self.whisper_model_type = whisper_model_type
14
  self.model = load_model(model_path) #
15
 
16
+ def get_sliced_feature(self,
17
+ feature_array,
18
+ vid_idx,
19
+ audio_feat_length=[2,2],
20
+ fps=25):
21
  """
22
  Get sliced features based on a given index
23
  :param feature_array:
scripts/inference.py CHANGED
@@ -16,12 +16,18 @@ from musetalk.utils.utils import load_all_model
16
  import shutil
17
 
18
  # load model weights
19
- audio_processor,vae,unet,pe = load_all_model()
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  timesteps = torch.tensor([0], device=device)
22
 
23
  @torch.no_grad()
24
  def main(args):
 
 
 
 
 
 
25
  inference_config = OmegaConf.load(args.inference_config)
26
  print(inference_config)
27
  for task_id in inference_config:
@@ -96,10 +102,11 @@ def main(args):
96
  gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
97
  res_frame_list = []
98
  for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
99
-
100
- tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
101
- audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384
102
  audio_feature_batch = pe(audio_feature_batch)
 
103
 
104
  pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
105
  recon = vae.decode_latents(pred_latents)
@@ -145,7 +152,10 @@ if __name__ == "__main__":
145
  parser.add_argument("--use_saved_coord",
146
  action="store_true",
147
  help='use saved coordinate to save time')
148
-
 
 
 
149
 
150
  args = parser.parse_args()
151
  main(args)
 
16
  import shutil
17
 
18
  # load model weights
19
+ audio_processor, vae, unet, pe = load_all_model()
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  timesteps = torch.tensor([0], device=device)
22
 
23
  @torch.no_grad()
24
  def main(args):
25
+ global pe
26
+ if args.use_float16 is True:
27
+ pe = pe.half()
28
+ vae.vae = vae.vae.half()
29
+ unet.model = unet.model.half()
30
+
31
  inference_config = OmegaConf.load(args.inference_config)
32
  print(inference_config)
33
  for task_id in inference_config:
 
102
  gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
103
  res_frame_list = []
104
  for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
105
+ audio_feature_batch = torch.from_numpy(whisper_batch)
106
+ audio_feature_batch = audio_feature_batch.to(device=unet.device,
107
+ dtype=unet.model.dtype) # torch, B, 5*N,384
108
  audio_feature_batch = pe(audio_feature_batch)
109
+ latent_batch = latent_batch.to(dtype=unet.model.dtype)
110
 
111
  pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
112
  recon = vae.decode_latents(pred_latents)
 
152
  parser.add_argument("--use_saved_coord",
153
  action="store_true",
154
  help='use saved coordinate to save time')
155
+ parser.add_argument("--use_float16",
156
+ action="store_true",
157
+ help="Whether use float16 to speed up inference",
158
+ )
159
 
160
  args = parser.parse_args()
161
  main(args)
scripts/realtime_inference.py CHANGED
@@ -22,10 +22,12 @@ import queue
22
  import time
23
 
24
  # load model weights
25
- audio_processor,vae,unet,pe = load_all_model()
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  timesteps = torch.tensor([0], device=device)
28
-
 
 
29
 
30
  def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
31
  cap = cv2.VideoCapture(vid_path)
@@ -99,6 +101,10 @@ class Avatar:
99
  osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
100
  self.prepare_material()
101
  else:
 
 
 
 
102
  with open(self.avatar_info_path, "r") as f:
103
  avatar_info = json.load(f)
104
 
@@ -182,7 +188,10 @@ class Avatar:
182
  torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path))
183
  #
184
 
185
- def process_frames(self, res_frame_queue,video_len):
 
 
 
186
  print(video_len)
187
  while True:
188
  if self.idx>=video_len-1:
@@ -205,44 +214,62 @@ class Avatar:
205
  #combine_frame = get_image(ori_frame,res_frame,bbox)
206
  combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
207
 
208
- fps = 1/(time.time()-start+1e-6)
209
- print(f"Displaying the {self.idx}-th frame with FPS: {fps:.2f}")
210
- cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png",combine_frame)
211
  self.idx = self.idx + 1
212
 
213
- def inference(self, audio_path, out_vid_name, fps):
 
 
 
 
214
  os.makedirs(self.avatar_path+'/tmp',exist_ok =True)
 
215
  ############################################## extract audio feature ##############################################
 
216
  whisper_feature = audio_processor.audio2feat(audio_path)
217
  whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
 
218
  ############################################## inference batch by batch ##############################################
219
  video_num = len(whisper_chunks)
220
- print("start inference")
221
  res_frame_queue = queue.Queue()
222
  self.idx = 0
223
  # # Create a sub-thread and start it
224
- process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue,video_num))
225
  process_thread.start()
226
- start_time = time.time()
227
- gen = datagen(whisper_chunks,self.input_latent_list_cycle, self.batch_size)
228
- print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms")
 
229
  start_time = time.time()
230
  res_frame_list = []
231
 
232
  for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/self.batch_size)))):
233
- start_time = time.time()
234
- tensor_list = [torch.FloatTensor(arr) for arr in whisper_batch]
235
- audio_feature_batch = torch.stack(tensor_list).to(unet.device) # torch, B, 5*N,384
236
  audio_feature_batch = pe(audio_feature_batch)
237
-
238
- pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
 
 
 
239
  recon = vae.decode_latents(pred_latents)
240
  for res_frame in recon:
241
  res_frame_queue.put(res_frame)
242
  # Close the queue and sub-thread after all tasks are completed
243
  process_thread.join()
244
 
245
- if out_vid_name is not None:
 
 
 
 
 
 
 
 
 
246
  # optional
247
  cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {self.avatar_path}/temp.mp4"
248
  print(cmd_img2video)
@@ -256,20 +283,31 @@ class Avatar:
256
  os.remove(f"{self.avatar_path}/temp.mp4")
257
  shutil.rmtree(f"{self.avatar_path}/tmp")
258
  print(f"result is save to {output_vid}")
 
259
 
260
 
261
-
262
-
263
-
264
  if __name__ == "__main__":
265
  '''
266
  This script is used to simulate online chatting and applies necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
267
  '''
268
 
269
  parser = argparse.ArgumentParser()
270
- parser.add_argument("--inference_config", type=str, default="configs/inference/realtime.yaml")
271
- parser.add_argument("--fps", type=int, default=25)
272
- parser.add_argument("--batch_size", type=int, default=4)
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
  args = parser.parse_args()
275
 
@@ -291,5 +329,7 @@ if __name__ == "__main__":
291
  audio_clips = inference_config[avatar_id]["audio_clips"]
292
  for audio_num, audio_path in audio_clips.items():
293
  print("Inferring using:",audio_path)
294
- avatar.inference(audio_path, audio_num, args.fps)
295
-
 
 
 
22
  import time
23
 
24
  # load model weights
25
+ audio_processor, vae, unet, pe = load_all_model()
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
  timesteps = torch.tensor([0], device=device)
28
+ pe = pe.half()
29
+ vae.vae = vae.vae.half()
30
+ unet.model = unet.model.half()
31
 
32
  def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
33
  cap = cv2.VideoCapture(vid_path)
 
101
  osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
102
  self.prepare_material()
103
  else:
104
+ if not os.path.exists(self.avatar_path):
105
+ print(f"{self.avatar_id} does not exist, you should set preparation to True")
106
+ sys.exit()
107
+
108
  with open(self.avatar_info_path, "r") as f:
109
  avatar_info = json.load(f)
110
 
 
188
  torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path))
189
  #
190
 
191
+ def process_frames(self,
192
+ res_frame_queue,
193
+ video_len,
194
+ skip_save_images):
195
  print(video_len)
196
  while True:
197
  if self.idx>=video_len-1:
 
214
  #combine_frame = get_image(ori_frame,res_frame,bbox)
215
  combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
216
 
217
+ if skip_save_images is False:
218
+ cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png",combine_frame)
 
219
  self.idx = self.idx + 1
220
 
221
+ def inference(self,
222
+ audio_path,
223
+ out_vid_name,
224
+ fps,
225
+ skip_save_images):
226
  os.makedirs(self.avatar_path+'/tmp',exist_ok =True)
227
+ print("start inference")
228
  ############################################## extract audio feature ##############################################
229
+ start_time = time.time()
230
  whisper_feature = audio_processor.audio2feat(audio_path)
231
  whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
232
+ print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms")
233
  ############################################## inference batch by batch ##############################################
234
  video_num = len(whisper_chunks)
 
235
  res_frame_queue = queue.Queue()
236
  self.idx = 0
237
  # # Create a sub-thread and start it
238
+ process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue, video_num, skip_save_images))
239
  process_thread.start()
240
+
241
+ gen = datagen(whisper_chunks,
242
+ self.input_latent_list_cycle,
243
+ self.batch_size)
244
  start_time = time.time()
245
  res_frame_list = []
246
 
247
  for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/self.batch_size)))):
248
+ audio_feature_batch = torch.from_numpy(whisper_batch)
249
+ audio_feature_batch = audio_feature_batch.to(device=unet.device,
250
+ dtype=unet.model.dtype)
251
  audio_feature_batch = pe(audio_feature_batch)
252
+ latent_batch = latent_batch.to(dtype=unet.model.dtype)
253
+
254
+ pred_latents = unet.model(latent_batch,
255
+ timesteps,
256
+ encoder_hidden_states=audio_feature_batch).sample
257
  recon = vae.decode_latents(pred_latents)
258
  for res_frame in recon:
259
  res_frame_queue.put(res_frame)
260
  # Close the queue and sub-thread after all tasks are completed
261
  process_thread.join()
262
 
263
+ if args.skip_save_images is True:
264
+ print('Total process time of {} frames without saving images = {}s'.format(
265
+ video_num,
266
+ time.time()-start_time))
267
+ else:
268
+ print('Total process time of {} frames including saving images = {}s'.format(
269
+ video_num,
270
+ time.time()-start_time))
271
+
272
+ if out_vid_name is not None and args.skip_save_images is False:
273
  # optional
274
  cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {self.avatar_path}/temp.mp4"
275
  print(cmd_img2video)
 
283
  os.remove(f"{self.avatar_path}/temp.mp4")
284
  shutil.rmtree(f"{self.avatar_path}/tmp")
285
  print(f"result is save to {output_vid}")
286
+ print("\n")
287
 
288
 
 
 
 
289
  if __name__ == "__main__":
290
  '''
291
  This script is used to simulate online chatting and applies necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
292
  '''
293
 
294
  parser = argparse.ArgumentParser()
295
+ parser.add_argument("--inference_config",
296
+ type=str,
297
+ default="configs/inference/realtime.yaml",
298
+ )
299
+ parser.add_argument("--fps",
300
+ type=int,
301
+ default=25,
302
+ )
303
+ parser.add_argument("--batch_size",
304
+ type=int,
305
+ default=4,
306
+ )
307
+ parser.add_argument("--skip_save_images",
308
+ action="store_true",
309
+ help="Whether skip saving images for better generation speed calculation",
310
+ )
311
 
312
  args = parser.parse_args()
313
 
 
329
  audio_clips = inference_config[avatar_id]["audio_clips"]
330
  for audio_num, audio_path in audio_clips.items():
331
  print("Inferring using:",audio_path)
332
+ avatar.inference(audio_path,
333
+ audio_num,
334
+ args.fps,
335
+ args.skip_save_images)