Spaces:
Runtime error
Runtime error
<enhance>: support using float16 in inference to speed up
Browse files- README.md +7 -12
- musetalk/models/unet.py +3 -3
- musetalk/utils/utils.py +7 -4
- musetalk/whisper/audio2feature.py +5 -1
- scripts/inference.py +15 -5
- scripts/realtime_inference.py +66 -26
README.md
CHANGED
|
@@ -267,10 +267,8 @@ As a complete solution to virtual human generation, you are suggested to first a
|
|
| 267 |
|
| 268 |
Here, we provide the inference script. This script first applies necessary pre-processing such as face detection, face parsing and VAE encode in advance. During inference, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
|
| 269 |
|
| 270 |
-
Note that in this script, the generation time is also limited by I/O (e.g. saving images).
|
| 271 |
-
|
| 272 |
```
|
| 273 |
-
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml
|
| 274 |
```
|
| 275 |
configs/inference/realtime.yaml is the path to the real-time inference configuration file, including `preparation`, `video_path` , `bbox_shift` and `audio_clips`.
|
| 276 |
|
|
@@ -280,17 +278,14 @@ configs/inference/realtime.yaml is the path to the real-time inference configura
|
|
| 280 |
Inferring using: data/audio/yongen.wav
|
| 281 |
```
|
| 282 |
1. While MuseTalk is inferring, sub-threads can simultaneously stream the results to the users. The generation process can achieve 30fps+ on an NVIDIA Tesla V100.
|
| 283 |
-
```
|
| 284 |
-
2%|βββ | 3/141 [00:00<00:32, 4.30it/s] # inference process
|
| 285 |
-
Displaying the 6-th frame with FPS: 48.58 # display process
|
| 286 |
-
Displaying the 7-th frame with FPS: 48.74
|
| 287 |
-
Displaying the 8-th frame with FPS: 49.17
|
| 288 |
-
3%|ββββ | 4/141 [00:00<00:32, 4.21it/s]
|
| 289 |
-
```
|
| 290 |
1. Set `preparation` to `False` and run this script if you want to genrate more videos using the same avatar.
|
| 291 |
|
| 292 |
-
|
| 293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
|
| 295 |
# Acknowledgement
|
| 296 |
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).
|
|
|
|
| 267 |
|
| 268 |
Here, we provide the inference script. This script first applies necessary pre-processing such as face detection, face parsing and VAE encode in advance. During inference, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
|
| 269 |
|
|
|
|
|
|
|
| 270 |
```
|
| 271 |
+
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --batch_size 4
|
| 272 |
```
|
| 273 |
configs/inference/realtime.yaml is the path to the real-time inference configuration file, including `preparation`, `video_path` , `bbox_shift` and `audio_clips`.
|
| 274 |
|
|
|
|
| 278 |
Inferring using: data/audio/yongen.wav
|
| 279 |
```
|
| 280 |
1. While MuseTalk is inferring, sub-threads can simultaneously stream the results to the users. The generation process can achieve 30fps+ on an NVIDIA Tesla V100.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 281 |
1. Set `preparation` to `False` and run this script if you want to genrate more videos using the same avatar.
|
| 282 |
|
| 283 |
+
##### Note for Real-time inference
|
| 284 |
+
1. If you want to generate multiple videos using the same avatar/video, you can also use this script to **SIGNIFICANTLY** expedite the generation process.
|
| 285 |
+
1. In the previous script, the generation time is also limited by I/O (e.g. saving images). If you just want to test the generation speed without saving the images, you can run
|
| 286 |
+
```
|
| 287 |
+
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
|
| 288 |
+
```
|
| 289 |
|
| 290 |
# Acknowledgement
|
| 291 |
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch).
|
musetalk/models/unet.py
CHANGED
|
@@ -37,11 +37,11 @@ class UNet():
|
|
| 37 |
self.model = UNet2DConditionModel(**unet_config)
|
| 38 |
self.pe = PositionalEncoding(d_model=384)
|
| 39 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 40 |
-
|
| 41 |
-
self.model.load_state_dict(
|
| 42 |
if use_float16:
|
| 43 |
self.model = self.model.half()
|
| 44 |
self.model.to(self.device)
|
| 45 |
|
| 46 |
if __name__ == "__main__":
|
| 47 |
-
unet = UNet()
|
|
|
|
| 37 |
self.model = UNet2DConditionModel(**unet_config)
|
| 38 |
self.pe = PositionalEncoding(d_model=384)
|
| 39 |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 40 |
+
weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
|
| 41 |
+
self.model.load_state_dict(weights)
|
| 42 |
if use_float16:
|
| 43 |
self.model = self.model.half()
|
| 44 |
self.model.to(self.device)
|
| 45 |
|
| 46 |
if __name__ == "__main__":
|
| 47 |
+
unet = UNet()
|
musetalk/utils/utils.py
CHANGED
|
@@ -39,7 +39,10 @@ def get_video_fps(video_path):
|
|
| 39 |
video.release()
|
| 40 |
return fps
|
| 41 |
|
| 42 |
-
def datagen(whisper_chunks,
|
|
|
|
|
|
|
|
|
|
| 43 |
whisper_batch, latent_batch = [], []
|
| 44 |
for i, w in enumerate(whisper_chunks):
|
| 45 |
idx = (i+delay_frame)%len(vae_encode_latents)
|
|
@@ -48,14 +51,14 @@ def datagen(whisper_chunks,vae_encode_latents,batch_size=8,delay_frame = 0):
|
|
| 48 |
latent_batch.append(latent)
|
| 49 |
|
| 50 |
if len(latent_batch) >= batch_size:
|
| 51 |
-
whisper_batch = np.
|
| 52 |
latent_batch = torch.cat(latent_batch, dim=0)
|
| 53 |
yield whisper_batch, latent_batch
|
| 54 |
whisper_batch, latent_batch = [], []
|
| 55 |
|
| 56 |
# the last batch may smaller than batch size
|
| 57 |
if len(latent_batch) > 0:
|
| 58 |
-
whisper_batch = np.
|
| 59 |
latent_batch = torch.cat(latent_batch, dim=0)
|
| 60 |
|
| 61 |
-
yield whisper_batch, latent_batch
|
|
|
|
| 39 |
video.release()
|
| 40 |
return fps
|
| 41 |
|
| 42 |
+
def datagen(whisper_chunks,
|
| 43 |
+
vae_encode_latents,
|
| 44 |
+
batch_size=8,
|
| 45 |
+
delay_frame=0):
|
| 46 |
whisper_batch, latent_batch = [], []
|
| 47 |
for i, w in enumerate(whisper_chunks):
|
| 48 |
idx = (i+delay_frame)%len(vae_encode_latents)
|
|
|
|
| 51 |
latent_batch.append(latent)
|
| 52 |
|
| 53 |
if len(latent_batch) >= batch_size:
|
| 54 |
+
whisper_batch = np.stack(whisper_batch)
|
| 55 |
latent_batch = torch.cat(latent_batch, dim=0)
|
| 56 |
yield whisper_batch, latent_batch
|
| 57 |
whisper_batch, latent_batch = [], []
|
| 58 |
|
| 59 |
# the last batch may smaller than batch size
|
| 60 |
if len(latent_batch) > 0:
|
| 61 |
+
whisper_batch = np.stack(whisper_batch)
|
| 62 |
latent_batch = torch.cat(latent_batch, dim=0)
|
| 63 |
|
| 64 |
+
yield whisper_batch, latent_batch
|
musetalk/whisper/audio2feature.py
CHANGED
|
@@ -13,7 +13,11 @@ class Audio2Feature():
|
|
| 13 |
self.whisper_model_type = whisper_model_type
|
| 14 |
self.model = load_model(model_path) #
|
| 15 |
|
| 16 |
-
def get_sliced_feature(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
"""
|
| 18 |
Get sliced features based on a given index
|
| 19 |
:param feature_array:
|
|
|
|
| 13 |
self.whisper_model_type = whisper_model_type
|
| 14 |
self.model = load_model(model_path) #
|
| 15 |
|
| 16 |
+
def get_sliced_feature(self,
|
| 17 |
+
feature_array,
|
| 18 |
+
vid_idx,
|
| 19 |
+
audio_feat_length=[2,2],
|
| 20 |
+
fps=25):
|
| 21 |
"""
|
| 22 |
Get sliced features based on a given index
|
| 23 |
:param feature_array:
|
scripts/inference.py
CHANGED
|
@@ -16,12 +16,18 @@ from musetalk.utils.utils import load_all_model
|
|
| 16 |
import shutil
|
| 17 |
|
| 18 |
# load model weights
|
| 19 |
-
audio_processor,vae,unet,pe
|
| 20 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
timesteps = torch.tensor([0], device=device)
|
| 22 |
|
| 23 |
@torch.no_grad()
|
| 24 |
def main(args):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
inference_config = OmegaConf.load(args.inference_config)
|
| 26 |
print(inference_config)
|
| 27 |
for task_id in inference_config:
|
|
@@ -96,10 +102,11 @@ def main(args):
|
|
| 96 |
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
|
| 97 |
res_frame_list = []
|
| 98 |
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
audio_feature_batch = pe(audio_feature_batch)
|
|
|
|
| 103 |
|
| 104 |
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
| 105 |
recon = vae.decode_latents(pred_latents)
|
|
@@ -145,7 +152,10 @@ if __name__ == "__main__":
|
|
| 145 |
parser.add_argument("--use_saved_coord",
|
| 146 |
action="store_true",
|
| 147 |
help='use saved coordinate to save time')
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
args = parser.parse_args()
|
| 151 |
main(args)
|
|
|
|
| 16 |
import shutil
|
| 17 |
|
| 18 |
# load model weights
|
| 19 |
+
audio_processor, vae, unet, pe = load_all_model()
|
| 20 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 21 |
timesteps = torch.tensor([0], device=device)
|
| 22 |
|
| 23 |
@torch.no_grad()
|
| 24 |
def main(args):
|
| 25 |
+
global pe
|
| 26 |
+
if args.use_float16 is True:
|
| 27 |
+
pe = pe.half()
|
| 28 |
+
vae.vae = vae.vae.half()
|
| 29 |
+
unet.model = unet.model.half()
|
| 30 |
+
|
| 31 |
inference_config = OmegaConf.load(args.inference_config)
|
| 32 |
print(inference_config)
|
| 33 |
for task_id in inference_config:
|
|
|
|
| 102 |
gen = datagen(whisper_chunks,input_latent_list_cycle,batch_size)
|
| 103 |
res_frame_list = []
|
| 104 |
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
| 105 |
+
audio_feature_batch = torch.from_numpy(whisper_batch)
|
| 106 |
+
audio_feature_batch = audio_feature_batch.to(device=unet.device,
|
| 107 |
+
dtype=unet.model.dtype) # torch, B, 5*N,384
|
| 108 |
audio_feature_batch = pe(audio_feature_batch)
|
| 109 |
+
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
| 110 |
|
| 111 |
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
| 112 |
recon = vae.decode_latents(pred_latents)
|
|
|
|
| 152 |
parser.add_argument("--use_saved_coord",
|
| 153 |
action="store_true",
|
| 154 |
help='use saved coordinate to save time')
|
| 155 |
+
parser.add_argument("--use_float16",
|
| 156 |
+
action="store_true",
|
| 157 |
+
help="Whether use float16 to speed up inference",
|
| 158 |
+
)
|
| 159 |
|
| 160 |
args = parser.parse_args()
|
| 161 |
main(args)
|
scripts/realtime_inference.py
CHANGED
|
@@ -22,10 +22,12 @@ import queue
|
|
| 22 |
import time
|
| 23 |
|
| 24 |
# load model weights
|
| 25 |
-
audio_processor,vae,unet,pe
|
| 26 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 27 |
timesteps = torch.tensor([0], device=device)
|
| 28 |
-
|
|
|
|
|
|
|
| 29 |
|
| 30 |
def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
|
| 31 |
cap = cv2.VideoCapture(vid_path)
|
|
@@ -99,6 +101,10 @@ class Avatar:
|
|
| 99 |
osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
|
| 100 |
self.prepare_material()
|
| 101 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
with open(self.avatar_info_path, "r") as f:
|
| 103 |
avatar_info = json.load(f)
|
| 104 |
|
|
@@ -182,7 +188,10 @@ class Avatar:
|
|
| 182 |
torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path))
|
| 183 |
#
|
| 184 |
|
| 185 |
-
def process_frames(self,
|
|
|
|
|
|
|
|
|
|
| 186 |
print(video_len)
|
| 187 |
while True:
|
| 188 |
if self.idx>=video_len-1:
|
|
@@ -205,44 +214,62 @@ class Avatar:
|
|
| 205 |
#combine_frame = get_image(ori_frame,res_frame,bbox)
|
| 206 |
combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
|
| 207 |
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png",combine_frame)
|
| 211 |
self.idx = self.idx + 1
|
| 212 |
|
| 213 |
-
def inference(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
os.makedirs(self.avatar_path+'/tmp',exist_ok =True)
|
|
|
|
| 215 |
############################################## extract audio feature ##############################################
|
|
|
|
| 216 |
whisper_feature = audio_processor.audio2feat(audio_path)
|
| 217 |
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
|
|
|
| 218 |
############################################## inference batch by batch ##############################################
|
| 219 |
video_num = len(whisper_chunks)
|
| 220 |
-
print("start inference")
|
| 221 |
res_frame_queue = queue.Queue()
|
| 222 |
self.idx = 0
|
| 223 |
# # Create a sub-thread and start it
|
| 224 |
-
process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue,video_num))
|
| 225 |
process_thread.start()
|
| 226 |
-
|
| 227 |
-
gen = datagen(whisper_chunks,
|
| 228 |
-
|
|
|
|
| 229 |
start_time = time.time()
|
| 230 |
res_frame_list = []
|
| 231 |
|
| 232 |
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/self.batch_size)))):
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
audio_feature_batch = pe(audio_feature_batch)
|
| 237 |
-
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
| 239 |
recon = vae.decode_latents(pred_latents)
|
| 240 |
for res_frame in recon:
|
| 241 |
res_frame_queue.put(res_frame)
|
| 242 |
# Close the queue and sub-thread after all tasks are completed
|
| 243 |
process_thread.join()
|
| 244 |
|
| 245 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
# optional
|
| 247 |
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {self.avatar_path}/temp.mp4"
|
| 248 |
print(cmd_img2video)
|
|
@@ -256,20 +283,31 @@ class Avatar:
|
|
| 256 |
os.remove(f"{self.avatar_path}/temp.mp4")
|
| 257 |
shutil.rmtree(f"{self.avatar_path}/tmp")
|
| 258 |
print(f"result is save to {output_vid}")
|
|
|
|
| 259 |
|
| 260 |
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
if __name__ == "__main__":
|
| 265 |
'''
|
| 266 |
This script is used to simulate online chatting and applies necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
|
| 267 |
'''
|
| 268 |
|
| 269 |
parser = argparse.ArgumentParser()
|
| 270 |
-
parser.add_argument("--inference_config",
|
| 271 |
-
|
| 272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 273 |
|
| 274 |
args = parser.parse_args()
|
| 275 |
|
|
@@ -291,5 +329,7 @@ if __name__ == "__main__":
|
|
| 291 |
audio_clips = inference_config[avatar_id]["audio_clips"]
|
| 292 |
for audio_num, audio_path in audio_clips.items():
|
| 293 |
print("Inferring using:",audio_path)
|
| 294 |
-
avatar.inference(audio_path,
|
| 295 |
-
|
|
|
|
|
|
|
|
|
| 22 |
import time
|
| 23 |
|
| 24 |
# load model weights
|
| 25 |
+
audio_processor, vae, unet, pe = load_all_model()
|
| 26 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 27 |
timesteps = torch.tensor([0], device=device)
|
| 28 |
+
pe = pe.half()
|
| 29 |
+
vae.vae = vae.vae.half()
|
| 30 |
+
unet.model = unet.model.half()
|
| 31 |
|
| 32 |
def video2imgs(vid_path, save_path, ext = '.png',cut_frame = 10000000):
|
| 33 |
cap = cv2.VideoCapture(vid_path)
|
|
|
|
| 101 |
osmakedirs([self.avatar_path,self.full_imgs_path,self.video_out_path,self.mask_out_path])
|
| 102 |
self.prepare_material()
|
| 103 |
else:
|
| 104 |
+
if not os.path.exists(self.avatar_path):
|
| 105 |
+
print(f"{self.avatar_id} does not exist, you should set preparation to True")
|
| 106 |
+
sys.exit()
|
| 107 |
+
|
| 108 |
with open(self.avatar_info_path, "r") as f:
|
| 109 |
avatar_info = json.load(f)
|
| 110 |
|
|
|
|
| 188 |
torch.save(self.input_latent_list_cycle, os.path.join(self.latents_out_path))
|
| 189 |
#
|
| 190 |
|
| 191 |
+
def process_frames(self,
|
| 192 |
+
res_frame_queue,
|
| 193 |
+
video_len,
|
| 194 |
+
skip_save_images):
|
| 195 |
print(video_len)
|
| 196 |
while True:
|
| 197 |
if self.idx>=video_len-1:
|
|
|
|
| 214 |
#combine_frame = get_image(ori_frame,res_frame,bbox)
|
| 215 |
combine_frame = get_image_blending(ori_frame,res_frame,bbox,mask,mask_crop_box)
|
| 216 |
|
| 217 |
+
if skip_save_images is False:
|
| 218 |
+
cv2.imwrite(f"{self.avatar_path}/tmp/{str(self.idx).zfill(8)}.png",combine_frame)
|
|
|
|
| 219 |
self.idx = self.idx + 1
|
| 220 |
|
| 221 |
+
def inference(self,
|
| 222 |
+
audio_path,
|
| 223 |
+
out_vid_name,
|
| 224 |
+
fps,
|
| 225 |
+
skip_save_images):
|
| 226 |
os.makedirs(self.avatar_path+'/tmp',exist_ok =True)
|
| 227 |
+
print("start inference")
|
| 228 |
############################################## extract audio feature ##############################################
|
| 229 |
+
start_time = time.time()
|
| 230 |
whisper_feature = audio_processor.audio2feat(audio_path)
|
| 231 |
whisper_chunks = audio_processor.feature2chunks(feature_array=whisper_feature,fps=fps)
|
| 232 |
+
print(f"processing audio:{audio_path} costs {(time.time() - start_time) * 1000}ms")
|
| 233 |
############################################## inference batch by batch ##############################################
|
| 234 |
video_num = len(whisper_chunks)
|
|
|
|
| 235 |
res_frame_queue = queue.Queue()
|
| 236 |
self.idx = 0
|
| 237 |
# # Create a sub-thread and start it
|
| 238 |
+
process_thread = threading.Thread(target=self.process_frames, args=(res_frame_queue, video_num, skip_save_images))
|
| 239 |
process_thread.start()
|
| 240 |
+
|
| 241 |
+
gen = datagen(whisper_chunks,
|
| 242 |
+
self.input_latent_list_cycle,
|
| 243 |
+
self.batch_size)
|
| 244 |
start_time = time.time()
|
| 245 |
res_frame_list = []
|
| 246 |
|
| 247 |
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/self.batch_size)))):
|
| 248 |
+
audio_feature_batch = torch.from_numpy(whisper_batch)
|
| 249 |
+
audio_feature_batch = audio_feature_batch.to(device=unet.device,
|
| 250 |
+
dtype=unet.model.dtype)
|
| 251 |
audio_feature_batch = pe(audio_feature_batch)
|
| 252 |
+
latent_batch = latent_batch.to(dtype=unet.model.dtype)
|
| 253 |
+
|
| 254 |
+
pred_latents = unet.model(latent_batch,
|
| 255 |
+
timesteps,
|
| 256 |
+
encoder_hidden_states=audio_feature_batch).sample
|
| 257 |
recon = vae.decode_latents(pred_latents)
|
| 258 |
for res_frame in recon:
|
| 259 |
res_frame_queue.put(res_frame)
|
| 260 |
# Close the queue and sub-thread after all tasks are completed
|
| 261 |
process_thread.join()
|
| 262 |
|
| 263 |
+
if args.skip_save_images is True:
|
| 264 |
+
print('Total process time of {} frames without saving images = {}s'.format(
|
| 265 |
+
video_num,
|
| 266 |
+
time.time()-start_time))
|
| 267 |
+
else:
|
| 268 |
+
print('Total process time of {} frames including saving images = {}s'.format(
|
| 269 |
+
video_num,
|
| 270 |
+
time.time()-start_time))
|
| 271 |
+
|
| 272 |
+
if out_vid_name is not None and args.skip_save_images is False:
|
| 273 |
# optional
|
| 274 |
cmd_img2video = f"ffmpeg -y -v warning -r {fps} -f image2 -i {self.avatar_path}/tmp/%08d.png -vcodec libx264 -vf format=rgb24,scale=out_color_matrix=bt709,format=yuv420p -crf 18 {self.avatar_path}/temp.mp4"
|
| 275 |
print(cmd_img2video)
|
|
|
|
| 283 |
os.remove(f"{self.avatar_path}/temp.mp4")
|
| 284 |
shutil.rmtree(f"{self.avatar_path}/tmp")
|
| 285 |
print(f"result is save to {output_vid}")
|
| 286 |
+
print("\n")
|
| 287 |
|
| 288 |
|
|
|
|
|
|
|
|
|
|
| 289 |
if __name__ == "__main__":
|
| 290 |
'''
|
| 291 |
This script is used to simulate online chatting and applies necessary pre-processing such as face detection and face parsing in advance. During online chatting, only UNet and the VAE decoder are involved, which makes MuseTalk real-time.
|
| 292 |
'''
|
| 293 |
|
| 294 |
parser = argparse.ArgumentParser()
|
| 295 |
+
parser.add_argument("--inference_config",
|
| 296 |
+
type=str,
|
| 297 |
+
default="configs/inference/realtime.yaml",
|
| 298 |
+
)
|
| 299 |
+
parser.add_argument("--fps",
|
| 300 |
+
type=int,
|
| 301 |
+
default=25,
|
| 302 |
+
)
|
| 303 |
+
parser.add_argument("--batch_size",
|
| 304 |
+
type=int,
|
| 305 |
+
default=4,
|
| 306 |
+
)
|
| 307 |
+
parser.add_argument("--skip_save_images",
|
| 308 |
+
action="store_true",
|
| 309 |
+
help="Whether skip saving images for better generation speed calculation",
|
| 310 |
+
)
|
| 311 |
|
| 312 |
args = parser.parse_args()
|
| 313 |
|
|
|
|
| 329 |
audio_clips = inference_config[avatar_id]["audio_clips"]
|
| 330 |
for audio_num, audio_path in audio_clips.items():
|
| 331 |
print("Inferring using:",audio_path)
|
| 332 |
+
avatar.inference(audio_path,
|
| 333 |
+
audio_num,
|
| 334 |
+
args.fps,
|
| 335 |
+
args.skip_save_images)
|