HikariDawn commited on
Commit
5000b0a
·
1 Parent(s): 1577493

feat: initial push

Browse files
.gitignore ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.csv
2
+ *.mp4
3
+ *.png
4
+ *.jpg
5
+ *.err
6
+ *.txt
7
+ *.log
8
+ *.pyc
9
+ *.pth
10
+ *.DS_Store*
11
+ *.o
12
+ *.so
13
+ *.egg*
14
+ *.json
15
+ *.zip
16
+ *.jpeg
17
+ *.pkl
18
+ *.gif
19
+ *.pem
20
+ *.npy
21
+ *.sh
22
+
23
+
24
+ pretrained/*
25
+ checkpoints/*
26
+ preprocess/sam2_code
27
+
28
+ !preprocess/oneformer_code/oneformer/data/bpe_simple_vocab_16e6.txt
29
+ !config/*.json
30
+ !requirements.txt
31
+ !requirements/*
32
+ !__assets__/*
33
+ !__assets__/page/*
app.py ADDED
@@ -0,0 +1,893 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, shutil
2
+ import csv
3
+ import numpy as np
4
+ import ffmpeg
5
+ import cv2
6
+ import collections
7
+ import json
8
+ import math
9
+ import time
10
+ import imageio
11
+ import random
12
+ import ast
13
+ import gradio as gr
14
+ from omegaconf import OmegaConf
15
+ from PIL import Image
16
+ from segment_anything import SamPredictor, sam_model_registry
17
+
18
+
19
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
20
+ import torch
21
+ from torch.utils.data import DataLoader, Dataset
22
+ from torchvision import transforms
23
+ from diffusers import AutoencoderKLCogVideoX
24
+ from transformers import T5EncoderModel
25
+ from diffusers.utils import export_to_video, load_image
26
+
27
+
28
+ # Import files from the local fodler
29
+ root_path = os.path.abspath('.')
30
+ sys.path.append(root_path)
31
+ from pipelines.pipeline_cogvideox_i2v_motion_FrameINO import CogVideoXImageToVideoPipeline
32
+ from architecture.cogvideox_transformer_3d import CogVideoXTransformer3DModel
33
+ from data_loader.video_dataset_motion import VideoDataset_Motion
34
+ from architecture.transformer_wan import WanTransformer3DModel
35
+ from pipelines.pipeline_wan_i2v_motion_FrameINO import WanImageToVideoPipeline
36
+ from architecture.autoencoder_kl_wan import AutoencoderKLWan
37
+
38
+
39
+
40
+ MARKDOWN = \
41
+ """
42
+ <div align='center'>
43
+ <h1> Frame In-N-Out </h1> \
44
+ <h2 style='font-weight: 450; font-size: 1rem; margin-bottom: 1rem;'>\
45
+ <a href='https://kiteretsu77.github.io/BoyangWang/'>Boyang Wang</a>, <a href='https://xuweiyichen.github.io/'>Xuweiyi Chen</a>, <a href='http://mgadelha.me/'>Matheus Gadelha</a>, <a href='https://sites.google.com/site/zezhoucheng/'>Zezhou Cheng</a>\
46
+ </h2> \
47
+
48
+ <div style="display: flex; flex-wrap: wrap; justify-content: center; gap: 2rem; margin-bottom: 1rem;">
49
+ <!-- 第一行按钮 -->
50
+ <a href="https://arxiv.org/abs/2505.21491" target="_blank"
51
+ style="display: inline-flex; align-items: center; padding: 0.5rem 1rem; background-color: #f0f0f0; /* 浅灰色背景 */ color: #333; /* 深色文字 */ text-decoration: none; border-radius: 9999px; font-weight: 500; transition: background-color 0.3s;">
52
+ <span style="margin-right: 0.5rem;">📄</span> <!-- 使用文档图标 -->
53
+ <span>Paper</span>
54
+ </a>
55
+ <a href="https://github.com/UVA-Computer-Vision-Lab/FrameINO" target="_blank"
56
+ style="display: inline-flex; align-items: center; padding: 0.5rem 1rem; background-color: #f0f0f0; color: #333; text-decoration: none; border-radius: 9999px; font-weight: 500; transition: background-color 0.3s;">
57
+ <span style="margin-right: 0.5rem;">💻</span> <!-- 使用电脑图标 -->
58
+ <span>GitHub</span>
59
+ </a>
60
+ <a href="https://uva-computer-vision-lab.github.io/Frame-In-N-Out" target="_blank"
61
+ style="display: inline-flex; align-items: center; padding: 0.5rem 1rem; background-color: #f0f0f0; color: #333; text-decoration: none; border-radius: 9999px; font-weight: 500; transition: background-color 0.3s;">
62
+ <span style="margin-right: 0.5rem;">🤖</span>
63
+ <span>Project Page</span>
64
+ </a>
65
+ <a href="https://huggingface.co/collections/uva-cv-lab/frame-in-n-out" target="_blank"
66
+ style="display: inline-flex; align-items: center; padding: 0.5rem 1rem; background-color: #f0f0f0; color: #333; text-decoration: none; border-radius: 9999px; font-weight: 500; transition: background-color 0.3s;">
67
+ <span style="margin-right: 0.5rem;">🤗</span>
68
+ <span>HF Model and Data</span>
69
+ </a>
70
+ </div>
71
+
72
+
73
+ </div>
74
+
75
+ Frame In-N-Out expands the first frame condition to a broader canvas region by setting top left and bottom right expansion amount,
76
+ and users could provide motion trajectory to existing objects or provide breaking new identity to enter the scene with motion trajectory, or both. <br>
77
+ The model we used here is <b>Wan2.2-5B</b> trained on our Frame In-N-Out control mechanism.
78
+
79
+
80
+ <br>
81
+ <b>Easiest way:</b>
82
+ Choose one example and then simply click <b>Generate</b>.
83
+
84
+ <br>
85
+ <br>
86
+ ❗️❗️❗️Instruction Steps:<br>
87
+ 1️⃣ Upload your first frame image. Set the size you want to resize to for <b>Resized Height for Input Image</b> and <b>Resized Width for Input Image</b>. <br>
88
+ 2️⃣ Set your <b>canvas top left</b> and <b>bottom right expansion</b>. The combined height and width should be the multiplier of 32. <br>
89
+ PLEASE ENSURE that <b>Canvas HEIGHT = 704</b> and <b>Canvas WIDTH = 1280</b> for the best performance (current training resolution). <br>
90
+ 3️⃣ Click <b>Build the Canvas</b>. <br>
91
+ 4️⃣ Provide the trajectory of the main object in the canvas by clicking on the <b>Expanded Canvas</b>. <br>
92
+ 5️⃣ Provide the ID reference image and its trajectory (optional). Also, write a detailed <b>text prompt</b>. <br>
93
+ Click the <b>Generate</b> button to start the Video Generation. <br>
94
+
95
+
96
+ If **Frame In-N-Out** is helpful, please help star the [GitHub Repo](https://github.com/UVA-Computer-Vision-Lab/FrameINO?tab=readme-ov-file). Thanks!
97
+
98
+ """
99
+
100
+
101
+
102
+ # Color
103
+ all_color_codes = [(255, 0, 0), (255, 255, 0), (0, 255, 0), (0, 255, 255),
104
+ (255, 0, 255), (0, 0, 255), (128, 128, 128), (64, 224, 208),
105
+ (233, 150, 122)]
106
+ for _ in range(100): # Should not be over 100 colors
107
+ all_color_codes.append((random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)))
108
+
109
+ # Data Transforms
110
+ train_transforms = transforms.Compose(
111
+ [
112
+ transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
113
+ ]
114
+ )
115
+
116
+
117
+
118
+
119
+
120
+ ######################################################## CogVideoX #################################################################
121
+
122
+ # Path Setting
123
+ model_code_name = "CogVideox"
124
+ base_model_id = "zai-org/CogVideoX-5b-I2V"
125
+ transformer_ckpt_path = "uva-cv-lab/FrameINO_CogVideoX_Stage2_MotionINO_v1.0"
126
+
127
+ # Load Model
128
+ transformer = CogVideoXTransformer3DModel.from_pretrained(transformer_ckpt_path, torch_dtype=torch.float16)
129
+ text_encoder = T5EncoderModel.from_pretrained(base_model_id, subfolder="text_encoder", torch_dtype=torch.float16)
130
+ vae = AutoencoderKLCogVideoX.from_pretrained(base_model_id, subfolder="vae", torch_dtype=torch.float16)
131
+
132
+ # Create pipeline and run inference
133
+ pipe = CogVideoXImageToVideoPipeline.from_pretrained(
134
+ base_model_id,
135
+ text_encoder = text_encoder,
136
+ transformer = transformer,
137
+ vae = vae,
138
+ torch_dtype = torch.float16,
139
+ )
140
+ pipe.enable_model_cpu_offload()
141
+
142
+ #####################################################################################################################################
143
+
144
+
145
+
146
+
147
+ ######################################################## Wan2.2 5B #################################################################
148
+
149
+ # Path Setting
150
+ model_code_name = "Wan"
151
+ base_model_id = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
152
+ transformer_ckpt_path = "uva-cv-lab/FrameINO_Wan2.2_5B_Stage2_MotionINO_v1.5"
153
+
154
+
155
+ # Load model
156
+ print("Loading the model!")
157
+ transformer = WanTransformer3DModel.from_pretrained(transformer_ckpt_path, torch_dtype=torch.float16)
158
+ vae = AutoencoderKLWan.from_pretrained(base_model_id, subfolder="vae", torch_dtype=torch.float32)
159
+
160
+ # Create the pipeline
161
+ print("Loading the pipeline!")
162
+ pipe = WanImageToVideoPipeline.from_pretrained(base_model_id, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
163
+ pipe.to("cuda")
164
+ pipe.enable_model_cpu_offload()
165
+
166
+ #####################################################################################################################################
167
+
168
+
169
+
170
+
171
+ ########################################################## Other Auxiliary Func #################################################################
172
+
173
+ # # Init SAM model
174
+ model_type = "vit_h" #vit-h has the most number of paramter
175
+ sam_pretrained_path = "pretrained/sam_vit_h_4b8939.pth"
176
+ if not os.path.exists(sam_pretrained_path):
177
+ os.system("wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P pretrained/")
178
+ sam = sam_model_registry[model_type](checkpoint = sam_pretrained_path).to(device="cuda")
179
+ sam_predictor = SamPredictor(sam) # There is a lot of setting here
180
+
181
+ #####################################################################################################################################
182
+
183
+
184
+
185
+
186
+ # Examples Sample
187
+ def get_example():
188
+ case = [
189
+ [
190
+ '__assets__/horse.jpg',
191
+ 480,
192
+ 736,
193
+ 128,
194
+ 224,
195
+ 96,
196
+ 320,
197
+ '__assets__/sheep.png',
198
+ "A brown horse with a black mane walks to the right on a wooden path in a green forest, and then a white sheep enters from the left and walks toward it. Natural daylight, realistic texture, smooth motion, cinematic focus, 4K detail.",
199
+ [[[[299, 241], [390, 236], [461, 245], [521, 249], [565, 240], [612, 246], [666, 245]], [[449, 224], [488, 212], [512, 206], [531, 209], [552, 202], [581, 204], [609, 210], [657, 206], [703, 202], [716, 211]]], [[[24, 305], [104, 300], [167, 299], [219, 303], [270, 296], [295, 304]]]],
200
+ ],
201
+
202
+ [
203
+ '__assets__/cup.jpg',
204
+ 448,
205
+ 736,
206
+ 256,
207
+ 64,
208
+ 0,
209
+ 480,
210
+ '__assets__/hand2.png',
211
+ "A human hand reaches into the frame, gently grabbing the black metal cup with a golden character design on the front, lifting it off the table and taking it away.",
212
+ [[[[565, 324], [473, 337], [386, 345], [346, 340], [339, 324], [352, 212], [328, 114], [328, 18], [348, 0]]]],
213
+ ],
214
+
215
+ [
216
+ '__assets__/grass.jpg',
217
+ 512,
218
+ 800,
219
+ 64,
220
+ 64,
221
+ 160,
222
+ 416,
223
+ '__assets__/dog.png',
224
+ "A fluffy, adorable puppy joyfully sprints onto the bright green grass, its fur bouncing with each step as sunlight highlights its soft coat. The scene takes place in a peaceful park filled with tall trees casting gentle shadows across the lawn. After dashing forward with enthusiasm, the puppy slows to a happy trot, continuing farther ahead into the deeper area of the park, disappearing toward the more shaded grass beneath the trees.",
225
+ [[[[600, 412], [512, 394], [408, 358], [333, 336], [270, 313], [259, 260], [236, 222], [231, 180]], [[592, 392], [295, 305], [256, 217], [243, 163]]]],
226
+ ],
227
+
228
+ [
229
+ '__assets__/man_scene.jpg',
230
+ 576,
231
+ 1024,
232
+ 64,
233
+ 32,
234
+ 64,
235
+ 224,
236
+ None,
237
+ "A single hiker, equipped with a backpack, walks toward the right side of a rugged mountainside trail. The bright sunlight highlights the pale rocky terrain around him, while massive stone cliffs loom in the background. Sparse patches of grass and scattered boulders sit along the path, emphasizing the isolation and vastness of the mountain environment as he steadily continues his journey.",
238
+ [[[[342, 247], [415, 247], [478, 262], [518, 271], [570, 275], [613, 283], [646, 308], [690, 307], [705, 325]], [[349, 227], [461, 232], [536, 254], [595, 252], [638, 269], [691, 289], [715, 291]], [[341, 283], [415, 291], [500, 316], [590, 317], [632, 354], [675, 362], [711, 372]]]],
239
+ ]
240
+
241
+ ]
242
+ return case
243
+
244
+
245
+
246
+
247
+ def on_example_click(
248
+ input_image, resized_height, resized_width,
249
+ top_left_height, top_left_width, bottom_right_height, bottom_right_width,
250
+ identity_image, text_prompt, traj_lists,
251
+ ):
252
+
253
+ # Convert
254
+ traj_lists = ast.literal_eval(traj_lists)
255
+ # Note: No need for the rest like resized_width and resized_height, because these will be replaced in function
256
+
257
+
258
+ # Sequentially build the canvas (We don't accept the empty traj_lists & traj_instance_idx returned by build_canvas)
259
+ visual_canvas, initial_visual_canvas, inference_canvas, _, _ = build_canvas(input_image, resized_height, resized_width, top_left_height, top_left_width, bottom_right_height, bottom_right_width)
260
+
261
+
262
+ # Sequentially load the Trajs of all instances on the canvas
263
+ visual_canvas, traj_instance_idx = fn_vis_all_instance_traj(visual_canvas, traj_lists)
264
+
265
+
266
+ return visual_canvas, initial_visual_canvas, inference_canvas, traj_instance_idx
267
+
268
+
269
+
270
+ def build_canvas(input_image_path, resized_height, resized_width, top_left_height, top_left_width, bottom_right_height, bottom_right_width):
271
+
272
+ # Init
273
+ canvas_color = (250, 249, 246) # This color is like white color used in painting paper
274
+
275
+
276
+ # Convert the string to integer
277
+ if not resized_height.isdigit():
278
+ raise gr.Error("resized_height must be integer input!")
279
+ resized_height = int(resized_height)
280
+
281
+ if not resized_width.isdigit():
282
+ raise gr.Error("resized_width must be integer input!")
283
+ resized_width = int(resized_width)
284
+
285
+ if not top_left_height.isdigit():
286
+ raise gr.Error("top_left_height must be integer input!")
287
+ top_left_height = int(top_left_height)
288
+
289
+ if not top_left_width.isdigit():
290
+ raise gr.Error("top_left_width must be integer input!")
291
+ top_left_width = int(top_left_width)
292
+
293
+ if not bottom_right_height.isdigit():
294
+ raise gr.Error("bottom_right_height must be integer input!")
295
+ bottom_right_height = int(bottom_right_height)
296
+
297
+ if not bottom_right_width.isdigit():
298
+ raise gr.Error("bottom_right_width must be integer input!")
299
+ bottom_right_width = int(bottom_right_width)
300
+
301
+
302
+
303
+ # Read the original image and preprare the placeholder
304
+ first_frame_img = np.uint8(np.asarray(Image.open(input_image_path))) # NOTE: this is BGR form, be careful for the later cropping process for ID Reference
305
+ # print("first_frame_img shape is ", first_frame_img.shape)
306
+
307
+ # Resize to a uniform resolution
308
+ first_frame_img = cv2.resize(first_frame_img, (resized_width, resized_height), interpolation = cv2.INTER_AREA)
309
+
310
+ # Expand to Outside Region to form the Canvas
311
+ expand_height = resized_height + top_left_height + bottom_right_height
312
+ expand_width = resized_width + top_left_width + bottom_right_width
313
+ inference_canvas = np.uint8(np.zeros((expand_height, expand_width, 3))) # Whole Black Canvas, same as other inference
314
+ visual_canvas = np.full((expand_height, expand_width, 3), canvas_color, dtype=np.uint8)
315
+
316
+
317
+ # Sanity Check
318
+ if expand_height % 32 != 0:
319
+ raise gr.Error("The Height of resized_height + top_left_height + bottom_right_height must be divisible by 32!")
320
+ if expand_width % 32 != 0:
321
+ raise gr.Error("The Width of resized_width + top_left_width + bottom_right_width must be divisible by 32!")
322
+
323
+
324
+ # Draw the Region Box Region (Original Resolution)
325
+ bottom_len = inference_canvas.shape[0] - bottom_right_height
326
+ right_len = inference_canvas.shape[1] - bottom_right_width
327
+ inference_canvas[top_left_height:bottom_len, top_left_width:right_len, :] = first_frame_img
328
+ visual_canvas[top_left_height:bottom_len, top_left_width:right_len, :] = first_frame_img
329
+
330
+
331
+ # Resize to the uniform height and width
332
+ visual_canvas = cv2.resize(visual_canvas, (uniform_width, uniform_height), interpolation = cv2.INTER_AREA)
333
+
334
+
335
+
336
+ # Return the visual_canvas (for visualizaiton) and canvas map
337
+ # Corresponds to: visual_canvas, initial_visual_canvas, inference_canvas, traj_instance_idx, traj_lists
338
+ return visual_canvas, visual_canvas.copy(), inference_canvas, 0, [ [ [] ] ] # The last two is initialized with the trajectory instance idx and trajectory list
339
+
340
+
341
+
342
+
343
+ def process_points(traj_list, num_frames=49):
344
+
345
+
346
+ if len(traj_list) < 2: # First point
347
+ return [traj_list[0]] * num_frames
348
+
349
+ elif len(traj_list) >= num_frames:
350
+ raise gr.Info("The number of trajectory points is more than 49 limits, we will do cropping!")
351
+ skip = len(traj_list) // num_frames
352
+ return traj_list[::skip][: num_frames - 1] + traj_list[-1:]
353
+
354
+ else:
355
+
356
+ insert_num = num_frames - len(traj_list)
357
+ insert_num_dict = {}
358
+ interval = len(traj_list) - 1
359
+ n = insert_num // interval
360
+ m = insert_num % interval
361
+
362
+ for i in range(interval):
363
+ insert_num_dict[i] = n
364
+
365
+ for i in range(m):
366
+ insert_num_dict[i] += 1
367
+
368
+ res = []
369
+ for i in range(interval):
370
+ insert_points = []
371
+ x0, y0 = traj_list[i]
372
+ x1, y1 = traj_list[i + 1]
373
+
374
+ delta_x = x1 - x0
375
+ delta_y = y1 - y0
376
+ for j in range(insert_num_dict[i]):
377
+ x = x0 + (j + 1) / (insert_num_dict[i] + 1) * delta_x
378
+ y = y0 + (j + 1) / (insert_num_dict[i] + 1) * delta_y
379
+ insert_points.append([int(x), int(y)])
380
+
381
+ res += traj_list[i : i + 1] + insert_points
382
+ res += traj_list[-1:]
383
+
384
+ # return
385
+ return res
386
+
387
+
388
+
389
+ def fn_vis_realtime_traj(visual_canvas, traj_list, traj_instance_idx): # Visualize the traj on canvas
390
+
391
+ # Process Points
392
+ points = process_points(traj_list)
393
+
394
+ # Draw straight line to connect
395
+ for i in range(len(points) - 1):
396
+ p = points[i]
397
+ p1 = points[i + 1]
398
+ cv2.line(visual_canvas, p, p1, all_color_codes[traj_instance_idx], 5)
399
+
400
+ return visual_canvas
401
+
402
+
403
+ def fn_vis_all_instance_traj(visual_canvas, traj_lists): # Visualize all traj from all instances on canvas
404
+
405
+ for traj_instance_idx, traj_list_instance in enumerate(traj_lists):
406
+ for traj_list_line in traj_list_instance:
407
+ visual_canvas = fn_vis_realtime_traj(visual_canvas, traj_list_line, traj_instance_idx)
408
+
409
+ return visual_canvas, traj_instance_idx # Also return the instance idx
410
+
411
+
412
+ def add_traj_point(
413
+ visual_canvas,
414
+ traj_lists,
415
+ traj_instance_idx,
416
+ evt: gr.SelectData,
417
+ ): # Add new Traj and then visualize
418
+
419
+ # Convert
420
+ traj_lists = ast.literal_eval(traj_lists)
421
+
422
+ # Mark New Trajectory Key Point
423
+ hotizontal, vertical = evt.index
424
+
425
+ # traj_lists data structure is: (Num of Instnace, Num of Trajecotries, Num of Points, [X, Y])
426
+ traj_lists[-1][-1].append( [int(hotizontal), int(vertical)] )
427
+
428
+ # Draw new trajectory on the Canvas image
429
+ visual_canvas = fn_vis_realtime_traj(visual_canvas, traj_lists[-1][-1], traj_instance_idx)
430
+
431
+
432
+ # Return New Traj Marked Canvas image
433
+ return visual_canvas, traj_lists
434
+
435
+
436
+
437
+ def clear_traj_points(initial_visual_canvas):
438
+
439
+
440
+ return initial_visual_canvas.copy(), 0, [ [ [] ] ] # 1sr One is the initial state canvas; 2nd one is the traj instance idx; 3rd one is the traj list (with the same data structure)
441
+
442
+
443
+ def traj_point_update(traj_lists):
444
+
445
+ # Convert
446
+ traj_lists = ast.literal_eval(traj_lists)
447
+
448
+ # Append on the last trajecotry line
449
+ traj_lists[-1].append([])
450
+
451
+ return traj_lists
452
+
453
+
454
+
455
+ def traj_instance_update(traj_instance_idx, traj_lists):
456
+
457
+ # Convert
458
+ traj_lists = ast.literal_eval(traj_lists)
459
+
460
+ # Update one index
461
+ if traj_instance_idx >= len(all_color_codes):
462
+ raise gr.Error("The trajectory instance number is over the limit!")
463
+
464
+ # Add one for the traj instance
465
+ traj_instance_idx = traj_instance_idx + 1
466
+
467
+ # Append a new empty list to the traj lists
468
+ traj_lists.append([[]])
469
+
470
+ # Reutn
471
+ return traj_instance_idx, traj_lists
472
+
473
+
474
+
475
+ def sample_traj_by_length(points, num_samples):
476
+ # Sample points evenly from traj based on the euclidean distance
477
+
478
+ pts = np.array(points, dtype=float) # shape (M, 2)
479
+
480
+ # 1) 每段长度
481
+ seg = pts[1:] - pts[:-1]
482
+ seg_len = np.sqrt((seg**2).sum(axis=1)) # shape (M-1,)
483
+
484
+ # 2) 累积长度
485
+ cum = np.cumsum(seg_len)
486
+ total_length = cum[-1]
487
+
488
+ # 3) 目标等距长度位置
489
+ target = np.linspace(0, total_length, num_samples)
490
+
491
+ res = []
492
+ for t in target:
493
+ # 4) 找到它落在哪一段
494
+ idx = np.searchsorted(cum, t)
495
+ if idx == 0:
496
+ prev = 0.
497
+ else:
498
+ prev = cum[idx-1]
499
+
500
+ # 5) 在该段内插值
501
+ ratio = (t - prev) / seg_len[idx]
502
+ p = pts[idx] * ratio + pts[idx+1] * (1-ratio) # careful: direction reversed?
503
+ # Actually want: start*(1-ratio) + end*ratio
504
+ p = pts[idx] * (1 - ratio) + pts[idx+1] * ratio
505
+ res.append(p)
506
+ return np.array(res)
507
+
508
+
509
+
510
+ def inference(inference_canvas, visual_canvas, text_prompt, traj_lists, main_reference_img,
511
+ resized_height, resized_width, top_left_height, top_left_width, bottom_right_height, bottom_right_width):
512
+
513
+ # TODO: enhance the text prompt by Qwen3-VL-32B?
514
+
515
+
516
+ # Convert
517
+ resized_height = int(resized_height)
518
+ resized_width = int(resized_width)
519
+ top_left_height = int(top_left_height)
520
+ top_left_width = int(top_left_width)
521
+ bottom_right_height = int(bottom_right_height)
522
+ bottom_right_width = int(bottom_right_width)
523
+ traj_lists = ast.literal_eval(traj_lists)
524
+
525
+
526
+
527
+ # Init Some Fixed Setting
528
+ if model_code_name == "Wan":
529
+ config_path = "config/train_wan_motion_FrameINO.yaml"
530
+ dot_radius = 7
531
+ num_frames = 81
532
+ elif model_code_name == "CogVideoX":
533
+ config_path = "config/train_cogvideox_i2v_motion_FrameINO.yaml"
534
+ dot_radius = 6
535
+ num_frames = 49
536
+ config = OmegaConf.load(config_path)
537
+
538
+
539
+ # Prepare tmp folders
540
+ print()
541
+ store_folder_path = "tmp_app_example_" + str(int(time.time()))
542
+ if os.path.exists(store_folder_path):
543
+ shutil.rmtree(store_folder_path)
544
+ os.makedirs(store_folder_path)
545
+
546
+
547
+ # Write the visual canvas
548
+ visual_canvas_store_path = os.path.join(store_folder_path, "visual_canvas.png")
549
+ cv2.imwrite( visual_canvas_store_path, cv2.cvtColor(visual_canvas, cv2.COLOR_BGR2RGB) )
550
+
551
+
552
+
553
+ # Resize the map
554
+ canvas_width = resized_width + top_left_width + bottom_right_width
555
+ canvas_height = resized_height + top_left_height + bottom_right_height
556
+ # inference_canvas = cv2.resize(visual_canvas, (canvas_width, canvas_height), interpolation = cv2.INTER_AREA)
557
+ print("Canvas Shape is", str(canvas_height) + "x" + str(canvas_width) )
558
+
559
+
560
+ # TODO: 还要去enhance这个text prompt要跟QWen的保持一致的complexity的感觉。。。
561
+
562
+ # Save the text prompt
563
+ print("Text Prompt is", text_prompt)
564
+ with open(os.path.join(store_folder_path, 'text_prompt.txt'), 'w') as file:
565
+ file.write(text_prompt)
566
+
567
+
568
+ ################################################## Motion Trajectory Condition #####################################################
569
+
570
+ # #Prepare the points in the linear way
571
+ full_pred_tracks = [[] for _ in range(num_frames)]
572
+ ID_tensor = None
573
+
574
+ # Iterate all tracking information for all objects
575
+ print("traj_lists is", traj_lists)
576
+ for instance_idx, traj_list_per_object in enumerate(traj_lists):
577
+
578
+ # Iterate all trajectory lines in one instance
579
+ for traj_idx, single_trajectory in enumerate(traj_list_per_object):
580
+
581
+ # Sanity Check
582
+ if len(single_trajectory) < 2:
583
+ raise gr.Error("One of the trajectory provided is too short!")
584
+
585
+
586
+ # Sampled the point based on the Euclidean distance
587
+ sampled_points = sample_traj_by_length(single_trajectory, num_frames)
588
+
589
+
590
+ # Iterate all points
591
+ temporal_idx = 0
592
+ for (raw_point_x, raw_point_y) in sampled_points:
593
+
594
+ # Scale the point coordinate to the Infernece Size (Realistic Canvas size)
595
+ point_x, point_y = int(raw_point_x * canvas_width / uniform_width), int(raw_point_y * canvas_height / uniform_height) # Clicking on the board is with respect to the Uniform Preset Height and Width
596
+
597
+ if traj_idx == 0: # Needs to init the list in list
598
+ full_pred_tracks[temporal_idx].append( [] )
599
+ full_pred_tracks[temporal_idx][-1].append( (point_x, point_y) ) # [-1] and [instance_idx] should have the same effect
600
+ temporal_idx += 1
601
+
602
+
603
+ # Create the traj tensor
604
+ traj_tensor, traj_imgs_np, _, img_with_traj = VideoDataset_Motion.prepare_traj_tensor(
605
+ full_pred_tracks, canvas_height, canvas_width,
606
+ [], dot_radius, canvas_width, canvas_height,
607
+ idx=0, first_frame_img = inference_canvas
608
+ )
609
+
610
+
611
+ # Store Trajectory
612
+ imageio.mimsave(os.path.join(store_folder_path, "traj_video.mp4"), traj_imgs_np, fps=8)
613
+
614
+ ######################################################################################################################################################
615
+
616
+
617
+
618
+ ########################################## Prepare the Identity Reference Condition #####################################################
619
+
620
+
621
+ # ID reference preparation
622
+ if main_reference_img is not None:
623
+ print("We have an ID reference being used!")
624
+
625
+ # Fetch
626
+ ref_h, ref_w, _ = main_reference_img.shape
627
+
628
+
629
+ # Using breakpoint to extract the points
630
+ sam_predictor.set_image(np.uint8(main_reference_img))
631
+
632
+
633
+ # Define the sample point
634
+ sam_points = [(ref_w//2, ref_h//2)] # We don't need that many points to express [:len(traj_points)//2]
635
+
636
+
637
+ # Reverse traj_points
638
+ positive_point_cords = np.array(sam_points)
639
+ positive_point_labels = np.ones(len(positive_point_cords))
640
+
641
+ # Predict the mask based on the point and bounding box designed
642
+ masks, scores, logits = sam_predictor.predict(
643
+ point_coords = positive_point_cords,
644
+ point_labels = positive_point_labels,
645
+ multimask_output = False,
646
+ )
647
+ mask = masks[0]
648
+ main_reference_img[mask == False] = 0 # Merge the mask the first first frame
649
+
650
+
651
+ # Resize to the same resolution as the first frame
652
+ scale_h = canvas_height / max(ref_h, ref_w)
653
+ scale_w = canvas_width / max(ref_h, ref_w)
654
+ new_h, new_w = int(ref_h * scale_h), int(ref_w * scale_w)
655
+ main_reference_img = cv2.resize(main_reference_img, (new_w, new_h), interpolation = cv2.INTER_AREA)
656
+
657
+ # Calculate padding amounts on all direction
658
+ pad_height1 = (canvas_height - main_reference_img.shape[0]) // 2
659
+ pad_height2 = canvas_height - main_reference_img.shape[0] - pad_height1
660
+ pad_width1 = (canvas_width - main_reference_img.shape[1]) // 2
661
+ pad_width2 = canvas_width - main_reference_img.shape[1] - pad_width1
662
+
663
+ # Apply padding to same resolution as the training farmes
664
+ main_reference_img = np.pad(
665
+ main_reference_img,
666
+ ((pad_height1, pad_height2), (pad_width1, pad_width2), (0, 0)),
667
+ mode = 'constant',
668
+ constant_values = 0
669
+ )
670
+
671
+ cv2.imwrite(os.path.join(store_folder_path, "ID.png"), cv2.cvtColor(main_reference_img, cv2.COLOR_BGR2RGB))
672
+
673
+ elif main_reference_img is None:
674
+ # Whole Black Color placeholder
675
+ main_reference_img = np.uint8(np.zeros((canvas_height, canvas_width, 3)))
676
+
677
+
678
+ # Convert to tensor
679
+ ID_tensor = torch.tensor(main_reference_img)
680
+ ID_tensor = train_transforms(ID_tensor).permute(2, 0, 1).contiguous()
681
+
682
+ if model_code_name == "Wan": # Needs to be the shape (B, C, F, H, W)
683
+ ID_tensor = ID_tensor.unsqueeze(0).unsqueeze(2)
684
+
685
+ ###############################################################################################################################################
686
+
687
+
688
+
689
+ ############################################# Call the Inference Pipeline ##########################################################
690
+
691
+ image = Image.fromarray(inference_canvas)
692
+
693
+ if model_code_name == "Wan":
694
+ video = pipe(
695
+ image = image,
696
+ prompt = text_prompt, negative_prompt = "", # Empty string as negative text prompt
697
+ traj_tensor = traj_tensor, # Should be shape (F, C, H, W)
698
+ ID_tensor = ID_tensor, # Should be shape (B, C, F, H, W)
699
+ height = canvas_height, width = canvas_width, num_frames = num_frames,
700
+ num_inference_steps = 50, # 38 is also ok
701
+ guidance_scale = 5.0,
702
+ ).frames[0]
703
+
704
+ elif model_code_name == "CogVideoX":
705
+ video = pipe(
706
+ image = image,
707
+ prompt = text_prompt,
708
+ traj_tensor = traj_tensor,
709
+ ID_tensor = ID_tensor,
710
+ height = canvas_height, width = canvas_width, num_frames = len(traj_tensor),
711
+ guidance_scale = 6, use_dynamic_cfg = False,
712
+ num_inference_steps = 50,
713
+ add_ID_reference_augment_noise = True,
714
+ ).frames[0]
715
+
716
+
717
+
718
+ # Store the reuslt
719
+ export_to_video(video, os.path.join(store_folder_path, "generated_video_padded.mp4"), fps=8)
720
+
721
+
722
+
723
+ # Save frames
724
+ print("Writing as Frames")
725
+ video_file_path = os.path.join(store_folder_path, "generated_video.mp4")
726
+ writer = imageio.get_writer(video_file_path, fps = 8)
727
+ for frame_idx, frame in enumerate(video):
728
+
729
+ # Extract Unpadded version
730
+ # frame = np.uint8(frame)
731
+ if model_code_name == "CogVideoX":
732
+ frame = np.asarray(frame) # PIL to RGB
733
+ bottom_right_y = frame.shape[0] - bottom_right_height
734
+ bottom_right_x = frame.shape[1] - bottom_right_width
735
+ cropped_region_frame = np.uint8(frame[top_left_height: bottom_right_y, top_left_width : bottom_right_x] * 255)
736
+ writer.append_data(cropped_region_frame)
737
+
738
+ writer.close()
739
+
740
+ #####################################################################################################################################
741
+
742
+
743
+ return gr.update(value = video_file_path, width = uniform_width, height = uniform_height)
744
+
745
+
746
+
747
+
748
+ if __name__ == '__main__':
749
+
750
+
751
+ # Global Setting
752
+ uniform_height = 480 # Visual Canvas as 480x720 is decent
753
+ uniform_width = 720
754
+
755
+
756
+ # Draw the Website
757
+ block = gr.Blocks().queue(max_size=10)
758
+ with block:
759
+
760
+
761
+ with gr.Row():
762
+ gr.Markdown(MARKDOWN)
763
+
764
+ with gr.Row(elem_classes=["container"]):
765
+
766
+ with gr.Column(scale=2):
767
+ # Input image
768
+ input_image = gr.Image(type="filepath", label="Input Image 🖼️ ")
769
+ # uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=200)
770
+
771
+ with gr.Column(scale=2):
772
+
773
+ # Input image
774
+ resized_height = gr.Textbox(label="Resized Height for Input Image")
775
+ resized_width = gr.Textbox(label="Resized Width for Input Image")
776
+ # gr.Number(value=unit_height, label="Fixed", interactive=False)
777
+ # gr.Number(value=unit_height * 1.77777, label="Fixed", interactive=False)
778
+
779
+ # Input the expansion factor
780
+ top_left_height = gr.Textbox(label="Top-Left Expand Height")
781
+ top_left_width = gr.Textbox(label="Top-Left Expand Width")
782
+ bottom_right_height = gr.Textbox(label="Bottom-Right Expand Height")
783
+ bottom_right_width = gr.Textbox(label="Bottom-Right Expand Width")
784
+
785
+ # Button
786
+ build_canvas_btn = gr.Button(value="Build the Canvas")
787
+
788
+
789
+ with gr.Row():
790
+
791
+ with gr.Column(scale=3):
792
+ with gr.Row(scale=3):
793
+ visual_canvas = gr.Image(height = uniform_height, width = uniform_width, type="numpy", label='Expanded Canvas 🖼️ ')
794
+ # inference_canvas = gr.Image(height = uniform_height, width = uniform_width, type="numpy")
795
+ # inference_canvas = None
796
+
797
+ with gr.Row(scale=1):
798
+ # TODO: 还差clear traj的选择
799
+ add_point = gr.Button(value = "Add New Traj Line (Same Obj)", visible = True) # Add new trajectory for the same instance
800
+ add_traj = gr.Button(value = "Add New Instance (New Obj, including new ID)", visible = True)
801
+ clear_traj_button = gr.Button("Clear All Traj", visible=True)
802
+
803
+ with gr.Column(scale=2):
804
+
805
+ with gr.Row(scale=2):
806
+ identity_image = gr.Image(type="numpy", label="Identity Reference (SAM on center point only) 🖼️ ")
807
+
808
+ with gr.Row(scale=2):
809
+ text_prompt = gr.Textbox(label="Text Prompt", lines=3)
810
+
811
+
812
+ with gr.Row():
813
+
814
+ # Button
815
+ generation_btn = gr.Button(value="Generate")
816
+
817
+
818
+ with gr.Row():
819
+ generated_video = gr.Video(value = None, label="Generated Video", show_label = True, height = uniform_height, width = uniform_width)
820
+
821
+
822
+
823
+ ################################################################## Click + Select + Any Effect Area ###########################################################################
824
+
825
+ # Init some states that will be supporting purposes
826
+ traj_lists = gr.Textbox(label="Trajectory", visible = False) # gr.State(None) # Data Structure is: (Number of Instance, Number of Trajectories, Points) Init as [ [ [] ] ]
827
+ inference_canvas = gr.State(None)
828
+ traj_instance_idx = gr.State(0)
829
+ initial_visual_canvas = gr.State(None) # gr.Image(height = uniform_height, width = uniform_width, type="numpy", label='Canvas Expanded Image (Initial State)') # This is the initila visual, used to load back in clearing
830
+
831
+
832
+ # Canvas Click
833
+ build_canvas_btn.click(
834
+ build_canvas,
835
+ inputs = [input_image, resized_height, resized_width, top_left_height, top_left_width, bottom_right_height, bottom_right_width],
836
+ outputs = [visual_canvas, initial_visual_canvas, inference_canvas, traj_instance_idx, traj_lists] # inference_canvas is used for inference; visual_canvas is for gradio visualization
837
+ )
838
+
839
+
840
+ # Draw Trajectory for each click on the canvas
841
+ visual_canvas.select(
842
+ fn = add_traj_point,
843
+ inputs = [visual_canvas, traj_lists, traj_instance_idx],
844
+ outputs = [visual_canvas, traj_lists]
845
+ )
846
+
847
+
848
+ # Add new Trajectory
849
+ add_point.click(
850
+ fn = traj_point_update,
851
+ inputs = [traj_lists],
852
+ outputs = [traj_lists],
853
+ )
854
+ add_traj.click(
855
+ fn = traj_instance_update,
856
+ inputs = [traj_instance_idx, traj_lists],
857
+ outputs = [traj_instance_idx, traj_lists],
858
+ )
859
+
860
+ # Clean all the traj points
861
+ clear_traj_button.click(
862
+ clear_traj_points,
863
+ [initial_visual_canvas],
864
+ [visual_canvas, traj_instance_idx, traj_lists],
865
+ )
866
+
867
+
868
+ # Inference Generation
869
+ generation_btn.click(
870
+ inference,
871
+ inputs = [inference_canvas, visual_canvas, text_prompt, traj_lists, identity_image, resized_height, resized_width, top_left_height, top_left_width, bottom_right_height, bottom_right_width],
872
+ outputs = [generated_video],
873
+ )
874
+
875
+
876
+
877
+
878
+ # Load Examples
879
+ with gr.Row(elem_classes=["container"]):
880
+ gr.Examples(
881
+ examples = get_example(),
882
+ inputs = [input_image, resized_height, resized_width, top_left_height, top_left_width, bottom_right_height, bottom_right_width, identity_image, text_prompt, traj_lists],
883
+ run_on_click = True,
884
+ fn = on_example_click,
885
+ outputs = [visual_canvas, initial_visual_canvas, inference_canvas, traj_instance_idx],
886
+ )
887
+
888
+
889
+ block.launch(share=True)
890
+
891
+
892
+
893
+
architecture/attention_processor.py ADDED
The diff for this file is too large to render. See raw diff
 
architecture/autoencoder_kl_wan.py ADDED
@@ -0,0 +1,1419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import List, Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+ import torch.utils.checkpoint
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import FromOriginalModelMixin
24
+ from diffusers.utils import logging
25
+ from diffusers.utils.accelerate_utils import apply_forward_hook
26
+ from diffusers.models.activations import get_activation
27
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
28
+ from diffusers.models.modeling_utils import ModelMixin
29
+ from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+ CACHE_T = 2
35
+
36
+
37
+ class AvgDown3D(nn.Module):
38
+ def __init__(
39
+ self,
40
+ in_channels,
41
+ out_channels,
42
+ factor_t,
43
+ factor_s=1,
44
+ ):
45
+ super().__init__()
46
+ self.in_channels = in_channels
47
+ self.out_channels = out_channels
48
+ self.factor_t = factor_t
49
+ self.factor_s = factor_s
50
+ self.factor = self.factor_t * self.factor_s * self.factor_s
51
+
52
+ assert in_channels * self.factor % out_channels == 0
53
+ self.group_size = in_channels * self.factor // out_channels
54
+
55
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
56
+ pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
57
+ pad = (0, 0, 0, 0, pad_t, 0)
58
+ x = F.pad(x, pad)
59
+ B, C, T, H, W = x.shape
60
+ x = x.view(
61
+ B,
62
+ C,
63
+ T // self.factor_t,
64
+ self.factor_t,
65
+ H // self.factor_s,
66
+ self.factor_s,
67
+ W // self.factor_s,
68
+ self.factor_s,
69
+ )
70
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
71
+ x = x.view(
72
+ B,
73
+ C * self.factor,
74
+ T // self.factor_t,
75
+ H // self.factor_s,
76
+ W // self.factor_s,
77
+ )
78
+ x = x.view(
79
+ B,
80
+ self.out_channels,
81
+ self.group_size,
82
+ T // self.factor_t,
83
+ H // self.factor_s,
84
+ W // self.factor_s,
85
+ )
86
+ x = x.mean(dim=2)
87
+ return x
88
+
89
+
90
+ class DupUp3D(nn.Module):
91
+ def __init__(
92
+ self,
93
+ in_channels: int,
94
+ out_channels: int,
95
+ factor_t,
96
+ factor_s=1,
97
+ ):
98
+ super().__init__()
99
+ self.in_channels = in_channels
100
+ self.out_channels = out_channels
101
+
102
+ self.factor_t = factor_t
103
+ self.factor_s = factor_s
104
+ self.factor = self.factor_t * self.factor_s * self.factor_s
105
+
106
+ assert out_channels * self.factor % in_channels == 0
107
+ self.repeats = out_channels * self.factor // in_channels
108
+
109
+ def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
110
+ x = x.repeat_interleave(self.repeats, dim=1)
111
+ x = x.view(
112
+ x.size(0),
113
+ self.out_channels,
114
+ self.factor_t,
115
+ self.factor_s,
116
+ self.factor_s,
117
+ x.size(2),
118
+ x.size(3),
119
+ x.size(4),
120
+ )
121
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
122
+ x = x.view(
123
+ x.size(0),
124
+ self.out_channels,
125
+ x.size(2) * self.factor_t,
126
+ x.size(4) * self.factor_s,
127
+ x.size(6) * self.factor_s,
128
+ )
129
+ if first_chunk:
130
+ x = x[:, :, self.factor_t - 1 :, :, :]
131
+ return x
132
+
133
+
134
+ class WanCausalConv3d(nn.Conv3d):
135
+ r"""
136
+ A custom 3D causal convolution layer with feature caching support.
137
+
138
+ This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
139
+ caching for efficient inference.
140
+
141
+ Args:
142
+ in_channels (int): Number of channels in the input image
143
+ out_channels (int): Number of channels produced by the convolution
144
+ kernel_size (int or tuple): Size of the convolving kernel
145
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
146
+ padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
147
+ """
148
+
149
+ def __init__(
150
+ self,
151
+ in_channels: int,
152
+ out_channels: int,
153
+ kernel_size: Union[int, Tuple[int, int, int]],
154
+ stride: Union[int, Tuple[int, int, int]] = 1,
155
+ padding: Union[int, Tuple[int, int, int]] = 0,
156
+ ) -> None:
157
+ super().__init__(
158
+ in_channels=in_channels,
159
+ out_channels=out_channels,
160
+ kernel_size=kernel_size,
161
+ stride=stride,
162
+ padding=padding,
163
+ )
164
+
165
+ # Set up causal padding
166
+ self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
167
+ self.padding = (0, 0, 0)
168
+
169
+ def forward(self, x, cache_x=None):
170
+ padding = list(self._padding)
171
+ if cache_x is not None and self._padding[4] > 0:
172
+ cache_x = cache_x.to(x.device)
173
+ x = torch.cat([cache_x, x], dim=2)
174
+ padding[4] -= cache_x.shape[2]
175
+ x = F.pad(x, padding)
176
+ return super().forward(x)
177
+
178
+
179
+ class WanRMS_norm(nn.Module):
180
+ r"""
181
+ A custom RMS normalization layer.
182
+
183
+ Args:
184
+ dim (int): The number of dimensions to normalize over.
185
+ channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
186
+ Default is True.
187
+ images (bool, optional): Whether the input represents image data. Default is True.
188
+ bias (bool, optional): Whether to include a learnable bias term. Default is False.
189
+ """
190
+
191
+ def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
192
+ super().__init__()
193
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
194
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
195
+
196
+ self.channel_first = channel_first
197
+ self.scale = dim**0.5
198
+ self.gamma = nn.Parameter(torch.ones(shape))
199
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
200
+
201
+ def forward(self, x):
202
+ return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
203
+
204
+
205
+ class WanUpsample(nn.Upsample):
206
+ r"""
207
+ Perform upsampling while ensuring the output tensor has the same data type as the input.
208
+
209
+ Args:
210
+ x (torch.Tensor): Input tensor to be upsampled.
211
+
212
+ Returns:
213
+ torch.Tensor: Upsampled tensor with the same data type as the input.
214
+ """
215
+
216
+ def forward(self, x):
217
+ return super().forward(x.float()).type_as(x)
218
+
219
+
220
+ class WanResample(nn.Module):
221
+ r"""
222
+ A custom resampling module for 2D and 3D data.
223
+
224
+ Args:
225
+ dim (int): The number of input/output channels.
226
+ mode (str): The resampling mode. Must be one of:
227
+ - 'none': No resampling (identity operation).
228
+ - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
229
+ - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
230
+ - 'downsample2d': 2D downsampling with zero-padding and convolution.
231
+ - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
232
+ """
233
+
234
+ def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
235
+ super().__init__()
236
+ self.dim = dim
237
+ self.mode = mode
238
+
239
+ # default to dim //2
240
+ if upsample_out_dim is None:
241
+ upsample_out_dim = dim // 2
242
+
243
+ # layers
244
+ if mode == "upsample2d":
245
+ self.resample = nn.Sequential(
246
+ WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
247
+ nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
248
+ )
249
+ elif mode == "upsample3d":
250
+ self.resample = nn.Sequential(
251
+ WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
252
+ nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
253
+ )
254
+ self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
255
+
256
+ elif mode == "downsample2d":
257
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
258
+ elif mode == "downsample3d":
259
+ self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
260
+ self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
261
+
262
+ else:
263
+ self.resample = nn.Identity()
264
+
265
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
266
+ b, c, t, h, w = x.size()
267
+ if self.mode == "upsample3d":
268
+ if feat_cache is not None:
269
+ idx = feat_idx[0]
270
+ if feat_cache[idx] is None:
271
+ feat_cache[idx] = "Rep"
272
+ feat_idx[0] += 1
273
+ else:
274
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
275
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
276
+ # cache last frame of last two chunk
277
+ cache_x = torch.cat(
278
+ [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
279
+ )
280
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
281
+ cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
282
+ if feat_cache[idx] == "Rep":
283
+ x = self.time_conv(x)
284
+ else:
285
+ x = self.time_conv(x, feat_cache[idx])
286
+ feat_cache[idx] = cache_x
287
+ feat_idx[0] += 1
288
+
289
+ x = x.reshape(b, 2, c, t, h, w)
290
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
291
+ x = x.reshape(b, c, t * 2, h, w)
292
+ t = x.shape[2]
293
+ x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
294
+ x = self.resample(x)
295
+ x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
296
+
297
+ if self.mode == "downsample3d":
298
+ if feat_cache is not None:
299
+ idx = feat_idx[0]
300
+ if feat_cache[idx] is None:
301
+ feat_cache[idx] = x.clone()
302
+ feat_idx[0] += 1
303
+ else:
304
+ cache_x = x[:, :, -1:, :, :].clone()
305
+ x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
306
+ feat_cache[idx] = cache_x
307
+ feat_idx[0] += 1
308
+ return x
309
+
310
+
311
+ class WanResidualBlock(nn.Module):
312
+ r"""
313
+ A custom residual block module.
314
+
315
+ Args:
316
+ in_dim (int): Number of input channels.
317
+ out_dim (int): Number of output channels.
318
+ dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
319
+ non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
320
+ """
321
+
322
+ def __init__(
323
+ self,
324
+ in_dim: int,
325
+ out_dim: int,
326
+ dropout: float = 0.0,
327
+ non_linearity: str = "silu",
328
+ ) -> None:
329
+ super().__init__()
330
+ self.in_dim = in_dim
331
+ self.out_dim = out_dim
332
+ self.nonlinearity = get_activation(non_linearity)
333
+
334
+ # layers
335
+ self.norm1 = WanRMS_norm(in_dim, images=False)
336
+ self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1)
337
+ self.norm2 = WanRMS_norm(out_dim, images=False)
338
+ self.dropout = nn.Dropout(dropout)
339
+ self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1)
340
+ self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
341
+
342
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
343
+ # Apply shortcut connection
344
+ h = self.conv_shortcut(x)
345
+
346
+ # First normalization and activation
347
+ x = self.norm1(x)
348
+ x = self.nonlinearity(x)
349
+
350
+ if feat_cache is not None:
351
+ idx = feat_idx[0]
352
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
353
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
354
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
355
+
356
+ x = self.conv1(x, feat_cache[idx])
357
+ feat_cache[idx] = cache_x
358
+ feat_idx[0] += 1
359
+ else:
360
+ x = self.conv1(x)
361
+
362
+ # Second normalization and activation
363
+ x = self.norm2(x)
364
+ x = self.nonlinearity(x)
365
+
366
+ # Dropout
367
+ x = self.dropout(x)
368
+
369
+ if feat_cache is not None:
370
+ idx = feat_idx[0]
371
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
372
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
373
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
374
+
375
+ x = self.conv2(x, feat_cache[idx])
376
+ feat_cache[idx] = cache_x
377
+ feat_idx[0] += 1
378
+ else:
379
+ x = self.conv2(x)
380
+
381
+ # Add residual connection
382
+ return x + h
383
+
384
+
385
+ class WanAttentionBlock(nn.Module):
386
+ r"""
387
+ Causal self-attention with a single head.
388
+
389
+ Args:
390
+ dim (int): The number of channels in the input tensor.
391
+ """
392
+
393
+ def __init__(self, dim):
394
+ super().__init__()
395
+ self.dim = dim
396
+
397
+ # layers
398
+ self.norm = WanRMS_norm(dim)
399
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
400
+ self.proj = nn.Conv2d(dim, dim, 1)
401
+
402
+ def forward(self, x):
403
+ identity = x
404
+ batch_size, channels, time, height, width = x.size()
405
+
406
+ x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
407
+ x = self.norm(x)
408
+
409
+ # compute query, key, value
410
+ qkv = self.to_qkv(x)
411
+ qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
412
+ qkv = qkv.permute(0, 1, 3, 2).contiguous()
413
+ q, k, v = qkv.chunk(3, dim=-1)
414
+
415
+ # apply attention
416
+ x = F.scaled_dot_product_attention(q, k, v)
417
+
418
+ x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
419
+
420
+ # output projection
421
+ x = self.proj(x)
422
+
423
+ # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
424
+ x = x.view(batch_size, time, channels, height, width)
425
+ x = x.permute(0, 2, 1, 3, 4)
426
+
427
+ return x + identity
428
+
429
+
430
+ class WanMidBlock(nn.Module):
431
+ """
432
+ Middle block for WanVAE encoder and decoder.
433
+
434
+ Args:
435
+ dim (int): Number of input/output channels.
436
+ dropout (float): Dropout rate.
437
+ non_linearity (str): Type of non-linearity to use.
438
+ """
439
+
440
+ def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
441
+ super().__init__()
442
+ self.dim = dim
443
+
444
+ # Create the components
445
+ resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)]
446
+ attentions = []
447
+ for _ in range(num_layers):
448
+ attentions.append(WanAttentionBlock(dim))
449
+ resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity))
450
+ self.attentions = nn.ModuleList(attentions)
451
+ self.resnets = nn.ModuleList(resnets)
452
+
453
+ self.gradient_checkpointing = False
454
+
455
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
456
+ # First residual block
457
+ x = self.resnets[0](x, feat_cache, feat_idx)
458
+
459
+ # Process through attention and residual blocks
460
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
461
+ if attn is not None:
462
+ x = attn(x)
463
+
464
+ x = resnet(x, feat_cache, feat_idx)
465
+
466
+ return x
467
+
468
+
469
+ class WanResidualDownBlock(nn.Module):
470
+ def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample=False, down_flag=False):
471
+ super().__init__()
472
+
473
+ # Shortcut path with downsample
474
+ self.avg_shortcut = AvgDown3D(
475
+ in_dim,
476
+ out_dim,
477
+ factor_t=2 if temperal_downsample else 1,
478
+ factor_s=2 if down_flag else 1,
479
+ )
480
+
481
+ # Main path with residual blocks and downsample
482
+ resnets = []
483
+ for _ in range(num_res_blocks):
484
+ resnets.append(WanResidualBlock(in_dim, out_dim, dropout))
485
+ in_dim = out_dim
486
+ self.resnets = nn.ModuleList(resnets)
487
+
488
+ # Add the final downsample block
489
+ if down_flag:
490
+ mode = "downsample3d" if temperal_downsample else "downsample2d"
491
+ self.downsampler = WanResample(out_dim, mode=mode)
492
+ else:
493
+ self.downsampler = None
494
+
495
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
496
+ x_copy = x.clone()
497
+ for resnet in self.resnets:
498
+ x = resnet(x, feat_cache, feat_idx)
499
+ if self.downsampler is not None:
500
+ x = self.downsampler(x, feat_cache, feat_idx)
501
+
502
+ return x + self.avg_shortcut(x_copy)
503
+
504
+
505
+ class WanEncoder3d(nn.Module):
506
+ r"""
507
+ A 3D encoder module.
508
+
509
+ Args:
510
+ dim (int): The base number of channels in the first layer.
511
+ z_dim (int): The dimensionality of the latent space.
512
+ dim_mult (list of int): Multipliers for the number of channels in each block.
513
+ num_res_blocks (int): Number of residual blocks in each block.
514
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
515
+ temperal_downsample (list of bool): Whether to downsample temporally in each block.
516
+ dropout (float): Dropout rate for the dropout layers.
517
+ non_linearity (str): Type of non-linearity to use.
518
+ """
519
+
520
+ def __init__(
521
+ self,
522
+ in_channels: int = 3,
523
+ dim=128,
524
+ z_dim=4,
525
+ dim_mult=[1, 2, 4, 4],
526
+ num_res_blocks=2,
527
+ attn_scales=[],
528
+ temperal_downsample=[True, True, False],
529
+ dropout=0.0,
530
+ non_linearity: str = "silu",
531
+ is_residual: bool = False, # wan 2.2 vae use a residual downblock
532
+ ):
533
+ super().__init__()
534
+ self.dim = dim
535
+ self.z_dim = z_dim
536
+ self.dim_mult = dim_mult
537
+ self.num_res_blocks = num_res_blocks
538
+ self.attn_scales = attn_scales
539
+ self.temperal_downsample = temperal_downsample
540
+ self.nonlinearity = get_activation(non_linearity)
541
+
542
+ # dimensions
543
+ dims = [dim * u for u in [1] + dim_mult]
544
+ scale = 1.0
545
+
546
+ # init block
547
+ self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)
548
+
549
+ # downsample blocks
550
+ self.down_blocks = nn.ModuleList([])
551
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
552
+ # residual (+attention) blocks
553
+ if is_residual:
554
+ self.down_blocks.append(
555
+ WanResidualDownBlock(
556
+ in_dim,
557
+ out_dim,
558
+ dropout,
559
+ num_res_blocks,
560
+ temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False,
561
+ down_flag=i != len(dim_mult) - 1,
562
+ )
563
+ )
564
+ else:
565
+ for _ in range(num_res_blocks):
566
+ self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
567
+ if scale in attn_scales:
568
+ self.down_blocks.append(WanAttentionBlock(out_dim))
569
+ in_dim = out_dim
570
+
571
+ # downsample block
572
+ if i != len(dim_mult) - 1:
573
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
574
+ self.down_blocks.append(WanResample(out_dim, mode=mode))
575
+ scale /= 2.0
576
+
577
+ # middle blocks
578
+ self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
579
+
580
+ # output blocks
581
+ self.norm_out = WanRMS_norm(out_dim, images=False)
582
+ self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1)
583
+
584
+ self.gradient_checkpointing = False
585
+
586
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
587
+ if feat_cache is not None:
588
+ idx = feat_idx[0]
589
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
590
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
591
+ # cache last frame of last two chunk
592
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
593
+ x = self.conv_in(x, feat_cache[idx])
594
+ feat_cache[idx] = cache_x
595
+ feat_idx[0] += 1
596
+ else:
597
+ x = self.conv_in(x)
598
+
599
+ ## downsamples
600
+ for layer in self.down_blocks:
601
+ if feat_cache is not None:
602
+ x = layer(x, feat_cache, feat_idx)
603
+ else:
604
+ x = layer(x)
605
+
606
+ ## middle
607
+ x = self.mid_block(x, feat_cache, feat_idx)
608
+
609
+ ## head
610
+ x = self.norm_out(x)
611
+ x = self.nonlinearity(x)
612
+ if feat_cache is not None:
613
+ idx = feat_idx[0]
614
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
615
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
616
+ # cache last frame of last two chunk
617
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
618
+ x = self.conv_out(x, feat_cache[idx])
619
+ feat_cache[idx] = cache_x
620
+ feat_idx[0] += 1
621
+ else:
622
+ x = self.conv_out(x)
623
+ return x
624
+
625
+
626
+ class WanResidualUpBlock(nn.Module):
627
+ """
628
+ A block that handles upsampling for the WanVAE decoder.
629
+
630
+ Args:
631
+ in_dim (int): Input dimension
632
+ out_dim (int): Output dimension
633
+ num_res_blocks (int): Number of residual blocks
634
+ dropout (float): Dropout rate
635
+ temperal_upsample (bool): Whether to upsample on temporal dimension
636
+ up_flag (bool): Whether to upsample or not
637
+ non_linearity (str): Type of non-linearity to use
638
+ """
639
+
640
+ def __init__(
641
+ self,
642
+ in_dim: int,
643
+ out_dim: int,
644
+ num_res_blocks: int,
645
+ dropout: float = 0.0,
646
+ temperal_upsample: bool = False,
647
+ up_flag: bool = False,
648
+ non_linearity: str = "silu",
649
+ ):
650
+ super().__init__()
651
+ self.in_dim = in_dim
652
+ self.out_dim = out_dim
653
+
654
+ if up_flag:
655
+ self.avg_shortcut = DupUp3D(
656
+ in_dim,
657
+ out_dim,
658
+ factor_t=2 if temperal_upsample else 1,
659
+ factor_s=2,
660
+ )
661
+ else:
662
+ self.avg_shortcut = None
663
+
664
+ # create residual blocks
665
+ resnets = []
666
+ current_dim = in_dim
667
+ for _ in range(num_res_blocks + 1):
668
+ resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
669
+ current_dim = out_dim
670
+
671
+ self.resnets = nn.ModuleList(resnets)
672
+
673
+ # Add upsampling layer if needed
674
+ if up_flag:
675
+ upsample_mode = "upsample3d" if temperal_upsample else "upsample2d"
676
+ self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim)
677
+ else:
678
+ self.upsampler = None
679
+
680
+ self.gradient_checkpointing = False
681
+
682
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
683
+ """
684
+ Forward pass through the upsampling block.
685
+
686
+ Args:
687
+ x (torch.Tensor): Input tensor
688
+ feat_cache (list, optional): Feature cache for causal convolutions
689
+ feat_idx (list, optional): Feature index for cache management
690
+
691
+ Returns:
692
+ torch.Tensor: Output tensor
693
+ """
694
+ x_copy = x.clone()
695
+
696
+ for resnet in self.resnets:
697
+ if feat_cache is not None:
698
+ x = resnet(x, feat_cache, feat_idx)
699
+ else:
700
+ x = resnet(x)
701
+
702
+ if self.upsampler is not None:
703
+ if feat_cache is not None:
704
+ x = self.upsampler(x, feat_cache, feat_idx)
705
+ else:
706
+ x = self.upsampler(x)
707
+
708
+ if self.avg_shortcut is not None:
709
+ x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
710
+
711
+ return x
712
+
713
+
714
+ class WanUpBlock(nn.Module):
715
+ """
716
+ A block that handles upsampling for the WanVAE decoder.
717
+
718
+ Args:
719
+ in_dim (int): Input dimension
720
+ out_dim (int): Output dimension
721
+ num_res_blocks (int): Number of residual blocks
722
+ dropout (float): Dropout rate
723
+ upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
724
+ non_linearity (str): Type of non-linearity to use
725
+ """
726
+
727
+ def __init__(
728
+ self,
729
+ in_dim: int,
730
+ out_dim: int,
731
+ num_res_blocks: int,
732
+ dropout: float = 0.0,
733
+ upsample_mode: Optional[str] = None,
734
+ non_linearity: str = "silu",
735
+ ):
736
+ super().__init__()
737
+ self.in_dim = in_dim
738
+ self.out_dim = out_dim
739
+
740
+ # Create layers list
741
+ resnets = []
742
+ # Add residual blocks and attention if needed
743
+ current_dim = in_dim
744
+ for _ in range(num_res_blocks + 1):
745
+ resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
746
+ current_dim = out_dim
747
+
748
+ self.resnets = nn.ModuleList(resnets)
749
+
750
+ # Add upsampling layer if needed
751
+ self.upsamplers = None
752
+ if upsample_mode is not None:
753
+ self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)])
754
+
755
+ self.gradient_checkpointing = False
756
+
757
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
758
+ """
759
+ Forward pass through the upsampling block.
760
+
761
+ Args:
762
+ x (torch.Tensor): Input tensor
763
+ feat_cache (list, optional): Feature cache for causal convolutions
764
+ feat_idx (list, optional): Feature index for cache management
765
+
766
+ Returns:
767
+ torch.Tensor: Output tensor
768
+ """
769
+ for resnet in self.resnets:
770
+ if feat_cache is not None:
771
+ x = resnet(x, feat_cache, feat_idx)
772
+ else:
773
+ x = resnet(x)
774
+
775
+ if self.upsamplers is not None:
776
+ if feat_cache is not None:
777
+ x = self.upsamplers[0](x, feat_cache, feat_idx)
778
+ else:
779
+ x = self.upsamplers[0](x)
780
+ return x
781
+
782
+
783
+ class WanDecoder3d(nn.Module):
784
+ r"""
785
+ A 3D decoder module.
786
+
787
+ Args:
788
+ dim (int): The base number of channels in the first layer.
789
+ z_dim (int): The dimensionality of the latent space.
790
+ dim_mult (list of int): Multipliers for the number of channels in each block.
791
+ num_res_blocks (int): Number of residual blocks in each block.
792
+ attn_scales (list of float): Scales at which to apply attention mechanisms.
793
+ temperal_upsample (list of bool): Whether to upsample temporally in each block.
794
+ dropout (float): Dropout rate for the dropout layers.
795
+ non_linearity (str): Type of non-linearity to use.
796
+ """
797
+
798
+ def __init__(
799
+ self,
800
+ dim=128,
801
+ z_dim=4,
802
+ dim_mult=[1, 2, 4, 4],
803
+ num_res_blocks=2,
804
+ attn_scales=[],
805
+ temperal_upsample=[False, True, True],
806
+ dropout=0.0,
807
+ non_linearity: str = "silu",
808
+ out_channels: int = 3,
809
+ is_residual: bool = False,
810
+ ):
811
+ super().__init__()
812
+ self.dim = dim
813
+ self.z_dim = z_dim
814
+ self.dim_mult = dim_mult
815
+ self.num_res_blocks = num_res_blocks
816
+ self.attn_scales = attn_scales
817
+ self.temperal_upsample = temperal_upsample
818
+
819
+ self.nonlinearity = get_activation(non_linearity)
820
+
821
+ # dimensions
822
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
823
+
824
+ # init block
825
+ self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
826
+
827
+ # middle blocks
828
+ self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
829
+
830
+ # upsample blocks
831
+ self.up_blocks = nn.ModuleList([])
832
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
833
+ # residual (+attention) blocks
834
+ if i > 0 and not is_residual:
835
+ # wan vae 2.1
836
+ in_dim = in_dim // 2
837
+
838
+ # determine if we need upsampling
839
+ up_flag = i != len(dim_mult) - 1
840
+ # determine upsampling mode, if not upsampling, set to None
841
+ upsample_mode = None
842
+ if up_flag and temperal_upsample[i]:
843
+ upsample_mode = "upsample3d"
844
+ elif up_flag:
845
+ upsample_mode = "upsample2d"
846
+ # Create and add the upsampling block
847
+ if is_residual:
848
+ up_block = WanResidualUpBlock(
849
+ in_dim=in_dim,
850
+ out_dim=out_dim,
851
+ num_res_blocks=num_res_blocks,
852
+ dropout=dropout,
853
+ temperal_upsample=temperal_upsample[i] if up_flag else False,
854
+ up_flag=up_flag,
855
+ non_linearity=non_linearity,
856
+ )
857
+ else:
858
+ up_block = WanUpBlock(
859
+ in_dim=in_dim,
860
+ out_dim=out_dim,
861
+ num_res_blocks=num_res_blocks,
862
+ dropout=dropout,
863
+ upsample_mode=upsample_mode,
864
+ non_linearity=non_linearity,
865
+ )
866
+ self.up_blocks.append(up_block)
867
+
868
+ # output blocks
869
+ self.norm_out = WanRMS_norm(out_dim, images=False)
870
+ self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
871
+
872
+ self.gradient_checkpointing = False
873
+
874
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
875
+ ## conv1
876
+ if feat_cache is not None:
877
+ idx = feat_idx[0]
878
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
879
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
880
+ # cache last frame of last two chunk
881
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
882
+ x = self.conv_in(x, feat_cache[idx])
883
+ feat_cache[idx] = cache_x
884
+ feat_idx[0] += 1
885
+ else:
886
+ x = self.conv_in(x)
887
+
888
+ ## middle
889
+ x = self.mid_block(x, feat_cache, feat_idx)
890
+
891
+ ## upsamples
892
+ for up_block in self.up_blocks:
893
+ x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
894
+
895
+ ## head
896
+ x = self.norm_out(x)
897
+ x = self.nonlinearity(x)
898
+ if feat_cache is not None:
899
+ idx = feat_idx[0]
900
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
901
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
902
+ # cache last frame of last two chunk
903
+ cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
904
+ x = self.conv_out(x, feat_cache[idx])
905
+ feat_cache[idx] = cache_x
906
+ feat_idx[0] += 1
907
+ else:
908
+ x = self.conv_out(x)
909
+ return x
910
+
911
+
912
+ def patchify(x, patch_size):
913
+ if patch_size == 1:
914
+ return x
915
+
916
+ if x.dim() != 5:
917
+ raise ValueError(f"Invalid input shape: {x.shape}")
918
+ # x shape: [batch_size, channels, frames, height, width]
919
+ batch_size, channels, frames, height, width = x.shape
920
+
921
+ # Ensure height and width are divisible by patch_size
922
+ if height % patch_size != 0 or width % patch_size != 0:
923
+ raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
924
+
925
+ # Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
926
+ x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
927
+
928
+ # Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
929
+ x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
930
+ x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
931
+
932
+ return x
933
+
934
+
935
+ def unpatchify(x, patch_size):
936
+ if patch_size == 1:
937
+ return x
938
+
939
+ if x.dim() != 5:
940
+ raise ValueError(f"Invalid input shape: {x.shape}")
941
+ # x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
942
+ batch_size, c_patches, frames, height, width = x.shape
943
+ channels = c_patches // (patch_size * patch_size)
944
+
945
+ # Reshape to [b, c, patch_size, patch_size, f, h, w]
946
+ x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
947
+
948
+ # Rearrange to [b, c, f, h * patch_size, w * patch_size]
949
+ x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous()
950
+ x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
951
+
952
+ return x
953
+
954
+
955
+ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
956
+ r"""
957
+ A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
958
+ Introduced in [Wan 2.1].
959
+
960
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
961
+ for all models (such as downloading or saving).
962
+ """
963
+
964
+ _supports_gradient_checkpointing = False
965
+
966
+ @register_to_config
967
+ def __init__(
968
+ self,
969
+ base_dim: int = 96,
970
+ decoder_base_dim: Optional[int] = None,
971
+ z_dim: int = 16,
972
+ dim_mult: Tuple[int] = [1, 2, 4, 4],
973
+ num_res_blocks: int = 2,
974
+ attn_scales: List[float] = [],
975
+ temperal_downsample: List[bool] = [False, True, True],
976
+ dropout: float = 0.0,
977
+ latents_mean: List[float] = [
978
+ -0.7571,
979
+ -0.7089,
980
+ -0.9113,
981
+ 0.1075,
982
+ -0.1745,
983
+ 0.9653,
984
+ -0.1517,
985
+ 1.5508,
986
+ 0.4134,
987
+ -0.0715,
988
+ 0.5517,
989
+ -0.3632,
990
+ -0.1922,
991
+ -0.9497,
992
+ 0.2503,
993
+ -0.2921,
994
+ ],
995
+ latents_std: List[float] = [
996
+ 2.8184,
997
+ 1.4541,
998
+ 2.3275,
999
+ 2.6558,
1000
+ 1.2196,
1001
+ 1.7708,
1002
+ 2.6052,
1003
+ 2.0743,
1004
+ 3.2687,
1005
+ 2.1526,
1006
+ 2.8652,
1007
+ 1.5579,
1008
+ 1.6382,
1009
+ 1.1253,
1010
+ 2.8251,
1011
+ 1.9160,
1012
+ ],
1013
+ is_residual: bool = False,
1014
+ in_channels: int = 3,
1015
+ out_channels: int = 3,
1016
+ patch_size: Optional[int] = None,
1017
+ scale_factor_temporal: Optional[int] = 4,
1018
+ scale_factor_spatial: Optional[int] = 8,
1019
+ ) -> None:
1020
+ super().__init__()
1021
+
1022
+ self.z_dim = z_dim
1023
+ self.temperal_downsample = temperal_downsample
1024
+ self.temperal_upsample = temperal_downsample[::-1]
1025
+
1026
+ if decoder_base_dim is None:
1027
+ decoder_base_dim = base_dim
1028
+
1029
+ self.encoder = WanEncoder3d(
1030
+ in_channels=in_channels,
1031
+ dim=base_dim,
1032
+ z_dim=z_dim * 2,
1033
+ dim_mult=dim_mult,
1034
+ num_res_blocks=num_res_blocks,
1035
+ attn_scales=attn_scales,
1036
+ temperal_downsample=temperal_downsample,
1037
+ dropout=dropout,
1038
+ is_residual=is_residual,
1039
+ )
1040
+ self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
1041
+ self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
1042
+
1043
+ self.decoder = WanDecoder3d(
1044
+ dim=decoder_base_dim,
1045
+ z_dim=z_dim,
1046
+ dim_mult=dim_mult,
1047
+ num_res_blocks=num_res_blocks,
1048
+ attn_scales=attn_scales,
1049
+ temperal_upsample=self.temperal_upsample,
1050
+ dropout=dropout,
1051
+ out_channels=out_channels,
1052
+ is_residual=is_residual,
1053
+ )
1054
+
1055
+ self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
1056
+
1057
+ # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
1058
+ # to perform decoding of a single video latent at a time.
1059
+ self.use_slicing = False
1060
+
1061
+ # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
1062
+ # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
1063
+ # intermediate tiles together, the memory requirement can be lowered.
1064
+ self.use_tiling = False
1065
+
1066
+ # The minimal tile height and width for spatial tiling to be used
1067
+ self.tile_sample_min_height = 256
1068
+ self.tile_sample_min_width = 256
1069
+
1070
+ # The minimal distance between two spatial tiles
1071
+ self.tile_sample_stride_height = 192
1072
+ self.tile_sample_stride_width = 192
1073
+
1074
+ # Precompute and cache conv counts for encoder and decoder for clear_cache speedup
1075
+ self._cached_conv_counts = {
1076
+ "decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
1077
+ if self.decoder is not None
1078
+ else 0,
1079
+ "encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules())
1080
+ if self.encoder is not None
1081
+ else 0,
1082
+ }
1083
+
1084
+ def enable_tiling(
1085
+ self,
1086
+ tile_sample_min_height: Optional[int] = None,
1087
+ tile_sample_min_width: Optional[int] = None,
1088
+ tile_sample_stride_height: Optional[float] = None,
1089
+ tile_sample_stride_width: Optional[float] = None,
1090
+ ) -> None:
1091
+ r"""
1092
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
1093
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
1094
+ processing larger images.
1095
+
1096
+ Args:
1097
+ tile_sample_min_height (`int`, *optional*):
1098
+ The minimum height required for a sample to be separated into tiles across the height dimension.
1099
+ tile_sample_min_width (`int`, *optional*):
1100
+ The minimum width required for a sample to be separated into tiles across the width dimension.
1101
+ tile_sample_stride_height (`int`, *optional*):
1102
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
1103
+ no tiling artifacts produced across the height dimension.
1104
+ tile_sample_stride_width (`int`, *optional*):
1105
+ The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
1106
+ artifacts produced across the width dimension.
1107
+ """
1108
+ self.use_tiling = True
1109
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
1110
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
1111
+ self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
1112
+ self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
1113
+
1114
+ def disable_tiling(self) -> None:
1115
+ r"""
1116
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1117
+ decoding in one step.
1118
+ """
1119
+ self.use_tiling = False
1120
+
1121
+ def enable_slicing(self) -> None:
1122
+ r"""
1123
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1124
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1125
+ """
1126
+ self.use_slicing = True
1127
+
1128
+ def disable_slicing(self) -> None:
1129
+ r"""
1130
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1131
+ decoding in one step.
1132
+ """
1133
+ self.use_slicing = False
1134
+
1135
+ def clear_cache(self):
1136
+ # Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
1137
+ self._conv_num = self._cached_conv_counts["decoder"]
1138
+ self._conv_idx = [0]
1139
+ self._feat_map = [None] * self._conv_num
1140
+ # cache encode
1141
+ self._enc_conv_num = self._cached_conv_counts["encoder"]
1142
+ self._enc_conv_idx = [0]
1143
+ self._enc_feat_map = [None] * self._enc_conv_num
1144
+
1145
+ def _encode(self, x: torch.Tensor):
1146
+ _, _, num_frame, height, width = x.shape
1147
+
1148
+ if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1149
+ return self.tiled_encode(x)
1150
+
1151
+ self.clear_cache()
1152
+ if self.config.patch_size is not None:
1153
+ x = patchify(x, patch_size=self.config.patch_size)
1154
+ iter_ = 1 + (num_frame - 1) // 4
1155
+ for i in range(iter_):
1156
+ self._enc_conv_idx = [0]
1157
+ if i == 0:
1158
+ out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
1159
+ else:
1160
+ out_ = self.encoder(
1161
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
1162
+ feat_cache=self._enc_feat_map,
1163
+ feat_idx=self._enc_conv_idx,
1164
+ )
1165
+ out = torch.cat([out, out_], 2)
1166
+
1167
+ enc = self.quant_conv(out)
1168
+ self.clear_cache()
1169
+ return enc
1170
+
1171
+ @apply_forward_hook
1172
+ def encode(
1173
+ self, x: torch.Tensor, return_dict: bool = True
1174
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
1175
+ r"""
1176
+ Encode a batch of images into latents.
1177
+
1178
+ Args:
1179
+ x (`torch.Tensor`): Input batch of images.
1180
+ return_dict (`bool`, *optional*, defaults to `True`):
1181
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
1182
+
1183
+ Returns:
1184
+ The latent representations of the encoded videos. If `return_dict` is True, a
1185
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
1186
+ """
1187
+ if self.use_slicing and x.shape[0] > 1:
1188
+ encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
1189
+ h = torch.cat(encoded_slices)
1190
+ else:
1191
+ h = self._encode(x)
1192
+ posterior = DiagonalGaussianDistribution(h)
1193
+
1194
+ if not return_dict:
1195
+ return (posterior,)
1196
+ return AutoencoderKLOutput(latent_dist=posterior)
1197
+
1198
+ def _decode(self, z: torch.Tensor, return_dict: bool = True):
1199
+ _, _, num_frame, height, width = z.shape
1200
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1201
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1202
+
1203
+ if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
1204
+ return self.tiled_decode(z, return_dict=return_dict)
1205
+
1206
+ self.clear_cache()
1207
+ x = self.post_quant_conv(z)
1208
+ for i in range(num_frame):
1209
+ self._conv_idx = [0]
1210
+ if i == 0:
1211
+ out = self.decoder(
1212
+ x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True
1213
+ )
1214
+ else:
1215
+ out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
1216
+ out = torch.cat([out, out_], 2)
1217
+
1218
+ if self.config.patch_size is not None:
1219
+ out = unpatchify(out, patch_size=self.config.patch_size)
1220
+
1221
+ out = torch.clamp(out, min=-1.0, max=1.0)
1222
+
1223
+ self.clear_cache()
1224
+ if not return_dict:
1225
+ return (out,)
1226
+
1227
+ return DecoderOutput(sample=out)
1228
+
1229
+ @apply_forward_hook
1230
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1231
+ r"""
1232
+ Decode a batch of images.
1233
+
1234
+ Args:
1235
+ z (`torch.Tensor`): Input batch of latent vectors.
1236
+ return_dict (`bool`, *optional*, defaults to `True`):
1237
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1238
+
1239
+ Returns:
1240
+ [`~models.vae.DecoderOutput`] or `tuple`:
1241
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1242
+ returned.
1243
+ """
1244
+ if self.use_slicing and z.shape[0] > 1:
1245
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
1246
+ decoded = torch.cat(decoded_slices)
1247
+ else:
1248
+ decoded = self._decode(z).sample
1249
+
1250
+ if not return_dict:
1251
+ return (decoded,)
1252
+ return DecoderOutput(sample=decoded)
1253
+
1254
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1255
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
1256
+ for y in range(blend_extent):
1257
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
1258
+ y / blend_extent
1259
+ )
1260
+ return b
1261
+
1262
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1263
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
1264
+ for x in range(blend_extent):
1265
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
1266
+ x / blend_extent
1267
+ )
1268
+ return b
1269
+
1270
+ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
1271
+ r"""Encode a batch of images using a tiled encoder.
1272
+
1273
+ Args:
1274
+ x (`torch.Tensor`): Input batch of videos.
1275
+
1276
+ Returns:
1277
+ `torch.Tensor`:
1278
+ The latent representation of the encoded videos.
1279
+ """
1280
+ _, _, num_frames, height, width = x.shape
1281
+ latent_height = height // self.spatial_compression_ratio
1282
+ latent_width = width // self.spatial_compression_ratio
1283
+
1284
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1285
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1286
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1287
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1288
+
1289
+ blend_height = tile_latent_min_height - tile_latent_stride_height
1290
+ blend_width = tile_latent_min_width - tile_latent_stride_width
1291
+
1292
+ # Split x into overlapping tiles and encode them separately.
1293
+ # The tiles have an overlap to avoid seams between tiles.
1294
+ rows = []
1295
+ for i in range(0, height, self.tile_sample_stride_height):
1296
+ row = []
1297
+ for j in range(0, width, self.tile_sample_stride_width):
1298
+ self.clear_cache()
1299
+ time = []
1300
+ frame_range = 1 + (num_frames - 1) // 4
1301
+ for k in range(frame_range):
1302
+ self._enc_conv_idx = [0]
1303
+ if k == 0:
1304
+ tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
1305
+ else:
1306
+ tile = x[
1307
+ :,
1308
+ :,
1309
+ 1 + 4 * (k - 1) : 1 + 4 * k,
1310
+ i : i + self.tile_sample_min_height,
1311
+ j : j + self.tile_sample_min_width,
1312
+ ]
1313
+ tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
1314
+ tile = self.quant_conv(tile)
1315
+ time.append(tile)
1316
+ row.append(torch.cat(time, dim=2))
1317
+ rows.append(row)
1318
+ self.clear_cache()
1319
+
1320
+ result_rows = []
1321
+ for i, row in enumerate(rows):
1322
+ result_row = []
1323
+ for j, tile in enumerate(row):
1324
+ # blend the above tile and the left tile
1325
+ # to the current tile and add the current tile to the result row
1326
+ if i > 0:
1327
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1328
+ if j > 0:
1329
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1330
+ result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
1331
+ result_rows.append(torch.cat(result_row, dim=-1))
1332
+
1333
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
1334
+ return enc
1335
+
1336
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1337
+ r"""
1338
+ Decode a batch of images using a tiled decoder.
1339
+
1340
+ Args:
1341
+ z (`torch.Tensor`): Input batch of latent vectors.
1342
+ return_dict (`bool`, *optional*, defaults to `True`):
1343
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1344
+
1345
+ Returns:
1346
+ [`~models.vae.DecoderOutput`] or `tuple`:
1347
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1348
+ returned.
1349
+ """
1350
+ _, _, num_frames, height, width = z.shape
1351
+ sample_height = height * self.spatial_compression_ratio
1352
+ sample_width = width * self.spatial_compression_ratio
1353
+
1354
+ tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1355
+ tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1356
+ tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1357
+ tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1358
+
1359
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
1360
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
1361
+
1362
+ # Split z into overlapping tiles and decode them separately.
1363
+ # The tiles have an overlap to avoid seams between tiles.
1364
+ rows = []
1365
+ for i in range(0, height, tile_latent_stride_height):
1366
+ row = []
1367
+ for j in range(0, width, tile_latent_stride_width):
1368
+ self.clear_cache()
1369
+ time = []
1370
+ for k in range(num_frames):
1371
+ self._conv_idx = [0]
1372
+ tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
1373
+ tile = self.post_quant_conv(tile)
1374
+ decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
1375
+ time.append(decoded)
1376
+ row.append(torch.cat(time, dim=2))
1377
+ rows.append(row)
1378
+ self.clear_cache()
1379
+
1380
+ result_rows = []
1381
+ for i, row in enumerate(rows):
1382
+ result_row = []
1383
+ for j, tile in enumerate(row):
1384
+ # blend the above tile and the left tile
1385
+ # to the current tile and add the current tile to the result row
1386
+ if i > 0:
1387
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1388
+ if j > 0:
1389
+ tile = self.blend_h(row[j - 1], tile, blend_width)
1390
+ result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
1391
+ result_rows.append(torch.cat(result_row, dim=-1))
1392
+
1393
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
1394
+
1395
+ if not return_dict:
1396
+ return (dec,)
1397
+ return DecoderOutput(sample=dec)
1398
+
1399
+ def forward(
1400
+ self,
1401
+ sample: torch.Tensor,
1402
+ sample_posterior: bool = False,
1403
+ return_dict: bool = True,
1404
+ generator: Optional[torch.Generator] = None,
1405
+ ) -> Union[DecoderOutput, torch.Tensor]:
1406
+ """
1407
+ Args:
1408
+ sample (`torch.Tensor`): Input sample.
1409
+ return_dict (`bool`, *optional*, defaults to `True`):
1410
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
1411
+ """
1412
+ x = sample
1413
+ posterior = self.encode(x).latent_dist
1414
+ if sample_posterior:
1415
+ z = posterior.sample(generator=generator)
1416
+ else:
1417
+ z = posterior.mode()
1418
+ dec = self.decode(z, return_dict=return_dict)
1419
+ return dec
architecture/cogvideox_transformer_3d.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+ import os, sys, shutil
18
+ import torch
19
+ from torch import nn
20
+
21
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
22
+ from diffusers.loaders import PeftAdapterMixin
23
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
24
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
25
+ from diffusers.models.attention import Attention, FeedForward
26
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
27
+ from diffusers.models.modeling_utils import ModelMixin
28
+ from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
29
+
30
+
31
+ # Import files from the local fodler
32
+ root_path = os.path.abspath('.')
33
+ sys.path.append(root_path)
34
+ from architecture.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
35
+ from architecture.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
36
+
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ @maybe_allow_in_graph
42
+ class CogVideoXBlock(nn.Module):
43
+ r"""
44
+ Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
45
+
46
+ Parameters:
47
+ dim (`int`):
48
+ The number of channels in the input and output.
49
+ num_attention_heads (`int`):
50
+ The number of heads to use for multi-head attention.
51
+ attention_head_dim (`int`):
52
+ The number of channels in each head.
53
+ time_embed_dim (`int`):
54
+ The number of channels in timestep embedding.
55
+ dropout (`float`, defaults to `0.0`):
56
+ The dropout probability to use.
57
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
58
+ Activation function to be used in feed-forward.
59
+ attention_bias (`bool`, defaults to `False`):
60
+ Whether or not to use bias in attention projection layers.
61
+ qk_norm (`bool`, defaults to `True`):
62
+ Whether or not to use normalization after query and key projections in Attention.
63
+ norm_elementwise_affine (`bool`, defaults to `True`):
64
+ Whether to use learnable elementwise affine parameters for normalization.
65
+ norm_eps (`float`, defaults to `1e-5`):
66
+ Epsilon value for normalization layers.
67
+ final_dropout (`bool` defaults to `False`):
68
+ Whether to apply a final dropout after the last feed-forward layer.
69
+ ff_inner_dim (`int`, *optional*, defaults to `None`):
70
+ Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
71
+ ff_bias (`bool`, defaults to `True`):
72
+ Whether or not to use bias in Feed-forward layer.
73
+ attention_out_bias (`bool`, defaults to `True`):
74
+ Whether or not to use bias in Attention output projection layer.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ dim: int,
80
+ num_attention_heads: int,
81
+ attention_head_dim: int,
82
+ time_embed_dim: int,
83
+ dropout: float = 0.0,
84
+ activation_fn: str = "gelu-approximate",
85
+ attention_bias: bool = False,
86
+ qk_norm: bool = True,
87
+ norm_elementwise_affine: bool = True,
88
+ norm_eps: float = 1e-5,
89
+ final_dropout: bool = True,
90
+ ff_inner_dim: Optional[int] = None,
91
+ ff_bias: bool = True,
92
+ attention_out_bias: bool = True,
93
+ ):
94
+ super().__init__()
95
+
96
+ # 1. Self Attention
97
+ self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
98
+
99
+ self.attn1 = Attention(
100
+ query_dim=dim,
101
+ dim_head=attention_head_dim,
102
+ heads=num_attention_heads,
103
+ qk_norm="layer_norm" if qk_norm else None,
104
+ eps=1e-6,
105
+ bias=attention_bias,
106
+ out_bias=attention_out_bias,
107
+ processor=CogVideoXAttnProcessor2_0(),
108
+ )
109
+
110
+ # 2. Feed Forward
111
+ self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
112
+
113
+ self.ff = FeedForward(
114
+ dim,
115
+ dropout=dropout,
116
+ activation_fn=activation_fn,
117
+ final_dropout=final_dropout,
118
+ inner_dim=ff_inner_dim,
119
+ bias=ff_bias,
120
+ )
121
+
122
+ def forward(
123
+ self,
124
+ hidden_states: torch.Tensor,
125
+ encoder_hidden_states: torch.Tensor,
126
+ temb: torch.Tensor,
127
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
128
+ attention_kwargs: Optional[Dict[str, Any]] = None,
129
+ ) -> torch.Tensor:
130
+ text_seq_length = encoder_hidden_states.size(1)
131
+ attention_kwargs = attention_kwargs or {}
132
+
133
+ # norm & modulate
134
+ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
135
+ hidden_states, encoder_hidden_states, temb
136
+ )
137
+
138
+ # attention
139
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
140
+ hidden_states = norm_hidden_states,
141
+ encoder_hidden_states = norm_encoder_hidden_states,
142
+ image_rotary_emb = image_rotary_emb,
143
+ **attention_kwargs,
144
+ )
145
+
146
+ hidden_states = hidden_states + gate_msa * attn_hidden_states
147
+ encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
148
+
149
+ # norm & modulate
150
+ norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
151
+ hidden_states, encoder_hidden_states, temb
152
+ )
153
+
154
+ # feed-forward
155
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
156
+ ff_output = self.ff(norm_hidden_states)
157
+
158
+ hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
159
+ encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
160
+
161
+ return hidden_states, encoder_hidden_states
162
+
163
+
164
+ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
165
+ """
166
+ A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
167
+
168
+ Parameters:
169
+ num_attention_heads (`int`, defaults to `30`):
170
+ The number of heads to use for multi-head attention.
171
+ attention_head_dim (`int`, defaults to `64`):
172
+ The number of channels in each head.
173
+ in_channels (`int`, defaults to `16`):
174
+ The number of channels in the input.
175
+ out_channels (`int`, *optional*, defaults to `16`):
176
+ The number of channels in the output.
177
+ flip_sin_to_cos (`bool`, defaults to `True`):
178
+ Whether to flip the sin to cos in the time embedding.
179
+ time_embed_dim (`int`, defaults to `512`):
180
+ Output dimension of timestep embeddings.
181
+ ofs_embed_dim (`int`, defaults to `512`):
182
+ Output dimension of "ofs" embeddings used in CogVideoX-5b-I2B in version 1.5
183
+ text_embed_dim (`int`, defaults to `4096`):
184
+ Input dimension of text embeddings from the text encoder.
185
+ num_layers (`int`, defaults to `30`):
186
+ The number of layers of Transformer blocks to use.
187
+ dropout (`float`, defaults to `0.0`):
188
+ The dropout probability to use.
189
+ attention_bias (`bool`, defaults to `True`):
190
+ Whether to use bias in the attention projection layers.
191
+ sample_width (`int`, defaults to `90`):
192
+ The width of the input latents.
193
+ sample_height (`int`, defaults to `60`):
194
+ The height of the input latents.
195
+ sample_frames (`int`, defaults to `49`):
196
+ The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
197
+ instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
198
+ but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
199
+ K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
200
+ patch_size (`int`, defaults to `2`):
201
+ The size of the patches to use in the patch embedding layer.
202
+ temporal_compression_ratio (`int`, defaults to `4`):
203
+ The compression ratio across the temporal dimension. See documentation for `sample_frames`.
204
+ max_text_seq_length (`int`, defaults to `226`):
205
+ The maximum sequence length of the input text embeddings.
206
+ activation_fn (`str`, defaults to `"gelu-approximate"`):
207
+ Activation function to use in feed-forward.
208
+ timestep_activation_fn (`str`, defaults to `"silu"`):
209
+ Activation function to use when generating the timestep embeddings.
210
+ norm_elementwise_affine (`bool`, defaults to `True`):
211
+ Whether to use elementwise affine in normalization layers.
212
+ norm_eps (`float`, defaults to `1e-5`):
213
+ The epsilon value to use in normalization layers.
214
+ spatial_interpolation_scale (`float`, defaults to `1.875`):
215
+ Scaling factor to apply in 3D positional embeddings across spatial dimensions.
216
+ temporal_interpolation_scale (`float`, defaults to `1.0`):
217
+ Scaling factor to apply in 3D positional embeddings across temporal dimensions.
218
+ """
219
+
220
+ _supports_gradient_checkpointing = True
221
+ _no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"]
222
+
223
+ @register_to_config
224
+ def __init__(
225
+ self,
226
+ num_attention_heads: int = 30,
227
+ attention_head_dim: int = 64,
228
+ in_channels: int = 16,
229
+ out_channels: Optional[int] = 16,
230
+ flip_sin_to_cos: bool = True,
231
+ freq_shift: int = 0,
232
+ time_embed_dim: int = 512,
233
+ ofs_embed_dim: Optional[int] = None,
234
+ text_embed_dim: int = 4096,
235
+ num_layers: int = 30,
236
+ dropout: float = 0.0,
237
+ attention_bias: bool = True,
238
+ sample_width: int = 90,
239
+ sample_height: int = 60,
240
+ sample_frames: int = 49,
241
+ patch_size: int = 2,
242
+ patch_size_t: Optional[int] = None,
243
+ temporal_compression_ratio: int = 4,
244
+ max_text_seq_length: int = 226,
245
+ activation_fn: str = "gelu-approximate",
246
+ timestep_activation_fn: str = "silu",
247
+ norm_elementwise_affine: bool = True,
248
+ norm_eps: float = 1e-5,
249
+ spatial_interpolation_scale: float = 1.875,
250
+ temporal_interpolation_scale: float = 1.0,
251
+ use_rotary_positional_embeddings: bool = False,
252
+ use_learned_positional_embeddings: bool = False,
253
+ patch_bias: bool = True,
254
+ extra_encoder_cond_channels: int = -1,
255
+ use_FrameIn: bool = False,
256
+ ):
257
+ super().__init__()
258
+ inner_dim = num_attention_heads * attention_head_dim
259
+
260
+
261
+ # breakpoint()
262
+ # if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
263
+ # raise ValueError(
264
+ # "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
265
+ # "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
266
+ # "issue at https://github.com/huggingface/diffusers/issues."
267
+ # )
268
+
269
+ # 1. Patch embedding
270
+ self.patch_embed = CogVideoXPatchEmbed(
271
+ patch_size = patch_size,
272
+ patch_size_t = patch_size_t,
273
+ in_channels = in_channels,
274
+ embed_dim = inner_dim,
275
+ text_embed_dim = text_embed_dim,
276
+ bias = patch_bias,
277
+ sample_width = sample_width,
278
+ sample_height = sample_height,
279
+ sample_frames = sample_frames,
280
+ temporal_compression_ratio = temporal_compression_ratio,
281
+ max_text_seq_length = max_text_seq_length,
282
+ spatial_interpolation_scale = spatial_interpolation_scale,
283
+ temporal_interpolation_scale = temporal_interpolation_scale,
284
+ use_positional_embeddings = not use_rotary_positional_embeddings, # HACK: use_positional_embeddings is the revert of use_rotary_positional_embeddings
285
+ use_learned_positional_embeddings = use_learned_positional_embeddings,
286
+ extra_encoder_cond_channels = extra_encoder_cond_channels,
287
+ use_FrameIn = use_FrameIn,
288
+ )
289
+ self.embedding_dropout = nn.Dropout(dropout)
290
+
291
+ # 2. Time embeddings and ofs embedding(Only CogVideoX1.5-5B I2V have)
292
+
293
+ self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
294
+ self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
295
+
296
+ self.ofs_proj = None
297
+ self.ofs_embedding = None
298
+ if ofs_embed_dim:
299
+ self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
300
+ self.ofs_embedding = TimestepEmbedding(
301
+ ofs_embed_dim, ofs_embed_dim, timestep_activation_fn
302
+ ) # same as time embeddings, for ofs
303
+
304
+ # 3. Define spatio-temporal transformers blocks
305
+ self.transformer_blocks = nn.ModuleList(
306
+ [
307
+ CogVideoXBlock(
308
+ dim=inner_dim,
309
+ num_attention_heads=num_attention_heads,
310
+ attention_head_dim=attention_head_dim,
311
+ time_embed_dim=time_embed_dim,
312
+ dropout=dropout,
313
+ activation_fn=activation_fn,
314
+ attention_bias=attention_bias,
315
+ norm_elementwise_affine=norm_elementwise_affine,
316
+ norm_eps=norm_eps,
317
+ )
318
+ for _ in range(num_layers)
319
+ ]
320
+ )
321
+ self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
322
+
323
+ # 4. Output blocks
324
+ self.norm_out = AdaLayerNorm(
325
+ embedding_dim=time_embed_dim,
326
+ output_dim=2 * inner_dim,
327
+ norm_elementwise_affine=norm_elementwise_affine,
328
+ norm_eps=norm_eps,
329
+ chunk_dim=1,
330
+ )
331
+
332
+ if patch_size_t is None:
333
+ # For CogVideox 1.0
334
+ output_dim = patch_size * patch_size * out_channels
335
+ else:
336
+ # For CogVideoX 1.5
337
+ output_dim = patch_size * patch_size * patch_size_t * out_channels
338
+
339
+ self.proj_out = nn.Linear(inner_dim, output_dim)
340
+
341
+ self.gradient_checkpointing = False
342
+
343
+ # def _set_gradient_checkpointing(self, module, value=False):
344
+ # self.gradient_checkpointing = value
345
+
346
+ @property
347
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
348
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
349
+ r"""
350
+ Returns:
351
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
352
+ indexed by its weight name.
353
+ """
354
+ # set recursively
355
+ processors = {}
356
+
357
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
358
+ if hasattr(module, "get_processor"):
359
+ processors[f"{name}.processor"] = module.get_processor()
360
+
361
+ for sub_name, child in module.named_children():
362
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
363
+
364
+ return processors
365
+
366
+ for name, module in self.named_children():
367
+ fn_recursive_add_processors(name, module, processors)
368
+
369
+ return processors
370
+
371
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
372
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
373
+ r"""
374
+ Sets the attention processor to use to compute attention.
375
+
376
+ Parameters:
377
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
378
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
379
+ for **all** `Attention` layers.
380
+
381
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
382
+ processor. This is strongly recommended when setting trainable attention processors.
383
+
384
+ """
385
+ count = len(self.attn_processors.keys())
386
+
387
+ if isinstance(processor, dict) and len(processor) != count:
388
+ raise ValueError(
389
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
390
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
391
+ )
392
+
393
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
394
+ if hasattr(module, "set_processor"):
395
+ if not isinstance(processor, dict):
396
+ module.set_processor(processor)
397
+ else:
398
+ module.set_processor(processor.pop(f"{name}.processor"))
399
+
400
+ for sub_name, child in module.named_children():
401
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
402
+
403
+ for name, module in self.named_children():
404
+ fn_recursive_attn_processor(name, module, processor)
405
+
406
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
407
+ def fuse_qkv_projections(self):
408
+ """
409
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
410
+ are fused. For cross-attention modules, key and value projection matrices are fused.
411
+
412
+ <Tip warning={true}>
413
+
414
+ This API is 🧪 experimental.
415
+
416
+ </Tip>
417
+ """
418
+ self.original_attn_processors = None
419
+
420
+ for _, attn_processor in self.attn_processors.items():
421
+ if "Added" in str(attn_processor.__class__.__name__):
422
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
423
+
424
+ self.original_attn_processors = self.attn_processors
425
+
426
+ for module in self.modules():
427
+ if isinstance(module, Attention):
428
+ module.fuse_projections(fuse=True)
429
+
430
+ self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
431
+
432
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
433
+ def unfuse_qkv_projections(self):
434
+ """Disables the fused QKV projection if enabled.
435
+
436
+ <Tip warning={true}>
437
+
438
+ This API is 🧪 experimental.
439
+
440
+ </Tip>
441
+
442
+ """
443
+ if self.original_attn_processors is not None:
444
+ self.set_attn_processor(self.original_attn_processors)
445
+
446
+ def forward(
447
+ self,
448
+ hidden_states: torch.Tensor,
449
+ encoder_hidden_states: torch.Tensor,
450
+ timestep: Union[int, float, torch.LongTensor],
451
+ timestep_cond: Optional[torch.Tensor] = None,
452
+ ofs: Optional[Union[int, float, torch.LongTensor]] = None,
453
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
454
+ attention_kwargs: Optional[Dict[str, Any]] = None,
455
+ return_dict: bool = True,
456
+ ):
457
+
458
+ if attention_kwargs is not None:
459
+ attention_kwargs = attention_kwargs.copy()
460
+ lora_scale = attention_kwargs.pop("scale", 1.0)
461
+ else:
462
+ lora_scale = 1.0
463
+
464
+ if USE_PEFT_BACKEND:
465
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
466
+ scale_lora_layers(self, lora_scale)
467
+ else:
468
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
469
+ logger.warning(
470
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
471
+ )
472
+
473
+
474
+ batch_size, num_frames, channels, height, width = hidden_states.shape
475
+
476
+ # 1. Time embedding
477
+ timesteps = timestep
478
+ t_emb = self.time_proj(timesteps)
479
+
480
+
481
+ # timesteps does not contain any weights and will always return f32 tensors
482
+ # but time_embedding might actually be running in fp16. so we need to cast here.
483
+ # there might be better ways to encapsulate this.
484
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
485
+ emb = self.time_embedding(t_emb, timestep_cond)
486
+
487
+ if self.ofs_embedding is not None:
488
+ ofs_emb = self.ofs_proj(ofs)
489
+ ofs_emb = ofs_emb.to(dtype = hidden_states.dtype)
490
+ ofs_emb = self.ofs_embedding(ofs_emb)
491
+ emb = emb + ofs_emb
492
+
493
+ # 2. Patch embedding
494
+ hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) # Only use patch embedding at the very beginning
495
+ hidden_states = self.embedding_dropout(hidden_states)
496
+
497
+ # HACK: patch_embed embedding is split after Adding with Positional Embedding
498
+ text_seq_length = encoder_hidden_states.shape[1]
499
+ encoder_hidden_states = hidden_states[:, :text_seq_length] # Merged encoder hidden states is split again
500
+ hidden_states = hidden_states[:, text_seq_length:]
501
+
502
+ # 3. Transformer blocks
503
+ for i, block in enumerate(self.transformer_blocks):
504
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
505
+
506
+ def create_custom_forward(module):
507
+ def custom_forward(*inputs):
508
+ return module(*inputs)
509
+
510
+ return custom_forward
511
+
512
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
513
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
514
+ create_custom_forward(block),
515
+ hidden_states,
516
+ encoder_hidden_states,
517
+ emb,
518
+ image_rotary_emb,
519
+ attention_kwargs,
520
+ **ckpt_kwargs,
521
+ )
522
+ else:
523
+ hidden_states, encoder_hidden_states = block(
524
+ hidden_states = hidden_states,
525
+ encoder_hidden_states = encoder_hidden_states,
526
+ temb = emb,
527
+ image_rotary_emb = image_rotary_emb,
528
+ attention_kwargs = attention_kwargs,
529
+ )
530
+
531
+ if not self.config.use_rotary_positional_embeddings:
532
+ # CogVideoX-2B
533
+ hidden_states = self.norm_final(hidden_states)
534
+ else:
535
+ # CogVideoX-5B
536
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
537
+ hidden_states = self.norm_final(hidden_states)
538
+ hidden_states = hidden_states[:, text_seq_length:]
539
+
540
+ # 4. Final block
541
+ hidden_states = self.norm_out(hidden_states, temb=emb)
542
+ hidden_states = self.proj_out(hidden_states)
543
+
544
+ # 5. Unpatchify
545
+ p = self.config.patch_size
546
+ p_t = self.config.patch_size_t
547
+
548
+ if p_t is None:
549
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
550
+ output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
551
+ else:
552
+ output = hidden_states.reshape(
553
+ batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
554
+ )
555
+ output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
556
+
557
+ if USE_PEFT_BACKEND:
558
+ # remove `lora_scale` from each PEFT layer
559
+ unscale_lora_layers(self, lora_scale)
560
+
561
+ if not return_dict:
562
+ return (output,)
563
+ return Transformer2DModelOutput(sample=output)
architecture/embeddings.py ADDED
The diff for this file is too large to render. See raw diff
 
architecture/noise_sampler.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """Modified from https://github.com/THUDM/CogVideo/blob/3710a612d8760f5cdb1741befeebb65b9e0f2fe0/sat/sgm/modules/diffusionmodules/sigma_sampling.py
3
+ """
4
+ import torch
5
+
6
+ class DiscreteSampling:
7
+
8
+ def __init__(self, num_idx, uniform_sampling=False):
9
+ self.num_idx = num_idx
10
+ self.uniform_sampling = uniform_sampling
11
+ self.is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
12
+
13
+ # print("self.is_distributed status is ", self.is_distributed)
14
+ if self.is_distributed and self.uniform_sampling:
15
+ world_size = torch.distributed.get_world_size()
16
+ self.rank = torch.distributed.get_rank()
17
+
18
+ i = 1
19
+ while True:
20
+ if world_size % i != 0 or num_idx % (world_size // i) != 0:
21
+ i += 1
22
+ else:
23
+ self.group_num = world_size // i
24
+ break
25
+ assert self.group_num > 0
26
+ assert world_size % self.group_num == 0
27
+ # the number of rank in one group
28
+ self.group_width = world_size // self.group_num
29
+ self.sigma_interval = self.num_idx // self.group_num
30
+ print('rank=%d world_size=%d group_num=%d group_width=%d sigma_interval=%s' % (
31
+ self.rank, world_size, self.group_num,
32
+ self.group_width, self.sigma_interval))
33
+
34
+
35
+ def __call__(self, n_samples, generator=None, device=None):
36
+
37
+
38
+ if self.is_distributed and self.uniform_sampling:
39
+ group_index = self.rank // self.group_width
40
+ idx = torch.randint(
41
+ group_index * self.sigma_interval,
42
+ (group_index + 1) * self.sigma_interval,
43
+ (n_samples,),
44
+ generator=generator, device=device,
45
+ )
46
+ # print('proc[%d] idx=%s' % (self.rank, idx))
47
+ # print("Uniform sample range is ", group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval)
48
+
49
+ else:
50
+ idx = torch.randint(
51
+ 0, self.num_idx, (n_samples,),
52
+ generator=generator, device=device,
53
+ )
54
+ return idx
architecture/transformer_wan.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Any, Dict, Optional, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
23
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
24
+ from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
25
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
26
+ from diffusers.models.attention import FeedForward
27
+ from diffusers.models.attention_processor import Attention
28
+ from diffusers.models.cache_utils import CacheMixin
29
+ from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
30
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
31
+ from diffusers.models.modeling_utils import ModelMixin
32
+ from diffusers.models.normalization import FP32LayerNorm
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ class WanAttnProcessor2_0:
39
+ def __init__(self):
40
+ if not hasattr(F, "scaled_dot_product_attention"):
41
+ raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
42
+
43
+ def __call__(
44
+ self,
45
+ attn: Attention,
46
+ hidden_states: torch.Tensor,
47
+ encoder_hidden_states: Optional[torch.Tensor] = None,
48
+ attention_mask: Optional[torch.Tensor] = None,
49
+ rotary_emb: Optional[torch.Tensor] = None,
50
+ ) -> torch.Tensor:
51
+ encoder_hidden_states_img = None
52
+ if attn.add_k_proj is not None:
53
+ # 512 is the context length of the text encoder, hardcoded for now
54
+ image_context_length = encoder_hidden_states.shape[1] - 512
55
+ encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
56
+ encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
57
+ if encoder_hidden_states is None:
58
+ encoder_hidden_states = hidden_states
59
+
60
+ query = attn.to_q(hidden_states)
61
+ key = attn.to_k(encoder_hidden_states)
62
+ value = attn.to_v(encoder_hidden_states)
63
+
64
+ if attn.norm_q is not None:
65
+ query = attn.norm_q(query)
66
+ if attn.norm_k is not None:
67
+ key = attn.norm_k(key)
68
+
69
+ query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
70
+ key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
71
+ value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
72
+
73
+ if rotary_emb is not None:
74
+
75
+ def apply_rotary_emb(
76
+ hidden_states: torch.Tensor,
77
+ freqs_cos: torch.Tensor,
78
+ freqs_sin: torch.Tensor,
79
+ ):
80
+ x = hidden_states.view(*hidden_states.shape[:-1], -1, 2)
81
+ x1, x2 = x[..., 0], x[..., 1]
82
+ cos = freqs_cos[..., 0::2]
83
+ sin = freqs_sin[..., 1::2]
84
+ out = torch.empty_like(hidden_states)
85
+ out[..., 0::2] = x1 * cos - x2 * sin
86
+ out[..., 1::2] = x1 * sin + x2 * cos
87
+ return out.type_as(hidden_states)
88
+
89
+ query = apply_rotary_emb(query, *rotary_emb)
90
+ key = apply_rotary_emb(key, *rotary_emb)
91
+
92
+ # I2V task
93
+ hidden_states_img = None
94
+ if encoder_hidden_states_img is not None:
95
+ key_img = attn.add_k_proj(encoder_hidden_states_img)
96
+ key_img = attn.norm_added_k(key_img)
97
+ value_img = attn.add_v_proj(encoder_hidden_states_img)
98
+
99
+ key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
100
+ value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
101
+
102
+ hidden_states_img = F.scaled_dot_product_attention(
103
+ query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
104
+ )
105
+ hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
106
+ hidden_states_img = hidden_states_img.type_as(query)
107
+
108
+ hidden_states = F.scaled_dot_product_attention(
109
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
110
+ )
111
+ hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
112
+ hidden_states = hidden_states.type_as(query)
113
+
114
+ if hidden_states_img is not None:
115
+ hidden_states = hidden_states + hidden_states_img
116
+
117
+ hidden_states = attn.to_out[0](hidden_states)
118
+ hidden_states = attn.to_out[1](hidden_states)
119
+ return hidden_states
120
+
121
+
122
+ class WanImageEmbedding(torch.nn.Module):
123
+ def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
124
+ super().__init__()
125
+
126
+ self.norm1 = FP32LayerNorm(in_features)
127
+ self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
128
+ self.norm2 = FP32LayerNorm(out_features)
129
+ if pos_embed_seq_len is not None:
130
+ self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
131
+ else:
132
+ self.pos_embed = None
133
+
134
+ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
135
+ if self.pos_embed is not None:
136
+ batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
137
+ encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
138
+ encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
139
+
140
+ hidden_states = self.norm1(encoder_hidden_states_image)
141
+ hidden_states = self.ff(hidden_states)
142
+ hidden_states = self.norm2(hidden_states)
143
+ return hidden_states
144
+
145
+
146
+ class WanTimeTextImageEmbedding(nn.Module):
147
+ def __init__(
148
+ self,
149
+ dim: int,
150
+ time_freq_dim: int,
151
+ time_proj_dim: int,
152
+ text_embed_dim: int,
153
+ image_embed_dim: Optional[int] = None,
154
+ pos_embed_seq_len: Optional[int] = None,
155
+ ):
156
+ super().__init__()
157
+
158
+ self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
159
+ self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
160
+ self.act_fn = nn.SiLU()
161
+ self.time_proj = nn.Linear(dim, time_proj_dim)
162
+ self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
163
+
164
+ self.image_embedder = None
165
+ if image_embed_dim is not None:
166
+ self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
167
+
168
+ def forward(
169
+ self,
170
+ timestep: torch.Tensor,
171
+ encoder_hidden_states: torch.Tensor,
172
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
173
+ timestep_seq_len: Optional[int] = None,
174
+ ):
175
+ timestep = self.timesteps_proj(timestep)
176
+ if timestep_seq_len is not None:
177
+ timestep = timestep.unflatten(0, (1, timestep_seq_len))
178
+
179
+ time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
180
+ if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
181
+ timestep = timestep.to(time_embedder_dtype)
182
+ temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
183
+ timestep_proj = self.time_proj(self.act_fn(temb))
184
+
185
+ encoder_hidden_states = self.text_embedder(encoder_hidden_states)
186
+ if encoder_hidden_states_image is not None:
187
+ encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
188
+
189
+ return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
190
+
191
+
192
+ class WanRotaryPosEmbed(nn.Module):
193
+ def __init__(
194
+ self,
195
+ attention_head_dim: int,
196
+ patch_size: Tuple[int, int, int],
197
+ max_seq_len: int,
198
+ theta: float = 10000.0,
199
+ ):
200
+ super().__init__()
201
+
202
+ self.attention_head_dim = attention_head_dim
203
+ self.patch_size = patch_size
204
+ self.max_seq_len = max_seq_len
205
+
206
+ h_dim = w_dim = 2 * (attention_head_dim // 6)
207
+ t_dim = attention_head_dim - h_dim - w_dim
208
+ freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
209
+
210
+ freqs_cos = []
211
+ freqs_sin = []
212
+
213
+ for dim in [t_dim, h_dim, w_dim]:
214
+ freq_cos, freq_sin = get_1d_rotary_pos_embed(
215
+ dim,
216
+ max_seq_len,
217
+ theta,
218
+ use_real=True,
219
+ repeat_interleave_real=True,
220
+ freqs_dtype=freqs_dtype,
221
+ )
222
+ freqs_cos.append(freq_cos)
223
+ freqs_sin.append(freq_sin)
224
+
225
+ self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
226
+ self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
227
+
228
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
229
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
230
+ p_t, p_h, p_w = self.patch_size
231
+ ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
232
+
233
+ split_sizes = [
234
+ self.attention_head_dim - 2 * (self.attention_head_dim // 3),
235
+ self.attention_head_dim // 3,
236
+ self.attention_head_dim // 3,
237
+ ]
238
+
239
+ freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
240
+ freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
241
+
242
+ freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
243
+ freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
244
+ freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
245
+
246
+ freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
247
+ freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
248
+ freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
249
+
250
+ freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
251
+ freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
252
+
253
+ return freqs_cos, freqs_sin
254
+
255
+
256
+ @maybe_allow_in_graph
257
+ class WanTransformerBlock(nn.Module):
258
+ def __init__(
259
+ self,
260
+ dim: int,
261
+ ffn_dim: int,
262
+ num_heads: int,
263
+ qk_norm: str = "rms_norm_across_heads",
264
+ cross_attn_norm: bool = False,
265
+ eps: float = 1e-6,
266
+ added_kv_proj_dim: Optional[int] = None,
267
+ ):
268
+ super().__init__()
269
+
270
+ # 1. Self-attention
271
+ self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
272
+ self.attn1 = Attention(
273
+ query_dim=dim,
274
+ heads=num_heads,
275
+ kv_heads=num_heads,
276
+ dim_head=dim // num_heads,
277
+ qk_norm=qk_norm,
278
+ eps=eps,
279
+ bias=True,
280
+ cross_attention_dim=None,
281
+ out_bias=True,
282
+ processor=WanAttnProcessor2_0(),
283
+ )
284
+
285
+ # 2. Cross-attention
286
+ self.attn2 = Attention(
287
+ query_dim=dim,
288
+ heads=num_heads,
289
+ kv_heads=num_heads,
290
+ dim_head=dim // num_heads,
291
+ qk_norm=qk_norm,
292
+ eps=eps,
293
+ bias=True,
294
+ cross_attention_dim=None,
295
+ out_bias=True,
296
+ added_kv_proj_dim=added_kv_proj_dim,
297
+ added_proj_bias=True,
298
+ processor=WanAttnProcessor2_0(),
299
+ )
300
+ self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
301
+
302
+ # 3. Feed-forward
303
+ self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
304
+ self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
305
+
306
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
307
+
308
+ def forward(
309
+ self,
310
+ hidden_states: torch.Tensor,
311
+ encoder_hidden_states: torch.Tensor,
312
+ temb: torch.Tensor,
313
+ rotary_emb: torch.Tensor,
314
+ ) -> torch.Tensor:
315
+ if temb.ndim == 4:
316
+ # temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
317
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
318
+ self.scale_shift_table.unsqueeze(0) + temb.float()
319
+ ).chunk(6, dim=2)
320
+ # batch_size, seq_len, 1, inner_dim
321
+ shift_msa = shift_msa.squeeze(2)
322
+ scale_msa = scale_msa.squeeze(2)
323
+ gate_msa = gate_msa.squeeze(2)
324
+ c_shift_msa = c_shift_msa.squeeze(2)
325
+ c_scale_msa = c_scale_msa.squeeze(2)
326
+ c_gate_msa = c_gate_msa.squeeze(2)
327
+ else:
328
+ # temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
329
+ shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
330
+ self.scale_shift_table + temb.float()
331
+ ).chunk(6, dim=1)
332
+
333
+ # 1. Self-attention
334
+ norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
335
+ attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
336
+ hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
337
+
338
+ # 2. Cross-attention
339
+ norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
340
+ attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
341
+ hidden_states = hidden_states + attn_output
342
+
343
+ # 3. Feed-forward
344
+ norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
345
+ hidden_states
346
+ )
347
+ ff_output = self.ffn(norm_hidden_states)
348
+ hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
349
+
350
+ return hidden_states
351
+
352
+
353
+ class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
354
+ r"""
355
+ A Transformer model for video-like data used in the Wan model.
356
+
357
+ Args:
358
+ patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
359
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
360
+ num_attention_heads (`int`, defaults to `40`):
361
+ Fixed length for text embeddings.
362
+ attention_head_dim (`int`, defaults to `128`):
363
+ The number of channels in each head.
364
+ in_channels (`int`, defaults to `16`):
365
+ The number of channels in the input.
366
+ out_channels (`int`, defaults to `16`):
367
+ The number of channels in the output.
368
+ text_dim (`int`, defaults to `512`):
369
+ Input dimension for text embeddings.
370
+ freq_dim (`int`, defaults to `256`):
371
+ Dimension for sinusoidal time embeddings.
372
+ ffn_dim (`int`, defaults to `13824`):
373
+ Intermediate dimension in feed-forward network.
374
+ num_layers (`int`, defaults to `40`):
375
+ The number of layers of transformer blocks to use.
376
+ window_size (`Tuple[int]`, defaults to `(-1, -1)`):
377
+ Window size for local attention (-1 indicates global attention).
378
+ cross_attn_norm (`bool`, defaults to `True`):
379
+ Enable cross-attention normalization.
380
+ qk_norm (`bool`, defaults to `True`):
381
+ Enable query/key normalization.
382
+ eps (`float`, defaults to `1e-6`):
383
+ Epsilon value for normalization layers.
384
+ add_img_emb (`bool`, defaults to `False`):
385
+ Whether to use img_emb.
386
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
387
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
388
+ """
389
+
390
+ _supports_gradient_checkpointing = True
391
+ _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
392
+ _no_split_modules = ["WanTransformerBlock"]
393
+ _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
394
+ _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
395
+ _repeated_blocks = ["WanTransformerBlock"]
396
+
397
+ @register_to_config
398
+ def __init__(
399
+ self,
400
+ patch_size: Tuple[int] = (1, 2, 2),
401
+ num_attention_heads: int = 40,
402
+ attention_head_dim: int = 128,
403
+ in_channels: int = 16,
404
+ out_channels: int = 16,
405
+ text_dim: int = 4096,
406
+ freq_dim: int = 256,
407
+ ffn_dim: int = 13824,
408
+ num_layers: int = 40,
409
+ cross_attn_norm: bool = True,
410
+ qk_norm: Optional[str] = "rms_norm_across_heads",
411
+ eps: float = 1e-6,
412
+ image_dim: Optional[int] = None,
413
+ added_kv_proj_dim: Optional[int] = None,
414
+ rope_max_seq_len: int = 1024,
415
+ pos_embed_seq_len: Optional[int] = None,
416
+ ) -> None:
417
+ super().__init__()
418
+
419
+ inner_dim = num_attention_heads * attention_head_dim
420
+ out_channels = out_channels or in_channels
421
+
422
+ # 1. Patch & position embedding
423
+ self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
424
+ self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
425
+
426
+ # 2. Condition embeddings
427
+ # image_embedding_dim=1280 for I2V model
428
+ self.condition_embedder = WanTimeTextImageEmbedding(
429
+ dim=inner_dim,
430
+ time_freq_dim=freq_dim,
431
+ time_proj_dim=inner_dim * 6,
432
+ text_embed_dim=text_dim,
433
+ image_embed_dim=image_dim,
434
+ pos_embed_seq_len=pos_embed_seq_len,
435
+ )
436
+
437
+ # 3. Transformer blocks
438
+ self.blocks = nn.ModuleList(
439
+ [
440
+ WanTransformerBlock(
441
+ inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
442
+ )
443
+ for _ in range(num_layers)
444
+ ]
445
+ )
446
+
447
+ # 4. Output norm & projection
448
+ self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
449
+ self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
450
+ self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
451
+
452
+ self.gradient_checkpointing = False
453
+
454
+ def forward(
455
+ self,
456
+ hidden_states: torch.Tensor,
457
+ timestep: torch.LongTensor,
458
+ encoder_hidden_states: torch.Tensor,
459
+ encoder_hidden_states_image: Optional[torch.Tensor] = None,
460
+ return_dict: bool = True,
461
+ attention_kwargs: Optional[Dict[str, Any]] = None,
462
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
463
+ if attention_kwargs is not None:
464
+ attention_kwargs = attention_kwargs.copy()
465
+ lora_scale = attention_kwargs.pop("scale", 1.0)
466
+ else:
467
+ lora_scale = 1.0
468
+
469
+ if USE_PEFT_BACKEND:
470
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
471
+ scale_lora_layers(self, lora_scale)
472
+ else:
473
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
474
+ logger.warning(
475
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
476
+ )
477
+
478
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
479
+ p_t, p_h, p_w = self.config.patch_size
480
+ post_patch_num_frames = num_frames // p_t
481
+ post_patch_height = height // p_h
482
+ post_patch_width = width // p_w
483
+
484
+ rotary_emb = self.rope(hidden_states)
485
+
486
+ hidden_states = self.patch_embedding(hidden_states)
487
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
488
+
489
+ # timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
490
+ if timestep.ndim == 2:
491
+ ts_seq_len = timestep.shape[1]
492
+ timestep = timestep.flatten() # batch_size * seq_len
493
+ else:
494
+ ts_seq_len = None
495
+
496
+ temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
497
+ timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
498
+ )
499
+ if ts_seq_len is not None:
500
+ # batch_size, seq_len, 6, inner_dim
501
+ timestep_proj = timestep_proj.unflatten(2, (6, -1))
502
+ else:
503
+ # batch_size, 6, inner_dim
504
+ timestep_proj = timestep_proj.unflatten(1, (6, -1))
505
+
506
+ if encoder_hidden_states_image is not None:
507
+ encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
508
+
509
+ # 4. Transformer blocks
510
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
511
+ for block in self.blocks:
512
+ hidden_states = self._gradient_checkpointing_func(
513
+ block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
514
+ )
515
+ else:
516
+ for block in self.blocks:
517
+ hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
518
+
519
+ # 5. Output norm, projection & unpatchify
520
+ if temb.ndim == 3:
521
+ # batch_size, seq_len, inner_dim (wan 2.2 ti2v)
522
+ shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
523
+ shift = shift.squeeze(2)
524
+ scale = scale.squeeze(2)
525
+ else:
526
+ # batch_size, inner_dim
527
+ shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
528
+
529
+ # Move the shift and scale tensors to the same device as hidden_states.
530
+ # When using multi-GPU inference via accelerate these will be on the
531
+ # first device rather than the last device, which hidden_states ends up
532
+ # on.
533
+ shift = shift.to(hidden_states.device)
534
+ scale = scale.to(hidden_states.device)
535
+
536
+ hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
537
+ hidden_states = self.proj_out(hidden_states)
538
+
539
+ hidden_states = hidden_states.reshape(
540
+ batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
541
+ )
542
+ hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
543
+ output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
544
+
545
+ if USE_PEFT_BACKEND:
546
+ # remove `lora_scale` from each PEFT layer
547
+ unscale_lora_layers(self, lora_scale)
548
+
549
+ if not return_dict:
550
+ return (output,)
551
+
552
+ return Transformer2DModelOutput(sample=output)
config/accelerate_config_4GPU.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "compute_environment": "LOCAL_MACHINE",
3
+ "debug": false,
4
+ "distributed_type": "MULTI_GPU",
5
+ "downcast_bf16": "no",
6
+ "gpu_ids": "all",
7
+ "machine_rank": 0,
8
+ "main_training_function": "main",
9
+ "mixed_precision": "bf16",
10
+ "num_machines": 1,
11
+ "num_processes": 4,
12
+ "rdzv_backend": "static",
13
+ "same_network": true,
14
+ "tpu_env": [],
15
+ "tpu_use_cluster": false,
16
+ "tpu_use_sudo": false,
17
+ "use_cpu": false
18
+ }
config/train_cogvideox_motion.yaml ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ experiment_name: CogVideoX_5B_Motion_480P # Store Folder Name
3
+
4
+
5
+ # Model Setting
6
+ base_model_path: zai-org/CogVideoX-5b-I2V
7
+ pretrained_transformer_path: # No need to set; if you set, this will load transformer model with non-default Wan transformer
8
+ enable_slicing: True
9
+ enable_tiling: True
10
+ use_learned_positional_embeddings: True
11
+ use_rotary_positional_embeddings: True
12
+
13
+
14
+
15
+ # Dataset Setting
16
+ download_folder_path: FrameINO_data/ # Set the downloaded folder path, all the other csv will be read automatically
17
+ train_csv_relative_path: dataset_csv_files/train_sample_short_dataset # No need to change, Fixed
18
+ train_video_relative_path: video_dataset/train_sample_dataset # No need to change, Fixed
19
+ validation_csv_relative_path: dataset_csv_files/val_sample_short_dataset # No need to change, Fixed
20
+ validation_video_relative_path: video_dataset/val_sample_dataset # No need to change, Fixed
21
+ dataloader_num_workers: 4 # This should be per GPU In Debug, we set to 1
22
+ # height_range: [480, 480] # Height Range; By slightly modify the dataloader code and use this setting, we can use variable resolution training
23
+ target_height: 480
24
+ target_width: 720
25
+ sample_accelerate_factor: 2 # Imitate 12FPS we have set before.
26
+ train_frame_num_range: [49, 49] # Number of frames for the trianing, required to be 4N+1
27
+ # min_train_frame_num: 49 # If it is less than this number, the dataloader will raise Exception and skip to the next one valid!
28
+
29
+
30
+ # Motion Setting
31
+ dot_radius: 6 # This is set with respect to 384 height pixel, will be adjust based on the height change
32
+ point_keep_ratio: 0.4 # The ratio of points left; Likelyhood by random.choices for each tracking point, so it can be quite versatile; 0.33 is also recommended
33
+ faster_motion_prob: 0.0 # Whether we support faster (~8FPS), 0.0 - 0.1 is also recomended (0.0 by default).
34
+
35
+
36
+ # Denoise + Text Setting
37
+ noised_image_dropout: 0.05 # No First Frame Setting, becomes T2V
38
+ empty_text_prompt: False # FOR TI2V, we needs to use text prompt
39
+ text_mask_ratio: 0.05 # Follow InstructPix2Pix
40
+ max_text_seq_length: 226
41
+
42
+
43
+ # Training Setting
44
+ resume_from_checkpoint: False # latest / False; latest will automatically fetch the newest checkpoint
45
+ max_train_steps: 1002 # Based on the needs; This is just a demo dataset, so training low is not needed
46
+ train_batch_size: 1 # batch size per GPU
47
+ gradient_accumulation_steps: 2 # Equivalent to multi batch size; Total GPU
48
+ checkpointing_steps: 2000 # Check point frequeuncy, don't recommend to be too frequent
49
+ checkpoints_total_limit: 8 # Transformer are too large, this size is too big (~32 GB per checkpoint)
50
+ mixed_precision: bf16 # CogvideoX official code usaully use bf16
51
+ gradient_checkpointing: True # This will save the memory but slower; Even if I have 80GB memory, this is still needed to open; else, OOM
52
+ seed: # If we set seed here, the reading of the data in each resume will be the same as the first time, which cannot train full dataset in resume mode
53
+ output_folder: checkpoints/
54
+ logging_name: logging
55
+ nccl_timeout: 1800
56
+
57
+
58
+ # Validation Setting
59
+ validation_step: 2000 # Don't set too frequent, which will be very resource consuming
60
+ first_iter_validation: True # Whether we do the first iter validation
61
+ num_inference_steps: 50
62
+
63
+
64
+ # Learning Rate and Optimizer
65
+ optimizer: adamw # Choose between ["adam", "adamw", "prodigy"]
66
+ learning_rate: 2e-5 # 1e-4 might be too big
67
+ scale_lr: False
68
+ lr_scheduler: constant_with_warmup # Most cases should be constant
69
+ adam_beta1: 0.9
70
+ adam_beta2: 0.95 # In the past, this used to be 0.999; smaller than usual
71
+ adam_beta3: 0.98
72
+ lr_power: 1.0
73
+ lr_num_cycles: 1.0
74
+ max_grad_norm: 1.0
75
+ prodigy_beta3: # Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2
76
+ adam_weight_decay: 1e-04
77
+ adam_epsilon: 1e-08
78
+ lr_warmup_steps: 400
79
+
80
+
81
+
82
+ # Other Setting
83
+ report_to: tensorboard
84
+ allow_tf32: True
85
+ revision:
86
+ variant:
87
+ cache_dir:
88
+ tracker_name:
config/train_cogvideox_motion_FrameINO.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ experiment_name: CogVideoX_5B_Motion_FINO_480P
3
+
4
+ # Model Setting
5
+ base_model_path: zai-org/CogVideoX-5b-I2V
6
+ pretrained_transformer_path: uva-cv-lab/FrameINO_CogVideoX_Stage1_Motion_v1.0 # Use the stage1 weight here; if you use your trained weight, it should go to the transformer folder (TODO: needs to check this)
7
+ enable_slicing: True
8
+ enable_tiling: True
9
+ use_learned_positional_embeddings: True
10
+ use_rotary_positional_embeddings: True
11
+
12
+
13
+
14
+ # Dataset Setting
15
+ download_folder_path: FrameINO_data/ # Set the downloaded folder path, all the other csv will be read automatically
16
+ train_csv_relative_path: dataset_csv_files/train_sample_short_dataset # No need to change, Fixed
17
+ train_video_relative_path: video_dataset/train_sample_dataset # No need to change, Fixed
18
+ train_ID_relative_path: video_dataset/train_ID_FrameIn # No need to change, Fixed
19
+ validation_csv_relative_path: dataset_csv_files/val_sample_short_dataset # No need to change, Fixed
20
+ validation_video_relative_path: video_dataset/val_sample_dataset # No need to change, Fixed
21
+ validation_ID_relative_path: video_dataset/val_ID_FrameIn # No need to change, Fixed
22
+ dataloader_num_workers: 4 # This should be per GPU
23
+ # height_range: [480, 704] # Height Range; By slightly modify the dataloader code and use this setting, we can use variable resolution training
24
+ target_height: 480
25
+ target_width: 720
26
+ sample_accelerate_factor: 2 # Imitate 12FPS we have set before.
27
+ train_frame_num_range: [49, 49] # Number of frames for the trianing, required to be 4N+1
28
+ min_train_frame_num: 49 # If it is less than this number, the dataloader will raise Exception and skip to the next one valid! We recommand CogVideoX to use exactly 49 frames.
29
+
30
+
31
+ # Motion Setting
32
+ dot_radius: 6 # This is set with respect to 384 height pixel, will be adjust based on the height change
33
+ point_keep_ratio_regular: 0.33 # Less points than motion control; The Ratio of points left for points inside the region box; For Non-main Object Motion
34
+ faster_motion_prob: 0.0 # Whether we support faster (~8FPS), 0.0 - 0.1 is also recomended (0.0 by default).
35
+
36
+
37
+ # Frame In and Out Setting
38
+ drop_FrameIn_prob: 0.15 # This is the cases where we only has FrameOut occur, FrameIn will be whole whilte place holder (Recommend: 0.15)
39
+ point_keep_ratio_ID: 0.33 # The Ratio of points left for new ID introduced
40
+
41
+
42
+ # Denoise + Text Setting
43
+ noised_image_dropout: 0.05 # No First Frame Setting, becomes T2V
44
+ empty_text_prompt: False # FOR TI2V, we needs to use text prompt
45
+ text_mask_ratio: 0.05 # Follow InstructPix2Pix
46
+ max_text_seq_length: 226
47
+
48
+
49
+ # Training Setting
50
+ resume_from_checkpoint: False # latest / False; latest will automatically fetch the newest checkpoint
51
+ max_train_steps: 1002 # Based on the needs; This is just a demo dataset, so training low is not needed
52
+ train_batch_size: 1 # batch size per GPU
53
+ gradient_accumulation_steps: 2 # This should be set to 1 usually.
54
+ checkpointing_steps: 2000 # Check point frequeuncy, don't recommend to be too frequent
55
+ checkpoints_total_limit: 8 # Transformer are too large, this size is too big (~32 GB per checkpoint)
56
+ mixed_precision: bf16 # CogvideoX official code usaully use bf16
57
+ gradient_checkpointing: True # This will save the memory but slower; Even if I have 80GB memory, this is still needed to open; else, OOM
58
+ seed: # 如果这里set seed了;你每次resume都跟resume前的data 读取顺序完全一致;如果连一个epoch都没train,那就每次同样数据循环
59
+ output_folder: checkpoints/
60
+ logging_name: logging
61
+ nccl_timeout: 1800
62
+
63
+
64
+ # Validation Setting
65
+ validation_step: 2000 # Don't set too frequent, which will be very resource consuming
66
+ first_iter_validation: True # Whether we do the first iter validation
67
+ num_inference_steps: 50
68
+
69
+
70
+
71
+ # Learning Rate and Optimizer
72
+ optimizer: adamw # Choose between ["adam", "adamw", "prodigy"]
73
+ learning_rate: 2e-5 # 1e-4 might be too big
74
+ scale_lr: False
75
+ lr_scheduler: constant_with_warmup # Most cases should be constant
76
+ adam_beta1: 0.9
77
+ adam_beta2: 0.95 # In the past, this used to be 0.999; smaller than usual
78
+ adam_beta3: 0.98
79
+ lr_power: 1.0
80
+ lr_num_cycles: 1.0
81
+ max_grad_norm: 1.0
82
+ prodigy_beta3: # Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2
83
+ # use_8bit_adam: False # This saves a lot of GPU memory, but slightly slower
84
+ adam_weight_decay: 1e-04
85
+ adam_epsilon: 1e-08
86
+ lr_warmup_steps: 400
87
+
88
+
89
+
90
+ # Other Setting
91
+ report_to: tensorboard
92
+ allow_tf32: True
93
+ revision:
94
+ variant:
95
+ cache_dir:
96
+ tracker_name:
config/train_wan_motion.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ experiment_name: Wan_5B_Motion_704P
3
+
4
+
5
+ # Model Setting
6
+ base_model_path: Wan-AI/Wan2.2-TI2V-5B-Diffusers
7
+ pretrained_transformer_path: # No need to set; if you set, this will load transformer model with non-default Wan transformer
8
+ enable_slicing: True
9
+ enable_tiling: True
10
+
11
+
12
+
13
+ # Dataset Setting
14
+ download_folder_path: FrameINO_data/ # Set the downloaded folder path, all the other csv will be read automatically
15
+ train_csv_relative_path: dataset_csv_files/train_sample_short_dataset # No need to change, Fixed
16
+ train_video_relative_path: video_dataset/train_sample_dataset # No need to change, Fixed
17
+ validation_csv_relative_path: dataset_csv_files/val_sample_short_dataset # No need to change, Fixed
18
+ validation_video_relative_path: video_dataset/val_sample_dataset # No need to change, Fixed
19
+ dataloader_num_workers: 4 # This should be per GPU; In Debug, we set to 1
20
+ # height_range: [480, 704] # Height Range; By slightly modify the dataloader code and use this setting, we can use variable resolution training
21
+ target_height: 704
22
+ target_width: 1280
23
+ sample_accelerate_factor: 2 # Imitate 12FPS we have set before.
24
+ train_frame_num_range: [81, 81] # Number of frames for the trianing, required to be 4N+1; If the total number of files is less than the min range, just use the minimum available; Now, set to 81 Frames
25
+ # min_train_frame_num: 49 # If it is less than this number, the dataloader will raise Exception and skip to the next one valid!
26
+
27
+
28
+ # Motion Setting
29
+ dot_radius: 7 # Due to the VAE of Wan, this is slightly larger than CogVideoX; this is set with respect to 384 height pixel, will be adjust based on the height change
30
+ point_keep_ratio: 0.4 # The ratio of points left; Likelyhood by random.choices for each tracking point, so it can be quite versatile; 0.33 is also recommended
31
+ faster_motion_prob: 0.0 # Whether we support faster (~8FPS), 0.0 - 0.1 is also recomended (0.0 by default).
32
+
33
+
34
+ # Denoise (For Flow Matchin-based)
35
+ noised_image_dropout: 0.0 # No First Frame Setting, becomes T2V; not used for Wan
36
+ train_sampling_steps: 1000
37
+ noise_scheduler_kwargs:
38
+ num_train_timesteps: 1000 # 1000 is the default value
39
+ shift: 5.0
40
+ use_dynamic_shifting: false # false is the default value
41
+ base_shift: 0.5 # 0.5 is the default value
42
+ max_shift: 1.15 # 1.15 is the default value
43
+ base_image_seq_len: 256 # 256 is the default value
44
+ max_image_seq_len: 4096 # 4096 is the default value
45
+
46
+
47
+ # Text Setting
48
+ text_mask_ratio: 0.0 # Follow InstructPix2Pix
49
+ empty_text_prompt: False # FOR TI2V, we start using text prompt
50
+ max_text_seq_length: 512 # For the Wan
51
+
52
+
53
+
54
+ # Training Setting
55
+ resume_from_checkpoint: False # latest / False; latest will automatically fetch the newest checkpoint
56
+ max_train_steps: 1002 # Based on the needs; This is just a demo dataset, so training low is not needed
57
+ train_batch_size: 1 # batch size per GPU
58
+ gradient_accumulation_steps: 2 # Equivalent to multi batch size; Total GPU
59
+ checkpointing_steps: 2000 # Check point frequeuncy, don't recommend to be too frequent
60
+ checkpoints_total_limit: 8 # Transformer are too large, this size is too big (~32 GB per checkpoint)
61
+ mixed_precision: bf16 # CogvideoX official code usaully use bf16
62
+ gradient_checkpointing: True # This will save the memory but slower; Even if I have 80GB memory, this is still needed to open; else, OOM
63
+ seed: # If we set seed here, the reading of the data in each resume will be the same as the first time, which cannot train full dataset in resume mode
64
+ output_folder: checkpoints/
65
+ logging_name: logging
66
+ nccl_timeout: 1800
67
+
68
+
69
+
70
+ # Validation Setting
71
+ validation_step: 2000 # Don't set too frequent, which will be very resource consuming
72
+ first_iter_validation: True # Whether we do the first iter validation
73
+ num_inference_steps: 38
74
+
75
+
76
+
77
+ # Learning Rate and Optimizer
78
+ optimizer: adamw # Choose between ["adam", "adamw", "prodigy"]
79
+ learning_rate: 3e-5 # 1e-4 might be too big
80
+ scale_lr: False
81
+ lr_scheduler: constant_with_warmup # Most cases should be constant
82
+ adam_beta1: 0.9 # This Setting is different from CogVideoX, we follow VideoFun
83
+ adam_beta2: 0.999
84
+ # adam_beta3: 0.98
85
+ lr_power: 1.0
86
+ lr_num_cycles: 1.0
87
+ initial_grad_norm_ratio: 5
88
+ abnormal_norm_clip_start: 1000 # Follow VideoFun
89
+ max_grad_norm: 0.05 # Follow VideoFun
90
+ prodigy_beta3: # Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2
91
+ # use_8bit_adam: False # This saves a lot of GPU memory, but slightly slower; Recommend to open
92
+ adam_weight_decay: 1e-4
93
+ adam_epsilon: 1e-10
94
+ lr_warmup_steps: 100
95
+
96
+
97
+
98
+ # Other Setting
99
+ report_to: tensorboard
100
+ allow_tf32: True
101
+ revision:
102
+ variant:
103
+ cache_dir:
104
+ tracker_name:
config/train_wan_motion_FrameINO.yaml ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ experiment_name: Wan_5B_Motion_FINO_704P
3
+
4
+ # Model Setting
5
+ base_model_path: Wan-AI/Wan2.2-TI2V-5B-Diffusers
6
+ pretrained_transformer_path: uva-cv-lab/FrameINO_Wan2.2_5B_Stage1_Motion_v1.5 # Use the one trained with the motion
7
+ enable_slicing: True
8
+ enable_tiling: True
9
+
10
+
11
+
12
+ # Dataset Setting
13
+ download_folder_path: FrameINO_data/ # Set the downloaded folder path, all the other csv will be read automatically
14
+ train_csv_relative_path: dataset_csv_files/train_sample_short_dataset # No need to change, Fixed
15
+ train_video_relative_path: video_dataset/train_sample_dataset # No need to change, Fixed
16
+ train_ID_relative_path: video_dataset/train_ID_FrameIn # No need to change, Fixed
17
+ validation_csv_relative_path: dataset_csv_files/val_sample_short_dataset # No need to change, Fixed
18
+ validation_video_relative_path: video_dataset/val_sample_dataset # No need to change, Fixed
19
+ validation_ID_relative_path: video_dataset/val_ID_FrameIn # No need to change, Fixed
20
+ dataloader_num_workers: 4 # This should be per GPU In Debug, we set to 1
21
+ # height_range: [480, 704] # Height Range; By slightly modify the dataloader code and use this setting, we can use variable resolution training
22
+ target_height: 704 # Recommend 704 x 1280 for the Wan2.2
23
+ target_width: 1280
24
+ sample_accelerate_factor: 2 # Imitate 12FPS we have set before.
25
+ train_frame_num_range: [81, 81] # Number of frames for the trianing, required to be 4N+1
26
+ min_train_frame_num: 49 # If it is less than this number, the dataloader will raise Exception and skip to the next one valid!
27
+
28
+
29
+ # Motion Setting
30
+ dot_radius: 7 # Due to VAE of Wan, this is slightly larger than CogVideoX; this is set with respect to 384 height pixel, will be adjust based on the height change
31
+ point_keep_ratio_regular: 0.33 # Less points than motion control; The Ratio of points left for points inside the region box; For Non-main Object Motion
32
+ faster_motion_prob: 0.0 # Whether we support faster (~8FPS), 0.0 - 0.1 is also recomended (0.0 by default).
33
+
34
+
35
+ # Frame In and Out Setting
36
+ drop_FrameIn_prob: 0.15 # This is the cases where we only has FrameOut occur; ID tokens will be filled with whole whilte place holder (Recommend value: 0.15)
37
+ point_keep_ratio_ID: 0.33 # The Ratio of points left for new ID introduced; For Main ID Object Motion
38
+
39
+
40
+ # Denoise
41
+ noised_image_dropout: 0.0 # No First Frame Setting, becomes T2V; not used for Wan
42
+ train_sampling_steps: 1000
43
+ noise_scheduler_kwargs:
44
+ num_train_timesteps: 1000 # 1000 is the default value
45
+ shift: 5.0
46
+ use_dynamic_shifting: false # false is the default value
47
+ base_shift: 0.5 # 0.5 is the default value
48
+ max_shift: 1.15 # 1.15 is the default value
49
+ base_image_seq_len: 256 # 256 is the default value
50
+ max_image_seq_len: 4096 # 4096 is the default value
51
+
52
+
53
+ # Text Setting
54
+ text_mask_ratio: 0.0 # Follow InstructPix2Pix, Currently, we set to 0; At most 0.05 is recommeneded
55
+ empty_text_prompt: False # FOR TI2V, we needs to use text prompt
56
+ max_text_seq_length: 512 # For the Wan
57
+
58
+
59
+
60
+ # Training setting
61
+ resume_from_checkpoint: False # latest / False; latest will automatically fetch the newest checkpoint
62
+ max_train_steps: 1002 # Based on the needs; This is just a demo dataset, so training low is not needed
63
+ train_batch_size: 1 # batch size per GPU
64
+ gradient_accumulation_steps: 2 # This should be set to 1 usually.
65
+ checkpointing_steps: 2000 # Check point frequeuncy, don't recommend to be too frequent
66
+ checkpoints_total_limit: 8 # Transformer are too large, this size is too big (~32 GB per checkpoint)
67
+ mixed_precision: bf16 # CogvideoX official code usaully use bf16
68
+ gradient_checkpointing: True # This will save the memory but slower; Even if I have 80GB memory, this is still needed to open; else, OOM
69
+ seed: # 如果这里set seed了;你每次resume都跟resume前的data 读取顺序完全一致;如果连一个epoch都没train,那就每次同样数据循环
70
+ output_folder: checkpoints/
71
+ logging_name: logging
72
+ nccl_timeout: 1800
73
+
74
+
75
+
76
+ # Validation Setting
77
+ validation_step: 2000 # Don't set too frequent, which will be very resource consuming
78
+ first_iter_validation: True # Whether we do the first iter validation
79
+ num_inference_steps: 38
80
+
81
+
82
+
83
+ # Learning Rate and Optimizer
84
+ optimizer: adamw # Choose between ["adam", "adamw", "prodigy"]
85
+ learning_rate: 3e-5 # 1e-4 might be too big
86
+ scale_lr: False
87
+ lr_scheduler: constant_with_warmup # Most cases should be constant
88
+ adam_beta1: 0.9 # This Setting is different from CogVideoX, we follow VideoFun
89
+ adam_beta2: 0.999
90
+ # adam_beta3: 0.98
91
+ lr_power: 1.0
92
+ lr_num_cycles: 1.0
93
+ initial_grad_norm_ratio: 5
94
+ abnormal_norm_clip_start: 1000 # Follow VideoFun
95
+ max_grad_norm: 0.05 # Follow VideoFun
96
+ prodigy_beta3: # Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2
97
+ # use_8bit_adam: False # This saves a lot of GPU memory, but slightly slower
98
+ adam_weight_decay: 1e-4
99
+ adam_epsilon: 1e-10
100
+ lr_warmup_steps: 100
101
+
102
+
103
+
104
+ # Other Setting
105
+ report_to: tensorboard
106
+ allow_tf32: True
107
+ revision:
108
+ variant:
109
+ cache_dir:
110
+ tracker_name:
data_loader/sampler.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Last modified: 2024-04-18
2
+ #
3
+ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ # --------------------------------------------------------------------------
17
+ # If you find this code useful, we kindly ask you to cite our paper in your work.
18
+ # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
19
+ # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
20
+ # More information about the method can be found at https://marigoldmonodepth.github.io
21
+ # --------------------------------------------------------------------------
22
+
23
+ import torch
24
+ from torch.utils.data import (
25
+ BatchSampler,
26
+ RandomSampler,
27
+ SequentialSampler,
28
+ )
29
+
30
+
31
+ class MixedBatchSampler(BatchSampler):
32
+ """Sample one batch from a selected dataset with given probability.
33
+ Compatible with datasets at different resolution
34
+ """
35
+
36
+ def __init__(
37
+ self, src_dataset_ls, batch_size, drop_last, shuffle, prob=None, generator=None
38
+ ):
39
+ self.base_sampler = None
40
+ self.batch_size = batch_size
41
+ self.shuffle = shuffle
42
+ self.drop_last = drop_last
43
+ self.generator = generator
44
+
45
+ self.src_dataset_ls = src_dataset_ls
46
+ self.n_dataset = len(self.src_dataset_ls)
47
+
48
+ # Dataset length
49
+ self.dataset_length = [len(ds) for ds in self.src_dataset_ls]
50
+ self.cum_dataset_length = [
51
+ sum(self.dataset_length[:i]) for i in range(self.n_dataset)
52
+ ] # cumulative dataset length
53
+
54
+ # BatchSamplers for each source dataset
55
+ if self.shuffle:
56
+ self.src_batch_samplers = [
57
+ BatchSampler(
58
+ sampler=RandomSampler(
59
+ ds, replacement=False, generator=self.generator
60
+ ),
61
+ batch_size=self.batch_size,
62
+ drop_last=self.drop_last,
63
+ )
64
+ for ds in self.src_dataset_ls
65
+ ]
66
+ else:
67
+ self.src_batch_samplers = [
68
+ BatchSampler(
69
+ sampler=SequentialSampler(ds),
70
+ batch_size=self.batch_size,
71
+ drop_last=self.drop_last,
72
+ )
73
+ for ds in self.src_dataset_ls
74
+ ]
75
+ self.raw_batches = [
76
+ list(bs) for bs in self.src_batch_samplers
77
+ ] # index in original dataset
78
+ self.n_batches = [len(b) for b in self.raw_batches]
79
+ self.n_total_batch = sum(self.n_batches)
80
+ # sampling probability
81
+ if prob is None:
82
+ # if not given, decide by dataset length
83
+ self.prob = torch.tensor(self.n_batches) / self.n_total_batch
84
+ else:
85
+ self.prob = torch.as_tensor(prob)
86
+
87
+ def __iter__(self):
88
+ """_summary_
89
+
90
+ Yields:
91
+ list(int): a batch of indics, corresponding to ConcatDataset of src_dataset_ls
92
+ """
93
+ for _ in range(self.n_total_batch):
94
+ idx_ds = torch.multinomial(
95
+ self.prob, 1, replacement=True, generator=self.generator
96
+ ).item()
97
+ # if batch list is empty, generate new list
98
+ if 0 == len(self.raw_batches[idx_ds]):
99
+ self.raw_batches[idx_ds] = list(self.src_batch_samplers[idx_ds])
100
+ # get a batch from list
101
+ batch_raw = self.raw_batches[idx_ds].pop()
102
+ # shift by cumulative dataset length
103
+ shift = self.cum_dataset_length[idx_ds]
104
+ batch = [n + shift for n in batch_raw]
105
+
106
+ yield batch
107
+
108
+ def __len__(self):
109
+ return self.n_total_batch
110
+
data_loader/video_dataset_motion.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, shutil
2
+ from typing import List, Optional, Tuple, Union
3
+ from pathlib import Path
4
+ import csv
5
+ import random
6
+ import math
7
+ import numpy as np
8
+ import ffmpeg
9
+ import json
10
+ import imageio
11
+ import collections
12
+ import cv2
13
+ import pdb
14
+ csv.field_size_limit(sys.maxsize) # Default setting is 131072, 100x expand should be enough
15
+
16
+ import torch
17
+ from torch.utils.data import Dataset
18
+ from torchvision import transforms
19
+
20
+ # Import files from the local folder
21
+ root_path = os.path.abspath('.')
22
+ sys.path.append(root_path)
23
+ from utils.optical_flow_utils import flow_to_image, filter_uv, bivariate_Gaussian
24
+
25
+
26
+ # Init paramter and global shared setting
27
+
28
+ # Blurring Kernel
29
+ blur_kernel = bivariate_Gaussian(45, 3, 3, 0, grid = None, isotropic = True)
30
+
31
+ # Color
32
+ all_color_codes = [(255, 0, 0), (255, 255, 0), (0, 255, 0), (0, 255, 255),
33
+ (255, 0, 255), (0, 0, 255), (128, 128, 128), (64, 224, 208),
34
+ (233, 150, 122)]
35
+ for _ in range(100): # Should not be over 100 colors
36
+ all_color_codes.append((random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)))
37
+
38
+ # Data Transforms
39
+ train_transforms = transforms.Compose(
40
+ [
41
+ transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
42
+ ]
43
+ )
44
+
45
+
46
+
47
+ class VideoDataset_Motion(Dataset):
48
+
49
+ def __init__(
50
+ self,
51
+ config,
52
+ download_folder_path,
53
+ csv_relative_path,
54
+ video_relative_path,
55
+ is_diy_test = False,
56
+ ) -> None:
57
+ super().__init__()
58
+
59
+ # Gen Size Settings
60
+ # self.height_range = config["height_range"]
61
+ # self.max_aspect_ratio = config["max_aspect_ratio"]
62
+ self.target_height = config["target_height"]
63
+ self.target_width = config["target_width"]
64
+ self.sample_accelerate_factor = config["sample_accelerate_factor"]
65
+ self.train_frame_num_range = config["train_frame_num_range"]
66
+
67
+ # Condition Settings (Text, Motion, etc.)
68
+ self.empty_text_prompt = config["empty_text_prompt"]
69
+ self.dot_radius = int(config["dot_radius"])
70
+ self.point_keep_ratio = config["point_keep_ratio"] # Point selection mechanism
71
+ self.faster_motion_prob = config["faster_motion_prob"]
72
+
73
+ # Other Settings
74
+ self.download_folder_path = download_folder_path
75
+ self.is_diy_test = is_diy_test
76
+ self.config = config
77
+ self.video_folder_path = os.path.join(download_folder_path, video_relative_path)
78
+ csv_folder_path = os.path.join(download_folder_path, csv_relative_path)
79
+
80
+
81
+ # Sanity Check
82
+ assert(os.path.exists(csv_folder_path))
83
+ assert(self.point_keep_ratio <= 1.0)
84
+
85
+
86
+
87
+ # Read the CSV files
88
+ info_lists = []
89
+ for csv_file_name in os.listdir(csv_folder_path): # Read all csv files
90
+ csv_file_path = os.path.join(csv_folder_path, csv_file_name)
91
+
92
+ with open(csv_file_path) as file_obj:
93
+ reader_obj = csv.reader(file_obj)
94
+
95
+ # Iterate over each row in the csv
96
+ for idx, row in enumerate(reader_obj):
97
+ if idx == 0:
98
+ elements = dict()
99
+ for element_idx, key in enumerate(row):
100
+ elements[key] = element_idx
101
+ continue
102
+
103
+ # Read the important information
104
+ info_lists.append(row)
105
+
106
+ # Organize
107
+ self.info_lists = info_lists
108
+ self.element_idx_dict = elements
109
+
110
+ # Log
111
+ print("The number of videos for ", csv_folder_path, " is ", len(self.info_lists))
112
+ # print("The memory cost is ", sys.getsizeof(self.info_lists))
113
+
114
+
115
+ def __len__(self):
116
+ return len(self.info_lists)
117
+
118
+
119
+ @staticmethod
120
+ def prepare_traj_tensor(full_pred_tracks, original_height, original_width, selected_frames,
121
+ dot_radius, target_width, target_height, idx = 0, first_frame_img = None):
122
+
123
+ # Prepare the color
124
+ target_color_codes = all_color_codes[:len(full_pred_tracks[0])] # This means how many objects in total we have
125
+
126
+ # Prepare the traj image
127
+ traj_img_lists = []
128
+
129
+ # Set a new dot radius based on the resolution fluctuating
130
+ dot_radius_resize = int( dot_radius * original_height / 384 ) # This is set with respect to default 384 height, will be adjust based on the height change
131
+
132
+ # Prepare base draw image if there is
133
+ if first_frame_img is not None:
134
+ img_with_traj = first_frame_img.copy()
135
+
136
+ # Iterate all temporal sequence
137
+ merge_frames = []
138
+ for temporal_idx, points_per_frame in enumerate(full_pred_tracks): # Iterate all downsampled frames, should be 13
139
+
140
+ # Init the base img for the traj figures
141
+ base_img = np.zeros((original_height, original_width, 3)).astype(np.float32) # Use the original image size
142
+ base_img.fill(255) # Whole white frames
143
+
144
+ # Iterate all points in each object
145
+ for obj_idx, points_per_obj in enumerate(points_per_frame):
146
+
147
+ # Basic setting
148
+ color_code = target_color_codes[obj_idx] # Color across frames should be consistent
149
+
150
+ # Process all points in this current object
151
+ for (horizontal, vertical) in points_per_obj:
152
+ if horizontal < 0 or horizontal >= original_width or vertical < 0 or vertical >= original_height:
153
+ continue # If the point is already out of the range, Don't draw
154
+
155
+ # Draw square around the target position
156
+ vertical_start = min(original_height, max(0, vertical - dot_radius_resize))
157
+ vertical_end = min(original_height, max(0, vertical + dot_radius_resize)) # Diameter, used to be 10, but want smaller if there are too many points now
158
+ horizontal_start = min(original_width, max(0, horizontal - dot_radius_resize))
159
+ horizontal_end = min(original_width, max(0, horizontal + dot_radius_resize))
160
+
161
+ # Paint
162
+ base_img[vertical_start:vertical_end, horizontal_start:horizontal_end, :] = color_code
163
+
164
+ # Draw the visual of traj if needed
165
+ if first_frame_img is not None:
166
+ img_with_traj[vertical_start:vertical_end, horizontal_start:horizontal_end, :] = color_code
167
+
168
+ # Resize frames Don't use negative and don't resize in [0,1]
169
+ base_img = cv2.resize(base_img, (target_width, target_height), interpolation = cv2.INTER_CUBIC)
170
+
171
+ # Dilate (Default to be True)
172
+ base_img = cv2.filter2D(base_img, -1, blur_kernel).astype(np.uint8)
173
+
174
+
175
+ # Append selected_frames and the color together for visualization
176
+ if len(selected_frames) != 0:
177
+ merge_frame = selected_frames[temporal_idx].copy()
178
+ merge_frame[base_img < 250] = base_img[base_img < 250]
179
+ merge_frames.append(merge_frame)
180
+ # cv2.imwrite("Video"+str(idx) + "_traj" + str(temporal_idx).zfill(2) + ".png", cv2.cvtColor(merge_frame, cv2.COLOR_RGB2BGR)) # Comment Out Later
181
+
182
+
183
+ # Append to the temporal index
184
+ traj_img_lists.append(base_img)
185
+
186
+
187
+ # Convert to tensor
188
+ traj_imgs_np = np.array(traj_img_lists)
189
+ traj_tensor = torch.tensor(traj_imgs_np)
190
+
191
+ # Transform
192
+ traj_tensor = traj_tensor.float()
193
+ traj_tensor = torch.stack([train_transforms(traj_frame) for traj_frame in traj_tensor], dim=0)
194
+ traj_tensor = traj_tensor.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
195
+
196
+
197
+ # Write to video (Comment Out Later)
198
+ # imageio.mimsave("merge_cond" + str(idx) + ".mp4", merge_frames, fps=12)
199
+
200
+
201
+ # Return
202
+ merge_frames = np.array(merge_frames)
203
+ if first_frame_img is not None:
204
+ return traj_tensor, traj_imgs_np, merge_frames, img_with_traj
205
+ else:
206
+ return traj_tensor, traj_imgs_np, merge_frames # Need to return traj_imgs_np for other purpose
207
+
208
+
209
+
210
+ def __getitem__(self, idx):
211
+
212
+ while True: # Iterate until there is a valid video read
213
+
214
+ # try:
215
+
216
+ # Fetch the information
217
+ info = self.info_lists[idx]
218
+ video_path = os.path.join(self.video_folder_path, info[self.element_idx_dict["video_path"]])
219
+ original_height = int(info[self.element_idx_dict["height"]])
220
+ original_width = int(info[self.element_idx_dict["width"]])
221
+ # num_frames = int(info[self.element_idx_dict["num_frames"]]) # Deprecated, this is about the whole frame duration, not just one
222
+
223
+ valid_duration = json.loads(info[self.element_idx_dict["valid_duration"]])
224
+ All_Frame_Panoptic_Segmentation = json.loads(info[self.element_idx_dict["Panoptic_Segmentation"]])
225
+ text_prompt_all = json.loads(info[self.element_idx_dict["Structured_Text_Prompt"]])
226
+ Track_Traj_all = json.loads(info[self.element_idx_dict["Track_Traj"]]) # NOTE: Same as above, only consider the first panoptic segmented frame
227
+ Obj_Info_all = json.loads(info[self.element_idx_dict["Obj_Info"]])
228
+
229
+
230
+ # Sanity check
231
+ if not os.path.exists(video_path):
232
+ raise Exception("This video path", video_path, "doesn't exists!")
233
+
234
+
235
+ ########################################## Mangage Resolution and selected Clip Setting ##########################################
236
+
237
+ # Option1: Variable Resolution Gen
238
+ # # Check the resolution size
239
+ # aspect_ratio = min(self.max_aspect_ratio, original_width / original_height)
240
+ # target_height_raw = min(original_height, random.randint(*self.height_range))
241
+ # target_width_raw = min(original_width, int(target_height_raw * aspect_ratio))
242
+ # # Must be the multiplier of 32
243
+ # target_height = (target_height_raw // 32) * 32
244
+ # target_width = (target_width_raw // 32) * 32
245
+ # print("New Height and Width are ", target_height, target_width)
246
+
247
+ # Option2: Fixed Resolution Gen (Assume that the provided is 32x valid)
248
+ target_width = self.target_width
249
+ target_height = self.target_height
250
+
251
+
252
+ # Only choose the first clip
253
+ Obj_Info = Obj_Info_all[0] # For the Motion Training, we have enough dataset, so we just choose the first panoptic section
254
+ Track_Traj = Track_Traj_all[0]
255
+ text_prompt = text_prompt_all[0]
256
+ resolution = str(target_width) + "x" + str(target_height) # Used for ffmpeg load
257
+ frame_start_idx = Obj_Info[0][1] # NOTE: If there is multiple objects Obj_Info[X][1] should be the same
258
+
259
+
260
+ ##############################################################################################################################
261
+
262
+
263
+
264
+ ############################################## Read the video by ffmpeg #################################################
265
+
266
+ # Read the video by ffmpeg in the needed decode fps and resolution
267
+ video_stream, err = ffmpeg.input(
268
+ video_path
269
+ ).output(
270
+ "pipe:", format = "rawvideo", pix_fmt = "rgb24", s = resolution, vsync = 'passthrough',
271
+ ).run(
272
+ capture_stdout = True, capture_stderr = True # If there is bug, command capture_stderr
273
+ ) # The resize is already included
274
+ video_np_full = np.frombuffer(video_stream, np.uint8).reshape(-1, target_height, target_width, 3)
275
+
276
+ # Fetch the valid duration
277
+ video_np = video_np_full[valid_duration[0] : valid_duration[1]]
278
+ valid_num_frames = len(video_np) # Update the number of frames
279
+
280
+
281
+ # Decide the accelerate factor
282
+ train_frame_num_raw = random.randint(*self.train_frame_num_range)
283
+ if frame_start_idx + 3 * train_frame_num_raw < valid_num_frames and random.random() < self.faster_motion_prob: # Should be (1) have enough frames and (2) in 10% probability
284
+ sample_accelerate_factor = self.sample_accelerate_factor + 1 # Hard Code
285
+ else:
286
+ sample_accelerate_factor = self.sample_accelerate_factor
287
+
288
+
289
+ # Check the number of frames needed this time
290
+ frame_end_idx = min(valid_num_frames, frame_start_idx + sample_accelerate_factor * train_frame_num_raw)
291
+ frame_end_idx = frame_start_idx + 4 * math.floor(( (frame_end_idx-frame_start_idx) - 1) / 4) + 1 # Rounded to the closest 4N + 1 size
292
+
293
+
294
+ # Select Frames and Convert to Tensor
295
+ selected_frames = video_np[ frame_start_idx : frame_end_idx : sample_accelerate_factor] # NOTE: start from the first frame
296
+ video_tensor = torch.tensor(selected_frames) # Convert to tensor
297
+ first_frame_np = selected_frames[0] # Needs to return for Validation
298
+ train_frame_num = len(video_tensor) # Read the actual number of frames from the video (Must be 4N+1)
299
+
300
+
301
+ # Data transforms and shape organize
302
+ video_tensor = video_tensor.float()
303
+ video_tensor = torch.stack([train_transforms(frame) for frame in video_tensor], dim=0)
304
+ video_tensor = video_tensor.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
305
+
306
+
307
+ #############################################################################################################################
308
+
309
+
310
+
311
+ ######################################### Define the text prompt #######################################################
312
+
313
+ # NOTE: text prompt is fetched above; here, we just decide if we you empty string
314
+ if self.empty_text_prompt or random.random() < self.config["text_mask_ratio"]:
315
+ text_prompt = ""
316
+ # print("Text Prompt for Video", idx, " is ", text_prompt)
317
+
318
+ ########################################################################################################################
319
+
320
+
321
+
322
+ ###################### Prepare the Tracking points for each object (each object has different color) #################################
323
+
324
+ # Iterate all the segmentation info
325
+ full_pred_tracks = [[] for _ in range(train_frame_num)] # The dim should be: (temporal, object, points, xy) The fps should be fixed to 12 fps, which is the same as training decode fps
326
+ for track_obj_idx in range(len(Obj_Info)):
327
+
328
+ # Read the basic info
329
+ text_name, frame_idx_raw = Obj_Info[track_obj_idx] # This is expected to be all the same in the video
330
+
331
+
332
+ # Sanity Check: make sure that the number of frames is consistent
333
+ if track_obj_idx > 0:
334
+ if frame_idx_raw != previous_frame_idx_raw:
335
+ raise Exception("The panoptic_frame_idx cannot pass the sanity check")
336
+
337
+
338
+ # Prepare the tracjectory
339
+ pred_tracks_full = Track_Traj[track_obj_idx]
340
+ pred_tracks = pred_tracks_full[ frame_start_idx : frame_end_idx : sample_accelerate_factor]
341
+ if len(pred_tracks) != train_frame_num:
342
+ raise Exception("The length of tracking images does not match the video GT.")
343
+
344
+
345
+ # Randomly select the points based on the prob given, here, the number of points is different for each objeects
346
+ kept_point_status = random.choices([True, False], weights = [self.point_keep_ratio, 1 - self.point_keep_ratio], k = len(pred_tracks[0]))
347
+ if len(kept_point_status) != len(pred_tracks[-1]):
348
+ raise Exception("The number of points filterred is not match with the dataset")
349
+
350
+
351
+ # Iterate and add all temporally
352
+ for temporal_idx, pred_track in enumerate(pred_tracks):
353
+
354
+ # Iterate all point one by one
355
+ left_points = []
356
+ for point_idx in range(len(pred_track)):
357
+ if kept_point_status[point_idx]:
358
+ left_points.append(pred_track[point_idx])
359
+ # Append the left points to the list
360
+ full_pred_tracks[temporal_idx].append(left_points) # pred_tracks will be 49 frames, and each one represent all tracking points for single objects; only one object here
361
+
362
+
363
+ # Other update
364
+ previous_frame_idx_raw = frame_idx_raw
365
+
366
+
367
+ # Draw the dilated traj points
368
+ traj_tensor, traj_imgs_np, merge_frames = self.prepare_traj_tensor(full_pred_tracks, original_height, original_width, selected_frames,
369
+ self.dot_radius, target_width, target_height, idx)
370
+
371
+ # Sanity Check to make sure that the traj tensor and ground truth has the same number of frames
372
+ if len(traj_tensor) != len(video_tensor): # If this two cannot match, the torch.cat on latents will fail
373
+ raise Exception("Traj length and Video length does not matched!")
374
+
375
+ #########################################################################################################################################
376
+
377
+
378
+ # except Exception as e: # Note: You can uncomment this part to jump failure cases in mass training.
379
+ # print("The exception is ", e)
380
+ # old_idx = idx
381
+ # idx = (idx + 1) % len(self.info_lists)
382
+ # print("We cannot process the video", old_idx, " and we choose a new idx of ", idx)
383
+ # continue # For any error occurs, we run it again with new idx proposed (a random int less than current value)
384
+
385
+
386
+ # If everything is ok, we should break at the end
387
+ break
388
+
389
+
390
+ # Return the information
391
+ return {
392
+ "video_tensor": video_tensor,
393
+ "traj_tensor": traj_tensor,
394
+ "text_prompt": text_prompt,
395
+
396
+ # The rest are auxiliary data for the validation/testing purposes
397
+ "video_gt_np": selected_frames,
398
+ "first_frame_np": first_frame_np,
399
+ "traj_imgs_np": traj_imgs_np,
400
+ "merge_frames": merge_frames,
401
+ "gt_video_path": video_path,
402
+ }
403
+
404
+
405
+
406
+
407
+
data_loader/video_dataset_motion_FrameINO.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, shutil
2
+ from typing import List, Optional, Tuple, Union
3
+ from pathlib import Path
4
+ import csv
5
+ import random
6
+ import numpy as np
7
+ import ffmpeg
8
+ import json
9
+ import imageio
10
+ import collections
11
+ import cv2
12
+ import pdb
13
+ import math
14
+ import PIL.Image as Image
15
+ csv.field_size_limit(sys.maxsize) # Default setting is 131072, 100x expand should be enough
16
+
17
+ import torch
18
+ from torch.utils.data import Dataset
19
+ from torchvision import transforms
20
+
21
+ # Import files from the local folder
22
+ root_path = os.path.abspath('.')
23
+ sys.path.append(root_path)
24
+ from utils.optical_flow_utils import flow_to_image, filter_uv, bivariate_Gaussian
25
+
26
+
27
+ # Init paramter and global shared setting
28
+
29
+ # Blurring Kernel
30
+ blur_kernel = bivariate_Gaussian(45, 3, 3, 0, grid = None, isotropic = True)
31
+
32
+ # Color
33
+ all_color_codes = [(255, 0, 0), (255, 255, 0), (0, 255, 0), (0, 255, 255),
34
+ (255, 0, 255), (0, 0, 255), (128, 128, 128), (64, 224, 208),
35
+ (233, 150, 122)]
36
+ for _ in range(100): # Should not be over 100 colors
37
+ all_color_codes.append((random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)))
38
+
39
+ # Data Transforms
40
+ train_transforms = transforms.Compose(
41
+ [
42
+ transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
43
+ ]
44
+ )
45
+
46
+
47
+ class VideoDataset_Motion_FrameINO(Dataset):
48
+ def __init__(
49
+ self,
50
+ config,
51
+ download_folder_path,
52
+ csv_relative_path,
53
+ video_relative_path,
54
+ ID_relative_path,
55
+ FrameOut_only = False,
56
+ one_point_one_obj = False,
57
+ strict_validation_match = False,
58
+ ) -> None:
59
+ super().__init__()
60
+
61
+ # Gen Size Settings
62
+ # self.height_range = config["height_range"]
63
+ # self.max_aspect_ratio = config["max_aspect_ratio"]
64
+ self.target_height = config["target_height"]
65
+ self.target_width = config["target_width"]
66
+ self.sample_accelerate_factor = config["sample_accelerate_factor"]
67
+ self.train_frame_num_range = config["train_frame_num_range"]
68
+ self.min_train_frame_num = config["min_train_frame_num"]
69
+
70
+
71
+ # Condition Settings (Text, Motion, etc.)
72
+ self.empty_text_prompt = config["empty_text_prompt"]
73
+ self.dot_radius = int(config["dot_radius"])
74
+ self.point_keep_ratio_ID = config["point_keep_ratio_ID"]
75
+ self.point_keep_ratio_regular = config["point_keep_ratio_regular"]
76
+ self.faster_motion_prob = config["faster_motion_prob"]
77
+
78
+ # Other Settings
79
+ self.FrameOut_only = FrameOut_only
80
+ self.one_point_one_obj = one_point_one_obj # Currently, this only open when FrameOut_only = True
81
+ self.strict_validation_match = strict_validation_match
82
+ self.config = config
83
+ self.video_folder_path = os.path.join(download_folder_path, video_relative_path)
84
+ self.ID_folder_path = os.path.join(download_folder_path, ID_relative_path)
85
+ csv_folder_path = os.path.join(download_folder_path, csv_relative_path)
86
+
87
+
88
+ # Sanity Check
89
+ assert(os.path.exists(csv_folder_path))
90
+ assert(self.point_keep_ratio_ID <= 1.0)
91
+ assert(self.point_keep_ratio_regular <= 1.0)
92
+
93
+
94
+ # Read the CSV files
95
+ info_lists = []
96
+ for csv_file_name in os.listdir(csv_folder_path): # Read all csv files
97
+ csv_file_path = os.path.join(csv_folder_path, csv_file_name)
98
+
99
+ with open(csv_file_path) as file_obj:
100
+ reader_obj = csv.reader(file_obj)
101
+
102
+ # Iterate over each row in the csv
103
+ for idx, row in enumerate(reader_obj):
104
+ if idx == 0:
105
+ elements = dict()
106
+ for element_idx, key in enumerate(row):
107
+ elements[key] = element_idx
108
+ continue
109
+
110
+ # Read the important information
111
+ info_lists.append(row)
112
+
113
+ # Organize
114
+ self.info_lists = info_lists
115
+ self.element_idx_dict = elements
116
+
117
+ # Log
118
+ print("The number of videos for ", csv_folder_path, " is ", len(self.info_lists))
119
+ # print("The memory cost is ", sys.getsizeof(self.info_lists))
120
+
121
+
122
+ def __len__(self):
123
+ return len(self.info_lists)
124
+
125
+
126
+ @staticmethod
127
+ def prepare_traj_tensor(full_pred_tracks, original_height, original_width, selected_frames,
128
+ dot_radius, target_width, target_height, region_box, idx = 0, first_frame_img = None):
129
+
130
+ # Prepare the color and other stuff
131
+ target_color_codes = all_color_codes[:len(full_pred_tracks[0])] # This means how many objects in total we have
132
+ (top_left_x, top_left_y), (bottom_right_x, bottom_right_y) = region_box
133
+
134
+ # Prepare the traj image
135
+ traj_img_lists = []
136
+
137
+ # Set a new dot radius based on the resolution fluctuating
138
+ dot_radius_resize = int( dot_radius * original_height / 384 ) # This is set with respect to default 384 height, will be adjust based on the height change
139
+
140
+ # Prepare base draw image if there is
141
+ if first_frame_img is not None:
142
+ img_with_traj = first_frame_img.copy()
143
+
144
+ # Iterate all object instance
145
+ merge_frames = []
146
+ for temporal_idx, obj_points in enumerate(full_pred_tracks): # Iterate all downsampled frames, should be 13
147
+
148
+ # Init the base img for the traj figures
149
+ base_img = np.zeros((original_height, original_width, 3)).astype(np.float32) # Use the original image size
150
+ base_img.fill(255) # Whole white frames
151
+
152
+ # Iterate for the per object
153
+ for obj_idx, points in enumerate(obj_points):
154
+
155
+ # Basic setting
156
+ color_code = target_color_codes[obj_idx] # Color across frames should be consistent
157
+
158
+
159
+ # Process all points in this current object
160
+ for (horizontal, vertical) in points:
161
+ if horizontal < 0 or horizontal >= original_width or vertical < 0 or vertical >= original_height:
162
+ continue # If the point is already out of the range, Don't draw
163
+
164
+ # Draw square around the target position
165
+ vertical_start = min(original_height, max(0, vertical - dot_radius_resize))
166
+ vertical_end = min(original_height, max(0, vertical + dot_radius_resize)) # Diameter, used to be 10, but want smaller if there are too many points now
167
+ horizontal_start = min(original_width, max(0, horizontal - dot_radius_resize))
168
+ horizontal_end = min(original_width, max(0, horizontal + dot_radius_resize))
169
+
170
+ # Paint
171
+ base_img[vertical_start:vertical_end, horizontal_start:horizontal_end, :] = color_code
172
+
173
+ # Draw the visual of traj if needed
174
+ if first_frame_img is not None:
175
+ img_with_traj[vertical_start:vertical_end, horizontal_start:horizontal_end, :] = color_code
176
+
177
+ # Resize frames Don't use negative and don't resize in [0,1]
178
+ base_img = cv2.resize(base_img, (target_width, target_height), interpolation = cv2.INTER_CUBIC)
179
+
180
+ # Dilate (Default to be True)
181
+ base_img = cv2.filter2D(base_img, -1, blur_kernel).astype(np.uint8)
182
+
183
+ # Append selected_frames and the color together for visualization
184
+ merge_frame = selected_frames[temporal_idx].copy()
185
+ merge_frame = cv2.rectangle(merge_frame, (top_left_x, top_left_y), (bottom_right_x, bottom_right_y), (255, 0, 0), 5) # Draw the Region Box Area
186
+ merge_frame[base_img < 250] = base_img[base_img < 250]
187
+ merge_frames.append(merge_frame)
188
+
189
+
190
+ # Append to the temporal index
191
+ traj_img_lists.append(base_img)
192
+
193
+ # Convert to tensor
194
+ traj_imgs_np = np.array(traj_img_lists)
195
+ traj_tensor = torch.tensor(traj_imgs_np)
196
+
197
+ # Transform
198
+ traj_tensor = traj_tensor.float()
199
+ traj_tensor = torch.stack([train_transforms(traj_frame) for traj_frame in traj_tensor], dim=0)
200
+ traj_tensor = traj_tensor.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
201
+
202
+
203
+ # Write to video (For Debug Purpose)
204
+ # imageio.mimsave("merge_cond" + str(idx) + ".mp4", merge_frames, fps=12)
205
+
206
+
207
+
208
+ # Return
209
+ merge_frames = np.array(merge_frames)
210
+ if first_frame_img is not None:
211
+ return traj_tensor, traj_imgs_np, merge_frames, img_with_traj
212
+ else:
213
+ return traj_tensor, traj_imgs_np, merge_frames # Need to return traj_imgs_np for other purpose
214
+
215
+
216
+
217
+ def __getitem__(self, idx):
218
+
219
+ while True: # Iterate until there is a valid video read
220
+
221
+ # try:
222
+
223
+ # Fetch the information
224
+ info = self.info_lists[idx]
225
+ video_path = os.path.join(self.video_folder_path, info[self.element_idx_dict["video_path"]])
226
+ original_height = int(info[self.element_idx_dict["height"]])
227
+ original_width = int(info[self.element_idx_dict["width"]])
228
+ # num_frames = int(info[self.element_idx_dict["num_frames"]]) # Deprecated, this is about the whole frame duration, not just one
229
+
230
+ valid_duration = json.loads(info[self.element_idx_dict["valid_duration"]])
231
+ All_Frame_Panoptic_Segmentation = json.loads(info[self.element_idx_dict["Panoptic_Segmentation"]])
232
+ text_prompt_all = json.loads(info[self.element_idx_dict["Structured_Text_Prompt"]])
233
+ Track_Traj_all = json.loads(info[self.element_idx_dict["Track_Traj"]])
234
+ Obj_Info_all = json.loads(info[self.element_idx_dict["Obj_Info"]])
235
+ ID_info_all = json.loads(info[self.element_idx_dict["ID_info"]]) # New elements compared to motion data loader
236
+
237
+
238
+ # Sanity check
239
+ if not os.path.exists(video_path):
240
+ raise Exception("This video path", video_path, "doesn't exists!")
241
+
242
+
243
+ ########################################## Mangage Resolution and selected Clip Setting ##########################################
244
+
245
+ # Option1: Variable Resolution Gen
246
+ # # Check the resolution size
247
+ # aspect_ratio = min(self.max_aspect_ratio, original_width / original_height)
248
+ # target_height_raw = min(original_height, random.randint(*self.height_range))
249
+ # target_width_raw = min(original_width, int(target_height_raw * aspect_ratio))
250
+ # # Must be the multiplier of 32
251
+ # target_height = (target_height_raw // 32) * 32
252
+ # target_width = (target_width_raw // 32) * 32
253
+ # print("New Height and Width are ", target_height, target_width)
254
+
255
+ # Option2: Fixed Resolution Gen (Assume that the provided is 32x valid)
256
+ target_width = self.target_width
257
+ target_height = self.target_height
258
+
259
+
260
+ # NOTE: Here, we only choose the first Panoptic choice, to avoid multiple panoptic choices.
261
+ Obj_Info = Obj_Info_all[0] # For panoptic Segmentation
262
+ Track_Traj = Track_Traj_all[0]
263
+ text_prompt = text_prompt_all[0]
264
+ ID_info = ID_info_all[0] # For Frame In ID information, Just one Panoptic Frame
265
+ resolution = str(target_width) + "x" + str(target_height)
266
+ frame_start_idx = Obj_Info[0][1] # NOTE: If there is multiple objects Obj_Info[X][1] should be the same
267
+
268
+
269
+ ##############################################################################################################################
270
+
271
+
272
+
273
+ #################################################### Fetch FrameIn ID information ###############################################################
274
+
275
+ # FrameIn drop
276
+ if self.FrameOut_only or random.random() < self.config["drop_FrameIn_prob"]:
277
+ drop_FrameIn = True
278
+ else:
279
+ drop_FrameIn = False
280
+
281
+ # Not all objects is ideal FrameIn, we need to select
282
+ if not self.strict_validation_match:
283
+ effective_ID_idxs = []
284
+ for ID_idx, ID_Info_obj in enumerate(ID_info):
285
+ if ID_Info_obj != []:
286
+ effective_ID_idxs.append(ID_idx)
287
+ main_target_ID_idx = random.choice(effective_ID_idxs) # NOTE: I think we should only has one object to be processed for now
288
+ else:
289
+ main_target_ID_idx = 0 # Always choose the first one
290
+
291
+ # Fetch the FrameIn ID info
292
+ segmentation_info, useful_region_box = ID_info[main_target_ID_idx] # There might be multiple objects ideal, but we just randomly choose one
293
+ if not self.FrameOut_only:
294
+ _, first_frame_reference_path, _ = segmentation_info # bbox_info, first_frame_reference_path, store_img_path_lists
295
+ first_frame_reference_path = os.path.join(self.ID_folder_path, first_frame_reference_path)
296
+ if not os.path.exists(first_frame_reference_path):
297
+ raise Exception("Cannot find ID path", first_frame_reference_path)
298
+ ##################################################################################################################################################
299
+
300
+
301
+
302
+ ################ Randomly choose one mask inside the multiple choice available (Resolution is respect to the origional resolution) #################
303
+
304
+ # Choose one region box
305
+ useful_region_box.sort(key=lambda x: x[0]) # Sort based on the BBox size
306
+ if not self.strict_validation_match:
307
+ mask_region = random.choice(useful_region_box[-5:])[1:] # Choose among the largest 5 BBox available
308
+ else:
309
+ mask_region = useful_region_box[-1][1:] # Choose the last one
310
+
311
+ # Fetch
312
+ (top_left_x_raw, top_left_y_raw), (bottom_right_x_raw, bottom_right_y_raw) = mask_region # As Original Resolution
313
+
314
+ # Resize the mask based on the CURRENT Target resolution (现在的384x480的resolution了)
315
+ top_left_x = int(top_left_x_raw * target_width / original_width)
316
+ top_left_y = int(top_left_y_raw * target_height / original_height)
317
+ bottom_right_x = int(bottom_right_x_raw * target_width / original_width)
318
+ bottom_right_y = int(bottom_right_y_raw * target_height / original_height)
319
+ resized_mask_region_box = (top_left_x, top_left_y), (bottom_right_x, bottom_right_y)
320
+
321
+
322
+ ###################################################################################################################################################
323
+
324
+
325
+
326
+ ################################################ Read the video by ffmpeg #########################################################################
327
+
328
+ # Read the video by ffmpeg in the needed decode fps and resolution
329
+ video_stream, err = ffmpeg.input(
330
+ video_path
331
+ ).output(
332
+ "pipe:", format = "rawvideo", pix_fmt = "rgb24", s = resolution, vsync = 'passthrough',
333
+ ).run(
334
+ capture_stdout = True, capture_stderr = True # If there is bug, command capture_stderr
335
+ ) # The resize is already included
336
+ video_np_full = np.frombuffer(video_stream, np.uint8).reshape(-1, target_height, target_width, 3)
337
+
338
+ # Fetch the valid duration
339
+ video_np = video_np_full[valid_duration[0] : valid_duration[1]]
340
+ valid_num_frames = len(video_np) # Update the number of frames
341
+
342
+
343
+ # Decide the accelerate factor
344
+ train_frame_num_raw = random.randint(*self.train_frame_num_range)
345
+ if frame_start_idx + 3 * train_frame_num_raw < valid_num_frames and random.random() < self.faster_motion_prob: # Should be (1) have enough frames and (2) in 10% probability
346
+ sample_accelerate_factor = self.sample_accelerate_factor + 1 # Hard Code
347
+ else:
348
+ sample_accelerate_factor = self.sample_accelerate_factor
349
+
350
+ # Check the number of frames needed this time
351
+ frame_end_idx = min(valid_num_frames, frame_start_idx + sample_accelerate_factor * train_frame_num_raw)
352
+ frame_end_idx = frame_start_idx + 4 * math.floor(( (frame_end_idx-frame_start_idx) - 1) / 4) + 1 # Rounded to the closest 4N + 1 size
353
+
354
+
355
+ # Select Frames based on the start and end idx; then, Convert to Tensor
356
+ selected_frames = video_np[ frame_start_idx : frame_end_idx : sample_accelerate_factor] # NOTE: start from the first frame
357
+ if len(selected_frames) < self.min_train_frame_num:
358
+ print(len(selected_frames), len(video_np), frame_start_idx, frame_end_idx, sample_accelerate_factor)
359
+ raise Exception(f"selected_frames is less than {self.min_train_frame_num} frames preset! We jump to the next valid one!") # 我这里让Number of Frames Exactly = 49
360
+ video_tensor = torch.tensor(selected_frames) # Convert to tensor
361
+ train_frame_num = len(video_tensor) # Read the actual number of frames from the video (Must be 4N+1)
362
+ # print("Number of frames is", train_frame_num)
363
+
364
+
365
+ # Data transforms and shape organize
366
+ video_tensor = video_tensor.float()
367
+ video_tensor = torch.stack([train_transforms(frame) for frame in video_tensor], dim=0)
368
+ video_tensor = video_tensor.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
369
+
370
+
371
+ # Crop the tensor with all Non-interest region becomes blank(black-0 value); The region is target resolution in training with VAE step size adjustment
372
+ video_np_masked = np.zeros(selected_frames.shape, dtype = np.uint8)
373
+ video_np_masked[:, top_left_y:bottom_right_y, top_left_x:bottom_right_x, :] = selected_frames[:, top_left_y:bottom_right_y, top_left_x:bottom_right_x, :]
374
+
375
+
376
+ # Decide the first frame with the masked one instead of the full one.
377
+ first_frame_np = video_np_masked[0] # Needs to return for Validation
378
+ # cv2.imwrite("first_frame"+str(idx)+".png", cv2.cvtColor(first_frame_np, cv2.COLOR_BGR2RGB)) # Comment Out Later
379
+
380
+ # Convert to Tensor and then Transforms
381
+ first_frame_tensor = torch.tensor(first_frame_np)
382
+ first_frame_tensor = train_transforms(first_frame_tensor).permute(2, 0, 1).contiguous()
383
+
384
+ #########################################################################################################################################
385
+
386
+
387
+
388
+ ############################################# Define the text prompt #######################################################
389
+
390
+ # NOTE: text prompt 上面已经extract好了,这里就是看到底要不要设置为empty的case
391
+ if self.empty_text_prompt or random.random() < self.config["text_mask_ratio"]:
392
+ text_prompt = ""
393
+ # print("Text Prompt for Video", idx, " is ", text_prompt) # Comment Out Later
394
+
395
+ #############################################################################################################################
396
+
397
+
398
+
399
+ ########################### Prepare the Tracking points for each object (each object has different color) #################################
400
+
401
+ # Iterate all the Segmentation Info
402
+ full_pred_tracks = [[] for _ in range(train_frame_num)] # The dim should be: (temporal, object, points, xy) The fps should be fixed to 12 fps, which is the same as training decode fps
403
+ for track_obj_idx in range(len(Obj_Info)):
404
+
405
+ # Read the basic info
406
+ text_name, frame_idx_raw = Obj_Info[track_obj_idx] # This is expected to be all the same in the video
407
+
408
+ # Sanity Check: make sure that the number of frames is consistent
409
+ if track_obj_idx > 0:
410
+ if frame_idx_raw != previous_frame_idx_raw:
411
+ raise Exception("The panoptic_frame_idx cannot pass the sanity check")
412
+
413
+
414
+ # Prepare the tracjectory
415
+ pred_tracks_full = Track_Traj[track_obj_idx]
416
+ pred_tracks = pred_tracks_full[ frame_start_idx : frame_end_idx : sample_accelerate_factor]
417
+ if len(pred_tracks) != train_frame_num:
418
+ raise Exception("The length of tracking images does not match the video GT.")
419
+
420
+
421
+ # Here is FrameINO special Setting on Kept Point Setting: For Non-main obj idx, we must ensure all points inside the region box; If it is main obj, the ID must be outside the region box
422
+ if track_obj_idx != main_target_ID_idx or self.FrameOut_only: # Non-main obj (Usually, for Frame Out cases)
423
+
424
+ # Randomly select the points based on the prob given, here, the number of points is different for each objeects
425
+ kept_point_status = random.choices([True, False], weights = [self.point_keep_ratio_regular, 1 - self.point_keep_ratio_regular], k = len(pred_tracks[0]))
426
+
427
+ # Check if point of the object is within the first frame; No need to check for following frames (allowed to have FrameOut effect)
428
+ first_frame_points = pred_tracks[0]
429
+ for point_idx in range(len(first_frame_points)):
430
+ (horizontal, vertical) = first_frame_points[point_idx]
431
+ if horizontal < top_left_x_raw or horizontal >= bottom_right_x_raw or vertical < top_left_y_raw or vertical >= bottom_right_y_raw: # Whether Outside the BBox region
432
+ kept_point_status[point_idx] = False
433
+
434
+ else: # For main object
435
+
436
+ # Randomly select the points based on the prob given, here, the number of points is different for each objeects
437
+ if drop_FrameIn:
438
+ # No motion provided on ID for Drop FrameIn cases
439
+ kept_point_status = random.choices([False], k = len(pred_tracks[0]))
440
+
441
+ else: # Regular FrameIn case
442
+ kept_point_status = random.choices([True, False], weights = [self.point_keep_ratio_ID, 1 - self.point_keep_ratio_ID], k = len(pred_tracks[0]))
443
+
444
+
445
+ # Sanity Check
446
+ if len(kept_point_status) != len(pred_tracks[-1]):
447
+ raise Exception("The number of points filterred does not match with the dataset")
448
+
449
+
450
+ # Iterate and add all temporally
451
+ for temporal_idx, pred_track in enumerate(pred_tracks): # The length = number of frames
452
+
453
+ # Iterate all point one by one
454
+ left_points = []
455
+ for point_idx in range(len(pred_track)):
456
+ # Select kept points
457
+ if kept_point_status[point_idx]:
458
+ left_points.append(pred_track[point_idx])
459
+
460
+ # Append the left points to the list
461
+ full_pred_tracks[temporal_idx].append(left_points) # pred_tracks will be 49 frames, and each one represent all tracking points for single objects; only one object here
462
+
463
+ # Other update
464
+ previous_frame_idx_raw = frame_idx_raw
465
+
466
+
467
+ # Fetch One Point
468
+ if self.one_point_one_obj:
469
+ one_track_point = []
470
+ for full_pred_track_per_frame in full_pred_tracks:
471
+ one_track_point.append( [[full_pred_track_per_frame[0][0]]])
472
+
473
+ #######################################################################################################################################
474
+
475
+
476
+
477
+ ############################### Process the Video Tensor (based on info fetched from traj) ############################################
478
+
479
+
480
+ if drop_FrameIn:
481
+
482
+ ID_img = np.uint8(np.zeros((target_height, target_width, 3))) # Whole Black (0-value) pixel placeholder
483
+
484
+ else:
485
+
486
+ # Fetch the reference and resize
487
+ ID_img = np.asarray(Image.open(first_frame_reference_path))
488
+
489
+ # Resize to the same size as the video
490
+ ref_h, ref_w = ID_img.shape[:2]
491
+ scale_h = target_height / max(ref_h, ref_w)
492
+ scale_w = target_width / max(ref_h, ref_w)
493
+ new_h, new_w = int(ref_h * scale_h), int(ref_w * scale_w)
494
+ ID_img = cv2.resize(ID_img, (new_w, new_h), interpolation = cv2.INTER_AREA)
495
+
496
+ # Calculate padding amounts on all direction
497
+ pad_height1 = (target_height - ID_img.shape[0]) // 2
498
+ pad_height2 = target_height - ID_img.shape[0] - pad_height1
499
+ pad_width1 = (target_width - ID_img.shape[1]) // 2
500
+ pad_width2 = target_width - ID_img.shape[1] - pad_width1
501
+
502
+ # Apply padding to same resolution as the training farmes
503
+ ID_img = np.pad(
504
+ ID_img,
505
+ ((pad_height1, pad_height2), (pad_width1, pad_width2), (0, 0)),
506
+ mode = 'constant',
507
+ constant_values = 0
508
+ )
509
+
510
+ # Visualize; Comment Out Later
511
+ # cv2.imwrite("ID_img_padded"+str(idx)+".png", cv2.cvtColor(ID_img, cv2.COLOR_BGR2RGB))
512
+
513
+
514
+ # Convert to tensor (Same as others)
515
+ ID_tensor = torch.tensor(ID_img)
516
+ ID_tensor = train_transforms(ID_tensor).permute(2, 0, 1).contiguous()
517
+
518
+ #######################################################################################################################################
519
+
520
+
521
+
522
+ ############################################## Draw the Traj Points and Transform to Tensor #############################################
523
+
524
+ # Draw the dilated points
525
+ if self.one_point_one_obj:
526
+ target_pred_tracks = one_track_point # For this case, we only has one point per one object
527
+ else:
528
+ target_pred_tracks = full_pred_tracks
529
+
530
+ traj_tensor, traj_imgs_np, merge_frames = self.prepare_traj_tensor(target_pred_tracks, original_height, original_width, selected_frames,
531
+ self.dot_radius, target_width, target_height, resized_mask_region_box, idx)
532
+
533
+ # Sanity Check to make sure that the traj tensor and ground truth has the same number of frames
534
+ if len(traj_tensor) != len(video_tensor): # If this two cannot match, the torch.cat on latents will fail
535
+ raise Exception("Traj length and Video length does not matched!")
536
+
537
+ #########################################################################################################################################
538
+
539
+
540
+ # Write some processed meta data
541
+ processed_meta_data = {
542
+ "full_pred_tracks": full_pred_tracks,
543
+ "original_width": original_width,
544
+ "original_height": original_height,
545
+ "mask_region": mask_region,
546
+ "resized_mask_region_box": resized_mask_region_box,
547
+ }
548
+
549
+ # except Exception as e: # Note: You can uncomment this part to jump failure cases in mass training.
550
+ # print("The exception is ", e)
551
+ # old_idx = idx
552
+ # idx = (idx + 1) % len(self.info_lists)
553
+ # print("We cannot process the video", old_idx, " and we choose a new idx of ", idx)
554
+ # continue # For any error occurs, we run it again with new idx proposed (a random int less than current value)
555
+
556
+
557
+ # If everything is ok, we should break at the end
558
+ break
559
+
560
+
561
+ # Return the information
562
+ return {
563
+ "video_tensor": video_tensor,
564
+ "traj_tensor": traj_tensor,
565
+ "first_frame_tensor": first_frame_tensor,
566
+ "ID_tensor": ID_tensor,
567
+ "text_prompt": text_prompt,
568
+
569
+ # The rest are auxiliary data for the validation/testing purposes
570
+ "video_gt_np": selected_frames,
571
+ "first_frame_np": first_frame_np,
572
+ "ID_np": ID_img,
573
+ "processed_meta_data": processed_meta_data,
574
+ "traj_imgs_np": traj_imgs_np,
575
+ "merge_frames" : merge_frames,
576
+ "gt_video_path": video_path,
577
+ }
578
+
data_loader/video_dataset_motion_FrameINO_old.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, shutil
2
+ from typing import List, Optional, Tuple, Union
3
+ from pathlib import Path
4
+ import csv
5
+ import random
6
+ import numpy as np
7
+ import ffmpeg
8
+ import json
9
+ import imageio
10
+ import collections
11
+ import cv2
12
+ import pdb
13
+ import math
14
+ import PIL.Image as Image
15
+ csv.field_size_limit(13107200) # Default setting is 131072, 100x expand should be enough
16
+
17
+ import torch
18
+ from torch.utils.data import Dataset
19
+ from torchvision import transforms
20
+
21
+ # Import files from the local folder
22
+ root_path = os.path.abspath('.')
23
+ sys.path.append(root_path)
24
+ from utils.optical_flow_utils import flow_to_image, filter_uv, bivariate_Gaussian
25
+
26
+ # Init paramter and global shared setting
27
+
28
+ # Blurring Kernel
29
+ blur_kernel = bivariate_Gaussian(45, 3, 3, 0, grid = None, isotropic = True)
30
+
31
+ # Color
32
+ all_color_codes = [(255, 0, 0), (255, 255, 0), (0, 255, 0), (0, 255, 255),
33
+ (255, 0, 255), (0, 0, 255), (128, 128, 128), (64, 224, 208),
34
+ (233, 150, 122)]
35
+ for _ in range(100): # Should not be over 100 colors
36
+ all_color_codes.append((random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)))
37
+
38
+ # Data Transforms
39
+ train_transforms = transforms.Compose(
40
+ [
41
+ transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
42
+ ]
43
+ )
44
+
45
+
46
+ class VideoDataset_Motion_FrameINO(Dataset):
47
+ def __init__(
48
+ self,
49
+ config,
50
+ csv_folder_path,
51
+ FrameOut_only = False,
52
+ one_point_one_obj = False,
53
+ strict_validation_match = False,
54
+ ) -> None:
55
+ super().__init__()
56
+
57
+ # Fetch the Fundamental Setting
58
+ self.dataset_folder_path = config["dataset_folder_path"]
59
+ if not FrameOut_only: # Frame In mode
60
+ self.ID_folder_path = config["ID_folder_path"]
61
+ self.target_height = config["height"]
62
+ self.target_width = config["width"]
63
+ # self.ref_cond_size = config["ref_cond_size"]
64
+ self.preset_decode_fps = config["preset_decode_fps"] # Set to be 16
65
+ self.train_frame_num = config["train_frame_num"]
66
+ self.empty_text_prompt = config["empty_text_prompt"]
67
+ self.start_skip = config["start_skip"]
68
+ self.end_skip = config["end_skip"]
69
+ self.dot_radius = int(config["dot_radius"]) # Set to be 6
70
+ self.point_keep_ratio_ID = config["point_keep_ratio_ID"]
71
+ self.point_keep_ratio_regular = config["point_keep_ratio_regular"]
72
+ self.faster_motion_prob = config["faster_motion_prob"]
73
+ self.FrameOut_only = FrameOut_only
74
+ self.one_point_one_obj = one_point_one_obj # Currently, this only open when FrameOut_only = True
75
+ self.strict_validation_match = strict_validation_match
76
+ self.config = config
77
+
78
+ # Sanity Check
79
+ assert(self.point_keep_ratio_ID <= 1.0)
80
+ assert(self.point_keep_ratio_regular <= 1.0)
81
+
82
+
83
+ # Read the CSV files
84
+ info_lists = []
85
+ for csv_file_name in os.listdir(csv_folder_path): # Read all csv files
86
+ csv_file_path = os.path.join(csv_folder_path, csv_file_name)
87
+ with open(csv_file_path) as file_obj:
88
+ reader_obj = csv.reader(file_obj)
89
+
90
+ # Iterate over each row in the csv
91
+ for idx, row in enumerate(reader_obj):
92
+ if idx == 0:
93
+ elements = dict()
94
+ for element_idx, key in enumerate(row):
95
+ elements[key] = element_idx
96
+ continue
97
+
98
+ # Read the important information
99
+ info_lists.append(row)
100
+
101
+
102
+ # Organize
103
+ self.info_lists = info_lists
104
+ self.element_idx_dict = elements
105
+
106
+ # Log
107
+ print("The number of videos for ", csv_folder_path, " is ", len(self.info_lists))
108
+ # print("The memory cost is ", sys.getsizeof(self.info_lists))
109
+
110
+
111
+ def __len__(self):
112
+ return len(self.info_lists)
113
+
114
+
115
+ @staticmethod
116
+ def prepare_traj_tensor(full_pred_tracks, original_height, original_width, selected_frames,
117
+ dot_radius, target_width, target_height, region_box, idx = 0, first_frame_img = None):
118
+
119
+ # Prepare the color and other stuff
120
+ target_color_codes = all_color_codes[:len(full_pred_tracks[0])] # This means how many objects in total we have
121
+ (top_left_x, top_left_y), (bottom_right_x, bottom_right_y) = region_box
122
+
123
+ # Prepare the traj image
124
+ traj_img_lists = []
125
+
126
+ # Set a new dot radius based on the resolution fluctuating
127
+ dot_radius_resize = int( dot_radius * original_height / 384 ) # This is set with respect to default 384 height, will be adjust based on the height change
128
+
129
+ # Prepare base draw image if there is
130
+ if first_frame_img is not None:
131
+ img_with_traj = first_frame_img.copy()
132
+
133
+ # Iterate all object instance
134
+ merge_frames = []
135
+ for temporal_idx, obj_points in enumerate(full_pred_tracks): # Iterate all downsampled frames, should be 13
136
+
137
+ # Init the base img for the traj figures
138
+ base_img = np.zeros((original_height, original_width, 3)).astype(np.float32) # Use the original image size
139
+ base_img.fill(255) # Whole white frames
140
+
141
+ # Iterate for the per object
142
+ for obj_idx, points in enumerate(obj_points):
143
+
144
+ # Basic setting
145
+ color_code = target_color_codes[obj_idx] # Color across frames should be consistent
146
+
147
+
148
+ # Process all points in this current object
149
+ for (horizontal, vertical) in points:
150
+ if horizontal < 0 or horizontal >= original_width or vertical < 0 or vertical >= original_height:
151
+ continue # If the point is already out of the range, Don't draw
152
+
153
+ # Draw square around the target position
154
+ vertical_start = min(original_height, max(0, vertical - dot_radius_resize))
155
+ vertical_end = min(original_height, max(0, vertical + dot_radius_resize)) # Diameter, used to be 10, but want smaller if there are too many points now
156
+ horizontal_start = min(original_width, max(0, horizontal - dot_radius_resize))
157
+ horizontal_end = min(original_width, max(0, horizontal + dot_radius_resize))
158
+
159
+ # Paint
160
+ base_img[vertical_start:vertical_end, horizontal_start:horizontal_end, :] = color_code
161
+
162
+ # Draw the visual of traj if needed
163
+ if first_frame_img is not None:
164
+ img_with_traj[vertical_start:vertical_end, horizontal_start:horizontal_end, :] = color_code
165
+
166
+ # Resize frames Don't use negative and don't resize in [0,1]
167
+ base_img = cv2.resize(base_img, (target_width, target_height), interpolation = cv2.INTER_CUBIC)
168
+
169
+ # Dilate (Default to be True)
170
+ base_img = cv2.filter2D(base_img, -1, blur_kernel).astype(np.uint8)
171
+
172
+ # Append selected_frames and the color together for visualization
173
+ merge_frame = selected_frames[temporal_idx].copy()
174
+ merge_frame = cv2.rectangle(merge_frame, (top_left_x, top_left_y), (bottom_right_x, bottom_right_y), (255, 0, 0), 5) # Draw the Region Box Area
175
+ merge_frame[base_img < 250] = base_img[base_img < 250]
176
+ merge_frames.append(merge_frame)
177
+
178
+
179
+ # Append to the temporal index
180
+ traj_img_lists.append(base_img)
181
+
182
+ # Convert to tensor
183
+ traj_imgs_np = np.array(traj_img_lists)
184
+ traj_tensor = torch.tensor(traj_imgs_np)
185
+
186
+ # Transform
187
+ traj_tensor = traj_tensor.float()
188
+ traj_tensor = torch.stack([train_transforms(traj_frame) for traj_frame in traj_tensor], dim=0)
189
+ traj_tensor = traj_tensor.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
190
+
191
+
192
+ # Write to video (For Debug Purpose)
193
+ # imageio.mimsave("merge_cond" + str(idx) + ".mp4", merge_frames, fps=12)
194
+
195
+
196
+ # Return
197
+ merge_frames = np.array(merge_frames)
198
+ if first_frame_img is not None:
199
+ return traj_tensor, traj_imgs_np, merge_frames, img_with_traj
200
+ else:
201
+ return traj_tensor, traj_imgs_np, merge_frames # Need to return traj_imgs_np for other purpose
202
+
203
+
204
+
205
+
206
+ def __getitem__(self, idx):
207
+
208
+ while True: # Iterate until there is a valid video read
209
+
210
+ try:
211
+
212
+ # Fetch the information
213
+ info = self.info_lists[idx]
214
+ video_path = os.path.join(self.dataset_folder_path, info[self.element_idx_dict["video_path"]])
215
+ original_height = int(info[self.element_idx_dict["height"]])
216
+ original_width = int(info[self.element_idx_dict["width"]])
217
+ num_frames = int(info[self.element_idx_dict["num_frames"]])
218
+ fps = float(info[self.element_idx_dict["fps"]])
219
+
220
+ # Fetch all panoptic frames
221
+ FrameIN_info_all = json.loads(info[self.element_idx_dict["FrameIN_info"]])
222
+ Track_Traj_all = json.loads(info[self.element_idx_dict["Track_Traj"]])
223
+ text_prompt_all = json.loads(info[self.element_idx_dict["Improved_Text_Prompt"]])
224
+ ID_info_all = json.loads(info[self.element_idx_dict["ID_info"]])
225
+
226
+
227
+ # Randomly Choose one available
228
+ panoptic_idx = random.choice(range(len(FrameIN_info_all)))
229
+ FrameIN_info = FrameIN_info_all[panoptic_idx]
230
+ Track_Traj = Track_Traj_all[panoptic_idx]
231
+ text_prompt = text_prompt_all[panoptic_idx]
232
+ ID_info_panoptic = ID_info_all[panoptic_idx]
233
+
234
+
235
+ # Organize
236
+ resolution = str(self.target_width) + "x" + str(self.target_height)
237
+ fps_scale = self.preset_decode_fps / fps
238
+ downsample_num_frames = int(num_frames * fps_scale)
239
+
240
+
241
+ # FrameIn drop
242
+ if self.FrameOut_only or random.random() < self.config["drop_FrameIn_prob"]:
243
+ drop_FrameIn = True
244
+ else:
245
+ drop_FrameIn = False
246
+
247
+
248
+
249
+ # Sanity check
250
+ if not os.path.exists(video_path):
251
+ raise Exception("This video path ", video_path, " doesn't exists!")
252
+
253
+
254
+ # Not all objects is ideal FrameIn, we need to select
255
+ if not self.strict_validation_match:
256
+ effective_obj_idxs = []
257
+ for obj_idx, obj_info in enumerate(ID_info_panoptic):
258
+ if obj_info != []:
259
+ effective_obj_idxs.append(obj_idx)
260
+ main_target_obj_idx = random.choice(effective_obj_idxs) # NOTE: I think we should only has one object to be processed for now
261
+ else:
262
+ main_target_obj_idx = 0 # Always choose the first one
263
+
264
+ #################################################### Fetch FrameIn ID information ###############################################################
265
+
266
+ # Fetch the FrameIn ID info
267
+ segmentation_info, useful_region_box = ID_info_panoptic[main_target_obj_idx] # There might be multiple objects ideal, but we just randomly choose one
268
+ if not self.FrameOut_only:
269
+ _, first_frame_reference_path, _ = segmentation_info # bbox_info, first_frame_reference_path, store_img_path_lists
270
+ first_frame_reference_path = os.path.join(self.ID_folder_path, first_frame_reference_path)
271
+
272
+ ##################################################################################################################################################
273
+
274
+
275
+
276
+ ############ Randomly choose one mask inside the multiple choice available (Resolution is respect to the origional resolution) ############
277
+ useful_region_box.sort(key=lambda x: x[0])
278
+
279
+ # Choose one region box
280
+ if not self.strict_validation_match:
281
+ mask_region = random.choice(useful_region_box[-5:])[1:] # Choose in the largest 5 available
282
+ else:
283
+ mask_region = useful_region_box[-1][1:] # Choose the last one
284
+
285
+ # Fetch
286
+ (top_left_x_raw, top_left_y_raw), (bottom_right_x_raw, bottom_right_y_raw) = mask_region # As Original Resolution
287
+
288
+ # Resize the mask based on the CURRENT Target resolution (现在的384x480的resolution了)
289
+ top_left_x = int(top_left_x_raw * self.target_width / original_width)
290
+ top_left_y = int(top_left_y_raw * self.target_height / original_height)
291
+ bottom_right_x = int(bottom_right_x_raw * self.target_width / original_width)
292
+ bottom_right_y = int(bottom_right_y_raw * self.target_height / original_height)
293
+ resized_mask_region_box = (top_left_x, top_left_y), (bottom_right_x, bottom_right_y)
294
+
295
+ ###########################################################################################################################################
296
+
297
+
298
+
299
+ ############################################## Read the video by ffmpeg #############################################################
300
+
301
+ # Read the video by ffmpeg in the needed decode fps and resolution
302
+ video_stream, err = ffmpeg.input(
303
+ video_path
304
+ ).filter(
305
+ 'fps', fps = self.preset_decode_fps, round = 'up'
306
+ ).output(
307
+ "pipe:", format = "rawvideo", pix_fmt = "rgb24", s = resolution
308
+ ).run(
309
+ capture_stdout = True, capture_stderr = True
310
+ ) # The resize is already included
311
+ video_np_raw = np.frombuffer(video_stream, np.uint8).reshape(-1, self.target_height, self.target_width, 3)
312
+
313
+ # Sanity Check
314
+ if len(video_np_raw) - self.start_skip - self.end_skip < self.train_frame_num:
315
+ raise Exception("The number of frames from the video is not enough")
316
+
317
+ # Crop the tensor with all Non-interest region becomes blank(black-0 value); The region is target resolution in training with VAE step size adjustment
318
+ video_np_masked = np.zeros(video_np_raw.shape, dtype = np.uint8)
319
+ video_np_masked[:, top_left_y:bottom_right_y, top_left_x:bottom_right_x, :] = video_np_raw[:, top_left_y:bottom_right_y, top_left_x:bottom_right_x, :]
320
+
321
+ #########################################################################################################################################
322
+
323
+
324
+
325
+ ######################################### Define the text prompt #######################################################
326
+
327
+ # Whether empty text prompt; Text Prompt already exists above
328
+ if self.empty_text_prompt or random.random() < self.config["text_mask_ratio"]:
329
+ text_prompt = ""
330
+
331
+ ########################################################################################################################
332
+
333
+
334
+
335
+ ###################### Prepare the Tracking points for each object (each object has different color) #################################
336
+
337
+ # Make sure that the frame from the FrameIN_info has enough number of frames
338
+ _, original_start_frame_idx, fps_scale = FrameIN_info[main_target_obj_idx] # This is expected to be all the same in the video
339
+ downsample_start_frame_idx = max(0, int(original_start_frame_idx * fps_scale))
340
+
341
+
342
+ # Check the max number of frames available (NOTE: Recommended to use Full Text Prompt Version)
343
+ max_step_num = (downsample_num_frames - downsample_start_frame_idx) // self.train_frame_num
344
+ if max_step_num == 0:
345
+ print("This video is ", video_path)
346
+ raise Exception("The video is too short!")
347
+ elif max_step_num >= 2 and random.random() < self.faster_motion_prob:
348
+ iter_gap = 2 # Maximum Setting now is 2x; else, the VAE might not works well
349
+ else:
350
+ iter_gap = 1
351
+
352
+
353
+ # Iterate all the Segmentation Info
354
+ full_pred_tracks = [[] for _ in range(self.train_frame_num)] # The dim should be: (temporal, object, points, xy) The fps should be fixed to 12 fps, which is the same as training decode fps
355
+
356
+ # Iterate all objects but not the main objects
357
+ for obj_idx in range(len(ID_info_panoptic)):
358
+
359
+ # Prepare the tracjectory
360
+ pred_tracks = Track_Traj[obj_idx]
361
+ pred_tracks = pred_tracks[downsample_start_frame_idx : downsample_start_frame_idx + iter_gap * self.train_frame_num : iter_gap]
362
+ if len(pred_tracks) != self.train_frame_num:
363
+ raise Exception("The len of pre_track does not match")
364
+
365
+
366
+ # For Non-main obj idx, we must ensure all points inside the region box; If it is main obj, the ID must be outside the region box
367
+ if obj_idx != main_target_obj_idx or self.FrameOut_only:
368
+
369
+ # Randomly select the points based on the prob given, here, the number of points is different for each objeects
370
+ kept_point_status = random.choices([True, False], weights = [self.point_keep_ratio_regular, 1 - self.point_keep_ratio_regular], k = len(pred_tracks[0]))
371
+
372
+ # Check witht the first frame, No need to check for following frames (allowed to have FrameOut effect)
373
+ first_frame_points = pred_tracks[0]
374
+ for point_idx in range(len(first_frame_points)):
375
+ (horizontal, vertical) = first_frame_points[point_idx]
376
+ if horizontal < top_left_x_raw or horizontal >= bottom_right_x_raw or vertical < top_left_y_raw or vertical >= bottom_right_y_raw:
377
+ kept_point_status[point_idx] = False
378
+
379
+ else: # For main object
380
+
381
+ # Randomly select the points based on the prob given, here, the number of points is different for each objeects
382
+ if drop_FrameIn:
383
+ # No motion provided on ID for Drop FrameIn cases
384
+ kept_point_status = random.choices([False], k = len(pred_tracks[0]))
385
+
386
+ else: # Regular FrameIn case
387
+ kept_point_status = random.choices([True, False], weights = [self.point_keep_ratio_ID, 1 - self.point_keep_ratio_ID], k = len(pred_tracks[0]))
388
+
389
+
390
+ # Sanity Check
391
+ if len(kept_point_status) != len(pred_tracks[-1]):
392
+ raise Exception("The number of points filterred is not match with the dataset")
393
+
394
+ # Iterate and add all temporally
395
+ for temporal_idx, pred_track in enumerate(pred_tracks):
396
+
397
+ # Iterate all point one by one
398
+ left_points = []
399
+ for point_idx in range(len(pred_track)):
400
+ # Select kept points
401
+ if kept_point_status[point_idx]:
402
+ left_points.append(pred_track[point_idx])
403
+
404
+ # Append the left points to the list
405
+ full_pred_tracks[temporal_idx].append(left_points) # pred_tracks will be 49 frames, and each one represent all tracking points for single objects; only one object here
406
+
407
+ # Fetch One Point
408
+ if self.one_point_one_obj:
409
+ one_track_point = []
410
+ for full_pred_track_per_frame in full_pred_tracks:
411
+ one_track_point.append( [[full_pred_track_per_frame[0][0]]])
412
+
413
+ #######################################################################################################################################
414
+
415
+
416
+
417
+ ############################### Process the Video Tensor (based on info fetched from traj) ############################################
418
+
419
+ # Select Frames based on the panoptic range (No Mask here)
420
+ selected_frames = video_np_raw[downsample_start_frame_idx : downsample_start_frame_idx + iter_gap * self.train_frame_num : iter_gap]
421
+
422
+ # Prepare the Video Tensor; NOTE: in this branch, video tensor is full image without mask
423
+ video_tensor = torch.tensor(selected_frames) # Convert to tensor
424
+ if len(video_tensor) != self.train_frame_num:
425
+ raise Exception("The len of train frames does not match")
426
+
427
+ # Training transforms for the Video and condition
428
+ video_tensor = video_tensor.float()
429
+ video_tensor = torch.stack([train_transforms(frame) for frame in video_tensor], dim=0)
430
+ video_tensor = video_tensor.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
431
+
432
+
433
+
434
+ if drop_FrameIn:
435
+ main_reference_img = np.uint8(np.zeros((self.target_height, self.target_width, 3))) # Whole Black (0-value) pixel placeholder
436
+
437
+ else:
438
+
439
+ # Fetch the reference and resize
440
+ main_reference_img = np.asarray(Image.open(first_frame_reference_path))
441
+
442
+ # Resize to the same size as the video
443
+ ref_h, ref_w = main_reference_img.shape[:2]
444
+ scale_h = self.target_height / max(ref_h, ref_w)
445
+ scale_w = self.target_width / max(ref_h, ref_w)
446
+ new_h, new_w = int(ref_h * scale_h), int(ref_w * scale_w)
447
+ main_reference_img = cv2.resize(main_reference_img, (new_w, new_h), interpolation = cv2.INTER_AREA)
448
+
449
+ # Calculate padding amounts on all direction
450
+ pad_height1 = (self.target_height - main_reference_img.shape[0]) // 2
451
+ pad_height2 = self.target_height - main_reference_img.shape[0] - pad_height1
452
+ pad_width1 = (self.target_width - main_reference_img.shape[1]) // 2
453
+ pad_width2 = self.target_width - main_reference_img.shape[1] - pad_width1
454
+
455
+ # Apply padding to same resolution as the training farmes
456
+ main_reference_img = np.pad(
457
+ main_reference_img,
458
+ ((pad_height1, pad_height2), (pad_width1, pad_width2), (0, 0)),
459
+ mode = 'constant',
460
+ constant_values = 0
461
+ )
462
+ # cv2.imwrite("main_reference_img_padded"+str(idx)+".png", cv2.cvtColor(main_reference_img, cv2.COLOR_BGR2RGB))
463
+
464
+
465
+ # Convert to tensor
466
+ main_reference_tensor = torch.tensor(main_reference_img)
467
+ main_reference_tensor = train_transforms(main_reference_tensor).permute(2, 0, 1).contiguous()
468
+
469
+
470
+ # Fetch the first frame and then do ID merge for this branch of training
471
+ first_frame_np = video_np_masked[downsample_start_frame_idx] # Needs to return for Validation
472
+ # cv2.imwrite("first_frame"+str(idx)+".png", cv2.cvtColor(first_frame_np, cv2.COLOR_BGR2RGB))
473
+
474
+ # Convert to Tensor and then Transforms
475
+ first_frame_tensor = torch.tensor(first_frame_np)
476
+ first_frame_tensor = train_transforms(first_frame_tensor).permute(2, 0, 1).contiguous()
477
+
478
+ #######################################################################################################################################
479
+
480
+
481
+
482
+ ############################################## Draw the Traj Points and Transform to Tensor #############################################
483
+
484
+ # Draw the dilated points
485
+ if self.one_point_one_obj:
486
+ target_pred_tracks = one_track_point # For this case, we only has one point per one object
487
+ else:
488
+ target_pred_tracks = full_pred_tracks
489
+
490
+ traj_tensor, traj_imgs_np, merge_frames = self.prepare_traj_tensor(target_pred_tracks, original_height, original_width, selected_frames,
491
+ self.dot_radius, self.target_width, self.target_height, resized_mask_region_box, idx)
492
+
493
+ #########################################################################################################################################
494
+
495
+
496
+ # Write some processed meta data
497
+ processed_meta_data = {
498
+ "full_pred_tracks": full_pred_tracks,
499
+ "original_width": original_width,
500
+ "original_height": original_height,
501
+ "mask_region": mask_region,
502
+ "resized_mask_region_box": resized_mask_region_box,
503
+ }
504
+
505
+ except Exception as e:
506
+ print("The exception is ", e)
507
+ old_idx = idx
508
+ idx = random.randint(0, len(self.info_lists))
509
+ print("We cannot process the video", old_idx, " and we choose a new idx of ", idx)
510
+ continue # For any error occurs, we run it again with new idx proposed (a random int less than current value)
511
+
512
+
513
+ # If everything is ok, we should break at the end
514
+ break
515
+
516
+
517
+ # Return the information
518
+ return {
519
+ "video_tensor": video_tensor,
520
+ "traj_tensor": traj_tensor,
521
+ "first_frame_tensor": first_frame_tensor,
522
+ "main_reference_tensor": main_reference_tensor,
523
+ "text_prompt": text_prompt,
524
+
525
+ # The rest are auxiliary data for the validation/testing purposes
526
+ "video_gt_np": selected_frames,
527
+ "first_frame_np": first_frame_np,
528
+ "main_reference_np": main_reference_img,
529
+ "processed_meta_data": processed_meta_data,
530
+ "traj_imgs_np": traj_imgs_np,
531
+ "merge_frames" : merge_frames,
532
+ "gt_video_path": video_path,
533
+ }
534
+
535
+
536
+
537
+
538
+
pipelines/pipeline_cogvideox_i2v_motion.py ADDED
@@ -0,0 +1,931 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os, sys, shutil
17
+ import inspect
18
+ import math
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import PIL
22
+ import torch
23
+ from transformers import T5EncoderModel, T5Tokenizer
24
+
25
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
26
+ from diffusers.image_processor import PipelineImageInput
27
+ from diffusers.loaders import CogVideoXLoraLoaderMixin
28
+ from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
29
+ # from diffusers.models.embeddings import get_3d_rotary_pos_embed
30
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
31
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
32
+ from diffusers.utils import (
33
+ is_torch_xla_available,
34
+ logging,
35
+ replace_example_docstring,
36
+ )
37
+ from diffusers.utils.torch_utils import randn_tensor
38
+ from diffusers.video_processor import VideoProcessor
39
+ from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput
40
+
41
+
42
+ # Import files from the local folder
43
+ root_path = os.path.abspath('.')
44
+ sys.path.append(root_path)
45
+ from architecture.embeddings import get_3d_rotary_pos_embed
46
+
47
+
48
+ if is_torch_xla_available():
49
+ import torch_xla.core.xla_model as xm
50
+
51
+ XLA_AVAILABLE = True
52
+ else:
53
+ XLA_AVAILABLE = False
54
+
55
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
56
+
57
+
58
+ EXAMPLE_DOC_STRING = """
59
+ Examples:
60
+ ```py
61
+ >>> import torch
62
+ >>> from diffusers import CogVideoXImageToVideoPipeline
63
+ >>> from diffusers.utils import export_to_video, load_image
64
+
65
+ >>> pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
66
+ >>> pipe.to("cuda")
67
+
68
+ >>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
69
+ >>> image = load_image(
70
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
71
+ ... )
72
+ >>> video = pipe(image, prompt, use_dynamic_cfg=True)
73
+ >>> export_to_video(video.frames[0], "output.mp4", fps=8)
74
+ ```
75
+ """
76
+
77
+
78
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
79
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
80
+
81
+ tw = tgt_width
82
+ th = tgt_height
83
+ h, w = src
84
+ r = h / w
85
+ if r > (th / tw): # NOTE: 这里应该是把aspect ratio align到target的程度 (类似于之前看的Reference Resize方法)
86
+ resize_height = th
87
+ resize_width = int(round(th / h * w)) # NOTE: 这个一个branch,这里会有多余位点
88
+ else:
89
+ resize_width = tw
90
+ resize_height = int(round(tw / w * h))
91
+
92
+ crop_top = int(round((th - resize_height) / 2.0))
93
+ crop_left = int(round((tw - resize_width) / 2.0)) # NOTE: 这个取了中间值
94
+
95
+
96
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
97
+
98
+
99
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
100
+ def retrieve_timesteps(
101
+ scheduler,
102
+ num_inference_steps: Optional[int] = None,
103
+ device: Optional[Union[str, torch.device]] = None,
104
+ timesteps: Optional[List[int]] = None,
105
+ sigmas: Optional[List[float]] = None,
106
+ **kwargs,
107
+ ):
108
+ r"""
109
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
110
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
111
+
112
+ Args:
113
+ scheduler (`SchedulerMixin`):
114
+ The scheduler to get timesteps from.
115
+ num_inference_steps (`int`):
116
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
117
+ must be `None`.
118
+ device (`str` or `torch.device`, *optional*):
119
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
120
+ timesteps (`List[int]`, *optional*):
121
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
122
+ `num_inference_steps` and `sigmas` must be `None`.
123
+ sigmas (`List[float]`, *optional*):
124
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
125
+ `num_inference_steps` and `timesteps` must be `None`.
126
+
127
+ Returns:
128
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
129
+ second element is the number of inference steps.
130
+ """
131
+ if timesteps is not None and sigmas is not None:
132
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
133
+ if timesteps is not None:
134
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
135
+ if not accepts_timesteps:
136
+ raise ValueError(
137
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
138
+ f" timestep schedules. Please check whether you are using the correct scheduler."
139
+ )
140
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
141
+ timesteps = scheduler.timesteps
142
+ num_inference_steps = len(timesteps)
143
+ elif sigmas is not None:
144
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
145
+ if not accept_sigmas:
146
+ raise ValueError(
147
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
148
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
149
+ )
150
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
151
+ timesteps = scheduler.timesteps
152
+ num_inference_steps = len(timesteps)
153
+ else:
154
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
155
+ timesteps = scheduler.timesteps
156
+ return timesteps, num_inference_steps
157
+
158
+
159
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
160
+ def retrieve_latents(
161
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
162
+ ):
163
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
164
+ return encoder_output.latent_dist.sample(generator)
165
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
166
+ return encoder_output.latent_dist.mode()
167
+ elif hasattr(encoder_output, "latents"):
168
+ return encoder_output.latents
169
+ else:
170
+ raise AttributeError("Could not access latents of provided encoder_output")
171
+
172
+
173
+ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
174
+ r"""
175
+ Pipeline for image-to-video generation using CogVideoX.
176
+
177
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
178
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
179
+
180
+ Args:
181
+ vae ([`AutoencoderKL`]):
182
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
183
+ text_encoder ([`T5EncoderModel`]):
184
+ Frozen text-encoder. CogVideoX uses
185
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
186
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
187
+ tokenizer (`T5Tokenizer`):
188
+ Tokenizer of class
189
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
190
+ transformer ([`CogVideoXTransformer3DModel`]):
191
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
192
+ scheduler ([`SchedulerMixin`]):
193
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
194
+ """
195
+
196
+ _optional_components = []
197
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
198
+
199
+ _callback_tensor_inputs = [
200
+ "latents",
201
+ "prompt_embeds",
202
+ "negative_prompt_embeds",
203
+ ]
204
+
205
+ def __init__(
206
+ self,
207
+ tokenizer: T5Tokenizer,
208
+ text_encoder: T5EncoderModel,
209
+ vae: AutoencoderKLCogVideoX,
210
+ transformer: CogVideoXTransformer3DModel,
211
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
212
+ ):
213
+ super().__init__()
214
+
215
+ self.register_modules(
216
+ tokenizer=tokenizer,
217
+ text_encoder=text_encoder,
218
+ vae=vae,
219
+ transformer=transformer,
220
+ scheduler=scheduler,
221
+ )
222
+ self.vae_scale_factor_spatial = (
223
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
224
+ )
225
+ self.vae_scale_factor_temporal = (
226
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
227
+ )
228
+ self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
229
+
230
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
231
+
232
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
233
+ def _get_t5_prompt_embeds(
234
+ self,
235
+ prompt: Union[str, List[str]] = None,
236
+ num_videos_per_prompt: int = 1,
237
+ max_sequence_length: int = 226,
238
+ device: Optional[torch.device] = None,
239
+ dtype: Optional[torch.dtype] = None,
240
+ ):
241
+ device = device or self._execution_device
242
+ dtype = dtype or self.text_encoder.dtype
243
+
244
+ prompt = [prompt] if isinstance(prompt, str) else prompt
245
+ batch_size = len(prompt)
246
+
247
+ text_inputs = self.tokenizer(
248
+ prompt,
249
+ padding="max_length",
250
+ max_length=max_sequence_length,
251
+ truncation=True,
252
+ add_special_tokens=True,
253
+ return_tensors="pt",
254
+ )
255
+ text_input_ids = text_inputs.input_ids
256
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
257
+
258
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
259
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
260
+ logger.warning(
261
+ "The following part of your input was truncated because `max_sequence_length` is set to "
262
+ f" {max_sequence_length} tokens: {removed_text}"
263
+ )
264
+
265
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
266
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
267
+
268
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
269
+ _, seq_len, _ = prompt_embeds.shape
270
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
271
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
272
+
273
+ return prompt_embeds
274
+
275
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
276
+ def encode_prompt(
277
+ self,
278
+ prompt: Union[str, List[str]],
279
+ negative_prompt: Optional[Union[str, List[str]]] = None,
280
+ do_classifier_free_guidance: bool = True,
281
+ num_videos_per_prompt: int = 1,
282
+ prompt_embeds: Optional[torch.Tensor] = None,
283
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
284
+ max_sequence_length: int = 226,
285
+ device: Optional[torch.device] = None,
286
+ dtype: Optional[torch.dtype] = None,
287
+ ):
288
+ r"""
289
+ Encodes the prompt into text encoder hidden states.
290
+
291
+ Args:
292
+ prompt (`str` or `List[str]`, *optional*):
293
+ prompt to be encoded
294
+ negative_prompt (`str` or `List[str]`, *optional*):
295
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
296
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
297
+ less than `1`).
298
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
299
+ Whether to use classifier free guidance or not.
300
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
301
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
302
+ prompt_embeds (`torch.Tensor`, *optional*):
303
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
304
+ provided, text embeddings will be generated from `prompt` input argument.
305
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
306
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
307
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
308
+ argument.
309
+ device: (`torch.device`, *optional*):
310
+ torch device
311
+ dtype: (`torch.dtype`, *optional*):
312
+ torch dtype
313
+ """
314
+ device = device or self._execution_device
315
+
316
+ prompt = [prompt] if isinstance(prompt, str) else prompt
317
+ if prompt is not None:
318
+ batch_size = len(prompt)
319
+ else:
320
+ batch_size = prompt_embeds.shape[0]
321
+
322
+ if prompt_embeds is None:
323
+ prompt_embeds = self._get_t5_prompt_embeds(
324
+ prompt=prompt,
325
+ num_videos_per_prompt=num_videos_per_prompt,
326
+ max_sequence_length=max_sequence_length,
327
+ device=device,
328
+ dtype=dtype,
329
+ )
330
+
331
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
332
+ negative_prompt = negative_prompt or ""
333
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
334
+
335
+ if prompt is not None and type(prompt) is not type(negative_prompt):
336
+ raise TypeError(
337
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
338
+ f" {type(prompt)}."
339
+ )
340
+ elif batch_size != len(negative_prompt):
341
+ raise ValueError(
342
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
343
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
344
+ " the batch size of `prompt`."
345
+ )
346
+
347
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
348
+ prompt=negative_prompt,
349
+ num_videos_per_prompt=num_videos_per_prompt,
350
+ max_sequence_length=max_sequence_length,
351
+ device=device,
352
+ dtype=dtype,
353
+ )
354
+
355
+ return prompt_embeds, negative_prompt_embeds
356
+
357
+ def prepare_latents(
358
+ self,
359
+ image: torch.Tensor,
360
+ batch_size: int = 1,
361
+ num_channels_latents: int = 16,
362
+ num_frames: int = 13,
363
+ height: int = 60,
364
+ width: int = 90,
365
+ dtype: Optional[torch.dtype] = None,
366
+ device: Optional[torch.device] = None,
367
+ generator: Optional[torch.Generator] = None,
368
+ latents: Optional[torch.Tensor] = None,
369
+ ):
370
+ if isinstance(generator, list) and len(generator) != batch_size:
371
+ raise ValueError(
372
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
373
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
374
+ )
375
+
376
+ num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
377
+ shape = (
378
+ batch_size,
379
+ num_frames,
380
+ num_channels_latents,
381
+ height // self.vae_scale_factor_spatial,
382
+ width // self.vae_scale_factor_spatial,
383
+ )
384
+
385
+ # For CogVideoX1.5, the latent should add 1 for padding (Not use)
386
+ if self.transformer.config.patch_size_t is not None:
387
+ shape = shape[:1] + (shape[1] + shape[1] % self.transformer.config.patch_size_t,) + shape[2:]
388
+
389
+ image = image.unsqueeze(2) # [B, C, F, H, W]
390
+
391
+ if isinstance(generator, list):
392
+ image_latents = [
393
+ retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
394
+ ]
395
+ else:
396
+ image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
397
+
398
+ image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
399
+
400
+ if not self.vae.config.invert_scale_latents:
401
+ image_latents = self.vae_scaling_factor_image * image_latents
402
+ else:
403
+ # This is awkward but required because the CogVideoX team forgot to multiply the
404
+ # scaling factor during training :)
405
+ image_latents = 1 / self.vae_scaling_factor_image * image_latents
406
+
407
+ padding_shape = (
408
+ batch_size,
409
+ num_frames - 1,
410
+ num_channels_latents,
411
+ height // self.vae_scale_factor_spatial,
412
+ width // self.vae_scale_factor_spatial,
413
+ )
414
+
415
+ latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
416
+ image_latents = torch.cat([image_latents, latent_padding], dim=1)
417
+
418
+ # Select the first frame along the second dimension
419
+ if self.transformer.config.patch_size_t is not None:
420
+ first_frame = image_latents[:, : image_latents.size(1) % self.transformer.config.patch_size_t, ...]
421
+ image_latents = torch.cat([first_frame, image_latents], dim=1)
422
+
423
+ if latents is None:
424
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
425
+ else:
426
+ latents = latents.to(device)
427
+
428
+ # scale the initial noise by the standard deviation required by the scheduler
429
+ latents = latents * self.scheduler.init_noise_sigma
430
+ return latents, image_latents
431
+
432
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
433
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
434
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
435
+ latents = 1 / self.vae_scaling_factor_image * latents
436
+
437
+ frames = self.vae.decode(latents).sample
438
+ return frames
439
+
440
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
441
+ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
442
+ # get the original timestep using init_timestep
443
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
444
+
445
+ t_start = max(num_inference_steps - init_timestep, 0)
446
+ timesteps = timesteps[t_start * self.scheduler.order :]
447
+
448
+ return timesteps, num_inference_steps - t_start
449
+
450
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
451
+ def prepare_extra_step_kwargs(self, generator, eta):
452
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
453
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
454
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
455
+ # and should be between [0, 1]
456
+
457
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
458
+ extra_step_kwargs = {}
459
+ if accepts_eta:
460
+ extra_step_kwargs["eta"] = eta
461
+
462
+ # check if the scheduler accepts generator
463
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
464
+ if accepts_generator:
465
+ extra_step_kwargs["generator"] = generator
466
+ return extra_step_kwargs
467
+
468
+ def check_inputs(
469
+ self,
470
+ image,
471
+ prompt,
472
+ height,
473
+ width,
474
+ negative_prompt,
475
+ callback_on_step_end_tensor_inputs,
476
+ latents=None,
477
+ prompt_embeds=None,
478
+ negative_prompt_embeds=None,
479
+ ):
480
+ if (
481
+ not isinstance(image, torch.Tensor)
482
+ and not isinstance(image, PIL.Image.Image)
483
+ and not isinstance(image, list)
484
+ ):
485
+ raise ValueError(
486
+ "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
487
+ f" {type(image)}"
488
+ )
489
+
490
+ if height % 8 != 0 or width % 8 != 0:
491
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
492
+
493
+ if callback_on_step_end_tensor_inputs is not None and not all(
494
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
495
+ ):
496
+ raise ValueError(
497
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
498
+ )
499
+ if prompt is not None and prompt_embeds is not None:
500
+ raise ValueError(
501
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
502
+ " only forward one of the two."
503
+ )
504
+ elif prompt is None and prompt_embeds is None:
505
+ raise ValueError(
506
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
507
+ )
508
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
509
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
510
+
511
+ if prompt is not None and negative_prompt_embeds is not None:
512
+ raise ValueError(
513
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
514
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
515
+ )
516
+
517
+ if negative_prompt is not None and negative_prompt_embeds is not None:
518
+ raise ValueError(
519
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
520
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
521
+ )
522
+
523
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
524
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
525
+ raise ValueError(
526
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
527
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
528
+ f" {negative_prompt_embeds.shape}."
529
+ )
530
+
531
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
532
+ def fuse_qkv_projections(self) -> None:
533
+ r"""Enables fused QKV projections."""
534
+ self.fusing_transformer = True
535
+ self.transformer.fuse_qkv_projections()
536
+
537
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections
538
+ def unfuse_qkv_projections(self) -> None:
539
+ r"""Disable QKV projection fusion if enabled."""
540
+ if not self.fusing_transformer:
541
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
542
+ else:
543
+ self.transformer.unfuse_qkv_projections()
544
+ self.fusing_transformer = False
545
+
546
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
547
+ def _prepare_rotary_positional_embeddings(
548
+ self,
549
+ height: int,
550
+ width: int,
551
+ num_frames: int,
552
+ device: torch.device,
553
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
554
+
555
+
556
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
557
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
558
+
559
+ p = self.transformer.config.patch_size
560
+ p_t = self.transformer.config.patch_size_t
561
+
562
+ base_size_width = self.transformer.config.sample_width // p
563
+ base_size_height = self.transformer.config.sample_height // p
564
+
565
+ # RoPE extrapolation factor in NTK
566
+ # token_factor_ratio = (grid_height * grid_width) / (base_size_width * base_size_height)
567
+ # if token_factor_ratio > 1.0:
568
+ # ntk_factor = token_factor_ratio
569
+ # else:
570
+ # ntk_factor = 1.0
571
+
572
+
573
+ if p_t is None: # HACK: Go this Branch
574
+ # CogVideoX 1.0
575
+ grid_crops_coords = get_resize_crop_region_for_grid(
576
+ (grid_height, grid_width), base_size_width, base_size_height
577
+ )
578
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
579
+ embed_dim=self.transformer.config.attention_head_dim,
580
+ crops_coords=grid_crops_coords, # ((0, 0), (30, 45))
581
+ grid_size=(grid_height, grid_width), # (30, 45)
582
+ # ntk_factor = ntk_factor, # For the extrapolation
583
+ temporal_size=num_frames,
584
+ device=device,
585
+ )
586
+ else:
587
+ # CogVideoX 1.5
588
+ base_num_frames = (num_frames + p_t - 1) // p_t
589
+
590
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
591
+ embed_dim=self.transformer.config.attention_head_dim,
592
+ crops_coords=None,
593
+ grid_size=(grid_height, grid_width),
594
+ temporal_size=base_num_frames,
595
+ grid_type="slice",
596
+ max_size=(base_size_height, base_size_width),
597
+ device=device,
598
+ )
599
+
600
+ return freqs_cos, freqs_sin
601
+
602
+ @property
603
+ def guidance_scale(self):
604
+ return self._guidance_scale
605
+
606
+ @property
607
+ def num_timesteps(self):
608
+ return self._num_timesteps
609
+
610
+ @property
611
+ def attention_kwargs(self):
612
+ return self._attention_kwargs
613
+
614
+ @property
615
+ def interrupt(self):
616
+ return self._interrupt
617
+
618
+ @torch.no_grad()
619
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
620
+ def __call__(
621
+ self,
622
+ image: PipelineImageInput,
623
+ traj_tensor = None,
624
+ prompt: Optional[Union[str, List[str]]] = None,
625
+ negative_prompt: Optional[Union[str, List[str]]] = None,
626
+ height: Optional[int] = None,
627
+ width: Optional[int] = None,
628
+ num_frames: int = 49,
629
+ num_inference_steps: int = 50,
630
+ timesteps: Optional[List[int]] = None,
631
+ guidance_scale: float = 6,
632
+ use_dynamic_cfg: bool = False,
633
+ num_videos_per_prompt: int = 1,
634
+ eta: float = 0.0,
635
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
636
+ latents: Optional[torch.FloatTensor] = None,
637
+ prompt_embeds: Optional[torch.FloatTensor] = None,
638
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
639
+ output_type: str = "pil",
640
+ return_dict: bool = True,
641
+ attention_kwargs: Optional[Dict[str, Any]] = None,
642
+ callback_on_step_end: Optional[
643
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
644
+ ] = None,
645
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
646
+ max_sequence_length: int = 226,
647
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
648
+ """
649
+ Function invoked when calling the pipeline for generation.
650
+
651
+ Args:
652
+ image (`PipelineImageInput`):
653
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
654
+ prompt (`str` or `List[str]`, *optional*):
655
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
656
+ instead.
657
+ negative_prompt (`str` or `List[str]`, *optional*):
658
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
659
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
660
+ less than `1`).
661
+ height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
662
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
663
+ width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
664
+ The width in pixels of the generated image. This is set to 720 by default for the best results.
665
+ num_frames (`int`, defaults to `48`):
666
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
667
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
668
+ num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
669
+ needs to be satisfied is that of divisibility mentioned above.
670
+ num_inference_steps (`int`, *optional*, defaults to 50):
671
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
672
+ expense of slower inference.
673
+ timesteps (`List[int]`, *optional*):
674
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
675
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
676
+ passed will be used. Must be in descending order.
677
+ guidance_scale (`float`, *optional*, defaults to 7.0):
678
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
679
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
680
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
681
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
682
+ usually at the expense of lower image quality.
683
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
684
+ The number of videos to generate per prompt.
685
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
686
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
687
+ to make generation deterministic.
688
+ latents (`torch.FloatTensor`, *optional*):
689
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
690
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
691
+ tensor will ge generated by sampling using the supplied random `generator`.
692
+ prompt_embeds (`torch.FloatTensor`, *optional*):
693
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
694
+ provided, text embeddings will be generated from `prompt` input argument.
695
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
696
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
697
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
698
+ argument.
699
+ output_type (`str`, *optional*, defaults to `"pil"`):
700
+ The output format of the generate image. Choose between
701
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
702
+ return_dict (`bool`, *optional*, defaults to `True`):
703
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
704
+ of a plain tuple.
705
+ attention_kwargs (`dict`, *optional*):
706
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
707
+ `self.processor` in
708
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
709
+ callback_on_step_end (`Callable`, *optional*):
710
+ A function that calls at the end of each denoising steps during the inference. The function is called
711
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
712
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
713
+ `callback_on_step_end_tensor_inputs`.
714
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
715
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
716
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
717
+ `._callback_tensor_inputs` attribute of your pipeline class.
718
+ max_sequence_length (`int`, defaults to `226`):
719
+ Maximum sequence length in encoded prompt. Must be consistent with
720
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
721
+
722
+ Examples:
723
+
724
+ Returns:
725
+ [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`:
726
+ [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
727
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
728
+ """
729
+
730
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
731
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
732
+
733
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
734
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
735
+ num_frames = num_frames or self.transformer.config.sample_frames
736
+
737
+ num_videos_per_prompt = 1
738
+
739
+ # 1. Check inputs. Raise error if not correct
740
+ self.check_inputs(
741
+ image=image,
742
+ prompt=prompt,
743
+ height=height,
744
+ width=width,
745
+ negative_prompt=negative_prompt,
746
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
747
+ latents=latents,
748
+ prompt_embeds=prompt_embeds,
749
+ negative_prompt_embeds=negative_prompt_embeds,
750
+ )
751
+ self._guidance_scale = guidance_scale
752
+ self._attention_kwargs = attention_kwargs
753
+ self._interrupt = False
754
+
755
+ # 2. Default call parameters
756
+ if prompt is not None and isinstance(prompt, str):
757
+ batch_size = 1
758
+ elif prompt is not None and isinstance(prompt, list):
759
+ batch_size = len(prompt)
760
+ else:
761
+ batch_size = prompt_embeds.shape[0]
762
+
763
+ device = self._execution_device
764
+
765
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
766
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
767
+ # corresponds to doing no classifier free guidance.
768
+ do_classifier_free_guidance = guidance_scale > 1.0
769
+
770
+ # 3. Encode input prompt
771
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
772
+ prompt=prompt,
773
+ negative_prompt=negative_prompt,
774
+ do_classifier_free_guidance=do_classifier_free_guidance,
775
+ num_videos_per_prompt=num_videos_per_prompt,
776
+ prompt_embeds=prompt_embeds,
777
+ negative_prompt_embeds=negative_prompt_embeds,
778
+ max_sequence_length=max_sequence_length,
779
+ device=device,
780
+ )
781
+ if do_classifier_free_guidance:
782
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
783
+
784
+ # 4. Prepare timesteps
785
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
786
+ self._num_timesteps = len(timesteps)
787
+
788
+ # 5. Prepare latents
789
+ latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
790
+
791
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
792
+ patch_size_t = self.transformer.config.patch_size_t
793
+ additional_frames = 0
794
+ if patch_size_t is not None and latent_frames % patch_size_t != 0:
795
+ additional_frames = patch_size_t - latent_frames % patch_size_t
796
+ num_frames += additional_frames * self.vae_scale_factor_temporal
797
+
798
+ image = self.video_processor.preprocess(image, height=height, width=width).to(
799
+ device, dtype=prompt_embeds.dtype
800
+ )
801
+
802
+ latent_channels = 16 # self.transformer.config.in_channels // 2
803
+ latents, image_latents = self.prepare_latents(
804
+ image,
805
+ batch_size * num_videos_per_prompt,
806
+ latent_channels,
807
+ num_frames,
808
+ height,
809
+ width,
810
+ prompt_embeds.dtype,
811
+ device,
812
+ generator,
813
+ latents,
814
+ )
815
+
816
+
817
+ # 5.5. Traj Preprocess
818
+ traj_tensor = traj_tensor.to(device, dtype = self.vae.dtype)[None] #.unsqueeze(0)
819
+ traj_tensor = traj_tensor.permute(0, 2, 1, 3, 4)
820
+ traj_latents = self.vae.encode(traj_tensor).latent_dist
821
+
822
+ # Scale, Permute, and other conversion
823
+ traj_latents = traj_latents.sample() * self.vae.config.scaling_factor
824
+ traj_latents = traj_latents.permute(0, 2, 1, 3, 4)
825
+ traj_latents = traj_latents.to(memory_format = torch.contiguous_format).float().to(dtype = prompt_embeds.dtype) # [B, F, C, H, W]
826
+
827
+
828
+
829
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
830
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
831
+
832
+ # 7. Create rotary embeds if required
833
+ image_rotary_emb = (
834
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
835
+ if self.transformer.config.use_rotary_positional_embeddings
836
+ else None
837
+ )
838
+
839
+ # 8. Create ofs embeds if required
840
+ ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0)
841
+
842
+ # 8. Denoising loop
843
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
844
+
845
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
846
+ # for DPM-solver++
847
+ old_pred_original_sample = None
848
+ for i, t in enumerate(timesteps):
849
+ if self.interrupt:
850
+ continue
851
+
852
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
853
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
854
+
855
+ latent_traj = torch.cat([traj_latents] * 2) if do_classifier_free_guidance else traj_latents
856
+
857
+ latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
858
+ latent_model_input = torch.cat([latent_model_input, latent_image_input, latent_traj], dim=2) # The thrid dim grow from 16 to 32
859
+
860
+
861
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
862
+ timestep = t.expand(latent_model_input.shape[0])
863
+
864
+ # predict noise model_output
865
+ noise_pred = self.transformer(
866
+ hidden_states=latent_model_input,
867
+ encoder_hidden_states=prompt_embeds,
868
+ timestep=timestep,
869
+ ofs=ofs_emb,
870
+ image_rotary_emb=image_rotary_emb,
871
+ attention_kwargs=attention_kwargs,
872
+ return_dict=False,
873
+ )[0]
874
+ noise_pred = noise_pred.float()
875
+
876
+ # perform guidance
877
+ if use_dynamic_cfg:
878
+ self._guidance_scale = 1 + guidance_scale * (
879
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
880
+ )
881
+ if do_classifier_free_guidance:
882
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
883
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
884
+
885
+ # compute the previous noisy sample x_t -> x_t-1
886
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
887
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
888
+ else:
889
+ latents, old_pred_original_sample = self.scheduler.step(
890
+ noise_pred,
891
+ old_pred_original_sample,
892
+ t,
893
+ timesteps[i - 1] if i > 0 else None,
894
+ latents,
895
+ **extra_step_kwargs,
896
+ return_dict=False,
897
+ )
898
+ latents = latents.to(prompt_embeds.dtype)
899
+
900
+ # call the callback, if provided
901
+ if callback_on_step_end is not None:
902
+ callback_kwargs = {}
903
+ for k in callback_on_step_end_tensor_inputs:
904
+ callback_kwargs[k] = locals()[k]
905
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
906
+
907
+ latents = callback_outputs.pop("latents", latents)
908
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
909
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
910
+
911
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
912
+ progress_bar.update()
913
+
914
+ if XLA_AVAILABLE:
915
+ xm.mark_step()
916
+
917
+ if not output_type == "latent":
918
+ # Discard any padding frames that were added for CogVideoX 1.5
919
+ latents = latents[:, additional_frames:]
920
+ video = self.decode_latents(latents)
921
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
922
+ else:
923
+ video = latents
924
+
925
+ # Offload all models
926
+ self.maybe_free_model_hooks()
927
+
928
+ if not return_dict:
929
+ return (video,)
930
+
931
+ return CogVideoXPipelineOutput(frames=video)
pipelines/pipeline_cogvideox_i2v_motion_FrameINO.py ADDED
@@ -0,0 +1,960 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
+ # All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import inspect
17
+ import math
18
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
+
20
+ import PIL
21
+ import torch
22
+ from transformers import T5EncoderModel, T5Tokenizer
23
+
24
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
25
+ from diffusers.image_processor import PipelineImageInput
26
+ from diffusers.loaders import CogVideoXLoraLoaderMixin
27
+ from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
28
+ from diffusers.models.embeddings import get_3d_rotary_pos_embed
29
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
30
+ from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
31
+ from diffusers.utils import (
32
+ is_torch_xla_available,
33
+ logging,
34
+ replace_example_docstring,
35
+ )
36
+ from diffusers.utils.torch_utils import randn_tensor
37
+ from diffusers.video_processor import VideoProcessor
38
+ from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput
39
+
40
+
41
+ if is_torch_xla_available():
42
+ import torch_xla.core.xla_model as xm
43
+
44
+ XLA_AVAILABLE = True
45
+ else:
46
+ XLA_AVAILABLE = False
47
+
48
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
+
50
+
51
+ EXAMPLE_DOC_STRING = """
52
+ Examples:
53
+ ```py
54
+ >>> import torch
55
+ >>> from diffusers import CogVideoXImageToVideoPipeline
56
+ >>> from diffusers.utils import export_to_video, load_image
57
+
58
+ >>> pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
59
+ >>> pipe.to("cuda")
60
+
61
+ >>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
62
+ >>> image = load_image(
63
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
64
+ ... )
65
+ >>> video = pipe(image, prompt, use_dynamic_cfg=True)
66
+ >>> export_to_video(video.frames[0], "output.mp4", fps=8)
67
+ ```
68
+ """
69
+
70
+
71
+ # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
72
+ def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
73
+
74
+ tw = tgt_width
75
+ th = tgt_height
76
+ h, w = src
77
+ r = h / w
78
+ if r > (th / tw):
79
+ resize_height = th
80
+ resize_width = int(round(th / h * w))
81
+ else:
82
+ resize_width = tw
83
+ resize_height = int(round(tw / w * h))
84
+
85
+ crop_top = int(round((th - resize_height) / 2.0))
86
+ crop_left = int(round((tw - resize_width) / 2.0))
87
+
88
+
89
+ return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
90
+
91
+
92
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
93
+ def retrieve_timesteps(
94
+ scheduler,
95
+ num_inference_steps: Optional[int] = None,
96
+ device: Optional[Union[str, torch.device]] = None,
97
+ timesteps: Optional[List[int]] = None,
98
+ sigmas: Optional[List[float]] = None,
99
+ **kwargs,
100
+ ):
101
+ r"""
102
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
103
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
104
+
105
+ Args:
106
+ scheduler (`SchedulerMixin`):
107
+ The scheduler to get timesteps from.
108
+ num_inference_steps (`int`):
109
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
110
+ must be `None`.
111
+ device (`str` or `torch.device`, *optional*):
112
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
113
+ timesteps (`List[int]`, *optional*):
114
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
115
+ `num_inference_steps` and `sigmas` must be `None`.
116
+ sigmas (`List[float]`, *optional*):
117
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
118
+ `num_inference_steps` and `timesteps` must be `None`.
119
+
120
+ Returns:
121
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
122
+ second element is the number of inference steps.
123
+ """
124
+ if timesteps is not None and sigmas is not None:
125
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
126
+ if timesteps is not None:
127
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
128
+ if not accepts_timesteps:
129
+ raise ValueError(
130
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
131
+ f" timestep schedules. Please check whether you are using the correct scheduler."
132
+ )
133
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
134
+ timesteps = scheduler.timesteps
135
+ num_inference_steps = len(timesteps)
136
+ elif sigmas is not None:
137
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
138
+ if not accept_sigmas:
139
+ raise ValueError(
140
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
141
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
142
+ )
143
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
144
+ timesteps = scheduler.timesteps
145
+ num_inference_steps = len(timesteps)
146
+ else:
147
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
148
+ timesteps = scheduler.timesteps
149
+ return timesteps, num_inference_steps
150
+
151
+
152
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
153
+ def retrieve_latents(
154
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
155
+ ):
156
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
157
+ return encoder_output.latent_dist.sample(generator)
158
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
159
+ return encoder_output.latent_dist.mode()
160
+ elif hasattr(encoder_output, "latents"):
161
+ return encoder_output.latents
162
+ else:
163
+ raise AttributeError("Could not access latents of provided encoder_output")
164
+
165
+
166
+ class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
167
+ r"""
168
+ Pipeline for image-to-video generation using CogVideoX.
169
+
170
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
171
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
172
+
173
+ Args:
174
+ vae ([`AutoencoderKL`]):
175
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
176
+ text_encoder ([`T5EncoderModel`]):
177
+ Frozen text-encoder. CogVideoX uses
178
+ [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
179
+ [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
180
+ tokenizer (`T5Tokenizer`):
181
+ Tokenizer of class
182
+ [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
183
+ transformer ([`CogVideoXTransformer3DModel`]):
184
+ A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
185
+ scheduler ([`SchedulerMixin`]):
186
+ A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
187
+ """
188
+
189
+ _optional_components = []
190
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
191
+
192
+ _callback_tensor_inputs = [
193
+ "latents",
194
+ "prompt_embeds",
195
+ "negative_prompt_embeds",
196
+ ]
197
+
198
+ def __init__(
199
+ self,
200
+ tokenizer: T5Tokenizer,
201
+ text_encoder: T5EncoderModel,
202
+ vae: AutoencoderKLCogVideoX,
203
+ transformer: CogVideoXTransformer3DModel,
204
+ scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
205
+ ):
206
+ super().__init__()
207
+
208
+ self.register_modules(
209
+ tokenizer=tokenizer,
210
+ text_encoder=text_encoder,
211
+ vae=vae,
212
+ transformer=transformer,
213
+ scheduler=scheduler,
214
+ )
215
+ self.vae_scale_factor_spatial = (
216
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
217
+ )
218
+ self.vae_scale_factor_temporal = (
219
+ self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
220
+ )
221
+ self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
222
+
223
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
224
+
225
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
226
+ def _get_t5_prompt_embeds(
227
+ self,
228
+ prompt: Union[str, List[str]] = None,
229
+ num_videos_per_prompt: int = 1,
230
+ max_sequence_length: int = 226,
231
+ device: Optional[torch.device] = None,
232
+ dtype: Optional[torch.dtype] = None,
233
+ ):
234
+ device = device or self._execution_device
235
+ dtype = dtype or self.text_encoder.dtype
236
+
237
+ prompt = [prompt] if isinstance(prompt, str) else prompt
238
+ batch_size = len(prompt)
239
+
240
+ text_inputs = self.tokenizer(
241
+ prompt,
242
+ padding="max_length",
243
+ max_length=max_sequence_length,
244
+ truncation=True,
245
+ add_special_tokens=True,
246
+ return_tensors="pt",
247
+ )
248
+ text_input_ids = text_inputs.input_ids
249
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
250
+
251
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
252
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
253
+ logger.warning(
254
+ "The following part of your input was truncated because `max_sequence_length` is set to "
255
+ f" {max_sequence_length} tokens: {removed_text}"
256
+ )
257
+
258
+ prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
259
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
260
+
261
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
262
+ _, seq_len, _ = prompt_embeds.shape
263
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
264
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
265
+
266
+ return prompt_embeds
267
+
268
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
269
+ def encode_prompt(
270
+ self,
271
+ prompt: Union[str, List[str]],
272
+ negative_prompt: Optional[Union[str, List[str]]] = None,
273
+ do_classifier_free_guidance: bool = True,
274
+ num_videos_per_prompt: int = 1,
275
+ prompt_embeds: Optional[torch.Tensor] = None,
276
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
277
+ max_sequence_length: int = 226,
278
+ device: Optional[torch.device] = None,
279
+ dtype: Optional[torch.dtype] = None,
280
+ ):
281
+ r"""
282
+ Encodes the prompt into text encoder hidden states.
283
+
284
+ Args:
285
+ prompt (`str` or `List[str]`, *optional*):
286
+ prompt to be encoded
287
+ negative_prompt (`str` or `List[str]`, *optional*):
288
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
289
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
290
+ less than `1`).
291
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
292
+ Whether to use classifier free guidance or not.
293
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
294
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
295
+ prompt_embeds (`torch.Tensor`, *optional*):
296
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
297
+ provided, text embeddings will be generated from `prompt` input argument.
298
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
299
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
300
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
301
+ argument.
302
+ device: (`torch.device`, *optional*):
303
+ torch device
304
+ dtype: (`torch.dtype`, *optional*):
305
+ torch dtype
306
+ """
307
+ device = device or self._execution_device
308
+
309
+ prompt = [prompt] if isinstance(prompt, str) else prompt
310
+ if prompt is not None:
311
+ batch_size = len(prompt)
312
+ else:
313
+ batch_size = prompt_embeds.shape[0]
314
+
315
+ if prompt_embeds is None:
316
+ prompt_embeds = self._get_t5_prompt_embeds(
317
+ prompt=prompt,
318
+ num_videos_per_prompt=num_videos_per_prompt,
319
+ max_sequence_length=max_sequence_length,
320
+ device=device,
321
+ dtype=dtype,
322
+ )
323
+
324
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
325
+ negative_prompt = negative_prompt or ""
326
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
327
+
328
+ if prompt is not None and type(prompt) is not type(negative_prompt):
329
+ raise TypeError(
330
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
331
+ f" {type(prompt)}."
332
+ )
333
+ elif batch_size != len(negative_prompt):
334
+ raise ValueError(
335
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
336
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
337
+ " the batch size of `prompt`."
338
+ )
339
+
340
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
341
+ prompt=negative_prompt,
342
+ num_videos_per_prompt=num_videos_per_prompt,
343
+ max_sequence_length=max_sequence_length,
344
+ device=device,
345
+ dtype=dtype,
346
+ )
347
+
348
+ return prompt_embeds, negative_prompt_embeds
349
+
350
+ def prepare_latents(
351
+ self,
352
+ image: torch.Tensor,
353
+ batch_size: int = 1,
354
+ num_channels_latents: int = 16,
355
+ num_frames: int = 13,
356
+ height: int = 60,
357
+ width: int = 90,
358
+ dtype: Optional[torch.dtype] = None,
359
+ device: Optional[torch.device] = None,
360
+ generator: Optional[torch.Generator] = None,
361
+ latents: Optional[torch.Tensor] = None,
362
+ ):
363
+ if isinstance(generator, list) and len(generator) != batch_size:
364
+ raise ValueError(
365
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
366
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
367
+ )
368
+
369
+ num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
370
+ shape = (
371
+ batch_size,
372
+ num_frames,
373
+ num_channels_latents,
374
+ height // self.vae_scale_factor_spatial,
375
+ width // self.vae_scale_factor_spatial,
376
+ )
377
+
378
+ # For CogVideoX1.5, the latent should add 1 for padding (Not use)
379
+ if self.transformer.config.patch_size_t is not None:
380
+ shape = shape[:1] + (shape[1] + shape[1] % self.transformer.config.patch_size_t,) + shape[2:]
381
+
382
+ image = image.unsqueeze(2) # [B, C, F, H, W]
383
+
384
+ if isinstance(generator, list):
385
+ image_latents = [
386
+ retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
387
+ ]
388
+ else:
389
+ image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
390
+
391
+ image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
392
+
393
+ if not self.vae.config.invert_scale_latents:
394
+ image_latents = self.vae_scaling_factor_image * image_latents
395
+ else:
396
+ # This is awkward but required because the CogVideoX team forgot to multiply the
397
+ # scaling factor during training :)
398
+ image_latents = 1 / self.vae_scaling_factor_image * image_latents
399
+
400
+ padding_shape = (
401
+ batch_size,
402
+ num_frames - 1,
403
+ num_channels_latents,
404
+ height // self.vae_scale_factor_spatial,
405
+ width // self.vae_scale_factor_spatial,
406
+ )
407
+
408
+ latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
409
+ image_latents = torch.cat([image_latents, latent_padding], dim=1)
410
+
411
+ # Select the first frame along the second dimension
412
+ if self.transformer.config.patch_size_t is not None:
413
+ first_frame = image_latents[:, : image_latents.size(1) % self.transformer.config.patch_size_t, ...]
414
+ image_latents = torch.cat([first_frame, image_latents], dim=1)
415
+
416
+ if latents is None:
417
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
418
+ else:
419
+ latents = latents.to(device)
420
+
421
+ # scale the initial noise by the standard deviation required by the scheduler
422
+ latents = latents * self.scheduler.init_noise_sigma
423
+ return latents, image_latents
424
+
425
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
426
+ def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
427
+ latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
428
+ latents = 1 / self.vae_scaling_factor_image * latents
429
+
430
+ frames = self.vae.decode(latents).sample
431
+ return frames
432
+
433
+ # Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
434
+ def get_timesteps(self, num_inference_steps, timesteps, strength, device):
435
+ # get the original timestep using init_timestep
436
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
437
+
438
+ t_start = max(num_inference_steps - init_timestep, 0)
439
+ timesteps = timesteps[t_start * self.scheduler.order :]
440
+
441
+ return timesteps, num_inference_steps - t_start
442
+
443
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
444
+ def prepare_extra_step_kwargs(self, generator, eta):
445
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
446
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
447
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
448
+ # and should be between [0, 1]
449
+
450
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
451
+ extra_step_kwargs = {}
452
+ if accepts_eta:
453
+ extra_step_kwargs["eta"] = eta
454
+
455
+ # check if the scheduler accepts generator
456
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
457
+ if accepts_generator:
458
+ extra_step_kwargs["generator"] = generator
459
+ return extra_step_kwargs
460
+
461
+ def check_inputs(
462
+ self,
463
+ image,
464
+ prompt,
465
+ height,
466
+ width,
467
+ negative_prompt,
468
+ callback_on_step_end_tensor_inputs,
469
+ latents=None,
470
+ prompt_embeds=None,
471
+ negative_prompt_embeds=None,
472
+ ):
473
+ if (
474
+ not isinstance(image, torch.Tensor)
475
+ and not isinstance(image, PIL.Image.Image)
476
+ and not isinstance(image, list)
477
+ ):
478
+ raise ValueError(
479
+ "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
480
+ f" {type(image)}"
481
+ )
482
+
483
+ if height % 8 != 0 or width % 8 != 0:
484
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
485
+
486
+ if callback_on_step_end_tensor_inputs is not None and not all(
487
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
488
+ ):
489
+ raise ValueError(
490
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
491
+ )
492
+ if prompt is not None and prompt_embeds is not None:
493
+ raise ValueError(
494
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
495
+ " only forward one of the two."
496
+ )
497
+ elif prompt is None and prompt_embeds is None:
498
+ raise ValueError(
499
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
500
+ )
501
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
502
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
503
+
504
+ if prompt is not None and negative_prompt_embeds is not None:
505
+ raise ValueError(
506
+ f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
507
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
508
+ )
509
+
510
+ if negative_prompt is not None and negative_prompt_embeds is not None:
511
+ raise ValueError(
512
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
513
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
514
+ )
515
+
516
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
517
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
518
+ raise ValueError(
519
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
520
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
521
+ f" {negative_prompt_embeds.shape}."
522
+ )
523
+
524
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
525
+ def fuse_qkv_projections(self) -> None:
526
+ r"""Enables fused QKV projections."""
527
+ self.fusing_transformer = True
528
+ self.transformer.fuse_qkv_projections()
529
+
530
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections
531
+ def unfuse_qkv_projections(self) -> None:
532
+ r"""Disable QKV projection fusion if enabled."""
533
+ if not self.fusing_transformer:
534
+ logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
535
+ else:
536
+ self.transformer.unfuse_qkv_projections()
537
+ self.fusing_transformer = False
538
+
539
+ # Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
540
+ def _prepare_rotary_positional_embeddings(
541
+ self,
542
+ height: int,
543
+ width: int,
544
+ num_frames: int,
545
+ device: torch.device,
546
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
547
+
548
+
549
+ grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
550
+ grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
551
+
552
+ p = self.transformer.config.patch_size
553
+ p_t = self.transformer.config.patch_size_t
554
+
555
+ base_size_width = self.transformer.config.sample_width // p
556
+ base_size_height = self.transformer.config.sample_height // p
557
+
558
+ if p_t is None: # HACK: Go this Branch
559
+ # CogVideoX 1.0
560
+ grid_crops_coords = get_resize_crop_region_for_grid(
561
+ (grid_height, grid_width), base_size_width, base_size_height
562
+ )
563
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
564
+ embed_dim=self.transformer.config.attention_head_dim,
565
+ crops_coords=grid_crops_coords, # ((0, 0), (30, 45))
566
+ grid_size=(grid_height, grid_width), # (30, 45)
567
+ temporal_size=num_frames,
568
+ device=device,
569
+ )
570
+ else:
571
+ # CogVideoX 1.5
572
+ base_num_frames = (num_frames + p_t - 1) // p_t
573
+
574
+ freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
575
+ embed_dim=self.transformer.config.attention_head_dim,
576
+ crops_coords=None,
577
+ grid_size=(grid_height, grid_width),
578
+ temporal_size=base_num_frames,
579
+ grid_type="slice",
580
+ max_size=(base_size_height, base_size_width),
581
+ device=device,
582
+ )
583
+
584
+ return freqs_cos, freqs_sin
585
+
586
+ @property
587
+ def guidance_scale(self):
588
+ return self._guidance_scale
589
+
590
+ @property
591
+ def num_timesteps(self):
592
+ return self._num_timesteps
593
+
594
+ @property
595
+ def attention_kwargs(self):
596
+ return self._attention_kwargs
597
+
598
+ @property
599
+ def interrupt(self):
600
+ return self._interrupt
601
+
602
+ @torch.no_grad()
603
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
604
+ def __call__(
605
+ self,
606
+ image: PipelineImageInput,
607
+ traj_tensor = None,
608
+ ID_tensor = None,
609
+ prompt: Optional[Union[str, List[str]]] = None,
610
+ negative_prompt: Optional[Union[str, List[str]]] = None,
611
+ height: Optional[int] = None,
612
+ width: Optional[int] = None,
613
+ num_frames: int = 49,
614
+ num_inference_steps: int = 50,
615
+ timesteps: Optional[List[int]] = None,
616
+ guidance_scale: float = 6,
617
+ use_dynamic_cfg: bool = False,
618
+ add_ID_reference_augment_noise: bool = True,
619
+ num_videos_per_prompt: int = 1,
620
+ eta: float = 0.0,
621
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
622
+ latents: Optional[torch.FloatTensor] = None,
623
+ prompt_embeds: Optional[torch.FloatTensor] = None,
624
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
625
+ output_type: str = "pil",
626
+ return_dict: bool = True,
627
+ attention_kwargs: Optional[Dict[str, Any]] = None,
628
+ callback_on_step_end: Optional[
629
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
630
+ ] = None,
631
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
632
+ max_sequence_length: int = 226,
633
+ ) -> Union[CogVideoXPipelineOutput, Tuple]:
634
+ """
635
+ Function invoked when calling the pipeline for generation.
636
+
637
+ Args:
638
+ image (`PipelineImageInput`):
639
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
640
+ prompt (`str` or `List[str]`, *optional*):
641
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
642
+ instead.
643
+ negative_prompt (`str` or `List[str]`, *optional*):
644
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
645
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
646
+ less than `1`).
647
+ height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
648
+ The height in pixels of the generated image. This is set to 480 by default for the best results.
649
+ width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
650
+ The width in pixels of the generated image. This is set to 720 by default for the best results.
651
+ num_frames (`int`, defaults to `48`):
652
+ Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
653
+ contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
654
+ num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
655
+ needs to be satisfied is that of divisibility mentioned above.
656
+ num_inference_steps (`int`, *optional*, defaults to 50):
657
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
658
+ expense of slower inference.
659
+ timesteps (`List[int]`, *optional*):
660
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
661
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
662
+ passed will be used. Must be in descending order.
663
+ guidance_scale (`float`, *optional*, defaults to 7.0):
664
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
665
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
666
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
667
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
668
+ usually at the expense of lower image quality.
669
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
670
+ The number of videos to generate per prompt.
671
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
672
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
673
+ to make generation deterministic.
674
+ latents (`torch.FloatTensor`, *optional*):
675
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
676
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
677
+ tensor will ge generated by sampling using the supplied random `generator`.
678
+ prompt_embeds (`torch.FloatTensor`, *optional*):
679
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
680
+ provided, text embeddings will be generated from `prompt` input argument.
681
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
682
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
683
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
684
+ argument.
685
+ output_type (`str`, *optional*, defaults to `"pil"`):
686
+ The output format of the generate image. Choose between
687
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
688
+ return_dict (`bool`, *optional*, defaults to `True`):
689
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
690
+ of a plain tuple.
691
+ attention_kwargs (`dict`, *optional*):
692
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
693
+ `self.processor` in
694
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
695
+ callback_on_step_end (`Callable`, *optional*):
696
+ A function that calls at the end of each denoising steps during the inference. The function is called
697
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
698
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
699
+ `callback_on_step_end_tensor_inputs`.
700
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
701
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
702
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
703
+ `._callback_tensor_inputs` attribute of your pipeline class.
704
+ max_sequence_length (`int`, defaults to `226`):
705
+ Maximum sequence length in encoded prompt. Must be consistent with
706
+ `self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
707
+
708
+ Examples:
709
+
710
+ Returns:
711
+ [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`:
712
+ [`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
713
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
714
+ """
715
+
716
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
717
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
718
+
719
+ height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
720
+ width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
721
+ num_frames = num_frames or self.transformer.config.sample_frames
722
+
723
+ num_videos_per_prompt = 1
724
+
725
+ # 1. Check inputs. Raise error if not correct
726
+ self.check_inputs(
727
+ image=image,
728
+ prompt=prompt,
729
+ height=height,
730
+ width=width,
731
+ negative_prompt=negative_prompt,
732
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
733
+ latents=latents,
734
+ prompt_embeds=prompt_embeds,
735
+ negative_prompt_embeds=negative_prompt_embeds,
736
+ )
737
+ self._guidance_scale = guidance_scale
738
+ self._attention_kwargs = attention_kwargs
739
+ self._interrupt = False
740
+
741
+ # 2. Default call parameters
742
+ if prompt is not None and isinstance(prompt, str):
743
+ batch_size = 1
744
+ elif prompt is not None and isinstance(prompt, list):
745
+ batch_size = len(prompt)
746
+ else:
747
+ batch_size = prompt_embeds.shape[0]
748
+
749
+ device = self._execution_device
750
+
751
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
752
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
753
+ # corresponds to doing no classifier free guidance.
754
+ do_classifier_free_guidance = guidance_scale > 1.0
755
+
756
+ # 3. Encode input prompt
757
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
758
+ prompt=prompt,
759
+ negative_prompt=negative_prompt,
760
+ do_classifier_free_guidance=do_classifier_free_guidance,
761
+ num_videos_per_prompt=num_videos_per_prompt,
762
+ prompt_embeds=prompt_embeds,
763
+ negative_prompt_embeds=negative_prompt_embeds,
764
+ max_sequence_length=max_sequence_length,
765
+ device=device,
766
+ )
767
+ if do_classifier_free_guidance:
768
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
769
+
770
+ # 4. Prepare timesteps
771
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
772
+ self._num_timesteps = len(timesteps)
773
+
774
+ # 5. Prepare latents
775
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
776
+
777
+ # For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
778
+ patch_size_t = self.transformer.config.patch_size_t
779
+ additional_frames = 0
780
+ if patch_size_t is not None and num_latent_frames % patch_size_t != 0:
781
+ additional_frames = patch_size_t - num_latent_frames % patch_size_t
782
+ num_frames += additional_frames * self.vae_scale_factor_temporal
783
+
784
+ image = self.video_processor.preprocess(image, height=height, width=width).to(
785
+ device, dtype=prompt_embeds.dtype
786
+ )
787
+
788
+ latent_channels = 16 # self.transformer.config.in_channels // 2
789
+ latents, image_latents = self.prepare_latents(
790
+ image,
791
+ batch_size * num_videos_per_prompt,
792
+ latent_channels,
793
+ num_frames,
794
+ height,
795
+ width,
796
+ prompt_embeds.dtype,
797
+ device,
798
+ generator,
799
+ latents,
800
+ )
801
+
802
+
803
+ # 5.1. Traj Preprocess
804
+ traj_tensor = traj_tensor.to(device, dtype = self.vae.dtype)[None] #.unsqueeze(0)
805
+ traj_tensor = traj_tensor.permute(0, 2, 1, 3, 4)
806
+ traj_latents = self.vae.encode(traj_tensor).latent_dist
807
+
808
+ # Scale, Permute, and other conversion
809
+ traj_latents = traj_latents.sample() * self.vae.config.scaling_factor
810
+ traj_latents = traj_latents.permute(0, 2, 1, 3, 4)
811
+ traj_latents = traj_latents.to(memory_format = torch.contiguous_format).float().to(dtype = prompt_embeds.dtype) # [B, F, C, H, W]
812
+
813
+
814
+ # 5.2. ID Reference Preprocess
815
+ if ID_tensor is not None:
816
+ from train_code.train_cogvideox_motion_FrameINO import img_tensor_to_vae_latent # Put it here to avoid circular import
817
+
818
+ # TODO: test中要不要加Augment Noise再验证一下
819
+ ID_latent = img_tensor_to_vae_latent(ID_tensor.unsqueeze(0), self.vae, traj_latents.device, add_augment_noise = add_ID_reference_augment_noise)
820
+ ID_latent = ID_latent.unsqueeze(1).to(dtype = prompt_embeds.dtype)
821
+
822
+
823
+
824
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
825
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
826
+
827
+ # 7. Create rotary embeds if required
828
+ image_rotary_emb = (
829
+ self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
830
+ if self.transformer.config.use_rotary_positional_embeddings
831
+ else None
832
+ )
833
+
834
+ # Copy the 14th frame with the first frame PE information
835
+ freqs_cos, freqs_sin = image_rotary_emb
836
+ first_frame_token_num = freqs_cos.shape[0] // num_latent_frames
837
+ freqs_cos = torch.cat([freqs_cos, freqs_cos[:first_frame_token_num]], dim=0) # Hard Code
838
+ freqs_sin = torch.cat([freqs_sin, freqs_sin[:first_frame_token_num]], dim=0)
839
+ image_rotary_emb = (freqs_cos, freqs_sin)
840
+
841
+
842
+ # 8. Create ofs embeds if required
843
+ ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0)
844
+
845
+ # 8. Denoising loop
846
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
847
+
848
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
849
+ # for DPM-solver++
850
+ old_pred_original_sample = None
851
+ for i, t in enumerate(timesteps):
852
+ if self.interrupt:
853
+ continue
854
+
855
+ # Noisy latents prepare
856
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
857
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
858
+
859
+ # First Frame latents prepare
860
+ latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
861
+
862
+ # Traj latents prepare
863
+ latent_traj = torch.cat([traj_latents] * 2) if do_classifier_free_guidance else traj_latents
864
+
865
+ # ID Refence prepare
866
+ if ID_tensor is not None:
867
+
868
+ # CFG Double Batch Size
869
+ latent_ID = torch.cat([ID_latent] * 2) if do_classifier_free_guidance else ID_latent
870
+
871
+ # Frame-Wise Token Increase
872
+ latent_model_input = torch.cat([latent_model_input, latent_ID], dim = 1)
873
+
874
+ # Increase the frame dimension of the Traj latents and the first frame latent
875
+ latent_ID_padding = latent_model_input.new_zeros(latent_ID.shape) # Zero latent values
876
+ latent_image_input = torch.cat([latent_image_input, latent_ID_padding], dim=1)
877
+ latent_traj = torch.cat([latent_traj, latent_ID_padding], dim=1)
878
+
879
+
880
+ # Dimension-Wise Concatenation
881
+ latent_model_input = torch.cat([latent_model_input, latent_image_input, latent_traj], dim=2) # The thrid dim grow from 16 to 32
882
+
883
+
884
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
885
+ timestep = t.expand(latent_model_input.shape[0])
886
+
887
+ # predict noise model_output
888
+ noise_pred = self.transformer(
889
+ hidden_states=latent_model_input,
890
+ encoder_hidden_states=prompt_embeds,
891
+ timestep=timestep,
892
+ ofs=ofs_emb,
893
+ image_rotary_emb=image_rotary_emb,
894
+ attention_kwargs=attention_kwargs,
895
+ return_dict=False,
896
+ )[0]
897
+ noise_pred = noise_pred.float()
898
+
899
+
900
+ # Discard the Extra ID tokens in the Noise Prediction
901
+ if ID_tensor is not None:
902
+ noise_pred = noise_pred[:, :num_latent_frames, :, :, :]
903
+
904
+
905
+ # perform guidance
906
+ if use_dynamic_cfg:
907
+ self._guidance_scale = 1 + guidance_scale * (
908
+ (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
909
+ )
910
+ if do_classifier_free_guidance:
911
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
912
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
913
+
914
+ # compute the previous noisy sample x_t -> x_t-1
915
+ if not isinstance(self.scheduler, CogVideoXDPMScheduler):
916
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
917
+ else:
918
+ latents, old_pred_original_sample = self.scheduler.step(
919
+ noise_pred,
920
+ old_pred_original_sample,
921
+ t,
922
+ timesteps[i - 1] if i > 0 else None,
923
+ latents,
924
+ **extra_step_kwargs,
925
+ return_dict=False,
926
+ )
927
+ latents = latents.to(prompt_embeds.dtype)
928
+
929
+ # call the callback, if provided
930
+ if callback_on_step_end is not None:
931
+ callback_kwargs = {}
932
+ for k in callback_on_step_end_tensor_inputs:
933
+ callback_kwargs[k] = locals()[k]
934
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
935
+
936
+ latents = callback_outputs.pop("latents", latents)
937
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
938
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
939
+
940
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
941
+ progress_bar.update()
942
+
943
+ if XLA_AVAILABLE:
944
+ xm.mark_step()
945
+
946
+ if not output_type == "latent":
947
+ # Discard any padding frames that were added for CogVideoX 1.5
948
+ latents = latents[:, additional_frames:]
949
+ video = self.decode_latents(latents)
950
+ video = self.video_processor.postprocess_video(video=video, output_type=output_type)
951
+ else:
952
+ video = latents
953
+
954
+ # Offload all models
955
+ self.maybe_free_model_hooks()
956
+
957
+ if not return_dict:
958
+ return (video,)
959
+
960
+ return CogVideoXPipelineOutput(frames=video)
pipelines/pipeline_wan_i2v_motion.py ADDED
@@ -0,0 +1,861 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import html
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import os, sys, shutil
19
+ import PIL
20
+ import regex as re
21
+ import torch
22
+ from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
23
+
24
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
25
+ from diffusers.image_processor import PipelineImageInput
26
+ from diffusers.loaders import WanLoraLoaderMixin
27
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
28
+ from diffusers.utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+ from diffusers.video_processor import VideoProcessor
31
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
32
+ from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
33
+
34
+
35
+ # Import files from the local folder
36
+ root_path = os.path.abspath('.')
37
+ sys.path.append(root_path)
38
+ from architecture.transformer_wan import WanTransformer3DModel
39
+ from architecture.autoencoder_kl_wan import AutoencoderKLWan
40
+
41
+
42
+ if is_torch_xla_available():
43
+ import torch_xla.core.xla_model as xm
44
+
45
+ XLA_AVAILABLE = True
46
+ else:
47
+ XLA_AVAILABLE = False
48
+
49
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
50
+
51
+ if is_ftfy_available():
52
+ import ftfy
53
+
54
+ EXAMPLE_DOC_STRING = """
55
+ Examples:
56
+ ```python
57
+ >>> import torch
58
+ >>> import numpy as np
59
+ >>> from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
60
+ >>> from diffusers.utils import export_to_video, load_image
61
+ >>> from transformers import CLIPVisionModel
62
+
63
+ >>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
64
+ >>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
65
+ >>> image_encoder = CLIPVisionModel.from_pretrained(
66
+ ... model_id, subfolder="image_encoder", torch_dtype=torch.float32
67
+ ... )
68
+ >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
69
+ >>> pipe = WanImageToVideoPipeline.from_pretrained(
70
+ ... model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
71
+ ... )
72
+ >>> pipe.to("cuda")
73
+
74
+ >>> image = load_image(
75
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
76
+ ... )
77
+ >>> max_area = 480 * 832
78
+ >>> aspect_ratio = image.height / image.width
79
+ >>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
80
+ >>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
81
+ >>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
82
+ >>> image = image.resize((width, height))
83
+ >>> prompt = (
84
+ ... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
85
+ ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
86
+ ... )
87
+ >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
88
+
89
+ >>> output = pipe(
90
+ ... image=image,
91
+ ... prompt=prompt,
92
+ ... negative_prompt=negative_prompt,
93
+ ... height=height,
94
+ ... width=width,
95
+ ... num_frames=81,
96
+ ... guidance_scale=5.0,
97
+ ... ).frames[0]
98
+ >>> export_to_video(output, "output.mp4", fps=16)
99
+ ```
100
+ """
101
+
102
+
103
+ def basic_clean(text):
104
+ text = ftfy.fix_text(text)
105
+ text = html.unescape(html.unescape(text))
106
+ return text.strip()
107
+
108
+
109
+ def whitespace_clean(text):
110
+ text = re.sub(r"\s+", " ", text)
111
+ text = text.strip()
112
+ return text
113
+
114
+
115
+ def prompt_clean(text):
116
+ text = whitespace_clean(basic_clean(text))
117
+ return text
118
+
119
+
120
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
121
+ def retrieve_latents(
122
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
123
+ ):
124
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
125
+ return encoder_output.latent_dist.sample(generator)
126
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
127
+ return encoder_output.latent_dist.mode()
128
+ elif hasattr(encoder_output, "latents"):
129
+ return encoder_output.latents
130
+ else:
131
+ raise AttributeError("Could not access latents of provided encoder_output")
132
+
133
+
134
+ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
135
+ r"""
136
+ Pipeline for image-to-video generation using Wan.
137
+
138
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
139
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
140
+
141
+ Args:
142
+ tokenizer ([`T5Tokenizer`]):
143
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
144
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
145
+ text_encoder ([`T5EncoderModel`]):
146
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
147
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
148
+ image_encoder ([`CLIPVisionModel`]):
149
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically
150
+ the
151
+ [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large)
152
+ variant.
153
+ transformer ([`WanTransformer3DModel`]):
154
+ Conditional Transformer to denoise the input latents.
155
+ scheduler ([`UniPCMultistepScheduler`]):
156
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
157
+ vae ([`AutoencoderKLWan`]):
158
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
159
+ transformer_2 ([`WanTransformer3DModel`], *optional*):
160
+ Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
161
+ `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
162
+ `transformer` is used.
163
+ boundary_ratio (`float`, *optional*, defaults to `None`):
164
+ Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
165
+ The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
166
+ `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
167
+ boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
168
+ """
169
+
170
+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae"
171
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
172
+ _optional_components = ["transformer", "transformer_2", "image_encoder", "image_processor"]
173
+
174
+ def __init__(
175
+ self,
176
+ tokenizer: AutoTokenizer,
177
+ text_encoder: UMT5EncoderModel,
178
+ vae: AutoencoderKLWan,
179
+ scheduler: FlowMatchEulerDiscreteScheduler,
180
+ image_processor: CLIPImageProcessor = None,
181
+ image_encoder: CLIPVisionModel = None,
182
+ transformer: WanTransformer3DModel = None,
183
+ transformer_2: WanTransformer3DModel = None,
184
+ boundary_ratio: Optional[float] = None,
185
+ expand_timesteps: bool = False,
186
+ ):
187
+ super().__init__()
188
+
189
+ self.register_modules(
190
+ vae=vae,
191
+ text_encoder=text_encoder,
192
+ tokenizer=tokenizer,
193
+ image_encoder=image_encoder,
194
+ transformer=transformer,
195
+ scheduler=scheduler,
196
+ image_processor=image_processor,
197
+ transformer_2=transformer_2,
198
+ )
199
+ self.register_to_config(boundary_ratio=boundary_ratio, expand_timesteps=expand_timesteps)
200
+
201
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
202
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
203
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
204
+ self.image_processor = image_processor
205
+
206
+ def _get_t5_prompt_embeds(
207
+ self,
208
+ prompt: Union[str, List[str]] = None,
209
+ num_videos_per_prompt: int = 1,
210
+ max_sequence_length: int = 512,
211
+ device: Optional[torch.device] = None,
212
+ dtype: Optional[torch.dtype] = None,
213
+ ):
214
+ device = device or self._execution_device
215
+ dtype = dtype or self.text_encoder.dtype
216
+
217
+ prompt = [prompt] if isinstance(prompt, str) else prompt
218
+ prompt = [prompt_clean(u) for u in prompt]
219
+ batch_size = len(prompt)
220
+
221
+ text_inputs = self.tokenizer(
222
+ prompt,
223
+ padding="max_length",
224
+ max_length=max_sequence_length,
225
+ truncation=True,
226
+ add_special_tokens=True,
227
+ return_attention_mask=True,
228
+ return_tensors="pt",
229
+ )
230
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
231
+ seq_lens = mask.gt(0).sum(dim=1).long()
232
+
233
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
234
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
235
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
236
+ prompt_embeds = torch.stack(
237
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
238
+ )
239
+
240
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
241
+ _, seq_len, _ = prompt_embeds.shape
242
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
243
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
244
+
245
+ return prompt_embeds
246
+
247
+ def encode_image(
248
+ self,
249
+ image: PipelineImageInput,
250
+ device: Optional[torch.device] = None,
251
+ ):
252
+ device = device or self._execution_device
253
+ image = self.image_processor(images=image, return_tensors="pt").to(device)
254
+ image_embeds = self.image_encoder(**image, output_hidden_states=True)
255
+ return image_embeds.hidden_states[-2]
256
+
257
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
258
+ def encode_prompt(
259
+ self,
260
+ prompt: Union[str, List[str]],
261
+ negative_prompt: Optional[Union[str, List[str]]] = None,
262
+ do_classifier_free_guidance: bool = True,
263
+ num_videos_per_prompt: int = 1,
264
+ prompt_embeds: Optional[torch.Tensor] = None,
265
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
266
+ max_sequence_length: int = 226,
267
+ device: Optional[torch.device] = None,
268
+ dtype: Optional[torch.dtype] = None,
269
+ ):
270
+ r"""
271
+ Encodes the prompt into text encoder hidden states.
272
+
273
+ Args:
274
+ prompt (`str` or `List[str]`, *optional*):
275
+ prompt to be encoded
276
+ negative_prompt (`str` or `List[str]`, *optional*):
277
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
278
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
279
+ less than `1`).
280
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
281
+ Whether to use classifier free guidance or not.
282
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
283
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
284
+ prompt_embeds (`torch.Tensor`, *optional*):
285
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
286
+ provided, text embeddings will be generated from `prompt` input argument.
287
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
288
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
289
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
290
+ argument.
291
+ device: (`torch.device`, *optional*):
292
+ torch device
293
+ dtype: (`torch.dtype`, *optional*):
294
+ torch dtype
295
+ """
296
+ device = device or self._execution_device
297
+
298
+ prompt = [prompt] if isinstance(prompt, str) else prompt
299
+ if prompt is not None:
300
+ batch_size = len(prompt)
301
+ else:
302
+ batch_size = prompt_embeds.shape[0]
303
+
304
+ if prompt_embeds is None:
305
+ prompt_embeds = self._get_t5_prompt_embeds(
306
+ prompt=prompt,
307
+ num_videos_per_prompt=num_videos_per_prompt,
308
+ max_sequence_length=max_sequence_length,
309
+ device=device,
310
+ dtype=dtype,
311
+ )
312
+
313
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
314
+ negative_prompt = negative_prompt or ""
315
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
316
+
317
+ if prompt is not None and type(prompt) is not type(negative_prompt):
318
+ raise TypeError(
319
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
320
+ f" {type(prompt)}."
321
+ )
322
+ elif batch_size != len(negative_prompt):
323
+ raise ValueError(
324
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
325
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
326
+ " the batch size of `prompt`."
327
+ )
328
+
329
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
330
+ prompt=negative_prompt,
331
+ num_videos_per_prompt=num_videos_per_prompt,
332
+ max_sequence_length=max_sequence_length,
333
+ device=device,
334
+ dtype=dtype,
335
+ )
336
+
337
+ return prompt_embeds, negative_prompt_embeds
338
+
339
+ def check_inputs(
340
+ self,
341
+ prompt,
342
+ negative_prompt,
343
+ image,
344
+ height,
345
+ width,
346
+ prompt_embeds=None,
347
+ negative_prompt_embeds=None,
348
+ image_embeds=None,
349
+ callback_on_step_end_tensor_inputs=None,
350
+ guidance_scale_2=None,
351
+ ):
352
+ if image is not None and image_embeds is not None:
353
+ raise ValueError(
354
+ f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
355
+ " only forward one of the two."
356
+ )
357
+ if image is None and image_embeds is None:
358
+ raise ValueError(
359
+ "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
360
+ )
361
+ if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
362
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
363
+ if height % 16 != 0 or width % 16 != 0:
364
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
365
+
366
+ if callback_on_step_end_tensor_inputs is not None and not all(
367
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
368
+ ):
369
+ raise ValueError(
370
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
371
+ )
372
+
373
+ if prompt is not None and prompt_embeds is not None:
374
+ raise ValueError(
375
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
376
+ " only forward one of the two."
377
+ )
378
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
379
+ raise ValueError(
380
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
381
+ " only forward one of the two."
382
+ )
383
+ elif prompt is None and prompt_embeds is None:
384
+ raise ValueError(
385
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
386
+ )
387
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
388
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
389
+ elif negative_prompt is not None and (
390
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
391
+ ):
392
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
393
+
394
+ if self.config.boundary_ratio is None and guidance_scale_2 is not None:
395
+ raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
396
+
397
+ if self.config.boundary_ratio is not None and image_embeds is not None:
398
+ raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.")
399
+
400
+ def prepare_latents(
401
+ self,
402
+ image: PipelineImageInput,
403
+ traj_tensor,
404
+ batch_size: int,
405
+ num_channels_latents: int = 16,
406
+ height: int = 480,
407
+ width: int = 832,
408
+ num_frames: int = 81,
409
+ dtype: Optional[torch.dtype] = None,
410
+ device: Optional[torch.device] = None,
411
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
412
+ latents: Optional[torch.Tensor] = None,
413
+ last_image: Optional[torch.Tensor] = None,
414
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
415
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
416
+ latent_height = height // self.vae_scale_factor_spatial
417
+ latent_width = width // self.vae_scale_factor_spatial
418
+
419
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
420
+ if isinstance(generator, list) and len(generator) != batch_size:
421
+ raise ValueError(
422
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
423
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
424
+ )
425
+
426
+ if latents is None:
427
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
428
+ else:
429
+ latents = latents.to(device=device, dtype=dtype)
430
+
431
+ image = image.unsqueeze(2) # [batch_size, channels, 1, height, width]
432
+
433
+ if self.config.expand_timesteps:
434
+ video_condition = image
435
+
436
+ elif last_image is None:
437
+ video_condition = torch.cat(
438
+ [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
439
+ )
440
+ else:
441
+ last_image = last_image.unsqueeze(2)
442
+ video_condition = torch.cat(
443
+ [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
444
+ dim=2,
445
+ )
446
+ video_condition = video_condition.to(device=device, dtype=self.vae.dtype)
447
+
448
+ latents_mean = (
449
+ torch.tensor(self.vae.config.latents_mean)
450
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
451
+ .to(latents.device, latents.dtype)
452
+ )
453
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
454
+ latents.device, latents.dtype
455
+ )
456
+
457
+ if isinstance(generator, list):
458
+ latent_condition = [
459
+ retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
460
+ ]
461
+ latent_condition = torch.cat(latent_condition)
462
+ else:
463
+ latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
464
+ latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
465
+
466
+ latent_condition = latent_condition.to(dtype)
467
+ latent_condition = (latent_condition - latents_mean) * latents_std
468
+
469
+
470
+
471
+ # Prepare the traj latent
472
+ traj_tensor = traj_tensor.to(device, dtype=self.vae.dtype) #.unsqueeze(0)
473
+ traj_tensor = traj_tensor.unsqueeze(0)
474
+ traj_tensor = traj_tensor.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
475
+
476
+ # VAE encode
477
+ traj_latents = retrieve_latents(self.vae.encode(traj_tensor), sample_mode="argmax")
478
+
479
+ # Extract Mean and Variance
480
+ traj_latents = (traj_latents - latents_mean) * latents_std
481
+
482
+ # Final Convert
483
+ traj_latents = traj_latents.to(memory_format = torch.contiguous_format).float()
484
+
485
+
486
+
487
+ if self.config.expand_timesteps:
488
+ first_frame_mask = torch.ones(
489
+ 1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device
490
+ )
491
+ first_frame_mask[:, :, 0] = 0
492
+ return latents, latent_condition, traj_latents, first_frame_mask
493
+
494
+
495
+
496
+ mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
497
+
498
+ if last_image is None:
499
+ mask_lat_size[:, :, list(range(1, num_frames))] = 0
500
+ else:
501
+ mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
502
+ first_frame_mask = mask_lat_size[:, :, 0:1]
503
+ first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
504
+ mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
505
+ mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width)
506
+ mask_lat_size = mask_lat_size.transpose(1, 2)
507
+ mask_lat_size = mask_lat_size.to(latent_condition.device)
508
+
509
+ return latents, torch.concat([mask_lat_size, latent_condition], dim=1)
510
+
511
+ @property
512
+ def guidance_scale(self):
513
+ return self._guidance_scale
514
+
515
+ @property
516
+ def do_classifier_free_guidance(self):
517
+ return self._guidance_scale > 1
518
+
519
+ @property
520
+ def num_timesteps(self):
521
+ return self._num_timesteps
522
+
523
+ @property
524
+ def current_timestep(self):
525
+ return self._current_timestep
526
+
527
+ @property
528
+ def interrupt(self):
529
+ return self._interrupt
530
+
531
+ @property
532
+ def attention_kwargs(self):
533
+ return self._attention_kwargs
534
+
535
+ @torch.no_grad()
536
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
537
+ def __call__(
538
+ self,
539
+ image: PipelineImageInput,
540
+ prompt: Union[str, List[str]] = None,
541
+ negative_prompt: Union[str, List[str]] = None,
542
+ traj_tensor = None,
543
+ height: int = 480,
544
+ width: int = 832,
545
+ num_frames: int = 81,
546
+ num_inference_steps: int = 50,
547
+ guidance_scale: float = 5.0,
548
+ guidance_scale_2: Optional[float] = None,
549
+ num_videos_per_prompt: Optional[int] = 1,
550
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
551
+ latents: Optional[torch.Tensor] = None,
552
+ prompt_embeds: Optional[torch.Tensor] = None,
553
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
554
+ image_embeds: Optional[torch.Tensor] = None,
555
+ last_image: Optional[torch.Tensor] = None,
556
+ output_type: Optional[str] = "np",
557
+ return_dict: bool = True,
558
+ attention_kwargs: Optional[Dict[str, Any]] = None,
559
+ callback_on_step_end: Optional[
560
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
561
+ ] = None,
562
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
563
+ max_sequence_length: int = 512,
564
+ ):
565
+ r"""
566
+ The call function to the pipeline for generation.
567
+
568
+ Args:
569
+ image (`PipelineImageInput`):
570
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
571
+ prompt (`str` or `List[str]`, *optional*):
572
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
573
+ instead.
574
+ negative_prompt (`str` or `List[str]`, *optional*):
575
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
576
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
577
+ less than `1`).
578
+ height (`int`, defaults to `480`):
579
+ The height of the generated video.
580
+ width (`int`, defaults to `832`):
581
+ The width of the generated video.
582
+ num_frames (`int`, defaults to `81`):
583
+ The number of frames in the generated video.
584
+ num_inference_steps (`int`, defaults to `50`):
585
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
586
+ expense of slower inference.
587
+ guidance_scale (`float`, defaults to `5.0`):
588
+ Guidance scale as defined in [Classifier-Free Diffusion
589
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
590
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
591
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
592
+ the text `prompt`, usually at the expense of lower image quality.
593
+ guidance_scale_2 (`float`, *optional*, defaults to `None`):
594
+ Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
595
+ `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
596
+ and the pipeline's `boundary_ratio` are not None.
597
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
598
+ The number of images to generate per prompt.
599
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
600
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
601
+ generation deterministic.
602
+ latents (`torch.Tensor`, *optional*):
603
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
604
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
605
+ tensor is generated by sampling using the supplied random `generator`.
606
+ prompt_embeds (`torch.Tensor`, *optional*):
607
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
608
+ provided, text embeddings are generated from the `prompt` input argument.
609
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
610
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
611
+ provided, text embeddings are generated from the `negative_prompt` input argument.
612
+ image_embeds (`torch.Tensor`, *optional*):
613
+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
614
+ image embeddings are generated from the `image` input argument.
615
+ output_type (`str`, *optional*, defaults to `"np"`):
616
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
617
+ return_dict (`bool`, *optional*, defaults to `True`):
618
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
619
+ attention_kwargs (`dict`, *optional*):
620
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
621
+ `self.processor` in
622
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
623
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
624
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
625
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
626
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
627
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
628
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
629
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
630
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
631
+ `._callback_tensor_inputs` attribute of your pipeline class.
632
+ max_sequence_length (`int`, defaults to `512`):
633
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
634
+ truncated. If the prompt is shorter, it will be padded to this length.
635
+
636
+ Examples:
637
+
638
+ Returns:
639
+ [`~WanPipelineOutput`] or `tuple`:
640
+ If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
641
+ the first element is a list with the generated images and the second element is a list of `bool`s
642
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
643
+ """
644
+
645
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
646
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
647
+
648
+ # 1. Check inputs. Raise error if not correct
649
+ self.check_inputs(
650
+ prompt,
651
+ negative_prompt,
652
+ image,
653
+ height,
654
+ width,
655
+ prompt_embeds,
656
+ negative_prompt_embeds,
657
+ image_embeds,
658
+ callback_on_step_end_tensor_inputs,
659
+ guidance_scale_2,
660
+ )
661
+
662
+ if num_frames % self.vae_scale_factor_temporal != 1:
663
+ logger.warning(
664
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
665
+ )
666
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
667
+ num_frames = max(num_frames, 1)
668
+
669
+ if self.config.boundary_ratio is not None and guidance_scale_2 is None:
670
+ guidance_scale_2 = guidance_scale
671
+
672
+ self._guidance_scale = guidance_scale
673
+ self._guidance_scale_2 = guidance_scale_2
674
+ self._attention_kwargs = attention_kwargs
675
+ self._current_timestep = None
676
+ self._interrupt = False
677
+
678
+ device = self._execution_device
679
+
680
+ # 2. Define call parameters
681
+ if prompt is not None and isinstance(prompt, str):
682
+ batch_size = 1
683
+ elif prompt is not None and isinstance(prompt, list):
684
+ batch_size = len(prompt)
685
+ else:
686
+ batch_size = prompt_embeds.shape[0]
687
+
688
+ # 3. Encode input prompt
689
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
690
+ prompt=prompt,
691
+ negative_prompt=negative_prompt,
692
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
693
+ num_videos_per_prompt=num_videos_per_prompt,
694
+ prompt_embeds=prompt_embeds,
695
+ negative_prompt_embeds=negative_prompt_embeds,
696
+ max_sequence_length=max_sequence_length,
697
+ device=device,
698
+ )
699
+
700
+ # Encode image embedding
701
+ transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
702
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
703
+ if negative_prompt_embeds is not None:
704
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
705
+
706
+ # only wan 2.1 i2v transformer accepts image_embeds
707
+ if self.transformer is not None and self.transformer.config.image_dim is not None:
708
+ if image_embeds is None:
709
+ if last_image is None:
710
+ image_embeds = self.encode_image(image, device)
711
+ else:
712
+ image_embeds = self.encode_image([image, last_image], device)
713
+ image_embeds = image_embeds.repeat(batch_size, 1, 1)
714
+ image_embeds = image_embeds.to(transformer_dtype)
715
+
716
+ # 4. Prepare timesteps
717
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
718
+ timesteps = self.scheduler.timesteps
719
+
720
+ # 5. Prepare latent variables
721
+ num_channels_latents = self.vae.config.z_dim
722
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
723
+ if last_image is not None:
724
+ last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
725
+ device, dtype=torch.float32
726
+ )
727
+
728
+ latents_outputs = self.prepare_latents(
729
+ image,
730
+ traj_tensor,
731
+ batch_size * num_videos_per_prompt,
732
+ num_channels_latents,
733
+ height,
734
+ width,
735
+ num_frames,
736
+ torch.float32,
737
+ device,
738
+ generator,
739
+ latents,
740
+ last_image,
741
+ )
742
+ if self.config.expand_timesteps:
743
+ # wan 2.2 5b i2v use firt_frame_mask to mask timesteps
744
+ latents, condition, traj_latents, first_frame_mask = latents_outputs
745
+ else:
746
+ latents, condition = latents_outputs
747
+
748
+
749
+ # 6. Denoising loop
750
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
751
+ self._num_timesteps = len(timesteps)
752
+
753
+ if self.config.boundary_ratio is not None:
754
+ boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
755
+ else:
756
+ boundary_timestep = None
757
+
758
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
759
+ for i, t in enumerate(timesteps):
760
+ if self.interrupt:
761
+ continue
762
+
763
+ self._current_timestep = t
764
+
765
+ if boundary_timestep is None or t >= boundary_timestep:
766
+ # wan2.1 or high-noise stage in wan2.2
767
+ current_model = self.transformer
768
+ current_guidance_scale = guidance_scale
769
+ else:
770
+ # low-noise stage in wan2.2
771
+ current_model = self.transformer_2
772
+ current_guidance_scale = guidance_scale_2
773
+
774
+ if self.config.expand_timesteps:
775
+ latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents
776
+ latent_model_input = latent_model_input.to(transformer_dtype)
777
+
778
+ # seq_len: num_latent_frames * (latent_height // patch_size) * (latent_width // patch_size)
779
+ temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
780
+ # batch_size, seq_len
781
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
782
+ else:
783
+ latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
784
+ timestep = t.expand(latents.shape[0])
785
+
786
+
787
+ # Concat the traj latents in channel dimension
788
+ latent_model_input = torch.cat([latent_model_input, traj_latents], dim=1).to(transformer_dtype)
789
+
790
+
791
+ # Predict the noise according to the timestep
792
+ with current_model.cache_context("cond"):
793
+ noise_pred = current_model(
794
+ hidden_states=latent_model_input,
795
+ timestep=timestep,
796
+ encoder_hidden_states=prompt_embeds,
797
+ encoder_hidden_states_image=image_embeds,
798
+ attention_kwargs=attention_kwargs,
799
+ return_dict=False,
800
+ )[0]
801
+
802
+ if self.do_classifier_free_guidance:
803
+ with current_model.cache_context("uncond"):
804
+ noise_uncond = current_model(
805
+ hidden_states=latent_model_input,
806
+ timestep=timestep,
807
+ encoder_hidden_states=negative_prompt_embeds,
808
+ encoder_hidden_states_image=image_embeds,
809
+ attention_kwargs=attention_kwargs,
810
+ return_dict=False,
811
+ )[0]
812
+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
813
+
814
+ # compute the previous noisy sample x_t -> x_t-1
815
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
816
+
817
+ if callback_on_step_end is not None:
818
+ callback_kwargs = {}
819
+ for k in callback_on_step_end_tensor_inputs:
820
+ callback_kwargs[k] = locals()[k]
821
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
822
+
823
+ latents = callback_outputs.pop("latents", latents)
824
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
825
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
826
+
827
+ # call the callback, if provided
828
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
829
+ progress_bar.update()
830
+
831
+ if XLA_AVAILABLE:
832
+ xm.mark_step()
833
+
834
+ self._current_timestep = None
835
+
836
+ if self.config.expand_timesteps:
837
+ latents = (1 - first_frame_mask) * condition + first_frame_mask * latents
838
+
839
+ if not output_type == "latent":
840
+ latents = latents.to(self.vae.dtype)
841
+ latents_mean = (
842
+ torch.tensor(self.vae.config.latents_mean)
843
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
844
+ .to(latents.device, latents.dtype)
845
+ )
846
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
847
+ latents.device, latents.dtype
848
+ )
849
+ latents = latents / latents_std + latents_mean
850
+ video = self.vae.decode(latents, return_dict=False)[0]
851
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
852
+ else:
853
+ video = latents
854
+
855
+ # Offload all models
856
+ self.maybe_free_model_hooks()
857
+
858
+ if not return_dict:
859
+ return (video,)
860
+
861
+ return WanPipelineOutput(frames=video)
pipelines/pipeline_wan_i2v_motion_FrameINO.py ADDED
@@ -0,0 +1,937 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import html
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import os, sys, shutil
19
+ import PIL
20
+ import regex as re
21
+ import torch
22
+ from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
23
+
24
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
25
+ from diffusers.image_processor import PipelineImageInput
26
+ from diffusers.loaders import WanLoraLoaderMixin
27
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
28
+ from diffusers.utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
29
+ from diffusers.utils.torch_utils import randn_tensor
30
+ from diffusers.video_processor import VideoProcessor
31
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
32
+ from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
33
+
34
+
35
+ # Import files from the local folder
36
+ root_path = os.path.abspath('.')
37
+ sys.path.append(root_path)
38
+ from architecture.transformer_wan import WanTransformer3DModel
39
+ from architecture.autoencoder_kl_wan import AutoencoderKLWan
40
+
41
+
42
+ if is_torch_xla_available():
43
+ import torch_xla.core.xla_model as xm
44
+
45
+ XLA_AVAILABLE = True
46
+ else:
47
+ XLA_AVAILABLE = False
48
+
49
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
50
+
51
+ if is_ftfy_available():
52
+ import ftfy
53
+
54
+ EXAMPLE_DOC_STRING = """
55
+ Examples:
56
+ ```python
57
+ >>> import torch
58
+ >>> import numpy as np
59
+ >>> from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
60
+ >>> from diffusers.utils import export_to_video, load_image
61
+ >>> from transformers import CLIPVisionModel
62
+
63
+ >>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
64
+ >>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
65
+ >>> image_encoder = CLIPVisionModel.from_pretrained(
66
+ ... model_id, subfolder="image_encoder", torch_dtype=torch.float32
67
+ ... )
68
+ >>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
69
+ >>> pipe = WanImageToVideoPipeline.from_pretrained(
70
+ ... model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
71
+ ... )
72
+ >>> pipe.to("cuda")
73
+
74
+ >>> image = load_image(
75
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
76
+ ... )
77
+ >>> max_area = 480 * 832
78
+ >>> aspect_ratio = image.height / image.width
79
+ >>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
80
+ >>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
81
+ >>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
82
+ >>> image = image.resize((width, height))
83
+ >>> prompt = (
84
+ ... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
85
+ ... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
86
+ ... )
87
+ >>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
88
+
89
+ >>> output = pipe(
90
+ ... image=image,
91
+ ... prompt=prompt,
92
+ ... negative_prompt=negative_prompt,
93
+ ... height=height,
94
+ ... width=width,
95
+ ... num_frames=81,
96
+ ... guidance_scale=5.0,
97
+ ... ).frames[0]
98
+ >>> export_to_video(output, "output.mp4", fps=16)
99
+ ```
100
+ """
101
+
102
+
103
+ def basic_clean(text):
104
+ text = ftfy.fix_text(text)
105
+ text = html.unescape(html.unescape(text))
106
+ return text.strip()
107
+
108
+
109
+ def whitespace_clean(text):
110
+ text = re.sub(r"\s+", " ", text)
111
+ text = text.strip()
112
+ return text
113
+
114
+
115
+ def prompt_clean(text):
116
+ text = whitespace_clean(basic_clean(text))
117
+ return text
118
+
119
+
120
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
121
+ def retrieve_latents(
122
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
123
+ ):
124
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
125
+ return encoder_output.latent_dist.sample(generator)
126
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
127
+ return encoder_output.latent_dist.mode()
128
+ elif hasattr(encoder_output, "latents"):
129
+ return encoder_output.latents
130
+ else:
131
+ raise AttributeError("Could not access latents of provided encoder_output")
132
+
133
+
134
+ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
135
+ r"""
136
+ Pipeline for image-to-video generation using Wan.
137
+
138
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
139
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
140
+
141
+ Args:
142
+ tokenizer ([`T5Tokenizer`]):
143
+ Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
144
+ specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
145
+ text_encoder ([`T5EncoderModel`]):
146
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
147
+ the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
148
+ image_encoder ([`CLIPVisionModel`]):
149
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically
150
+ the
151
+ [clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large)
152
+ variant.
153
+ transformer ([`WanTransformer3DModel`]):
154
+ Conditional Transformer to denoise the input latents.
155
+ scheduler ([`UniPCMultistepScheduler`]):
156
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
157
+ vae ([`AutoencoderKLWan`]):
158
+ Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
159
+ transformer_2 ([`WanTransformer3DModel`], *optional*):
160
+ Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
161
+ `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
162
+ `transformer` is used.
163
+ boundary_ratio (`float`, *optional*, defaults to `None`):
164
+ Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
165
+ The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
166
+ `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
167
+ boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
168
+ """
169
+
170
+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae"
171
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
172
+ _optional_components = ["transformer", "transformer_2", "image_encoder", "image_processor"]
173
+
174
+ def __init__(
175
+ self,
176
+ tokenizer: AutoTokenizer,
177
+ text_encoder: UMT5EncoderModel,
178
+ vae: AutoencoderKLWan,
179
+ scheduler: FlowMatchEulerDiscreteScheduler,
180
+ image_processor: CLIPImageProcessor = None,
181
+ image_encoder: CLIPVisionModel = None,
182
+ transformer: WanTransformer3DModel = None,
183
+ transformer_2: WanTransformer3DModel = None,
184
+ boundary_ratio: Optional[float] = None,
185
+ expand_timesteps: bool = False,
186
+ ):
187
+ super().__init__()
188
+
189
+ self.register_modules(
190
+ vae=vae,
191
+ text_encoder=text_encoder,
192
+ tokenizer=tokenizer,
193
+ image_encoder=image_encoder,
194
+ transformer=transformer,
195
+ scheduler=scheduler,
196
+ image_processor=image_processor,
197
+ transformer_2=transformer_2,
198
+ )
199
+ self.register_to_config(boundary_ratio=boundary_ratio, expand_timesteps=expand_timesteps)
200
+
201
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
202
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
203
+ self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
204
+ self.image_processor = image_processor
205
+
206
+ def _get_t5_prompt_embeds(
207
+ self,
208
+ prompt: Union[str, List[str]] = None,
209
+ num_videos_per_prompt: int = 1,
210
+ max_sequence_length: int = 512,
211
+ device: Optional[torch.device] = None,
212
+ dtype: Optional[torch.dtype] = None,
213
+ ):
214
+ device = device or self._execution_device
215
+ dtype = dtype or self.text_encoder.dtype
216
+
217
+ prompt = [prompt] if isinstance(prompt, str) else prompt
218
+ prompt = [prompt_clean(u) for u in prompt]
219
+ batch_size = len(prompt)
220
+
221
+ text_inputs = self.tokenizer(
222
+ prompt,
223
+ padding="max_length",
224
+ max_length=max_sequence_length,
225
+ truncation=True,
226
+ add_special_tokens=True,
227
+ return_attention_mask=True,
228
+ return_tensors="pt",
229
+ )
230
+ text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
231
+ seq_lens = mask.gt(0).sum(dim=1).long()
232
+
233
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
234
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
235
+ prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
236
+ prompt_embeds = torch.stack(
237
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
238
+ )
239
+
240
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
241
+ _, seq_len, _ = prompt_embeds.shape
242
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
243
+ prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
244
+
245
+ return prompt_embeds
246
+
247
+ def encode_image(
248
+ self,
249
+ image: PipelineImageInput,
250
+ device: Optional[torch.device] = None,
251
+ ):
252
+ device = device or self._execution_device
253
+ image = self.image_processor(images=image, return_tensors="pt").to(device)
254
+ image_embeds = self.image_encoder(**image, output_hidden_states=True)
255
+ return image_embeds.hidden_states[-2]
256
+
257
+ # Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
258
+ def encode_prompt(
259
+ self,
260
+ prompt: Union[str, List[str]],
261
+ negative_prompt: Optional[Union[str, List[str]]] = None,
262
+ do_classifier_free_guidance: bool = True,
263
+ num_videos_per_prompt: int = 1,
264
+ prompt_embeds: Optional[torch.Tensor] = None,
265
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
266
+ max_sequence_length: int = 226,
267
+ device: Optional[torch.device] = None,
268
+ dtype: Optional[torch.dtype] = None,
269
+ ):
270
+ r"""
271
+ Encodes the prompt into text encoder hidden states.
272
+
273
+ Args:
274
+ prompt (`str` or `List[str]`, *optional*):
275
+ prompt to be encoded
276
+ negative_prompt (`str` or `List[str]`, *optional*):
277
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
278
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
279
+ less than `1`).
280
+ do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
281
+ Whether to use classifier free guidance or not.
282
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
283
+ Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
284
+ prompt_embeds (`torch.Tensor`, *optional*):
285
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
286
+ provided, text embeddings will be generated from `prompt` input argument.
287
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
288
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
289
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
290
+ argument.
291
+ device: (`torch.device`, *optional*):
292
+ torch device
293
+ dtype: (`torch.dtype`, *optional*):
294
+ torch dtype
295
+ """
296
+ device = device or self._execution_device
297
+
298
+ prompt = [prompt] if isinstance(prompt, str) else prompt
299
+ if prompt is not None:
300
+ batch_size = len(prompt)
301
+ else:
302
+ batch_size = prompt_embeds.shape[0]
303
+
304
+ if prompt_embeds is None:
305
+ prompt_embeds = self._get_t5_prompt_embeds(
306
+ prompt=prompt,
307
+ num_videos_per_prompt=num_videos_per_prompt,
308
+ max_sequence_length=max_sequence_length,
309
+ device=device,
310
+ dtype=dtype,
311
+ )
312
+
313
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
314
+ negative_prompt = negative_prompt or ""
315
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
316
+
317
+ if prompt is not None and type(prompt) is not type(negative_prompt):
318
+ raise TypeError(
319
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
320
+ f" {type(prompt)}."
321
+ )
322
+ elif batch_size != len(negative_prompt):
323
+ raise ValueError(
324
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
325
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
326
+ " the batch size of `prompt`."
327
+ )
328
+
329
+ negative_prompt_embeds = self._get_t5_prompt_embeds(
330
+ prompt=negative_prompt,
331
+ num_videos_per_prompt=num_videos_per_prompt,
332
+ max_sequence_length=max_sequence_length,
333
+ device=device,
334
+ dtype=dtype,
335
+ )
336
+
337
+ return prompt_embeds, negative_prompt_embeds
338
+
339
+ def check_inputs(
340
+ self,
341
+ prompt,
342
+ negative_prompt,
343
+ image,
344
+ height,
345
+ width,
346
+ prompt_embeds=None,
347
+ negative_prompt_embeds=None,
348
+ image_embeds=None,
349
+ callback_on_step_end_tensor_inputs=None,
350
+ guidance_scale_2=None,
351
+ ):
352
+ if image is not None and image_embeds is not None:
353
+ raise ValueError(
354
+ f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
355
+ " only forward one of the two."
356
+ )
357
+ if image is None and image_embeds is None:
358
+ raise ValueError(
359
+ "Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
360
+ )
361
+ if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
362
+ raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
363
+ if height % 16 != 0 or width % 16 != 0:
364
+ raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
365
+
366
+ if callback_on_step_end_tensor_inputs is not None and not all(
367
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
368
+ ):
369
+ raise ValueError(
370
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
371
+ )
372
+
373
+ if prompt is not None and prompt_embeds is not None:
374
+ raise ValueError(
375
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
376
+ " only forward one of the two."
377
+ )
378
+ elif negative_prompt is not None and negative_prompt_embeds is not None:
379
+ raise ValueError(
380
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
381
+ " only forward one of the two."
382
+ )
383
+ elif prompt is None and prompt_embeds is None:
384
+ raise ValueError(
385
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
386
+ )
387
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
388
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
389
+ elif negative_prompt is not None and (
390
+ not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
391
+ ):
392
+ raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
393
+
394
+ if self.config.boundary_ratio is None and guidance_scale_2 is not None:
395
+ raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
396
+
397
+ if self.config.boundary_ratio is not None and image_embeds is not None:
398
+ raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.")
399
+
400
+ def prepare_latents(
401
+ self,
402
+ image: PipelineImageInput,
403
+ traj_tensor,
404
+ ID_tensor,
405
+ batch_size: int,
406
+ num_channels_latents: int = 16,
407
+ height: int = 480,
408
+ width: int = 832,
409
+ num_frames: int = 81,
410
+ dtype: Optional[torch.dtype] = None,
411
+ device: Optional[torch.device] = None,
412
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
413
+ latents: Optional[torch.Tensor] = None,
414
+ last_image: Optional[torch.Tensor] = None,
415
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
416
+ num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
417
+ latent_height = height // self.vae_scale_factor_spatial
418
+ latent_width = width // self.vae_scale_factor_spatial
419
+
420
+ shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
421
+ if isinstance(generator, list) and len(generator) != batch_size:
422
+ raise ValueError(
423
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
424
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
425
+ )
426
+
427
+ if latents is None:
428
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
429
+ else:
430
+ latents = latents.to(device=device, dtype=dtype)
431
+
432
+ image = image.unsqueeze(2) # [batch_size, channels, 1, height, width]
433
+
434
+ if self.config.expand_timesteps:
435
+ video_condition = image
436
+
437
+ elif last_image is None:
438
+ video_condition = torch.cat(
439
+ [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
440
+ )
441
+ else:
442
+ last_image = last_image.unsqueeze(2)
443
+ video_condition = torch.cat(
444
+ [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
445
+ dim=2,
446
+ )
447
+ video_condition = video_condition.to(device=device, dtype=self.vae.dtype)
448
+
449
+ latents_mean = (
450
+ torch.tensor(self.vae.config.latents_mean)
451
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
452
+ .to(latents.device, latents.dtype)
453
+ )
454
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
455
+ latents.device, latents.dtype
456
+ )
457
+
458
+ if isinstance(generator, list):
459
+ latent_condition = [
460
+ retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
461
+ ]
462
+ latent_condition = torch.cat(latent_condition)
463
+ else:
464
+ latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
465
+ latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
466
+
467
+ latent_condition = latent_condition.to(dtype)
468
+ latent_condition = (latent_condition - latents_mean) * latents_std
469
+
470
+
471
+
472
+ # Prepare the traj latent
473
+ traj_tensor = traj_tensor.to(device, dtype=self.vae.dtype) #.unsqueeze(0)
474
+ traj_tensor = traj_tensor.unsqueeze(0)
475
+ traj_tensor = traj_tensor.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
476
+
477
+ # VAE Encode
478
+ traj_latents = retrieve_latents(self.vae.encode(traj_tensor), sample_mode="argmax")
479
+
480
+ # Extract Mean and Variance
481
+ traj_latents = (traj_latents - latents_mean) * latents_std
482
+
483
+ # Final Convert
484
+ traj_latents = traj_latents.to(memory_format = torch.contiguous_format).float()
485
+
486
+
487
+
488
+ # Prepare the ID latents
489
+ if ID_tensor.shape[2] != 0: # Must have at least one ID frame, could be empty sometime
490
+
491
+ # Tranform
492
+ ID_tensor = ID_tensor.to(device=device, dtype=self.vae.dtype)
493
+
494
+ # VAE encode for each frame One by One
495
+ ID_latents = []
496
+ for frame_idx in range(ID_tensor.shape[2]):
497
+
498
+ # Fetch
499
+ ID_tensor = ID_tensor[:, :, frame_idx].unsqueeze(2)
500
+
501
+ # Single Frame Encode, which will be single frame token
502
+ ID_latent = retrieve_latents(self.vae.encode(ID_tensor), sample_mode="argmax")
503
+ ID_latent = ID_latent.repeat(batch_size, 1, 1, 1, 1)
504
+
505
+ # Convert
506
+ ID_latent = ID_latent.to(dtype)
507
+ ID_latent = (ID_latent - latents_mean) * latents_std
508
+
509
+ # Append
510
+ ID_latents.append(ID_latent)
511
+
512
+ # Final Convert
513
+ ID_latent_condition = torch.cat(ID_latents, dim = 2)
514
+
515
+ # Add padding to the traj latents
516
+ ID_latent_padding = torch.zeros_like(ID_latent_condition)
517
+ traj_latents = torch.cat([traj_latents, ID_latent_padding], dim=2)
518
+
519
+ # Update the number of latents frames for the first frame mask
520
+ # num_latent_frames = num_latent_frames + len(ID_latents)
521
+
522
+ else:
523
+ # Return an empty one
524
+ ID_latent_condition = None
525
+
526
+
527
+
528
+ if self.config.expand_timesteps: # For Wan2.2
529
+ first_frame_mask = torch.ones(
530
+ 1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device
531
+ )
532
+ first_frame_mask[:, :, 0] = 0
533
+
534
+ # Return all condition information needed
535
+ return latents, latent_condition, traj_latents, ID_latent_condition, first_frame_mask
536
+
537
+
538
+
539
+ # The rest if for Wan2.1
540
+ mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
541
+
542
+ if last_image is None:
543
+ mask_lat_size[:, :, list(range(1, num_frames))] = 0
544
+ else:
545
+ mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
546
+ first_frame_mask = mask_lat_size[:, :, 0:1]
547
+ first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
548
+ mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
549
+ mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width)
550
+ mask_lat_size = mask_lat_size.transpose(1, 2)
551
+ mask_lat_size = mask_lat_size.to(latent_condition.device)
552
+
553
+ return latents, torch.concat([mask_lat_size, latent_condition], dim=1)
554
+
555
+ @property
556
+ def guidance_scale(self):
557
+ return self._guidance_scale
558
+
559
+ @property
560
+ def do_classifier_free_guidance(self):
561
+ return self._guidance_scale > 1
562
+
563
+ @property
564
+ def num_timesteps(self):
565
+ return self._num_timesteps
566
+
567
+ @property
568
+ def current_timestep(self):
569
+ return self._current_timestep
570
+
571
+ @property
572
+ def interrupt(self):
573
+ return self._interrupt
574
+
575
+ @property
576
+ def attention_kwargs(self):
577
+ return self._attention_kwargs
578
+
579
+ @torch.no_grad()
580
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
581
+ def __call__(
582
+ self,
583
+ image: PipelineImageInput,
584
+ prompt: Union[str, List[str]] = None,
585
+ negative_prompt: Union[str, List[str]] = None,
586
+ traj_tensor = None,
587
+ ID_tensor = None,
588
+ height: int = 480,
589
+ width: int = 832,
590
+ num_frames: int = 81,
591
+ num_inference_steps: int = 50,
592
+ guidance_scale: float = 5.0,
593
+ guidance_scale_2: Optional[float] = None,
594
+ num_videos_per_prompt: Optional[int] = 1,
595
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
596
+ latents: Optional[torch.Tensor] = None,
597
+ prompt_embeds: Optional[torch.Tensor] = None,
598
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
599
+ image_embeds: Optional[torch.Tensor] = None,
600
+ last_image: Optional[torch.Tensor] = None,
601
+ output_type: Optional[str] = "np",
602
+ return_dict: bool = True,
603
+ attention_kwargs: Optional[Dict[str, Any]] = None,
604
+ callback_on_step_end: Optional[
605
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
606
+ ] = None,
607
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
608
+ max_sequence_length: int = 512,
609
+ ):
610
+ r"""
611
+ The call function to the pipeline for generation.
612
+
613
+ Args:
614
+ image (`PipelineImageInput`):
615
+ The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
616
+ prompt (`str` or `List[str]`, *optional*):
617
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
618
+ instead.
619
+ negative_prompt (`str` or `List[str]`, *optional*):
620
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
621
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
622
+ less than `1`).
623
+ height (`int`, defaults to `480`):
624
+ The height of the generated video.
625
+ width (`int`, defaults to `832`):
626
+ The width of the generated video.
627
+ num_frames (`int`, defaults to `81`):
628
+ The number of frames in the generated video.
629
+ num_inference_steps (`int`, defaults to `50`):
630
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
631
+ expense of slower inference.
632
+ guidance_scale (`float`, defaults to `5.0`):
633
+ Guidance scale as defined in [Classifier-Free Diffusion
634
+ Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
635
+ of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
636
+ `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
637
+ the text `prompt`, usually at the expense of lower image quality.
638
+ guidance_scale_2 (`float`, *optional*, defaults to `None`):
639
+ Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
640
+ `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
641
+ and the pipeline's `boundary_ratio` are not None.
642
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
643
+ The number of images to generate per prompt.
644
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
645
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
646
+ generation deterministic.
647
+ latents (`torch.Tensor`, *optional*):
648
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
649
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
650
+ tensor is generated by sampling using the supplied random `generator`.
651
+ prompt_embeds (`torch.Tensor`, *optional*):
652
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
653
+ provided, text embeddings are generated from the `prompt` input argument.
654
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
655
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
656
+ provided, text embeddings are generated from the `negative_prompt` input argument.
657
+ image_embeds (`torch.Tensor`, *optional*):
658
+ Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
659
+ image embeddings are generated from the `image` input argument.
660
+ output_type (`str`, *optional*, defaults to `"np"`):
661
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
662
+ return_dict (`bool`, *optional*, defaults to `True`):
663
+ Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
664
+ attention_kwargs (`dict`, *optional*):
665
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
666
+ `self.processor` in
667
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
668
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
669
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
670
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
671
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
672
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
673
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
674
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
675
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
676
+ `._callback_tensor_inputs` attribute of your pipeline class.
677
+ max_sequence_length (`int`, defaults to `512`):
678
+ The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
679
+ truncated. If the prompt is shorter, it will be padded to this length.
680
+
681
+ Examples:
682
+
683
+ Returns:
684
+ [`~WanPipelineOutput`] or `tuple`:
685
+ If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
686
+ the first element is a list with the generated images and the second element is a list of `bool`s
687
+ indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
688
+ """
689
+
690
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
691
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
692
+
693
+ # 1. Check inputs. Raise error if not correct
694
+ self.check_inputs(
695
+ prompt,
696
+ negative_prompt,
697
+ image,
698
+ height,
699
+ width,
700
+ prompt_embeds,
701
+ negative_prompt_embeds,
702
+ image_embeds,
703
+ callback_on_step_end_tensor_inputs,
704
+ guidance_scale_2,
705
+ )
706
+
707
+ if num_frames % self.vae_scale_factor_temporal != 1:
708
+ logger.warning(
709
+ f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
710
+ )
711
+ num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
712
+ num_frames = max(num_frames, 1)
713
+
714
+ if self.config.boundary_ratio is not None and guidance_scale_2 is None:
715
+ guidance_scale_2 = guidance_scale
716
+
717
+ self._guidance_scale = guidance_scale
718
+ self._guidance_scale_2 = guidance_scale_2
719
+ self._attention_kwargs = attention_kwargs
720
+ self._current_timestep = None
721
+ self._interrupt = False
722
+
723
+ device = self._execution_device
724
+
725
+ # 2. Define call parameters
726
+ if prompt is not None and isinstance(prompt, str):
727
+ batch_size = 1
728
+ elif prompt is not None and isinstance(prompt, list):
729
+ batch_size = len(prompt)
730
+ else:
731
+ batch_size = prompt_embeds.shape[0]
732
+
733
+ # 3. Encode input prompt
734
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
735
+ prompt=prompt,
736
+ negative_prompt=negative_prompt,
737
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
738
+ num_videos_per_prompt=num_videos_per_prompt,
739
+ prompt_embeds=prompt_embeds,
740
+ negative_prompt_embeds=negative_prompt_embeds,
741
+ max_sequence_length=max_sequence_length,
742
+ device=device,
743
+ )
744
+
745
+ # Encode image embedding
746
+ transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
747
+ prompt_embeds = prompt_embeds.to(transformer_dtype)
748
+ if negative_prompt_embeds is not None:
749
+ negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
750
+
751
+ # only wan 2.1 i2v transformer accepts image_embeds
752
+ if self.transformer is not None and self.transformer.config.image_dim is not None:
753
+ if image_embeds is None:
754
+ if last_image is None:
755
+ image_embeds = self.encode_image(image, device)
756
+ else:
757
+ image_embeds = self.encode_image([image, last_image], device)
758
+ image_embeds = image_embeds.repeat(batch_size, 1, 1)
759
+ image_embeds = image_embeds.to(transformer_dtype)
760
+
761
+ # 4. Prepare timesteps
762
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
763
+ timesteps = self.scheduler.timesteps
764
+
765
+ # 5. Prepare latent variables
766
+ num_channels_latents = self.vae.config.z_dim
767
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
768
+ if last_image is not None:
769
+ last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
770
+ device, dtype=torch.float32
771
+ )
772
+
773
+ latents_outputs = self.prepare_latents(
774
+ image,
775
+ traj_tensor,
776
+ ID_tensor,
777
+ batch_size * num_videos_per_prompt,
778
+ num_channels_latents,
779
+ height,
780
+ width,
781
+ num_frames,
782
+ torch.float32,
783
+ device,
784
+ generator,
785
+ latents,
786
+ last_image,
787
+ )
788
+ if self.config.expand_timesteps:
789
+ # wan 2.2 5b i2v use firt_frame_mask to mask timesteps
790
+ latents, condition, traj_latents, ID_latent_condition, first_frame_mask = latents_outputs
791
+ else:
792
+ latents, condition = latents_outputs
793
+
794
+
795
+ # 5.5. For ID reference change, we need to add padding for the latents each time
796
+ _, channel_num, num_gen_frames, latent_height, latent_width = latents.shape
797
+
798
+
799
+
800
+ # 6. Denoising loop
801
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
802
+ self._num_timesteps = len(timesteps)
803
+
804
+ if self.config.boundary_ratio is not None:
805
+ boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
806
+ else:
807
+ boundary_timestep = None
808
+
809
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
810
+ for i, t in enumerate(timesteps):
811
+ if self.interrupt:
812
+ continue
813
+
814
+ self._current_timestep = t
815
+
816
+ if boundary_timestep is None or t >= boundary_timestep:
817
+ # wan2.1 or high-noise stage in wan2.2
818
+ current_model = self.transformer
819
+ current_guidance_scale = guidance_scale
820
+ else:
821
+ # low-noise stage in wan2.2
822
+ current_model = self.transformer_2
823
+ current_guidance_scale = guidance_scale_2
824
+
825
+
826
+ if self.config.expand_timesteps:
827
+
828
+ # Multiply with the mask, such that the first frame latent of the model input is the clean latent of the first frame condition (Here, for Frame INO, the first frame should be masked outpainting design)
829
+ latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents # NOTE: 现在first frame应该设定为带masked的first frame(有outpainting的样式的)
830
+ latent_model_input = latent_model_input.to(transformer_dtype)
831
+
832
+ # Add padding for the first_frame_mask here with the length of ID tokens
833
+ if ID_latent_condition is not None:
834
+ mask_padding = torch.ones(
835
+ 1, 1, ID_latent_condition.shape[2], latent_height, latent_width, dtype=transformer_dtype, device=device
836
+ )
837
+ first_frame_mask_adjust = torch.cat([first_frame_mask, mask_padding], dim = 2)
838
+ else:
839
+ first_frame_mask_adjust = first_frame_mask
840
+
841
+ # Reshape to num_latent_frames * (latent_height // patch_size) * (latent_width // patch_size)
842
+ temp_ts = (first_frame_mask_adjust[0][0][:, ::2, ::2] * t).flatten()
843
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
844
+
845
+ else:
846
+
847
+ latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
848
+ timestep = t.expand(latents.shape[0])
849
+ # TODO: 我现在不是特别确定这里的timestep 跟training的align了吗?
850
+
851
+
852
+ # Frame-Wise concatenate ID tokens
853
+ if ID_latent_condition is not None:
854
+ latent_model_input = torch.cat([latent_model_input, ID_latent_condition], dim = 2)
855
+
856
+
857
+ # Concat the trajectory latents in Channel dimension
858
+ latent_model_input = torch.cat([latent_model_input, traj_latents], dim = 1).to(transformer_dtype)
859
+
860
+
861
+ # Predict the noise according to the timestep
862
+ with current_model.cache_context("cond"):
863
+ noise_pred = current_model(
864
+ hidden_states = latent_model_input,
865
+ timestep = timestep,
866
+ encoder_hidden_states = prompt_embeds,
867
+ encoder_hidden_states_image = image_embeds,
868
+ attention_kwargs = attention_kwargs,
869
+ return_dict = False,
870
+ )[0]
871
+
872
+ if self.do_classifier_free_guidance:
873
+ with current_model.cache_context("uncond"):
874
+ noise_uncond = current_model(
875
+ hidden_states = latent_model_input,
876
+ timestep = timestep,
877
+ encoder_hidden_states = negative_prompt_embeds,
878
+ encoder_hidden_states_image = image_embeds,
879
+ attention_kwargs = attention_kwargs,
880
+ return_dict = False,
881
+ )[0]
882
+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
883
+
884
+
885
+ # Discard the Extra ID tokens in the Noise Prediction
886
+ noise_pred = noise_pred[:, :, :num_gen_frames]
887
+
888
+
889
+
890
+ # compute the previous noisy sample x_t -> x_t-1
891
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
892
+
893
+ if callback_on_step_end is not None:
894
+ callback_kwargs = {}
895
+ for k in callback_on_step_end_tensor_inputs:
896
+ callback_kwargs[k] = locals()[k]
897
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
898
+
899
+ latents = callback_outputs.pop("latents", latents)
900
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
901
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
902
+
903
+ # call the callback, if provided
904
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
905
+ progress_bar.update()
906
+
907
+ if XLA_AVAILABLE:
908
+ xm.mark_step()
909
+
910
+ self._current_timestep = None
911
+
912
+ if self.config.expand_timesteps:
913
+ latents = (1 - first_frame_mask) * condition + first_frame_mask * latents
914
+
915
+ if not output_type == "latent":
916
+ latents = latents.to(self.vae.dtype)
917
+ latents_mean = (
918
+ torch.tensor(self.vae.config.latents_mean)
919
+ .view(1, self.vae.config.z_dim, 1, 1, 1)
920
+ .to(latents.device, latents.dtype)
921
+ )
922
+ latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
923
+ latents.device, latents.dtype
924
+ )
925
+ latents = latents / latents_std + latents_mean
926
+ video = self.vae.decode(latents, return_dict=False)[0]
927
+ video = self.video_processor.postprocess_video(video, output_type=output_type)
928
+ else:
929
+ video = latents
930
+
931
+ # Offload all models
932
+ self.maybe_free_model_hooks()
933
+
934
+ if not return_dict:
935
+ return (video,)
936
+
937
+ return WanPipelineOutput(frames=video)
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pandas
2
+ tqdm
3
+ opencv-python
4
+ pyiqa
5
+ numpy==1.26.0
6
+ ffmpeg-python
7
+ bitsandbytes
8
+ pyarrow
9
+ omegaconf
10
+ peft>=0.15.0
11
+ transformers>=4.56.2 # Install in the newest version
12
+ git+https://github.com/huggingface/diffusers.git
13
+ sentencepiece
14
+ qwen-vl-utils[decord]==0.0.8
15
+ scikit-learn
16
+ matplotlib
17
+ gradio
18
+ imageio-ffmpeg
19
+ bitsandbytes
20
+ git+https://github.com/facebookresearch/segment-anything.git
21
+ git+https://github.com/facebookresearch/sam2.git
22
+ accelerate
23
+ hf-transfer
24
+
utils/optical_flow_utils.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def make_colorwheel():
5
+ """
6
+ Generates a color wheel for optical flow visualization as presented in:
7
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
8
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
9
+
10
+ Code follows the original C++ source code of Daniel Scharstein.
11
+ Code follows the the Matlab source code of Deqing Sun.
12
+
13
+ Returns:
14
+ np.ndarray: Color wheel
15
+ """
16
+
17
+ RY = 15
18
+ YG = 6
19
+ GC = 4
20
+ CB = 11
21
+ BM = 13
22
+ MR = 6
23
+
24
+ ncols = RY + YG + GC + CB + BM + MR
25
+ colorwheel = np.zeros((ncols, 3))
26
+ col = 0
27
+
28
+ # RY
29
+ colorwheel[0:RY, 0] = 255
30
+ colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
31
+ col = col+RY
32
+ # YG
33
+ colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
34
+ colorwheel[col:col+YG, 1] = 255
35
+ col = col+YG
36
+ # GC
37
+ colorwheel[col:col+GC, 1] = 255
38
+ colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
39
+ col = col+GC
40
+ # CB
41
+ colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
42
+ colorwheel[col:col+CB, 2] = 255
43
+ col = col+CB
44
+ # BM
45
+ colorwheel[col:col+BM, 2] = 255
46
+ colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
47
+ col = col+BM
48
+ # MR
49
+ colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
50
+ colorwheel[col:col+MR, 0] = 255
51
+ return colorwheel
52
+
53
+
54
+ def flow_uv_to_colors(u, v, convert_to_bgr=False):
55
+ """
56
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
57
+
58
+ According to the C++ source code of Daniel Scharstein
59
+ According to the Matlab source code of Deqing Sun
60
+
61
+ Args:
62
+ u (np.ndarray): Input horizontal flow of shape [H,W]
63
+ v (np.ndarray): Input vertical flow of shape [H,W]
64
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
65
+
66
+ Returns:
67
+ np.ndarray: Flow visualization image of shape [H,W,3] in range [0, 255]
68
+ """
69
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
70
+ colorwheel = make_colorwheel() # shape [55x3]
71
+ ncols = colorwheel.shape[0]
72
+ rad = np.sqrt(np.square(u) + np.square(v))
73
+ a = np.arctan2(-v, -u)/np.pi
74
+ fk = (a+1) / 2*(ncols-1)
75
+ k0 = np.floor(fk).astype(np.int32)
76
+ k1 = k0 + 1
77
+ k1[k1 == ncols] = 0
78
+ f = fk - k0
79
+ for i in range(colorwheel.shape[1]):
80
+ tmp = colorwheel[:,i]
81
+ col0 = tmp[k0] / 255.0
82
+ col1 = tmp[k1] / 255.0
83
+ col = (1-f)*col0 + f*col1
84
+ idx = (rad <= 1)
85
+ col[idx] = 1 - rad[idx] * (1-col[idx])
86
+ col[~idx] = col[~idx] * 0.75 # out of range
87
+ # Note the 2-i => BGR instead of RGB
88
+ ch_idx = 2-i if convert_to_bgr else i
89
+ flow_image[:,:,ch_idx] = np.floor(255 * col)
90
+ return flow_image
91
+
92
+
93
+ def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
94
+ """
95
+ Expects a two dimensional flow image of shape.
96
+
97
+ Args:
98
+ flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
99
+ clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
100
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
101
+
102
+ Returns:
103
+ np.ndarray: Flow visualization image of shape [H,W,3]
104
+ """
105
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
106
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
107
+
108
+ if clip_flow is not None:
109
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
110
+
111
+ u = flow_uv[:,:,0]
112
+ v = flow_uv[:,:,1]
113
+ rad = np.sqrt(np.square(u) + np.square(v))
114
+ rad_max = np.max(rad)
115
+ epsilon = 1e-5
116
+ u = u / (rad_max + epsilon)
117
+ v = v / (rad_max + epsilon)
118
+ return flow_uv_to_colors(u, v, convert_to_bgr)
119
+
120
+
121
+
122
+ def filter_uv(flow, threshold_factor = 0.1, sample_prob = 1.0):
123
+ '''
124
+ Args:
125
+ flow (numpy): A 2-dim array that stores x and y change in optical flow
126
+ threshold_factor (float): Prob of discarding outliers vector
127
+ sample_prob (float): The selection rate of how much proportion of points we need to store
128
+ '''
129
+ u = flow[:,:,0]
130
+ v = flow[:,:,1]
131
+
132
+ # Filter out those less than the threshold
133
+ rad = np.sqrt(np.square(u) + np.square(v))
134
+ rad_max = np.max(rad)
135
+
136
+ threshold = threshold_factor * rad_max
137
+ flow[:,:,0][rad < threshold] = 0
138
+ flow[:,:,1][rad < threshold] = 0
139
+
140
+
141
+ # Randomly sample based on sample_prob
142
+ zero_prob = 1 - sample_prob
143
+ random_array = np.random.randn(*flow.shape)
144
+ random_array[random_array < zero_prob] = 0
145
+ random_array[random_array >= zero_prob] = 1
146
+ flow = flow * random_array
147
+
148
+
149
+ return flow
150
+
151
+
152
+
153
+ ############################################# The following is for dilation method in optical flow ######################################
154
+ def sigma_matrix2(sig_x, sig_y, theta):
155
+ """Calculate the rotated sigma matrix (two dimensional matrix).
156
+ Args:
157
+ sig_x (float):
158
+ sig_y (float):
159
+ theta (float): Radian measurement.
160
+ Returns:
161
+ ndarray: Rotated sigma matrix.
162
+ """
163
+ d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
164
+ u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
165
+ return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
166
+
167
+
168
+ def mesh_grid(kernel_size):
169
+ """Generate the mesh grid, centering at zero.
170
+ Args:
171
+ kernel_size (int):
172
+ Returns:
173
+ xy (ndarray): with the shape (kernel_size, kernel_size, 2)
174
+ xx (ndarray): with the shape (kernel_size, kernel_size)
175
+ yy (ndarray): with the shape (kernel_size, kernel_size)
176
+ """
177
+ ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
178
+ xx, yy = np.meshgrid(ax, ax)
179
+ xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
180
+ 1))).reshape(kernel_size, kernel_size, 2)
181
+ return xy, xx, yy
182
+
183
+
184
+ def pdf2(sigma_matrix, grid):
185
+ """Calculate PDF of the bivariate Gaussian distribution.
186
+ Args:
187
+ sigma_matrix (ndarray): with the shape (2, 2)
188
+ grid (ndarray): generated by :func:`mesh_grid`,
189
+ with the shape (K, K, 2), K is the kernel size.
190
+ Returns:
191
+ kernel (ndarrray): un-normalized kernel.
192
+ """
193
+ inverse_sigma = np.linalg.inv(sigma_matrix)
194
+ kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
195
+ return kernel
196
+
197
+ def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
198
+ """Generate a bivariate isotropic or anisotropic Gaussian kernel.
199
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
200
+ Args:
201
+ kernel_size (int):
202
+ sig_x (float):
203
+ sig_y (float):
204
+ theta (float): Radian measurement.
205
+ grid (ndarray, optional): generated by :func:`mesh_grid`,
206
+ with the shape (K, K, 2), K is the kernel size. Default: None
207
+ isotropic (bool):
208
+ Returns:
209
+ kernel (ndarray): normalized kernel.
210
+ """
211
+ if grid is None:
212
+ grid, _, _ = mesh_grid(kernel_size)
213
+ if isotropic:
214
+ sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
215
+ else:
216
+ sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
217
+ kernel = pdf2(sigma_matrix, grid)
218
+ kernel = kernel / np.sum(kernel)
219
+ return kernel