Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
5000b0a
1
Parent(s):
1577493
feat: initial push
Browse files- .gitignore +33 -0
- app.py +893 -0
- architecture/attention_processor.py +0 -0
- architecture/autoencoder_kl_wan.py +1419 -0
- architecture/cogvideox_transformer_3d.py +563 -0
- architecture/embeddings.py +0 -0
- architecture/noise_sampler.py +54 -0
- architecture/transformer_wan.py +552 -0
- config/accelerate_config_4GPU.json +18 -0
- config/train_cogvideox_motion.yaml +88 -0
- config/train_cogvideox_motion_FrameINO.yaml +96 -0
- config/train_wan_motion.yaml +104 -0
- config/train_wan_motion_FrameINO.yaml +110 -0
- data_loader/sampler.py +110 -0
- data_loader/video_dataset_motion.py +407 -0
- data_loader/video_dataset_motion_FrameINO.py +578 -0
- data_loader/video_dataset_motion_FrameINO_old.py +538 -0
- pipelines/pipeline_cogvideox_i2v_motion.py +931 -0
- pipelines/pipeline_cogvideox_i2v_motion_FrameINO.py +960 -0
- pipelines/pipeline_wan_i2v_motion.py +861 -0
- pipelines/pipeline_wan_i2v_motion_FrameINO.py +937 -0
- requirements.txt +24 -0
- utils/optical_flow_utils.py +219 -0
.gitignore
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.csv
|
| 2 |
+
*.mp4
|
| 3 |
+
*.png
|
| 4 |
+
*.jpg
|
| 5 |
+
*.err
|
| 6 |
+
*.txt
|
| 7 |
+
*.log
|
| 8 |
+
*.pyc
|
| 9 |
+
*.pth
|
| 10 |
+
*.DS_Store*
|
| 11 |
+
*.o
|
| 12 |
+
*.so
|
| 13 |
+
*.egg*
|
| 14 |
+
*.json
|
| 15 |
+
*.zip
|
| 16 |
+
*.jpeg
|
| 17 |
+
*.pkl
|
| 18 |
+
*.gif
|
| 19 |
+
*.pem
|
| 20 |
+
*.npy
|
| 21 |
+
*.sh
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
pretrained/*
|
| 25 |
+
checkpoints/*
|
| 26 |
+
preprocess/sam2_code
|
| 27 |
+
|
| 28 |
+
!preprocess/oneformer_code/oneformer/data/bpe_simple_vocab_16e6.txt
|
| 29 |
+
!config/*.json
|
| 30 |
+
!requirements.txt
|
| 31 |
+
!requirements/*
|
| 32 |
+
!__assets__/*
|
| 33 |
+
!__assets__/page/*
|
app.py
ADDED
|
@@ -0,0 +1,893 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys, shutil
|
| 2 |
+
import csv
|
| 3 |
+
import numpy as np
|
| 4 |
+
import ffmpeg
|
| 5 |
+
import cv2
|
| 6 |
+
import collections
|
| 7 |
+
import json
|
| 8 |
+
import math
|
| 9 |
+
import time
|
| 10 |
+
import imageio
|
| 11 |
+
import random
|
| 12 |
+
import ast
|
| 13 |
+
import gradio as gr
|
| 14 |
+
from omegaconf import OmegaConf
|
| 15 |
+
from PIL import Image
|
| 16 |
+
from segment_anything import SamPredictor, sam_model_registry
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
| 20 |
+
import torch
|
| 21 |
+
from torch.utils.data import DataLoader, Dataset
|
| 22 |
+
from torchvision import transforms
|
| 23 |
+
from diffusers import AutoencoderKLCogVideoX
|
| 24 |
+
from transformers import T5EncoderModel
|
| 25 |
+
from diffusers.utils import export_to_video, load_image
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Import files from the local fodler
|
| 29 |
+
root_path = os.path.abspath('.')
|
| 30 |
+
sys.path.append(root_path)
|
| 31 |
+
from pipelines.pipeline_cogvideox_i2v_motion_FrameINO import CogVideoXImageToVideoPipeline
|
| 32 |
+
from architecture.cogvideox_transformer_3d import CogVideoXTransformer3DModel
|
| 33 |
+
from data_loader.video_dataset_motion import VideoDataset_Motion
|
| 34 |
+
from architecture.transformer_wan import WanTransformer3DModel
|
| 35 |
+
from pipelines.pipeline_wan_i2v_motion_FrameINO import WanImageToVideoPipeline
|
| 36 |
+
from architecture.autoencoder_kl_wan import AutoencoderKLWan
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
MARKDOWN = \
|
| 41 |
+
"""
|
| 42 |
+
<div align='center'>
|
| 43 |
+
<h1> Frame In-N-Out </h1> \
|
| 44 |
+
<h2 style='font-weight: 450; font-size: 1rem; margin-bottom: 1rem;'>\
|
| 45 |
+
<a href='https://kiteretsu77.github.io/BoyangWang/'>Boyang Wang</a>, <a href='https://xuweiyichen.github.io/'>Xuweiyi Chen</a>, <a href='http://mgadelha.me/'>Matheus Gadelha</a>, <a href='https://sites.google.com/site/zezhoucheng/'>Zezhou Cheng</a>\
|
| 46 |
+
</h2> \
|
| 47 |
+
|
| 48 |
+
<div style="display: flex; flex-wrap: wrap; justify-content: center; gap: 2rem; margin-bottom: 1rem;">
|
| 49 |
+
<!-- 第一行按钮 -->
|
| 50 |
+
<a href="https://arxiv.org/abs/2505.21491" target="_blank"
|
| 51 |
+
style="display: inline-flex; align-items: center; padding: 0.5rem 1rem; background-color: #f0f0f0; /* 浅灰色背景 */ color: #333; /* 深色文字 */ text-decoration: none; border-radius: 9999px; font-weight: 500; transition: background-color 0.3s;">
|
| 52 |
+
<span style="margin-right: 0.5rem;">📄</span> <!-- 使用文档图标 -->
|
| 53 |
+
<span>Paper</span>
|
| 54 |
+
</a>
|
| 55 |
+
<a href="https://github.com/UVA-Computer-Vision-Lab/FrameINO" target="_blank"
|
| 56 |
+
style="display: inline-flex; align-items: center; padding: 0.5rem 1rem; background-color: #f0f0f0; color: #333; text-decoration: none; border-radius: 9999px; font-weight: 500; transition: background-color 0.3s;">
|
| 57 |
+
<span style="margin-right: 0.5rem;">💻</span> <!-- 使用电脑图标 -->
|
| 58 |
+
<span>GitHub</span>
|
| 59 |
+
</a>
|
| 60 |
+
<a href="https://uva-computer-vision-lab.github.io/Frame-In-N-Out" target="_blank"
|
| 61 |
+
style="display: inline-flex; align-items: center; padding: 0.5rem 1rem; background-color: #f0f0f0; color: #333; text-decoration: none; border-radius: 9999px; font-weight: 500; transition: background-color 0.3s;">
|
| 62 |
+
<span style="margin-right: 0.5rem;">🤖</span>
|
| 63 |
+
<span>Project Page</span>
|
| 64 |
+
</a>
|
| 65 |
+
<a href="https://huggingface.co/collections/uva-cv-lab/frame-in-n-out" target="_blank"
|
| 66 |
+
style="display: inline-flex; align-items: center; padding: 0.5rem 1rem; background-color: #f0f0f0; color: #333; text-decoration: none; border-radius: 9999px; font-weight: 500; transition: background-color 0.3s;">
|
| 67 |
+
<span style="margin-right: 0.5rem;">🤗</span>
|
| 68 |
+
<span>HF Model and Data</span>
|
| 69 |
+
</a>
|
| 70 |
+
</div>
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
</div>
|
| 74 |
+
|
| 75 |
+
Frame In-N-Out expands the first frame condition to a broader canvas region by setting top left and bottom right expansion amount,
|
| 76 |
+
and users could provide motion trajectory to existing objects or provide breaking new identity to enter the scene with motion trajectory, or both. <br>
|
| 77 |
+
The model we used here is <b>Wan2.2-5B</b> trained on our Frame In-N-Out control mechanism.
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
<br>
|
| 81 |
+
<b>Easiest way:</b>
|
| 82 |
+
Choose one example and then simply click <b>Generate</b>.
|
| 83 |
+
|
| 84 |
+
<br>
|
| 85 |
+
<br>
|
| 86 |
+
❗️❗️❗️Instruction Steps:<br>
|
| 87 |
+
1️⃣ Upload your first frame image. Set the size you want to resize to for <b>Resized Height for Input Image</b> and <b>Resized Width for Input Image</b>. <br>
|
| 88 |
+
2️⃣ Set your <b>canvas top left</b> and <b>bottom right expansion</b>. The combined height and width should be the multiplier of 32. <br>
|
| 89 |
+
PLEASE ENSURE that <b>Canvas HEIGHT = 704</b> and <b>Canvas WIDTH = 1280</b> for the best performance (current training resolution). <br>
|
| 90 |
+
3️⃣ Click <b>Build the Canvas</b>. <br>
|
| 91 |
+
4️⃣ Provide the trajectory of the main object in the canvas by clicking on the <b>Expanded Canvas</b>. <br>
|
| 92 |
+
5️⃣ Provide the ID reference image and its trajectory (optional). Also, write a detailed <b>text prompt</b>. <br>
|
| 93 |
+
Click the <b>Generate</b> button to start the Video Generation. <br>
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
If **Frame In-N-Out** is helpful, please help star the [GitHub Repo](https://github.com/UVA-Computer-Vision-Lab/FrameINO?tab=readme-ov-file). Thanks!
|
| 97 |
+
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# Color
|
| 103 |
+
all_color_codes = [(255, 0, 0), (255, 255, 0), (0, 255, 0), (0, 255, 255),
|
| 104 |
+
(255, 0, 255), (0, 0, 255), (128, 128, 128), (64, 224, 208),
|
| 105 |
+
(233, 150, 122)]
|
| 106 |
+
for _ in range(100): # Should not be over 100 colors
|
| 107 |
+
all_color_codes.append((random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)))
|
| 108 |
+
|
| 109 |
+
# Data Transforms
|
| 110 |
+
train_transforms = transforms.Compose(
|
| 111 |
+
[
|
| 112 |
+
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
|
| 113 |
+
]
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
######################################################## CogVideoX #################################################################
|
| 121 |
+
|
| 122 |
+
# Path Setting
|
| 123 |
+
model_code_name = "CogVideox"
|
| 124 |
+
base_model_id = "zai-org/CogVideoX-5b-I2V"
|
| 125 |
+
transformer_ckpt_path = "uva-cv-lab/FrameINO_CogVideoX_Stage2_MotionINO_v1.0"
|
| 126 |
+
|
| 127 |
+
# Load Model
|
| 128 |
+
transformer = CogVideoXTransformer3DModel.from_pretrained(transformer_ckpt_path, torch_dtype=torch.float16)
|
| 129 |
+
text_encoder = T5EncoderModel.from_pretrained(base_model_id, subfolder="text_encoder", torch_dtype=torch.float16)
|
| 130 |
+
vae = AutoencoderKLCogVideoX.from_pretrained(base_model_id, subfolder="vae", torch_dtype=torch.float16)
|
| 131 |
+
|
| 132 |
+
# Create pipeline and run inference
|
| 133 |
+
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
|
| 134 |
+
base_model_id,
|
| 135 |
+
text_encoder = text_encoder,
|
| 136 |
+
transformer = transformer,
|
| 137 |
+
vae = vae,
|
| 138 |
+
torch_dtype = torch.float16,
|
| 139 |
+
)
|
| 140 |
+
pipe.enable_model_cpu_offload()
|
| 141 |
+
|
| 142 |
+
#####################################################################################################################################
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
######################################################## Wan2.2 5B #################################################################
|
| 148 |
+
|
| 149 |
+
# Path Setting
|
| 150 |
+
model_code_name = "Wan"
|
| 151 |
+
base_model_id = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
|
| 152 |
+
transformer_ckpt_path = "uva-cv-lab/FrameINO_Wan2.2_5B_Stage2_MotionINO_v1.5"
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# Load model
|
| 156 |
+
print("Loading the model!")
|
| 157 |
+
transformer = WanTransformer3DModel.from_pretrained(transformer_ckpt_path, torch_dtype=torch.float16)
|
| 158 |
+
vae = AutoencoderKLWan.from_pretrained(base_model_id, subfolder="vae", torch_dtype=torch.float32)
|
| 159 |
+
|
| 160 |
+
# Create the pipeline
|
| 161 |
+
print("Loading the pipeline!")
|
| 162 |
+
pipe = WanImageToVideoPipeline.from_pretrained(base_model_id, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
|
| 163 |
+
pipe.to("cuda")
|
| 164 |
+
pipe.enable_model_cpu_offload()
|
| 165 |
+
|
| 166 |
+
#####################################################################################################################################
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
########################################################## Other Auxiliary Func #################################################################
|
| 172 |
+
|
| 173 |
+
# # Init SAM model
|
| 174 |
+
model_type = "vit_h" #vit-h has the most number of paramter
|
| 175 |
+
sam_pretrained_path = "pretrained/sam_vit_h_4b8939.pth"
|
| 176 |
+
if not os.path.exists(sam_pretrained_path):
|
| 177 |
+
os.system("wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P pretrained/")
|
| 178 |
+
sam = sam_model_registry[model_type](checkpoint = sam_pretrained_path).to(device="cuda")
|
| 179 |
+
sam_predictor = SamPredictor(sam) # There is a lot of setting here
|
| 180 |
+
|
| 181 |
+
#####################################################################################################################################
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# Examples Sample
|
| 187 |
+
def get_example():
|
| 188 |
+
case = [
|
| 189 |
+
[
|
| 190 |
+
'__assets__/horse.jpg',
|
| 191 |
+
480,
|
| 192 |
+
736,
|
| 193 |
+
128,
|
| 194 |
+
224,
|
| 195 |
+
96,
|
| 196 |
+
320,
|
| 197 |
+
'__assets__/sheep.png',
|
| 198 |
+
"A brown horse with a black mane walks to the right on a wooden path in a green forest, and then a white sheep enters from the left and walks toward it. Natural daylight, realistic texture, smooth motion, cinematic focus, 4K detail.",
|
| 199 |
+
[[[[299, 241], [390, 236], [461, 245], [521, 249], [565, 240], [612, 246], [666, 245]], [[449, 224], [488, 212], [512, 206], [531, 209], [552, 202], [581, 204], [609, 210], [657, 206], [703, 202], [716, 211]]], [[[24, 305], [104, 300], [167, 299], [219, 303], [270, 296], [295, 304]]]],
|
| 200 |
+
],
|
| 201 |
+
|
| 202 |
+
[
|
| 203 |
+
'__assets__/cup.jpg',
|
| 204 |
+
448,
|
| 205 |
+
736,
|
| 206 |
+
256,
|
| 207 |
+
64,
|
| 208 |
+
0,
|
| 209 |
+
480,
|
| 210 |
+
'__assets__/hand2.png',
|
| 211 |
+
"A human hand reaches into the frame, gently grabbing the black metal cup with a golden character design on the front, lifting it off the table and taking it away.",
|
| 212 |
+
[[[[565, 324], [473, 337], [386, 345], [346, 340], [339, 324], [352, 212], [328, 114], [328, 18], [348, 0]]]],
|
| 213 |
+
],
|
| 214 |
+
|
| 215 |
+
[
|
| 216 |
+
'__assets__/grass.jpg',
|
| 217 |
+
512,
|
| 218 |
+
800,
|
| 219 |
+
64,
|
| 220 |
+
64,
|
| 221 |
+
160,
|
| 222 |
+
416,
|
| 223 |
+
'__assets__/dog.png',
|
| 224 |
+
"A fluffy, adorable puppy joyfully sprints onto the bright green grass, its fur bouncing with each step as sunlight highlights its soft coat. The scene takes place in a peaceful park filled with tall trees casting gentle shadows across the lawn. After dashing forward with enthusiasm, the puppy slows to a happy trot, continuing farther ahead into the deeper area of the park, disappearing toward the more shaded grass beneath the trees.",
|
| 225 |
+
[[[[600, 412], [512, 394], [408, 358], [333, 336], [270, 313], [259, 260], [236, 222], [231, 180]], [[592, 392], [295, 305], [256, 217], [243, 163]]]],
|
| 226 |
+
],
|
| 227 |
+
|
| 228 |
+
[
|
| 229 |
+
'__assets__/man_scene.jpg',
|
| 230 |
+
576,
|
| 231 |
+
1024,
|
| 232 |
+
64,
|
| 233 |
+
32,
|
| 234 |
+
64,
|
| 235 |
+
224,
|
| 236 |
+
None,
|
| 237 |
+
"A single hiker, equipped with a backpack, walks toward the right side of a rugged mountainside trail. The bright sunlight highlights the pale rocky terrain around him, while massive stone cliffs loom in the background. Sparse patches of grass and scattered boulders sit along the path, emphasizing the isolation and vastness of the mountain environment as he steadily continues his journey.",
|
| 238 |
+
[[[[342, 247], [415, 247], [478, 262], [518, 271], [570, 275], [613, 283], [646, 308], [690, 307], [705, 325]], [[349, 227], [461, 232], [536, 254], [595, 252], [638, 269], [691, 289], [715, 291]], [[341, 283], [415, 291], [500, 316], [590, 317], [632, 354], [675, 362], [711, 372]]]],
|
| 239 |
+
]
|
| 240 |
+
|
| 241 |
+
]
|
| 242 |
+
return case
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def on_example_click(
|
| 248 |
+
input_image, resized_height, resized_width,
|
| 249 |
+
top_left_height, top_left_width, bottom_right_height, bottom_right_width,
|
| 250 |
+
identity_image, text_prompt, traj_lists,
|
| 251 |
+
):
|
| 252 |
+
|
| 253 |
+
# Convert
|
| 254 |
+
traj_lists = ast.literal_eval(traj_lists)
|
| 255 |
+
# Note: No need for the rest like resized_width and resized_height, because these will be replaced in function
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# Sequentially build the canvas (We don't accept the empty traj_lists & traj_instance_idx returned by build_canvas)
|
| 259 |
+
visual_canvas, initial_visual_canvas, inference_canvas, _, _ = build_canvas(input_image, resized_height, resized_width, top_left_height, top_left_width, bottom_right_height, bottom_right_width)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
# Sequentially load the Trajs of all instances on the canvas
|
| 263 |
+
visual_canvas, traj_instance_idx = fn_vis_all_instance_traj(visual_canvas, traj_lists)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
return visual_canvas, initial_visual_canvas, inference_canvas, traj_instance_idx
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def build_canvas(input_image_path, resized_height, resized_width, top_left_height, top_left_width, bottom_right_height, bottom_right_width):
|
| 271 |
+
|
| 272 |
+
# Init
|
| 273 |
+
canvas_color = (250, 249, 246) # This color is like white color used in painting paper
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# Convert the string to integer
|
| 277 |
+
if not resized_height.isdigit():
|
| 278 |
+
raise gr.Error("resized_height must be integer input!")
|
| 279 |
+
resized_height = int(resized_height)
|
| 280 |
+
|
| 281 |
+
if not resized_width.isdigit():
|
| 282 |
+
raise gr.Error("resized_width must be integer input!")
|
| 283 |
+
resized_width = int(resized_width)
|
| 284 |
+
|
| 285 |
+
if not top_left_height.isdigit():
|
| 286 |
+
raise gr.Error("top_left_height must be integer input!")
|
| 287 |
+
top_left_height = int(top_left_height)
|
| 288 |
+
|
| 289 |
+
if not top_left_width.isdigit():
|
| 290 |
+
raise gr.Error("top_left_width must be integer input!")
|
| 291 |
+
top_left_width = int(top_left_width)
|
| 292 |
+
|
| 293 |
+
if not bottom_right_height.isdigit():
|
| 294 |
+
raise gr.Error("bottom_right_height must be integer input!")
|
| 295 |
+
bottom_right_height = int(bottom_right_height)
|
| 296 |
+
|
| 297 |
+
if not bottom_right_width.isdigit():
|
| 298 |
+
raise gr.Error("bottom_right_width must be integer input!")
|
| 299 |
+
bottom_right_width = int(bottom_right_width)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
# Read the original image and preprare the placeholder
|
| 304 |
+
first_frame_img = np.uint8(np.asarray(Image.open(input_image_path))) # NOTE: this is BGR form, be careful for the later cropping process for ID Reference
|
| 305 |
+
# print("first_frame_img shape is ", first_frame_img.shape)
|
| 306 |
+
|
| 307 |
+
# Resize to a uniform resolution
|
| 308 |
+
first_frame_img = cv2.resize(first_frame_img, (resized_width, resized_height), interpolation = cv2.INTER_AREA)
|
| 309 |
+
|
| 310 |
+
# Expand to Outside Region to form the Canvas
|
| 311 |
+
expand_height = resized_height + top_left_height + bottom_right_height
|
| 312 |
+
expand_width = resized_width + top_left_width + bottom_right_width
|
| 313 |
+
inference_canvas = np.uint8(np.zeros((expand_height, expand_width, 3))) # Whole Black Canvas, same as other inference
|
| 314 |
+
visual_canvas = np.full((expand_height, expand_width, 3), canvas_color, dtype=np.uint8)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
# Sanity Check
|
| 318 |
+
if expand_height % 32 != 0:
|
| 319 |
+
raise gr.Error("The Height of resized_height + top_left_height + bottom_right_height must be divisible by 32!")
|
| 320 |
+
if expand_width % 32 != 0:
|
| 321 |
+
raise gr.Error("The Width of resized_width + top_left_width + bottom_right_width must be divisible by 32!")
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
# Draw the Region Box Region (Original Resolution)
|
| 325 |
+
bottom_len = inference_canvas.shape[0] - bottom_right_height
|
| 326 |
+
right_len = inference_canvas.shape[1] - bottom_right_width
|
| 327 |
+
inference_canvas[top_left_height:bottom_len, top_left_width:right_len, :] = first_frame_img
|
| 328 |
+
visual_canvas[top_left_height:bottom_len, top_left_width:right_len, :] = first_frame_img
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# Resize to the uniform height and width
|
| 332 |
+
visual_canvas = cv2.resize(visual_canvas, (uniform_width, uniform_height), interpolation = cv2.INTER_AREA)
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
# Return the visual_canvas (for visualizaiton) and canvas map
|
| 337 |
+
# Corresponds to: visual_canvas, initial_visual_canvas, inference_canvas, traj_instance_idx, traj_lists
|
| 338 |
+
return visual_canvas, visual_canvas.copy(), inference_canvas, 0, [ [ [] ] ] # The last two is initialized with the trajectory instance idx and trajectory list
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def process_points(traj_list, num_frames=49):
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
if len(traj_list) < 2: # First point
|
| 347 |
+
return [traj_list[0]] * num_frames
|
| 348 |
+
|
| 349 |
+
elif len(traj_list) >= num_frames:
|
| 350 |
+
raise gr.Info("The number of trajectory points is more than 49 limits, we will do cropping!")
|
| 351 |
+
skip = len(traj_list) // num_frames
|
| 352 |
+
return traj_list[::skip][: num_frames - 1] + traj_list[-1:]
|
| 353 |
+
|
| 354 |
+
else:
|
| 355 |
+
|
| 356 |
+
insert_num = num_frames - len(traj_list)
|
| 357 |
+
insert_num_dict = {}
|
| 358 |
+
interval = len(traj_list) - 1
|
| 359 |
+
n = insert_num // interval
|
| 360 |
+
m = insert_num % interval
|
| 361 |
+
|
| 362 |
+
for i in range(interval):
|
| 363 |
+
insert_num_dict[i] = n
|
| 364 |
+
|
| 365 |
+
for i in range(m):
|
| 366 |
+
insert_num_dict[i] += 1
|
| 367 |
+
|
| 368 |
+
res = []
|
| 369 |
+
for i in range(interval):
|
| 370 |
+
insert_points = []
|
| 371 |
+
x0, y0 = traj_list[i]
|
| 372 |
+
x1, y1 = traj_list[i + 1]
|
| 373 |
+
|
| 374 |
+
delta_x = x1 - x0
|
| 375 |
+
delta_y = y1 - y0
|
| 376 |
+
for j in range(insert_num_dict[i]):
|
| 377 |
+
x = x0 + (j + 1) / (insert_num_dict[i] + 1) * delta_x
|
| 378 |
+
y = y0 + (j + 1) / (insert_num_dict[i] + 1) * delta_y
|
| 379 |
+
insert_points.append([int(x), int(y)])
|
| 380 |
+
|
| 381 |
+
res += traj_list[i : i + 1] + insert_points
|
| 382 |
+
res += traj_list[-1:]
|
| 383 |
+
|
| 384 |
+
# return
|
| 385 |
+
return res
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def fn_vis_realtime_traj(visual_canvas, traj_list, traj_instance_idx): # Visualize the traj on canvas
|
| 390 |
+
|
| 391 |
+
# Process Points
|
| 392 |
+
points = process_points(traj_list)
|
| 393 |
+
|
| 394 |
+
# Draw straight line to connect
|
| 395 |
+
for i in range(len(points) - 1):
|
| 396 |
+
p = points[i]
|
| 397 |
+
p1 = points[i + 1]
|
| 398 |
+
cv2.line(visual_canvas, p, p1, all_color_codes[traj_instance_idx], 5)
|
| 399 |
+
|
| 400 |
+
return visual_canvas
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def fn_vis_all_instance_traj(visual_canvas, traj_lists): # Visualize all traj from all instances on canvas
|
| 404 |
+
|
| 405 |
+
for traj_instance_idx, traj_list_instance in enumerate(traj_lists):
|
| 406 |
+
for traj_list_line in traj_list_instance:
|
| 407 |
+
visual_canvas = fn_vis_realtime_traj(visual_canvas, traj_list_line, traj_instance_idx)
|
| 408 |
+
|
| 409 |
+
return visual_canvas, traj_instance_idx # Also return the instance idx
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def add_traj_point(
|
| 413 |
+
visual_canvas,
|
| 414 |
+
traj_lists,
|
| 415 |
+
traj_instance_idx,
|
| 416 |
+
evt: gr.SelectData,
|
| 417 |
+
): # Add new Traj and then visualize
|
| 418 |
+
|
| 419 |
+
# Convert
|
| 420 |
+
traj_lists = ast.literal_eval(traj_lists)
|
| 421 |
+
|
| 422 |
+
# Mark New Trajectory Key Point
|
| 423 |
+
hotizontal, vertical = evt.index
|
| 424 |
+
|
| 425 |
+
# traj_lists data structure is: (Num of Instnace, Num of Trajecotries, Num of Points, [X, Y])
|
| 426 |
+
traj_lists[-1][-1].append( [int(hotizontal), int(vertical)] )
|
| 427 |
+
|
| 428 |
+
# Draw new trajectory on the Canvas image
|
| 429 |
+
visual_canvas = fn_vis_realtime_traj(visual_canvas, traj_lists[-1][-1], traj_instance_idx)
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
# Return New Traj Marked Canvas image
|
| 433 |
+
return visual_canvas, traj_lists
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def clear_traj_points(initial_visual_canvas):
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
return initial_visual_canvas.copy(), 0, [ [ [] ] ] # 1sr One is the initial state canvas; 2nd one is the traj instance idx; 3rd one is the traj list (with the same data structure)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def traj_point_update(traj_lists):
|
| 444 |
+
|
| 445 |
+
# Convert
|
| 446 |
+
traj_lists = ast.literal_eval(traj_lists)
|
| 447 |
+
|
| 448 |
+
# Append on the last trajecotry line
|
| 449 |
+
traj_lists[-1].append([])
|
| 450 |
+
|
| 451 |
+
return traj_lists
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def traj_instance_update(traj_instance_idx, traj_lists):
|
| 456 |
+
|
| 457 |
+
# Convert
|
| 458 |
+
traj_lists = ast.literal_eval(traj_lists)
|
| 459 |
+
|
| 460 |
+
# Update one index
|
| 461 |
+
if traj_instance_idx >= len(all_color_codes):
|
| 462 |
+
raise gr.Error("The trajectory instance number is over the limit!")
|
| 463 |
+
|
| 464 |
+
# Add one for the traj instance
|
| 465 |
+
traj_instance_idx = traj_instance_idx + 1
|
| 466 |
+
|
| 467 |
+
# Append a new empty list to the traj lists
|
| 468 |
+
traj_lists.append([[]])
|
| 469 |
+
|
| 470 |
+
# Reutn
|
| 471 |
+
return traj_instance_idx, traj_lists
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def sample_traj_by_length(points, num_samples):
|
| 476 |
+
# Sample points evenly from traj based on the euclidean distance
|
| 477 |
+
|
| 478 |
+
pts = np.array(points, dtype=float) # shape (M, 2)
|
| 479 |
+
|
| 480 |
+
# 1) 每段长度
|
| 481 |
+
seg = pts[1:] - pts[:-1]
|
| 482 |
+
seg_len = np.sqrt((seg**2).sum(axis=1)) # shape (M-1,)
|
| 483 |
+
|
| 484 |
+
# 2) 累积长度
|
| 485 |
+
cum = np.cumsum(seg_len)
|
| 486 |
+
total_length = cum[-1]
|
| 487 |
+
|
| 488 |
+
# 3) 目标等距长度位置
|
| 489 |
+
target = np.linspace(0, total_length, num_samples)
|
| 490 |
+
|
| 491 |
+
res = []
|
| 492 |
+
for t in target:
|
| 493 |
+
# 4) 找到它落在哪一段
|
| 494 |
+
idx = np.searchsorted(cum, t)
|
| 495 |
+
if idx == 0:
|
| 496 |
+
prev = 0.
|
| 497 |
+
else:
|
| 498 |
+
prev = cum[idx-1]
|
| 499 |
+
|
| 500 |
+
# 5) 在该段内插值
|
| 501 |
+
ratio = (t - prev) / seg_len[idx]
|
| 502 |
+
p = pts[idx] * ratio + pts[idx+1] * (1-ratio) # careful: direction reversed?
|
| 503 |
+
# Actually want: start*(1-ratio) + end*ratio
|
| 504 |
+
p = pts[idx] * (1 - ratio) + pts[idx+1] * ratio
|
| 505 |
+
res.append(p)
|
| 506 |
+
return np.array(res)
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
def inference(inference_canvas, visual_canvas, text_prompt, traj_lists, main_reference_img,
|
| 511 |
+
resized_height, resized_width, top_left_height, top_left_width, bottom_right_height, bottom_right_width):
|
| 512 |
+
|
| 513 |
+
# TODO: enhance the text prompt by Qwen3-VL-32B?
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
# Convert
|
| 517 |
+
resized_height = int(resized_height)
|
| 518 |
+
resized_width = int(resized_width)
|
| 519 |
+
top_left_height = int(top_left_height)
|
| 520 |
+
top_left_width = int(top_left_width)
|
| 521 |
+
bottom_right_height = int(bottom_right_height)
|
| 522 |
+
bottom_right_width = int(bottom_right_width)
|
| 523 |
+
traj_lists = ast.literal_eval(traj_lists)
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
# Init Some Fixed Setting
|
| 528 |
+
if model_code_name == "Wan":
|
| 529 |
+
config_path = "config/train_wan_motion_FrameINO.yaml"
|
| 530 |
+
dot_radius = 7
|
| 531 |
+
num_frames = 81
|
| 532 |
+
elif model_code_name == "CogVideoX":
|
| 533 |
+
config_path = "config/train_cogvideox_i2v_motion_FrameINO.yaml"
|
| 534 |
+
dot_radius = 6
|
| 535 |
+
num_frames = 49
|
| 536 |
+
config = OmegaConf.load(config_path)
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
# Prepare tmp folders
|
| 540 |
+
print()
|
| 541 |
+
store_folder_path = "tmp_app_example_" + str(int(time.time()))
|
| 542 |
+
if os.path.exists(store_folder_path):
|
| 543 |
+
shutil.rmtree(store_folder_path)
|
| 544 |
+
os.makedirs(store_folder_path)
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
# Write the visual canvas
|
| 548 |
+
visual_canvas_store_path = os.path.join(store_folder_path, "visual_canvas.png")
|
| 549 |
+
cv2.imwrite( visual_canvas_store_path, cv2.cvtColor(visual_canvas, cv2.COLOR_BGR2RGB) )
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
# Resize the map
|
| 554 |
+
canvas_width = resized_width + top_left_width + bottom_right_width
|
| 555 |
+
canvas_height = resized_height + top_left_height + bottom_right_height
|
| 556 |
+
# inference_canvas = cv2.resize(visual_canvas, (canvas_width, canvas_height), interpolation = cv2.INTER_AREA)
|
| 557 |
+
print("Canvas Shape is", str(canvas_height) + "x" + str(canvas_width) )
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
# TODO: 还要去enhance这个text prompt要跟QWen的保持一致的complexity的感觉。。。
|
| 561 |
+
|
| 562 |
+
# Save the text prompt
|
| 563 |
+
print("Text Prompt is", text_prompt)
|
| 564 |
+
with open(os.path.join(store_folder_path, 'text_prompt.txt'), 'w') as file:
|
| 565 |
+
file.write(text_prompt)
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
################################################## Motion Trajectory Condition #####################################################
|
| 569 |
+
|
| 570 |
+
# #Prepare the points in the linear way
|
| 571 |
+
full_pred_tracks = [[] for _ in range(num_frames)]
|
| 572 |
+
ID_tensor = None
|
| 573 |
+
|
| 574 |
+
# Iterate all tracking information for all objects
|
| 575 |
+
print("traj_lists is", traj_lists)
|
| 576 |
+
for instance_idx, traj_list_per_object in enumerate(traj_lists):
|
| 577 |
+
|
| 578 |
+
# Iterate all trajectory lines in one instance
|
| 579 |
+
for traj_idx, single_trajectory in enumerate(traj_list_per_object):
|
| 580 |
+
|
| 581 |
+
# Sanity Check
|
| 582 |
+
if len(single_trajectory) < 2:
|
| 583 |
+
raise gr.Error("One of the trajectory provided is too short!")
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
# Sampled the point based on the Euclidean distance
|
| 587 |
+
sampled_points = sample_traj_by_length(single_trajectory, num_frames)
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
# Iterate all points
|
| 591 |
+
temporal_idx = 0
|
| 592 |
+
for (raw_point_x, raw_point_y) in sampled_points:
|
| 593 |
+
|
| 594 |
+
# Scale the point coordinate to the Infernece Size (Realistic Canvas size)
|
| 595 |
+
point_x, point_y = int(raw_point_x * canvas_width / uniform_width), int(raw_point_y * canvas_height / uniform_height) # Clicking on the board is with respect to the Uniform Preset Height and Width
|
| 596 |
+
|
| 597 |
+
if traj_idx == 0: # Needs to init the list in list
|
| 598 |
+
full_pred_tracks[temporal_idx].append( [] )
|
| 599 |
+
full_pred_tracks[temporal_idx][-1].append( (point_x, point_y) ) # [-1] and [instance_idx] should have the same effect
|
| 600 |
+
temporal_idx += 1
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
# Create the traj tensor
|
| 604 |
+
traj_tensor, traj_imgs_np, _, img_with_traj = VideoDataset_Motion.prepare_traj_tensor(
|
| 605 |
+
full_pred_tracks, canvas_height, canvas_width,
|
| 606 |
+
[], dot_radius, canvas_width, canvas_height,
|
| 607 |
+
idx=0, first_frame_img = inference_canvas
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
# Store Trajectory
|
| 612 |
+
imageio.mimsave(os.path.join(store_folder_path, "traj_video.mp4"), traj_imgs_np, fps=8)
|
| 613 |
+
|
| 614 |
+
######################################################################################################################################################
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
########################################## Prepare the Identity Reference Condition #####################################################
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
# ID reference preparation
|
| 622 |
+
if main_reference_img is not None:
|
| 623 |
+
print("We have an ID reference being used!")
|
| 624 |
+
|
| 625 |
+
# Fetch
|
| 626 |
+
ref_h, ref_w, _ = main_reference_img.shape
|
| 627 |
+
|
| 628 |
+
|
| 629 |
+
# Using breakpoint to extract the points
|
| 630 |
+
sam_predictor.set_image(np.uint8(main_reference_img))
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
# Define the sample point
|
| 634 |
+
sam_points = [(ref_w//2, ref_h//2)] # We don't need that many points to express [:len(traj_points)//2]
|
| 635 |
+
|
| 636 |
+
|
| 637 |
+
# Reverse traj_points
|
| 638 |
+
positive_point_cords = np.array(sam_points)
|
| 639 |
+
positive_point_labels = np.ones(len(positive_point_cords))
|
| 640 |
+
|
| 641 |
+
# Predict the mask based on the point and bounding box designed
|
| 642 |
+
masks, scores, logits = sam_predictor.predict(
|
| 643 |
+
point_coords = positive_point_cords,
|
| 644 |
+
point_labels = positive_point_labels,
|
| 645 |
+
multimask_output = False,
|
| 646 |
+
)
|
| 647 |
+
mask = masks[0]
|
| 648 |
+
main_reference_img[mask == False] = 0 # Merge the mask the first first frame
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
# Resize to the same resolution as the first frame
|
| 652 |
+
scale_h = canvas_height / max(ref_h, ref_w)
|
| 653 |
+
scale_w = canvas_width / max(ref_h, ref_w)
|
| 654 |
+
new_h, new_w = int(ref_h * scale_h), int(ref_w * scale_w)
|
| 655 |
+
main_reference_img = cv2.resize(main_reference_img, (new_w, new_h), interpolation = cv2.INTER_AREA)
|
| 656 |
+
|
| 657 |
+
# Calculate padding amounts on all direction
|
| 658 |
+
pad_height1 = (canvas_height - main_reference_img.shape[0]) // 2
|
| 659 |
+
pad_height2 = canvas_height - main_reference_img.shape[0] - pad_height1
|
| 660 |
+
pad_width1 = (canvas_width - main_reference_img.shape[1]) // 2
|
| 661 |
+
pad_width2 = canvas_width - main_reference_img.shape[1] - pad_width1
|
| 662 |
+
|
| 663 |
+
# Apply padding to same resolution as the training farmes
|
| 664 |
+
main_reference_img = np.pad(
|
| 665 |
+
main_reference_img,
|
| 666 |
+
((pad_height1, pad_height2), (pad_width1, pad_width2), (0, 0)),
|
| 667 |
+
mode = 'constant',
|
| 668 |
+
constant_values = 0
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
cv2.imwrite(os.path.join(store_folder_path, "ID.png"), cv2.cvtColor(main_reference_img, cv2.COLOR_BGR2RGB))
|
| 672 |
+
|
| 673 |
+
elif main_reference_img is None:
|
| 674 |
+
# Whole Black Color placeholder
|
| 675 |
+
main_reference_img = np.uint8(np.zeros((canvas_height, canvas_width, 3)))
|
| 676 |
+
|
| 677 |
+
|
| 678 |
+
# Convert to tensor
|
| 679 |
+
ID_tensor = torch.tensor(main_reference_img)
|
| 680 |
+
ID_tensor = train_transforms(ID_tensor).permute(2, 0, 1).contiguous()
|
| 681 |
+
|
| 682 |
+
if model_code_name == "Wan": # Needs to be the shape (B, C, F, H, W)
|
| 683 |
+
ID_tensor = ID_tensor.unsqueeze(0).unsqueeze(2)
|
| 684 |
+
|
| 685 |
+
###############################################################################################################################################
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
|
| 689 |
+
############################################# Call the Inference Pipeline ##########################################################
|
| 690 |
+
|
| 691 |
+
image = Image.fromarray(inference_canvas)
|
| 692 |
+
|
| 693 |
+
if model_code_name == "Wan":
|
| 694 |
+
video = pipe(
|
| 695 |
+
image = image,
|
| 696 |
+
prompt = text_prompt, negative_prompt = "", # Empty string as negative text prompt
|
| 697 |
+
traj_tensor = traj_tensor, # Should be shape (F, C, H, W)
|
| 698 |
+
ID_tensor = ID_tensor, # Should be shape (B, C, F, H, W)
|
| 699 |
+
height = canvas_height, width = canvas_width, num_frames = num_frames,
|
| 700 |
+
num_inference_steps = 50, # 38 is also ok
|
| 701 |
+
guidance_scale = 5.0,
|
| 702 |
+
).frames[0]
|
| 703 |
+
|
| 704 |
+
elif model_code_name == "CogVideoX":
|
| 705 |
+
video = pipe(
|
| 706 |
+
image = image,
|
| 707 |
+
prompt = text_prompt,
|
| 708 |
+
traj_tensor = traj_tensor,
|
| 709 |
+
ID_tensor = ID_tensor,
|
| 710 |
+
height = canvas_height, width = canvas_width, num_frames = len(traj_tensor),
|
| 711 |
+
guidance_scale = 6, use_dynamic_cfg = False,
|
| 712 |
+
num_inference_steps = 50,
|
| 713 |
+
add_ID_reference_augment_noise = True,
|
| 714 |
+
).frames[0]
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
# Store the reuslt
|
| 719 |
+
export_to_video(video, os.path.join(store_folder_path, "generated_video_padded.mp4"), fps=8)
|
| 720 |
+
|
| 721 |
+
|
| 722 |
+
|
| 723 |
+
# Save frames
|
| 724 |
+
print("Writing as Frames")
|
| 725 |
+
video_file_path = os.path.join(store_folder_path, "generated_video.mp4")
|
| 726 |
+
writer = imageio.get_writer(video_file_path, fps = 8)
|
| 727 |
+
for frame_idx, frame in enumerate(video):
|
| 728 |
+
|
| 729 |
+
# Extract Unpadded version
|
| 730 |
+
# frame = np.uint8(frame)
|
| 731 |
+
if model_code_name == "CogVideoX":
|
| 732 |
+
frame = np.asarray(frame) # PIL to RGB
|
| 733 |
+
bottom_right_y = frame.shape[0] - bottom_right_height
|
| 734 |
+
bottom_right_x = frame.shape[1] - bottom_right_width
|
| 735 |
+
cropped_region_frame = np.uint8(frame[top_left_height: bottom_right_y, top_left_width : bottom_right_x] * 255)
|
| 736 |
+
writer.append_data(cropped_region_frame)
|
| 737 |
+
|
| 738 |
+
writer.close()
|
| 739 |
+
|
| 740 |
+
#####################################################################################################################################
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
return gr.update(value = video_file_path, width = uniform_width, height = uniform_height)
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
if __name__ == '__main__':
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
# Global Setting
|
| 752 |
+
uniform_height = 480 # Visual Canvas as 480x720 is decent
|
| 753 |
+
uniform_width = 720
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
# Draw the Website
|
| 757 |
+
block = gr.Blocks().queue(max_size=10)
|
| 758 |
+
with block:
|
| 759 |
+
|
| 760 |
+
|
| 761 |
+
with gr.Row():
|
| 762 |
+
gr.Markdown(MARKDOWN)
|
| 763 |
+
|
| 764 |
+
with gr.Row(elem_classes=["container"]):
|
| 765 |
+
|
| 766 |
+
with gr.Column(scale=2):
|
| 767 |
+
# Input image
|
| 768 |
+
input_image = gr.Image(type="filepath", label="Input Image 🖼️ ")
|
| 769 |
+
# uploaded_files = gr.Gallery(label="Your images", visible=False, columns=5, rows=1, height=200)
|
| 770 |
+
|
| 771 |
+
with gr.Column(scale=2):
|
| 772 |
+
|
| 773 |
+
# Input image
|
| 774 |
+
resized_height = gr.Textbox(label="Resized Height for Input Image")
|
| 775 |
+
resized_width = gr.Textbox(label="Resized Width for Input Image")
|
| 776 |
+
# gr.Number(value=unit_height, label="Fixed", interactive=False)
|
| 777 |
+
# gr.Number(value=unit_height * 1.77777, label="Fixed", interactive=False)
|
| 778 |
+
|
| 779 |
+
# Input the expansion factor
|
| 780 |
+
top_left_height = gr.Textbox(label="Top-Left Expand Height")
|
| 781 |
+
top_left_width = gr.Textbox(label="Top-Left Expand Width")
|
| 782 |
+
bottom_right_height = gr.Textbox(label="Bottom-Right Expand Height")
|
| 783 |
+
bottom_right_width = gr.Textbox(label="Bottom-Right Expand Width")
|
| 784 |
+
|
| 785 |
+
# Button
|
| 786 |
+
build_canvas_btn = gr.Button(value="Build the Canvas")
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
with gr.Row():
|
| 790 |
+
|
| 791 |
+
with gr.Column(scale=3):
|
| 792 |
+
with gr.Row(scale=3):
|
| 793 |
+
visual_canvas = gr.Image(height = uniform_height, width = uniform_width, type="numpy", label='Expanded Canvas 🖼️ ')
|
| 794 |
+
# inference_canvas = gr.Image(height = uniform_height, width = uniform_width, type="numpy")
|
| 795 |
+
# inference_canvas = None
|
| 796 |
+
|
| 797 |
+
with gr.Row(scale=1):
|
| 798 |
+
# TODO: 还差clear traj的选择
|
| 799 |
+
add_point = gr.Button(value = "Add New Traj Line (Same Obj)", visible = True) # Add new trajectory for the same instance
|
| 800 |
+
add_traj = gr.Button(value = "Add New Instance (New Obj, including new ID)", visible = True)
|
| 801 |
+
clear_traj_button = gr.Button("Clear All Traj", visible=True)
|
| 802 |
+
|
| 803 |
+
with gr.Column(scale=2):
|
| 804 |
+
|
| 805 |
+
with gr.Row(scale=2):
|
| 806 |
+
identity_image = gr.Image(type="numpy", label="Identity Reference (SAM on center point only) 🖼️ ")
|
| 807 |
+
|
| 808 |
+
with gr.Row(scale=2):
|
| 809 |
+
text_prompt = gr.Textbox(label="Text Prompt", lines=3)
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
with gr.Row():
|
| 813 |
+
|
| 814 |
+
# Button
|
| 815 |
+
generation_btn = gr.Button(value="Generate")
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
with gr.Row():
|
| 819 |
+
generated_video = gr.Video(value = None, label="Generated Video", show_label = True, height = uniform_height, width = uniform_width)
|
| 820 |
+
|
| 821 |
+
|
| 822 |
+
|
| 823 |
+
################################################################## Click + Select + Any Effect Area ###########################################################################
|
| 824 |
+
|
| 825 |
+
# Init some states that will be supporting purposes
|
| 826 |
+
traj_lists = gr.Textbox(label="Trajectory", visible = False) # gr.State(None) # Data Structure is: (Number of Instance, Number of Trajectories, Points) Init as [ [ [] ] ]
|
| 827 |
+
inference_canvas = gr.State(None)
|
| 828 |
+
traj_instance_idx = gr.State(0)
|
| 829 |
+
initial_visual_canvas = gr.State(None) # gr.Image(height = uniform_height, width = uniform_width, type="numpy", label='Canvas Expanded Image (Initial State)') # This is the initila visual, used to load back in clearing
|
| 830 |
+
|
| 831 |
+
|
| 832 |
+
# Canvas Click
|
| 833 |
+
build_canvas_btn.click(
|
| 834 |
+
build_canvas,
|
| 835 |
+
inputs = [input_image, resized_height, resized_width, top_left_height, top_left_width, bottom_right_height, bottom_right_width],
|
| 836 |
+
outputs = [visual_canvas, initial_visual_canvas, inference_canvas, traj_instance_idx, traj_lists] # inference_canvas is used for inference; visual_canvas is for gradio visualization
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
# Draw Trajectory for each click on the canvas
|
| 841 |
+
visual_canvas.select(
|
| 842 |
+
fn = add_traj_point,
|
| 843 |
+
inputs = [visual_canvas, traj_lists, traj_instance_idx],
|
| 844 |
+
outputs = [visual_canvas, traj_lists]
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
# Add new Trajectory
|
| 849 |
+
add_point.click(
|
| 850 |
+
fn = traj_point_update,
|
| 851 |
+
inputs = [traj_lists],
|
| 852 |
+
outputs = [traj_lists],
|
| 853 |
+
)
|
| 854 |
+
add_traj.click(
|
| 855 |
+
fn = traj_instance_update,
|
| 856 |
+
inputs = [traj_instance_idx, traj_lists],
|
| 857 |
+
outputs = [traj_instance_idx, traj_lists],
|
| 858 |
+
)
|
| 859 |
+
|
| 860 |
+
# Clean all the traj points
|
| 861 |
+
clear_traj_button.click(
|
| 862 |
+
clear_traj_points,
|
| 863 |
+
[initial_visual_canvas],
|
| 864 |
+
[visual_canvas, traj_instance_idx, traj_lists],
|
| 865 |
+
)
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
# Inference Generation
|
| 869 |
+
generation_btn.click(
|
| 870 |
+
inference,
|
| 871 |
+
inputs = [inference_canvas, visual_canvas, text_prompt, traj_lists, identity_image, resized_height, resized_width, top_left_height, top_left_width, bottom_right_height, bottom_right_width],
|
| 872 |
+
outputs = [generated_video],
|
| 873 |
+
)
|
| 874 |
+
|
| 875 |
+
|
| 876 |
+
|
| 877 |
+
|
| 878 |
+
# Load Examples
|
| 879 |
+
with gr.Row(elem_classes=["container"]):
|
| 880 |
+
gr.Examples(
|
| 881 |
+
examples = get_example(),
|
| 882 |
+
inputs = [input_image, resized_height, resized_width, top_left_height, top_left_width, bottom_right_height, bottom_right_width, identity_image, text_prompt, traj_lists],
|
| 883 |
+
run_on_click = True,
|
| 884 |
+
fn = on_example_click,
|
| 885 |
+
outputs = [visual_canvas, initial_visual_canvas, inference_canvas, traj_instance_idx],
|
| 886 |
+
)
|
| 887 |
+
|
| 888 |
+
|
| 889 |
+
block.launch(share=True)
|
| 890 |
+
|
| 891 |
+
|
| 892 |
+
|
| 893 |
+
|
architecture/attention_processor.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
architecture/autoencoder_kl_wan.py
ADDED
|
@@ -0,0 +1,1419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import List, Optional, Tuple, Union
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
import torch.utils.checkpoint
|
| 21 |
+
|
| 22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from diffusers.loaders import FromOriginalModelMixin
|
| 24 |
+
from diffusers.utils import logging
|
| 25 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
| 26 |
+
from diffusers.models.activations import get_activation
|
| 27 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 28 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 29 |
+
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 33 |
+
|
| 34 |
+
CACHE_T = 2
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class AvgDown3D(nn.Module):
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
in_channels,
|
| 41 |
+
out_channels,
|
| 42 |
+
factor_t,
|
| 43 |
+
factor_s=1,
|
| 44 |
+
):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.in_channels = in_channels
|
| 47 |
+
self.out_channels = out_channels
|
| 48 |
+
self.factor_t = factor_t
|
| 49 |
+
self.factor_s = factor_s
|
| 50 |
+
self.factor = self.factor_t * self.factor_s * self.factor_s
|
| 51 |
+
|
| 52 |
+
assert in_channels * self.factor % out_channels == 0
|
| 53 |
+
self.group_size = in_channels * self.factor // out_channels
|
| 54 |
+
|
| 55 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 56 |
+
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
| 57 |
+
pad = (0, 0, 0, 0, pad_t, 0)
|
| 58 |
+
x = F.pad(x, pad)
|
| 59 |
+
B, C, T, H, W = x.shape
|
| 60 |
+
x = x.view(
|
| 61 |
+
B,
|
| 62 |
+
C,
|
| 63 |
+
T // self.factor_t,
|
| 64 |
+
self.factor_t,
|
| 65 |
+
H // self.factor_s,
|
| 66 |
+
self.factor_s,
|
| 67 |
+
W // self.factor_s,
|
| 68 |
+
self.factor_s,
|
| 69 |
+
)
|
| 70 |
+
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
| 71 |
+
x = x.view(
|
| 72 |
+
B,
|
| 73 |
+
C * self.factor,
|
| 74 |
+
T // self.factor_t,
|
| 75 |
+
H // self.factor_s,
|
| 76 |
+
W // self.factor_s,
|
| 77 |
+
)
|
| 78 |
+
x = x.view(
|
| 79 |
+
B,
|
| 80 |
+
self.out_channels,
|
| 81 |
+
self.group_size,
|
| 82 |
+
T // self.factor_t,
|
| 83 |
+
H // self.factor_s,
|
| 84 |
+
W // self.factor_s,
|
| 85 |
+
)
|
| 86 |
+
x = x.mean(dim=2)
|
| 87 |
+
return x
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class DupUp3D(nn.Module):
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
in_channels: int,
|
| 94 |
+
out_channels: int,
|
| 95 |
+
factor_t,
|
| 96 |
+
factor_s=1,
|
| 97 |
+
):
|
| 98 |
+
super().__init__()
|
| 99 |
+
self.in_channels = in_channels
|
| 100 |
+
self.out_channels = out_channels
|
| 101 |
+
|
| 102 |
+
self.factor_t = factor_t
|
| 103 |
+
self.factor_s = factor_s
|
| 104 |
+
self.factor = self.factor_t * self.factor_s * self.factor_s
|
| 105 |
+
|
| 106 |
+
assert out_channels * self.factor % in_channels == 0
|
| 107 |
+
self.repeats = out_channels * self.factor // in_channels
|
| 108 |
+
|
| 109 |
+
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
| 110 |
+
x = x.repeat_interleave(self.repeats, dim=1)
|
| 111 |
+
x = x.view(
|
| 112 |
+
x.size(0),
|
| 113 |
+
self.out_channels,
|
| 114 |
+
self.factor_t,
|
| 115 |
+
self.factor_s,
|
| 116 |
+
self.factor_s,
|
| 117 |
+
x.size(2),
|
| 118 |
+
x.size(3),
|
| 119 |
+
x.size(4),
|
| 120 |
+
)
|
| 121 |
+
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
| 122 |
+
x = x.view(
|
| 123 |
+
x.size(0),
|
| 124 |
+
self.out_channels,
|
| 125 |
+
x.size(2) * self.factor_t,
|
| 126 |
+
x.size(4) * self.factor_s,
|
| 127 |
+
x.size(6) * self.factor_s,
|
| 128 |
+
)
|
| 129 |
+
if first_chunk:
|
| 130 |
+
x = x[:, :, self.factor_t - 1 :, :, :]
|
| 131 |
+
return x
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class WanCausalConv3d(nn.Conv3d):
|
| 135 |
+
r"""
|
| 136 |
+
A custom 3D causal convolution layer with feature caching support.
|
| 137 |
+
|
| 138 |
+
This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
|
| 139 |
+
caching for efficient inference.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
in_channels (int): Number of channels in the input image
|
| 143 |
+
out_channels (int): Number of channels produced by the convolution
|
| 144 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
| 145 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
| 146 |
+
padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
def __init__(
|
| 150 |
+
self,
|
| 151 |
+
in_channels: int,
|
| 152 |
+
out_channels: int,
|
| 153 |
+
kernel_size: Union[int, Tuple[int, int, int]],
|
| 154 |
+
stride: Union[int, Tuple[int, int, int]] = 1,
|
| 155 |
+
padding: Union[int, Tuple[int, int, int]] = 0,
|
| 156 |
+
) -> None:
|
| 157 |
+
super().__init__(
|
| 158 |
+
in_channels=in_channels,
|
| 159 |
+
out_channels=out_channels,
|
| 160 |
+
kernel_size=kernel_size,
|
| 161 |
+
stride=stride,
|
| 162 |
+
padding=padding,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# Set up causal padding
|
| 166 |
+
self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
|
| 167 |
+
self.padding = (0, 0, 0)
|
| 168 |
+
|
| 169 |
+
def forward(self, x, cache_x=None):
|
| 170 |
+
padding = list(self._padding)
|
| 171 |
+
if cache_x is not None and self._padding[4] > 0:
|
| 172 |
+
cache_x = cache_x.to(x.device)
|
| 173 |
+
x = torch.cat([cache_x, x], dim=2)
|
| 174 |
+
padding[4] -= cache_x.shape[2]
|
| 175 |
+
x = F.pad(x, padding)
|
| 176 |
+
return super().forward(x)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class WanRMS_norm(nn.Module):
|
| 180 |
+
r"""
|
| 181 |
+
A custom RMS normalization layer.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
dim (int): The number of dimensions to normalize over.
|
| 185 |
+
channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
|
| 186 |
+
Default is True.
|
| 187 |
+
images (bool, optional): Whether the input represents image data. Default is True.
|
| 188 |
+
bias (bool, optional): Whether to include a learnable bias term. Default is False.
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
|
| 192 |
+
super().__init__()
|
| 193 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
| 194 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 195 |
+
|
| 196 |
+
self.channel_first = channel_first
|
| 197 |
+
self.scale = dim**0.5
|
| 198 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
| 199 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
| 200 |
+
|
| 201 |
+
def forward(self, x):
|
| 202 |
+
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class WanUpsample(nn.Upsample):
|
| 206 |
+
r"""
|
| 207 |
+
Perform upsampling while ensuring the output tensor has the same data type as the input.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
x (torch.Tensor): Input tensor to be upsampled.
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
torch.Tensor: Upsampled tensor with the same data type as the input.
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
def forward(self, x):
|
| 217 |
+
return super().forward(x.float()).type_as(x)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class WanResample(nn.Module):
|
| 221 |
+
r"""
|
| 222 |
+
A custom resampling module for 2D and 3D data.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
dim (int): The number of input/output channels.
|
| 226 |
+
mode (str): The resampling mode. Must be one of:
|
| 227 |
+
- 'none': No resampling (identity operation).
|
| 228 |
+
- 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
|
| 229 |
+
- 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
|
| 230 |
+
- 'downsample2d': 2D downsampling with zero-padding and convolution.
|
| 231 |
+
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
|
| 235 |
+
super().__init__()
|
| 236 |
+
self.dim = dim
|
| 237 |
+
self.mode = mode
|
| 238 |
+
|
| 239 |
+
# default to dim //2
|
| 240 |
+
if upsample_out_dim is None:
|
| 241 |
+
upsample_out_dim = dim // 2
|
| 242 |
+
|
| 243 |
+
# layers
|
| 244 |
+
if mode == "upsample2d":
|
| 245 |
+
self.resample = nn.Sequential(
|
| 246 |
+
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
| 247 |
+
nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
|
| 248 |
+
)
|
| 249 |
+
elif mode == "upsample3d":
|
| 250 |
+
self.resample = nn.Sequential(
|
| 251 |
+
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
| 252 |
+
nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
|
| 253 |
+
)
|
| 254 |
+
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
| 255 |
+
|
| 256 |
+
elif mode == "downsample2d":
|
| 257 |
+
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 258 |
+
elif mode == "downsample3d":
|
| 259 |
+
self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 260 |
+
self.time_conv = WanCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
| 261 |
+
|
| 262 |
+
else:
|
| 263 |
+
self.resample = nn.Identity()
|
| 264 |
+
|
| 265 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 266 |
+
b, c, t, h, w = x.size()
|
| 267 |
+
if self.mode == "upsample3d":
|
| 268 |
+
if feat_cache is not None:
|
| 269 |
+
idx = feat_idx[0]
|
| 270 |
+
if feat_cache[idx] is None:
|
| 271 |
+
feat_cache[idx] = "Rep"
|
| 272 |
+
feat_idx[0] += 1
|
| 273 |
+
else:
|
| 274 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 275 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
|
| 276 |
+
# cache last frame of last two chunk
|
| 277 |
+
cache_x = torch.cat(
|
| 278 |
+
[feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
|
| 279 |
+
)
|
| 280 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
|
| 281 |
+
cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
|
| 282 |
+
if feat_cache[idx] == "Rep":
|
| 283 |
+
x = self.time_conv(x)
|
| 284 |
+
else:
|
| 285 |
+
x = self.time_conv(x, feat_cache[idx])
|
| 286 |
+
feat_cache[idx] = cache_x
|
| 287 |
+
feat_idx[0] += 1
|
| 288 |
+
|
| 289 |
+
x = x.reshape(b, 2, c, t, h, w)
|
| 290 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
|
| 291 |
+
x = x.reshape(b, c, t * 2, h, w)
|
| 292 |
+
t = x.shape[2]
|
| 293 |
+
x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
|
| 294 |
+
x = self.resample(x)
|
| 295 |
+
x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
|
| 296 |
+
|
| 297 |
+
if self.mode == "downsample3d":
|
| 298 |
+
if feat_cache is not None:
|
| 299 |
+
idx = feat_idx[0]
|
| 300 |
+
if feat_cache[idx] is None:
|
| 301 |
+
feat_cache[idx] = x.clone()
|
| 302 |
+
feat_idx[0] += 1
|
| 303 |
+
else:
|
| 304 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
| 305 |
+
x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
| 306 |
+
feat_cache[idx] = cache_x
|
| 307 |
+
feat_idx[0] += 1
|
| 308 |
+
return x
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class WanResidualBlock(nn.Module):
|
| 312 |
+
r"""
|
| 313 |
+
A custom residual block module.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
in_dim (int): Number of input channels.
|
| 317 |
+
out_dim (int): Number of output channels.
|
| 318 |
+
dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
|
| 319 |
+
non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
|
| 320 |
+
"""
|
| 321 |
+
|
| 322 |
+
def __init__(
|
| 323 |
+
self,
|
| 324 |
+
in_dim: int,
|
| 325 |
+
out_dim: int,
|
| 326 |
+
dropout: float = 0.0,
|
| 327 |
+
non_linearity: str = "silu",
|
| 328 |
+
) -> None:
|
| 329 |
+
super().__init__()
|
| 330 |
+
self.in_dim = in_dim
|
| 331 |
+
self.out_dim = out_dim
|
| 332 |
+
self.nonlinearity = get_activation(non_linearity)
|
| 333 |
+
|
| 334 |
+
# layers
|
| 335 |
+
self.norm1 = WanRMS_norm(in_dim, images=False)
|
| 336 |
+
self.conv1 = WanCausalConv3d(in_dim, out_dim, 3, padding=1)
|
| 337 |
+
self.norm2 = WanRMS_norm(out_dim, images=False)
|
| 338 |
+
self.dropout = nn.Dropout(dropout)
|
| 339 |
+
self.conv2 = WanCausalConv3d(out_dim, out_dim, 3, padding=1)
|
| 340 |
+
self.conv_shortcut = WanCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
|
| 341 |
+
|
| 342 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 343 |
+
# Apply shortcut connection
|
| 344 |
+
h = self.conv_shortcut(x)
|
| 345 |
+
|
| 346 |
+
# First normalization and activation
|
| 347 |
+
x = self.norm1(x)
|
| 348 |
+
x = self.nonlinearity(x)
|
| 349 |
+
|
| 350 |
+
if feat_cache is not None:
|
| 351 |
+
idx = feat_idx[0]
|
| 352 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 353 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 354 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 355 |
+
|
| 356 |
+
x = self.conv1(x, feat_cache[idx])
|
| 357 |
+
feat_cache[idx] = cache_x
|
| 358 |
+
feat_idx[0] += 1
|
| 359 |
+
else:
|
| 360 |
+
x = self.conv1(x)
|
| 361 |
+
|
| 362 |
+
# Second normalization and activation
|
| 363 |
+
x = self.norm2(x)
|
| 364 |
+
x = self.nonlinearity(x)
|
| 365 |
+
|
| 366 |
+
# Dropout
|
| 367 |
+
x = self.dropout(x)
|
| 368 |
+
|
| 369 |
+
if feat_cache is not None:
|
| 370 |
+
idx = feat_idx[0]
|
| 371 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 372 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 373 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 374 |
+
|
| 375 |
+
x = self.conv2(x, feat_cache[idx])
|
| 376 |
+
feat_cache[idx] = cache_x
|
| 377 |
+
feat_idx[0] += 1
|
| 378 |
+
else:
|
| 379 |
+
x = self.conv2(x)
|
| 380 |
+
|
| 381 |
+
# Add residual connection
|
| 382 |
+
return x + h
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
class WanAttentionBlock(nn.Module):
|
| 386 |
+
r"""
|
| 387 |
+
Causal self-attention with a single head.
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
dim (int): The number of channels in the input tensor.
|
| 391 |
+
"""
|
| 392 |
+
|
| 393 |
+
def __init__(self, dim):
|
| 394 |
+
super().__init__()
|
| 395 |
+
self.dim = dim
|
| 396 |
+
|
| 397 |
+
# layers
|
| 398 |
+
self.norm = WanRMS_norm(dim)
|
| 399 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
| 400 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
| 401 |
+
|
| 402 |
+
def forward(self, x):
|
| 403 |
+
identity = x
|
| 404 |
+
batch_size, channels, time, height, width = x.size()
|
| 405 |
+
|
| 406 |
+
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
|
| 407 |
+
x = self.norm(x)
|
| 408 |
+
|
| 409 |
+
# compute query, key, value
|
| 410 |
+
qkv = self.to_qkv(x)
|
| 411 |
+
qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
|
| 412 |
+
qkv = qkv.permute(0, 1, 3, 2).contiguous()
|
| 413 |
+
q, k, v = qkv.chunk(3, dim=-1)
|
| 414 |
+
|
| 415 |
+
# apply attention
|
| 416 |
+
x = F.scaled_dot_product_attention(q, k, v)
|
| 417 |
+
|
| 418 |
+
x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
|
| 419 |
+
|
| 420 |
+
# output projection
|
| 421 |
+
x = self.proj(x)
|
| 422 |
+
|
| 423 |
+
# Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
|
| 424 |
+
x = x.view(batch_size, time, channels, height, width)
|
| 425 |
+
x = x.permute(0, 2, 1, 3, 4)
|
| 426 |
+
|
| 427 |
+
return x + identity
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
class WanMidBlock(nn.Module):
|
| 431 |
+
"""
|
| 432 |
+
Middle block for WanVAE encoder and decoder.
|
| 433 |
+
|
| 434 |
+
Args:
|
| 435 |
+
dim (int): Number of input/output channels.
|
| 436 |
+
dropout (float): Dropout rate.
|
| 437 |
+
non_linearity (str): Type of non-linearity to use.
|
| 438 |
+
"""
|
| 439 |
+
|
| 440 |
+
def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
|
| 441 |
+
super().__init__()
|
| 442 |
+
self.dim = dim
|
| 443 |
+
|
| 444 |
+
# Create the components
|
| 445 |
+
resnets = [WanResidualBlock(dim, dim, dropout, non_linearity)]
|
| 446 |
+
attentions = []
|
| 447 |
+
for _ in range(num_layers):
|
| 448 |
+
attentions.append(WanAttentionBlock(dim))
|
| 449 |
+
resnets.append(WanResidualBlock(dim, dim, dropout, non_linearity))
|
| 450 |
+
self.attentions = nn.ModuleList(attentions)
|
| 451 |
+
self.resnets = nn.ModuleList(resnets)
|
| 452 |
+
|
| 453 |
+
self.gradient_checkpointing = False
|
| 454 |
+
|
| 455 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 456 |
+
# First residual block
|
| 457 |
+
x = self.resnets[0](x, feat_cache, feat_idx)
|
| 458 |
+
|
| 459 |
+
# Process through attention and residual blocks
|
| 460 |
+
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
| 461 |
+
if attn is not None:
|
| 462 |
+
x = attn(x)
|
| 463 |
+
|
| 464 |
+
x = resnet(x, feat_cache, feat_idx)
|
| 465 |
+
|
| 466 |
+
return x
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
class WanResidualDownBlock(nn.Module):
|
| 470 |
+
def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample=False, down_flag=False):
|
| 471 |
+
super().__init__()
|
| 472 |
+
|
| 473 |
+
# Shortcut path with downsample
|
| 474 |
+
self.avg_shortcut = AvgDown3D(
|
| 475 |
+
in_dim,
|
| 476 |
+
out_dim,
|
| 477 |
+
factor_t=2 if temperal_downsample else 1,
|
| 478 |
+
factor_s=2 if down_flag else 1,
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
# Main path with residual blocks and downsample
|
| 482 |
+
resnets = []
|
| 483 |
+
for _ in range(num_res_blocks):
|
| 484 |
+
resnets.append(WanResidualBlock(in_dim, out_dim, dropout))
|
| 485 |
+
in_dim = out_dim
|
| 486 |
+
self.resnets = nn.ModuleList(resnets)
|
| 487 |
+
|
| 488 |
+
# Add the final downsample block
|
| 489 |
+
if down_flag:
|
| 490 |
+
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
| 491 |
+
self.downsampler = WanResample(out_dim, mode=mode)
|
| 492 |
+
else:
|
| 493 |
+
self.downsampler = None
|
| 494 |
+
|
| 495 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 496 |
+
x_copy = x.clone()
|
| 497 |
+
for resnet in self.resnets:
|
| 498 |
+
x = resnet(x, feat_cache, feat_idx)
|
| 499 |
+
if self.downsampler is not None:
|
| 500 |
+
x = self.downsampler(x, feat_cache, feat_idx)
|
| 501 |
+
|
| 502 |
+
return x + self.avg_shortcut(x_copy)
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
class WanEncoder3d(nn.Module):
|
| 506 |
+
r"""
|
| 507 |
+
A 3D encoder module.
|
| 508 |
+
|
| 509 |
+
Args:
|
| 510 |
+
dim (int): The base number of channels in the first layer.
|
| 511 |
+
z_dim (int): The dimensionality of the latent space.
|
| 512 |
+
dim_mult (list of int): Multipliers for the number of channels in each block.
|
| 513 |
+
num_res_blocks (int): Number of residual blocks in each block.
|
| 514 |
+
attn_scales (list of float): Scales at which to apply attention mechanisms.
|
| 515 |
+
temperal_downsample (list of bool): Whether to downsample temporally in each block.
|
| 516 |
+
dropout (float): Dropout rate for the dropout layers.
|
| 517 |
+
non_linearity (str): Type of non-linearity to use.
|
| 518 |
+
"""
|
| 519 |
+
|
| 520 |
+
def __init__(
|
| 521 |
+
self,
|
| 522 |
+
in_channels: int = 3,
|
| 523 |
+
dim=128,
|
| 524 |
+
z_dim=4,
|
| 525 |
+
dim_mult=[1, 2, 4, 4],
|
| 526 |
+
num_res_blocks=2,
|
| 527 |
+
attn_scales=[],
|
| 528 |
+
temperal_downsample=[True, True, False],
|
| 529 |
+
dropout=0.0,
|
| 530 |
+
non_linearity: str = "silu",
|
| 531 |
+
is_residual: bool = False, # wan 2.2 vae use a residual downblock
|
| 532 |
+
):
|
| 533 |
+
super().__init__()
|
| 534 |
+
self.dim = dim
|
| 535 |
+
self.z_dim = z_dim
|
| 536 |
+
self.dim_mult = dim_mult
|
| 537 |
+
self.num_res_blocks = num_res_blocks
|
| 538 |
+
self.attn_scales = attn_scales
|
| 539 |
+
self.temperal_downsample = temperal_downsample
|
| 540 |
+
self.nonlinearity = get_activation(non_linearity)
|
| 541 |
+
|
| 542 |
+
# dimensions
|
| 543 |
+
dims = [dim * u for u in [1] + dim_mult]
|
| 544 |
+
scale = 1.0
|
| 545 |
+
|
| 546 |
+
# init block
|
| 547 |
+
self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)
|
| 548 |
+
|
| 549 |
+
# downsample blocks
|
| 550 |
+
self.down_blocks = nn.ModuleList([])
|
| 551 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 552 |
+
# residual (+attention) blocks
|
| 553 |
+
if is_residual:
|
| 554 |
+
self.down_blocks.append(
|
| 555 |
+
WanResidualDownBlock(
|
| 556 |
+
in_dim,
|
| 557 |
+
out_dim,
|
| 558 |
+
dropout,
|
| 559 |
+
num_res_blocks,
|
| 560 |
+
temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False,
|
| 561 |
+
down_flag=i != len(dim_mult) - 1,
|
| 562 |
+
)
|
| 563 |
+
)
|
| 564 |
+
else:
|
| 565 |
+
for _ in range(num_res_blocks):
|
| 566 |
+
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
|
| 567 |
+
if scale in attn_scales:
|
| 568 |
+
self.down_blocks.append(WanAttentionBlock(out_dim))
|
| 569 |
+
in_dim = out_dim
|
| 570 |
+
|
| 571 |
+
# downsample block
|
| 572 |
+
if i != len(dim_mult) - 1:
|
| 573 |
+
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
| 574 |
+
self.down_blocks.append(WanResample(out_dim, mode=mode))
|
| 575 |
+
scale /= 2.0
|
| 576 |
+
|
| 577 |
+
# middle blocks
|
| 578 |
+
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
|
| 579 |
+
|
| 580 |
+
# output blocks
|
| 581 |
+
self.norm_out = WanRMS_norm(out_dim, images=False)
|
| 582 |
+
self.conv_out = WanCausalConv3d(out_dim, z_dim, 3, padding=1)
|
| 583 |
+
|
| 584 |
+
self.gradient_checkpointing = False
|
| 585 |
+
|
| 586 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 587 |
+
if feat_cache is not None:
|
| 588 |
+
idx = feat_idx[0]
|
| 589 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 590 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 591 |
+
# cache last frame of last two chunk
|
| 592 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 593 |
+
x = self.conv_in(x, feat_cache[idx])
|
| 594 |
+
feat_cache[idx] = cache_x
|
| 595 |
+
feat_idx[0] += 1
|
| 596 |
+
else:
|
| 597 |
+
x = self.conv_in(x)
|
| 598 |
+
|
| 599 |
+
## downsamples
|
| 600 |
+
for layer in self.down_blocks:
|
| 601 |
+
if feat_cache is not None:
|
| 602 |
+
x = layer(x, feat_cache, feat_idx)
|
| 603 |
+
else:
|
| 604 |
+
x = layer(x)
|
| 605 |
+
|
| 606 |
+
## middle
|
| 607 |
+
x = self.mid_block(x, feat_cache, feat_idx)
|
| 608 |
+
|
| 609 |
+
## head
|
| 610 |
+
x = self.norm_out(x)
|
| 611 |
+
x = self.nonlinearity(x)
|
| 612 |
+
if feat_cache is not None:
|
| 613 |
+
idx = feat_idx[0]
|
| 614 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 615 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 616 |
+
# cache last frame of last two chunk
|
| 617 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 618 |
+
x = self.conv_out(x, feat_cache[idx])
|
| 619 |
+
feat_cache[idx] = cache_x
|
| 620 |
+
feat_idx[0] += 1
|
| 621 |
+
else:
|
| 622 |
+
x = self.conv_out(x)
|
| 623 |
+
return x
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
class WanResidualUpBlock(nn.Module):
|
| 627 |
+
"""
|
| 628 |
+
A block that handles upsampling for the WanVAE decoder.
|
| 629 |
+
|
| 630 |
+
Args:
|
| 631 |
+
in_dim (int): Input dimension
|
| 632 |
+
out_dim (int): Output dimension
|
| 633 |
+
num_res_blocks (int): Number of residual blocks
|
| 634 |
+
dropout (float): Dropout rate
|
| 635 |
+
temperal_upsample (bool): Whether to upsample on temporal dimension
|
| 636 |
+
up_flag (bool): Whether to upsample or not
|
| 637 |
+
non_linearity (str): Type of non-linearity to use
|
| 638 |
+
"""
|
| 639 |
+
|
| 640 |
+
def __init__(
|
| 641 |
+
self,
|
| 642 |
+
in_dim: int,
|
| 643 |
+
out_dim: int,
|
| 644 |
+
num_res_blocks: int,
|
| 645 |
+
dropout: float = 0.0,
|
| 646 |
+
temperal_upsample: bool = False,
|
| 647 |
+
up_flag: bool = False,
|
| 648 |
+
non_linearity: str = "silu",
|
| 649 |
+
):
|
| 650 |
+
super().__init__()
|
| 651 |
+
self.in_dim = in_dim
|
| 652 |
+
self.out_dim = out_dim
|
| 653 |
+
|
| 654 |
+
if up_flag:
|
| 655 |
+
self.avg_shortcut = DupUp3D(
|
| 656 |
+
in_dim,
|
| 657 |
+
out_dim,
|
| 658 |
+
factor_t=2 if temperal_upsample else 1,
|
| 659 |
+
factor_s=2,
|
| 660 |
+
)
|
| 661 |
+
else:
|
| 662 |
+
self.avg_shortcut = None
|
| 663 |
+
|
| 664 |
+
# create residual blocks
|
| 665 |
+
resnets = []
|
| 666 |
+
current_dim = in_dim
|
| 667 |
+
for _ in range(num_res_blocks + 1):
|
| 668 |
+
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
|
| 669 |
+
current_dim = out_dim
|
| 670 |
+
|
| 671 |
+
self.resnets = nn.ModuleList(resnets)
|
| 672 |
+
|
| 673 |
+
# Add upsampling layer if needed
|
| 674 |
+
if up_flag:
|
| 675 |
+
upsample_mode = "upsample3d" if temperal_upsample else "upsample2d"
|
| 676 |
+
self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim)
|
| 677 |
+
else:
|
| 678 |
+
self.upsampler = None
|
| 679 |
+
|
| 680 |
+
self.gradient_checkpointing = False
|
| 681 |
+
|
| 682 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
| 683 |
+
"""
|
| 684 |
+
Forward pass through the upsampling block.
|
| 685 |
+
|
| 686 |
+
Args:
|
| 687 |
+
x (torch.Tensor): Input tensor
|
| 688 |
+
feat_cache (list, optional): Feature cache for causal convolutions
|
| 689 |
+
feat_idx (list, optional): Feature index for cache management
|
| 690 |
+
|
| 691 |
+
Returns:
|
| 692 |
+
torch.Tensor: Output tensor
|
| 693 |
+
"""
|
| 694 |
+
x_copy = x.clone()
|
| 695 |
+
|
| 696 |
+
for resnet in self.resnets:
|
| 697 |
+
if feat_cache is not None:
|
| 698 |
+
x = resnet(x, feat_cache, feat_idx)
|
| 699 |
+
else:
|
| 700 |
+
x = resnet(x)
|
| 701 |
+
|
| 702 |
+
if self.upsampler is not None:
|
| 703 |
+
if feat_cache is not None:
|
| 704 |
+
x = self.upsampler(x, feat_cache, feat_idx)
|
| 705 |
+
else:
|
| 706 |
+
x = self.upsampler(x)
|
| 707 |
+
|
| 708 |
+
if self.avg_shortcut is not None:
|
| 709 |
+
x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
|
| 710 |
+
|
| 711 |
+
return x
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
class WanUpBlock(nn.Module):
|
| 715 |
+
"""
|
| 716 |
+
A block that handles upsampling for the WanVAE decoder.
|
| 717 |
+
|
| 718 |
+
Args:
|
| 719 |
+
in_dim (int): Input dimension
|
| 720 |
+
out_dim (int): Output dimension
|
| 721 |
+
num_res_blocks (int): Number of residual blocks
|
| 722 |
+
dropout (float): Dropout rate
|
| 723 |
+
upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
|
| 724 |
+
non_linearity (str): Type of non-linearity to use
|
| 725 |
+
"""
|
| 726 |
+
|
| 727 |
+
def __init__(
|
| 728 |
+
self,
|
| 729 |
+
in_dim: int,
|
| 730 |
+
out_dim: int,
|
| 731 |
+
num_res_blocks: int,
|
| 732 |
+
dropout: float = 0.0,
|
| 733 |
+
upsample_mode: Optional[str] = None,
|
| 734 |
+
non_linearity: str = "silu",
|
| 735 |
+
):
|
| 736 |
+
super().__init__()
|
| 737 |
+
self.in_dim = in_dim
|
| 738 |
+
self.out_dim = out_dim
|
| 739 |
+
|
| 740 |
+
# Create layers list
|
| 741 |
+
resnets = []
|
| 742 |
+
# Add residual blocks and attention if needed
|
| 743 |
+
current_dim = in_dim
|
| 744 |
+
for _ in range(num_res_blocks + 1):
|
| 745 |
+
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
|
| 746 |
+
current_dim = out_dim
|
| 747 |
+
|
| 748 |
+
self.resnets = nn.ModuleList(resnets)
|
| 749 |
+
|
| 750 |
+
# Add upsampling layer if needed
|
| 751 |
+
self.upsamplers = None
|
| 752 |
+
if upsample_mode is not None:
|
| 753 |
+
self.upsamplers = nn.ModuleList([WanResample(out_dim, mode=upsample_mode)])
|
| 754 |
+
|
| 755 |
+
self.gradient_checkpointing = False
|
| 756 |
+
|
| 757 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
|
| 758 |
+
"""
|
| 759 |
+
Forward pass through the upsampling block.
|
| 760 |
+
|
| 761 |
+
Args:
|
| 762 |
+
x (torch.Tensor): Input tensor
|
| 763 |
+
feat_cache (list, optional): Feature cache for causal convolutions
|
| 764 |
+
feat_idx (list, optional): Feature index for cache management
|
| 765 |
+
|
| 766 |
+
Returns:
|
| 767 |
+
torch.Tensor: Output tensor
|
| 768 |
+
"""
|
| 769 |
+
for resnet in self.resnets:
|
| 770 |
+
if feat_cache is not None:
|
| 771 |
+
x = resnet(x, feat_cache, feat_idx)
|
| 772 |
+
else:
|
| 773 |
+
x = resnet(x)
|
| 774 |
+
|
| 775 |
+
if self.upsamplers is not None:
|
| 776 |
+
if feat_cache is not None:
|
| 777 |
+
x = self.upsamplers[0](x, feat_cache, feat_idx)
|
| 778 |
+
else:
|
| 779 |
+
x = self.upsamplers[0](x)
|
| 780 |
+
return x
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
class WanDecoder3d(nn.Module):
|
| 784 |
+
r"""
|
| 785 |
+
A 3D decoder module.
|
| 786 |
+
|
| 787 |
+
Args:
|
| 788 |
+
dim (int): The base number of channels in the first layer.
|
| 789 |
+
z_dim (int): The dimensionality of the latent space.
|
| 790 |
+
dim_mult (list of int): Multipliers for the number of channels in each block.
|
| 791 |
+
num_res_blocks (int): Number of residual blocks in each block.
|
| 792 |
+
attn_scales (list of float): Scales at which to apply attention mechanisms.
|
| 793 |
+
temperal_upsample (list of bool): Whether to upsample temporally in each block.
|
| 794 |
+
dropout (float): Dropout rate for the dropout layers.
|
| 795 |
+
non_linearity (str): Type of non-linearity to use.
|
| 796 |
+
"""
|
| 797 |
+
|
| 798 |
+
def __init__(
|
| 799 |
+
self,
|
| 800 |
+
dim=128,
|
| 801 |
+
z_dim=4,
|
| 802 |
+
dim_mult=[1, 2, 4, 4],
|
| 803 |
+
num_res_blocks=2,
|
| 804 |
+
attn_scales=[],
|
| 805 |
+
temperal_upsample=[False, True, True],
|
| 806 |
+
dropout=0.0,
|
| 807 |
+
non_linearity: str = "silu",
|
| 808 |
+
out_channels: int = 3,
|
| 809 |
+
is_residual: bool = False,
|
| 810 |
+
):
|
| 811 |
+
super().__init__()
|
| 812 |
+
self.dim = dim
|
| 813 |
+
self.z_dim = z_dim
|
| 814 |
+
self.dim_mult = dim_mult
|
| 815 |
+
self.num_res_blocks = num_res_blocks
|
| 816 |
+
self.attn_scales = attn_scales
|
| 817 |
+
self.temperal_upsample = temperal_upsample
|
| 818 |
+
|
| 819 |
+
self.nonlinearity = get_activation(non_linearity)
|
| 820 |
+
|
| 821 |
+
# dimensions
|
| 822 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 823 |
+
|
| 824 |
+
# init block
|
| 825 |
+
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
|
| 826 |
+
|
| 827 |
+
# middle blocks
|
| 828 |
+
self.mid_block = WanMidBlock(dims[0], dropout, non_linearity, num_layers=1)
|
| 829 |
+
|
| 830 |
+
# upsample blocks
|
| 831 |
+
self.up_blocks = nn.ModuleList([])
|
| 832 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 833 |
+
# residual (+attention) blocks
|
| 834 |
+
if i > 0 and not is_residual:
|
| 835 |
+
# wan vae 2.1
|
| 836 |
+
in_dim = in_dim // 2
|
| 837 |
+
|
| 838 |
+
# determine if we need upsampling
|
| 839 |
+
up_flag = i != len(dim_mult) - 1
|
| 840 |
+
# determine upsampling mode, if not upsampling, set to None
|
| 841 |
+
upsample_mode = None
|
| 842 |
+
if up_flag and temperal_upsample[i]:
|
| 843 |
+
upsample_mode = "upsample3d"
|
| 844 |
+
elif up_flag:
|
| 845 |
+
upsample_mode = "upsample2d"
|
| 846 |
+
# Create and add the upsampling block
|
| 847 |
+
if is_residual:
|
| 848 |
+
up_block = WanResidualUpBlock(
|
| 849 |
+
in_dim=in_dim,
|
| 850 |
+
out_dim=out_dim,
|
| 851 |
+
num_res_blocks=num_res_blocks,
|
| 852 |
+
dropout=dropout,
|
| 853 |
+
temperal_upsample=temperal_upsample[i] if up_flag else False,
|
| 854 |
+
up_flag=up_flag,
|
| 855 |
+
non_linearity=non_linearity,
|
| 856 |
+
)
|
| 857 |
+
else:
|
| 858 |
+
up_block = WanUpBlock(
|
| 859 |
+
in_dim=in_dim,
|
| 860 |
+
out_dim=out_dim,
|
| 861 |
+
num_res_blocks=num_res_blocks,
|
| 862 |
+
dropout=dropout,
|
| 863 |
+
upsample_mode=upsample_mode,
|
| 864 |
+
non_linearity=non_linearity,
|
| 865 |
+
)
|
| 866 |
+
self.up_blocks.append(up_block)
|
| 867 |
+
|
| 868 |
+
# output blocks
|
| 869 |
+
self.norm_out = WanRMS_norm(out_dim, images=False)
|
| 870 |
+
self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
|
| 871 |
+
|
| 872 |
+
self.gradient_checkpointing = False
|
| 873 |
+
|
| 874 |
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
| 875 |
+
## conv1
|
| 876 |
+
if feat_cache is not None:
|
| 877 |
+
idx = feat_idx[0]
|
| 878 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 879 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 880 |
+
# cache last frame of last two chunk
|
| 881 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 882 |
+
x = self.conv_in(x, feat_cache[idx])
|
| 883 |
+
feat_cache[idx] = cache_x
|
| 884 |
+
feat_idx[0] += 1
|
| 885 |
+
else:
|
| 886 |
+
x = self.conv_in(x)
|
| 887 |
+
|
| 888 |
+
## middle
|
| 889 |
+
x = self.mid_block(x, feat_cache, feat_idx)
|
| 890 |
+
|
| 891 |
+
## upsamples
|
| 892 |
+
for up_block in self.up_blocks:
|
| 893 |
+
x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
|
| 894 |
+
|
| 895 |
+
## head
|
| 896 |
+
x = self.norm_out(x)
|
| 897 |
+
x = self.nonlinearity(x)
|
| 898 |
+
if feat_cache is not None:
|
| 899 |
+
idx = feat_idx[0]
|
| 900 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 901 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 902 |
+
# cache last frame of last two chunk
|
| 903 |
+
cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 904 |
+
x = self.conv_out(x, feat_cache[idx])
|
| 905 |
+
feat_cache[idx] = cache_x
|
| 906 |
+
feat_idx[0] += 1
|
| 907 |
+
else:
|
| 908 |
+
x = self.conv_out(x)
|
| 909 |
+
return x
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
def patchify(x, patch_size):
|
| 913 |
+
if patch_size == 1:
|
| 914 |
+
return x
|
| 915 |
+
|
| 916 |
+
if x.dim() != 5:
|
| 917 |
+
raise ValueError(f"Invalid input shape: {x.shape}")
|
| 918 |
+
# x shape: [batch_size, channels, frames, height, width]
|
| 919 |
+
batch_size, channels, frames, height, width = x.shape
|
| 920 |
+
|
| 921 |
+
# Ensure height and width are divisible by patch_size
|
| 922 |
+
if height % patch_size != 0 or width % patch_size != 0:
|
| 923 |
+
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
|
| 924 |
+
|
| 925 |
+
# Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
|
| 926 |
+
x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
|
| 927 |
+
|
| 928 |
+
# Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
|
| 929 |
+
x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
|
| 930 |
+
x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
|
| 931 |
+
|
| 932 |
+
return x
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
def unpatchify(x, patch_size):
|
| 936 |
+
if patch_size == 1:
|
| 937 |
+
return x
|
| 938 |
+
|
| 939 |
+
if x.dim() != 5:
|
| 940 |
+
raise ValueError(f"Invalid input shape: {x.shape}")
|
| 941 |
+
# x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
|
| 942 |
+
batch_size, c_patches, frames, height, width = x.shape
|
| 943 |
+
channels = c_patches // (patch_size * patch_size)
|
| 944 |
+
|
| 945 |
+
# Reshape to [b, c, patch_size, patch_size, f, h, w]
|
| 946 |
+
x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
|
| 947 |
+
|
| 948 |
+
# Rearrange to [b, c, f, h * patch_size, w * patch_size]
|
| 949 |
+
x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous()
|
| 950 |
+
x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
|
| 951 |
+
|
| 952 |
+
return x
|
| 953 |
+
|
| 954 |
+
|
| 955 |
+
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
| 956 |
+
r"""
|
| 957 |
+
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
|
| 958 |
+
Introduced in [Wan 2.1].
|
| 959 |
+
|
| 960 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
| 961 |
+
for all models (such as downloading or saving).
|
| 962 |
+
"""
|
| 963 |
+
|
| 964 |
+
_supports_gradient_checkpointing = False
|
| 965 |
+
|
| 966 |
+
@register_to_config
|
| 967 |
+
def __init__(
|
| 968 |
+
self,
|
| 969 |
+
base_dim: int = 96,
|
| 970 |
+
decoder_base_dim: Optional[int] = None,
|
| 971 |
+
z_dim: int = 16,
|
| 972 |
+
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
| 973 |
+
num_res_blocks: int = 2,
|
| 974 |
+
attn_scales: List[float] = [],
|
| 975 |
+
temperal_downsample: List[bool] = [False, True, True],
|
| 976 |
+
dropout: float = 0.0,
|
| 977 |
+
latents_mean: List[float] = [
|
| 978 |
+
-0.7571,
|
| 979 |
+
-0.7089,
|
| 980 |
+
-0.9113,
|
| 981 |
+
0.1075,
|
| 982 |
+
-0.1745,
|
| 983 |
+
0.9653,
|
| 984 |
+
-0.1517,
|
| 985 |
+
1.5508,
|
| 986 |
+
0.4134,
|
| 987 |
+
-0.0715,
|
| 988 |
+
0.5517,
|
| 989 |
+
-0.3632,
|
| 990 |
+
-0.1922,
|
| 991 |
+
-0.9497,
|
| 992 |
+
0.2503,
|
| 993 |
+
-0.2921,
|
| 994 |
+
],
|
| 995 |
+
latents_std: List[float] = [
|
| 996 |
+
2.8184,
|
| 997 |
+
1.4541,
|
| 998 |
+
2.3275,
|
| 999 |
+
2.6558,
|
| 1000 |
+
1.2196,
|
| 1001 |
+
1.7708,
|
| 1002 |
+
2.6052,
|
| 1003 |
+
2.0743,
|
| 1004 |
+
3.2687,
|
| 1005 |
+
2.1526,
|
| 1006 |
+
2.8652,
|
| 1007 |
+
1.5579,
|
| 1008 |
+
1.6382,
|
| 1009 |
+
1.1253,
|
| 1010 |
+
2.8251,
|
| 1011 |
+
1.9160,
|
| 1012 |
+
],
|
| 1013 |
+
is_residual: bool = False,
|
| 1014 |
+
in_channels: int = 3,
|
| 1015 |
+
out_channels: int = 3,
|
| 1016 |
+
patch_size: Optional[int] = None,
|
| 1017 |
+
scale_factor_temporal: Optional[int] = 4,
|
| 1018 |
+
scale_factor_spatial: Optional[int] = 8,
|
| 1019 |
+
) -> None:
|
| 1020 |
+
super().__init__()
|
| 1021 |
+
|
| 1022 |
+
self.z_dim = z_dim
|
| 1023 |
+
self.temperal_downsample = temperal_downsample
|
| 1024 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
| 1025 |
+
|
| 1026 |
+
if decoder_base_dim is None:
|
| 1027 |
+
decoder_base_dim = base_dim
|
| 1028 |
+
|
| 1029 |
+
self.encoder = WanEncoder3d(
|
| 1030 |
+
in_channels=in_channels,
|
| 1031 |
+
dim=base_dim,
|
| 1032 |
+
z_dim=z_dim * 2,
|
| 1033 |
+
dim_mult=dim_mult,
|
| 1034 |
+
num_res_blocks=num_res_blocks,
|
| 1035 |
+
attn_scales=attn_scales,
|
| 1036 |
+
temperal_downsample=temperal_downsample,
|
| 1037 |
+
dropout=dropout,
|
| 1038 |
+
is_residual=is_residual,
|
| 1039 |
+
)
|
| 1040 |
+
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
| 1041 |
+
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
|
| 1042 |
+
|
| 1043 |
+
self.decoder = WanDecoder3d(
|
| 1044 |
+
dim=decoder_base_dim,
|
| 1045 |
+
z_dim=z_dim,
|
| 1046 |
+
dim_mult=dim_mult,
|
| 1047 |
+
num_res_blocks=num_res_blocks,
|
| 1048 |
+
attn_scales=attn_scales,
|
| 1049 |
+
temperal_upsample=self.temperal_upsample,
|
| 1050 |
+
dropout=dropout,
|
| 1051 |
+
out_channels=out_channels,
|
| 1052 |
+
is_residual=is_residual,
|
| 1053 |
+
)
|
| 1054 |
+
|
| 1055 |
+
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
|
| 1056 |
+
|
| 1057 |
+
# When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
|
| 1058 |
+
# to perform decoding of a single video latent at a time.
|
| 1059 |
+
self.use_slicing = False
|
| 1060 |
+
|
| 1061 |
+
# When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
|
| 1062 |
+
# frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
|
| 1063 |
+
# intermediate tiles together, the memory requirement can be lowered.
|
| 1064 |
+
self.use_tiling = False
|
| 1065 |
+
|
| 1066 |
+
# The minimal tile height and width for spatial tiling to be used
|
| 1067 |
+
self.tile_sample_min_height = 256
|
| 1068 |
+
self.tile_sample_min_width = 256
|
| 1069 |
+
|
| 1070 |
+
# The minimal distance between two spatial tiles
|
| 1071 |
+
self.tile_sample_stride_height = 192
|
| 1072 |
+
self.tile_sample_stride_width = 192
|
| 1073 |
+
|
| 1074 |
+
# Precompute and cache conv counts for encoder and decoder for clear_cache speedup
|
| 1075 |
+
self._cached_conv_counts = {
|
| 1076 |
+
"decoder": sum(isinstance(m, WanCausalConv3d) for m in self.decoder.modules())
|
| 1077 |
+
if self.decoder is not None
|
| 1078 |
+
else 0,
|
| 1079 |
+
"encoder": sum(isinstance(m, WanCausalConv3d) for m in self.encoder.modules())
|
| 1080 |
+
if self.encoder is not None
|
| 1081 |
+
else 0,
|
| 1082 |
+
}
|
| 1083 |
+
|
| 1084 |
+
def enable_tiling(
|
| 1085 |
+
self,
|
| 1086 |
+
tile_sample_min_height: Optional[int] = None,
|
| 1087 |
+
tile_sample_min_width: Optional[int] = None,
|
| 1088 |
+
tile_sample_stride_height: Optional[float] = None,
|
| 1089 |
+
tile_sample_stride_width: Optional[float] = None,
|
| 1090 |
+
) -> None:
|
| 1091 |
+
r"""
|
| 1092 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 1093 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 1094 |
+
processing larger images.
|
| 1095 |
+
|
| 1096 |
+
Args:
|
| 1097 |
+
tile_sample_min_height (`int`, *optional*):
|
| 1098 |
+
The minimum height required for a sample to be separated into tiles across the height dimension.
|
| 1099 |
+
tile_sample_min_width (`int`, *optional*):
|
| 1100 |
+
The minimum width required for a sample to be separated into tiles across the width dimension.
|
| 1101 |
+
tile_sample_stride_height (`int`, *optional*):
|
| 1102 |
+
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
|
| 1103 |
+
no tiling artifacts produced across the height dimension.
|
| 1104 |
+
tile_sample_stride_width (`int`, *optional*):
|
| 1105 |
+
The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
|
| 1106 |
+
artifacts produced across the width dimension.
|
| 1107 |
+
"""
|
| 1108 |
+
self.use_tiling = True
|
| 1109 |
+
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
|
| 1110 |
+
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
| 1111 |
+
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
|
| 1112 |
+
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
| 1113 |
+
|
| 1114 |
+
def disable_tiling(self) -> None:
|
| 1115 |
+
r"""
|
| 1116 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
| 1117 |
+
decoding in one step.
|
| 1118 |
+
"""
|
| 1119 |
+
self.use_tiling = False
|
| 1120 |
+
|
| 1121 |
+
def enable_slicing(self) -> None:
|
| 1122 |
+
r"""
|
| 1123 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 1124 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 1125 |
+
"""
|
| 1126 |
+
self.use_slicing = True
|
| 1127 |
+
|
| 1128 |
+
def disable_slicing(self) -> None:
|
| 1129 |
+
r"""
|
| 1130 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
| 1131 |
+
decoding in one step.
|
| 1132 |
+
"""
|
| 1133 |
+
self.use_slicing = False
|
| 1134 |
+
|
| 1135 |
+
def clear_cache(self):
|
| 1136 |
+
# Use cached conv counts for decoder and encoder to avoid re-iterating modules each call
|
| 1137 |
+
self._conv_num = self._cached_conv_counts["decoder"]
|
| 1138 |
+
self._conv_idx = [0]
|
| 1139 |
+
self._feat_map = [None] * self._conv_num
|
| 1140 |
+
# cache encode
|
| 1141 |
+
self._enc_conv_num = self._cached_conv_counts["encoder"]
|
| 1142 |
+
self._enc_conv_idx = [0]
|
| 1143 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
| 1144 |
+
|
| 1145 |
+
def _encode(self, x: torch.Tensor):
|
| 1146 |
+
_, _, num_frame, height, width = x.shape
|
| 1147 |
+
|
| 1148 |
+
if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
|
| 1149 |
+
return self.tiled_encode(x)
|
| 1150 |
+
|
| 1151 |
+
self.clear_cache()
|
| 1152 |
+
if self.config.patch_size is not None:
|
| 1153 |
+
x = patchify(x, patch_size=self.config.patch_size)
|
| 1154 |
+
iter_ = 1 + (num_frame - 1) // 4
|
| 1155 |
+
for i in range(iter_):
|
| 1156 |
+
self._enc_conv_idx = [0]
|
| 1157 |
+
if i == 0:
|
| 1158 |
+
out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
|
| 1159 |
+
else:
|
| 1160 |
+
out_ = self.encoder(
|
| 1161 |
+
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
|
| 1162 |
+
feat_cache=self._enc_feat_map,
|
| 1163 |
+
feat_idx=self._enc_conv_idx,
|
| 1164 |
+
)
|
| 1165 |
+
out = torch.cat([out, out_], 2)
|
| 1166 |
+
|
| 1167 |
+
enc = self.quant_conv(out)
|
| 1168 |
+
self.clear_cache()
|
| 1169 |
+
return enc
|
| 1170 |
+
|
| 1171 |
+
@apply_forward_hook
|
| 1172 |
+
def encode(
|
| 1173 |
+
self, x: torch.Tensor, return_dict: bool = True
|
| 1174 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
| 1175 |
+
r"""
|
| 1176 |
+
Encode a batch of images into latents.
|
| 1177 |
+
|
| 1178 |
+
Args:
|
| 1179 |
+
x (`torch.Tensor`): Input batch of images.
|
| 1180 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1181 |
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
| 1182 |
+
|
| 1183 |
+
Returns:
|
| 1184 |
+
The latent representations of the encoded videos. If `return_dict` is True, a
|
| 1185 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
| 1186 |
+
"""
|
| 1187 |
+
if self.use_slicing and x.shape[0] > 1:
|
| 1188 |
+
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
| 1189 |
+
h = torch.cat(encoded_slices)
|
| 1190 |
+
else:
|
| 1191 |
+
h = self._encode(x)
|
| 1192 |
+
posterior = DiagonalGaussianDistribution(h)
|
| 1193 |
+
|
| 1194 |
+
if not return_dict:
|
| 1195 |
+
return (posterior,)
|
| 1196 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 1197 |
+
|
| 1198 |
+
def _decode(self, z: torch.Tensor, return_dict: bool = True):
|
| 1199 |
+
_, _, num_frame, height, width = z.shape
|
| 1200 |
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
| 1201 |
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
| 1202 |
+
|
| 1203 |
+
if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
|
| 1204 |
+
return self.tiled_decode(z, return_dict=return_dict)
|
| 1205 |
+
|
| 1206 |
+
self.clear_cache()
|
| 1207 |
+
x = self.post_quant_conv(z)
|
| 1208 |
+
for i in range(num_frame):
|
| 1209 |
+
self._conv_idx = [0]
|
| 1210 |
+
if i == 0:
|
| 1211 |
+
out = self.decoder(
|
| 1212 |
+
x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True
|
| 1213 |
+
)
|
| 1214 |
+
else:
|
| 1215 |
+
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
| 1216 |
+
out = torch.cat([out, out_], 2)
|
| 1217 |
+
|
| 1218 |
+
if self.config.patch_size is not None:
|
| 1219 |
+
out = unpatchify(out, patch_size=self.config.patch_size)
|
| 1220 |
+
|
| 1221 |
+
out = torch.clamp(out, min=-1.0, max=1.0)
|
| 1222 |
+
|
| 1223 |
+
self.clear_cache()
|
| 1224 |
+
if not return_dict:
|
| 1225 |
+
return (out,)
|
| 1226 |
+
|
| 1227 |
+
return DecoderOutput(sample=out)
|
| 1228 |
+
|
| 1229 |
+
@apply_forward_hook
|
| 1230 |
+
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 1231 |
+
r"""
|
| 1232 |
+
Decode a batch of images.
|
| 1233 |
+
|
| 1234 |
+
Args:
|
| 1235 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
| 1236 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1237 |
+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 1238 |
+
|
| 1239 |
+
Returns:
|
| 1240 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 1241 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 1242 |
+
returned.
|
| 1243 |
+
"""
|
| 1244 |
+
if self.use_slicing and z.shape[0] > 1:
|
| 1245 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
| 1246 |
+
decoded = torch.cat(decoded_slices)
|
| 1247 |
+
else:
|
| 1248 |
+
decoded = self._decode(z).sample
|
| 1249 |
+
|
| 1250 |
+
if not return_dict:
|
| 1251 |
+
return (decoded,)
|
| 1252 |
+
return DecoderOutput(sample=decoded)
|
| 1253 |
+
|
| 1254 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 1255 |
+
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
| 1256 |
+
for y in range(blend_extent):
|
| 1257 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
|
| 1258 |
+
y / blend_extent
|
| 1259 |
+
)
|
| 1260 |
+
return b
|
| 1261 |
+
|
| 1262 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 1263 |
+
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
| 1264 |
+
for x in range(blend_extent):
|
| 1265 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
|
| 1266 |
+
x / blend_extent
|
| 1267 |
+
)
|
| 1268 |
+
return b
|
| 1269 |
+
|
| 1270 |
+
def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
|
| 1271 |
+
r"""Encode a batch of images using a tiled encoder.
|
| 1272 |
+
|
| 1273 |
+
Args:
|
| 1274 |
+
x (`torch.Tensor`): Input batch of videos.
|
| 1275 |
+
|
| 1276 |
+
Returns:
|
| 1277 |
+
`torch.Tensor`:
|
| 1278 |
+
The latent representation of the encoded videos.
|
| 1279 |
+
"""
|
| 1280 |
+
_, _, num_frames, height, width = x.shape
|
| 1281 |
+
latent_height = height // self.spatial_compression_ratio
|
| 1282 |
+
latent_width = width // self.spatial_compression_ratio
|
| 1283 |
+
|
| 1284 |
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
| 1285 |
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
| 1286 |
+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
| 1287 |
+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
| 1288 |
+
|
| 1289 |
+
blend_height = tile_latent_min_height - tile_latent_stride_height
|
| 1290 |
+
blend_width = tile_latent_min_width - tile_latent_stride_width
|
| 1291 |
+
|
| 1292 |
+
# Split x into overlapping tiles and encode them separately.
|
| 1293 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 1294 |
+
rows = []
|
| 1295 |
+
for i in range(0, height, self.tile_sample_stride_height):
|
| 1296 |
+
row = []
|
| 1297 |
+
for j in range(0, width, self.tile_sample_stride_width):
|
| 1298 |
+
self.clear_cache()
|
| 1299 |
+
time = []
|
| 1300 |
+
frame_range = 1 + (num_frames - 1) // 4
|
| 1301 |
+
for k in range(frame_range):
|
| 1302 |
+
self._enc_conv_idx = [0]
|
| 1303 |
+
if k == 0:
|
| 1304 |
+
tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
|
| 1305 |
+
else:
|
| 1306 |
+
tile = x[
|
| 1307 |
+
:,
|
| 1308 |
+
:,
|
| 1309 |
+
1 + 4 * (k - 1) : 1 + 4 * k,
|
| 1310 |
+
i : i + self.tile_sample_min_height,
|
| 1311 |
+
j : j + self.tile_sample_min_width,
|
| 1312 |
+
]
|
| 1313 |
+
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
|
| 1314 |
+
tile = self.quant_conv(tile)
|
| 1315 |
+
time.append(tile)
|
| 1316 |
+
row.append(torch.cat(time, dim=2))
|
| 1317 |
+
rows.append(row)
|
| 1318 |
+
self.clear_cache()
|
| 1319 |
+
|
| 1320 |
+
result_rows = []
|
| 1321 |
+
for i, row in enumerate(rows):
|
| 1322 |
+
result_row = []
|
| 1323 |
+
for j, tile in enumerate(row):
|
| 1324 |
+
# blend the above tile and the left tile
|
| 1325 |
+
# to the current tile and add the current tile to the result row
|
| 1326 |
+
if i > 0:
|
| 1327 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
| 1328 |
+
if j > 0:
|
| 1329 |
+
tile = self.blend_h(row[j - 1], tile, blend_width)
|
| 1330 |
+
result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
|
| 1331 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
| 1332 |
+
|
| 1333 |
+
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
| 1334 |
+
return enc
|
| 1335 |
+
|
| 1336 |
+
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
| 1337 |
+
r"""
|
| 1338 |
+
Decode a batch of images using a tiled decoder.
|
| 1339 |
+
|
| 1340 |
+
Args:
|
| 1341 |
+
z (`torch.Tensor`): Input batch of latent vectors.
|
| 1342 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1343 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 1344 |
+
|
| 1345 |
+
Returns:
|
| 1346 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 1347 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 1348 |
+
returned.
|
| 1349 |
+
"""
|
| 1350 |
+
_, _, num_frames, height, width = z.shape
|
| 1351 |
+
sample_height = height * self.spatial_compression_ratio
|
| 1352 |
+
sample_width = width * self.spatial_compression_ratio
|
| 1353 |
+
|
| 1354 |
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
| 1355 |
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
| 1356 |
+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
| 1357 |
+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
| 1358 |
+
|
| 1359 |
+
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
|
| 1360 |
+
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
|
| 1361 |
+
|
| 1362 |
+
# Split z into overlapping tiles and decode them separately.
|
| 1363 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 1364 |
+
rows = []
|
| 1365 |
+
for i in range(0, height, tile_latent_stride_height):
|
| 1366 |
+
row = []
|
| 1367 |
+
for j in range(0, width, tile_latent_stride_width):
|
| 1368 |
+
self.clear_cache()
|
| 1369 |
+
time = []
|
| 1370 |
+
for k in range(num_frames):
|
| 1371 |
+
self._conv_idx = [0]
|
| 1372 |
+
tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
|
| 1373 |
+
tile = self.post_quant_conv(tile)
|
| 1374 |
+
decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
| 1375 |
+
time.append(decoded)
|
| 1376 |
+
row.append(torch.cat(time, dim=2))
|
| 1377 |
+
rows.append(row)
|
| 1378 |
+
self.clear_cache()
|
| 1379 |
+
|
| 1380 |
+
result_rows = []
|
| 1381 |
+
for i, row in enumerate(rows):
|
| 1382 |
+
result_row = []
|
| 1383 |
+
for j, tile in enumerate(row):
|
| 1384 |
+
# blend the above tile and the left tile
|
| 1385 |
+
# to the current tile and add the current tile to the result row
|
| 1386 |
+
if i > 0:
|
| 1387 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
| 1388 |
+
if j > 0:
|
| 1389 |
+
tile = self.blend_h(row[j - 1], tile, blend_width)
|
| 1390 |
+
result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
|
| 1391 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
| 1392 |
+
|
| 1393 |
+
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
|
| 1394 |
+
|
| 1395 |
+
if not return_dict:
|
| 1396 |
+
return (dec,)
|
| 1397 |
+
return DecoderOutput(sample=dec)
|
| 1398 |
+
|
| 1399 |
+
def forward(
|
| 1400 |
+
self,
|
| 1401 |
+
sample: torch.Tensor,
|
| 1402 |
+
sample_posterior: bool = False,
|
| 1403 |
+
return_dict: bool = True,
|
| 1404 |
+
generator: Optional[torch.Generator] = None,
|
| 1405 |
+
) -> Union[DecoderOutput, torch.Tensor]:
|
| 1406 |
+
"""
|
| 1407 |
+
Args:
|
| 1408 |
+
sample (`torch.Tensor`): Input sample.
|
| 1409 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 1410 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
| 1411 |
+
"""
|
| 1412 |
+
x = sample
|
| 1413 |
+
posterior = self.encode(x).latent_dist
|
| 1414 |
+
if sample_posterior:
|
| 1415 |
+
z = posterior.sample(generator=generator)
|
| 1416 |
+
else:
|
| 1417 |
+
z = posterior.mode()
|
| 1418 |
+
dec = self.decode(z, return_dict=return_dict)
|
| 1419 |
+
return dec
|
architecture/cogvideox_transformer_3d.py
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 17 |
+
import os, sys, shutil
|
| 18 |
+
import torch
|
| 19 |
+
from torch import nn
|
| 20 |
+
|
| 21 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 22 |
+
from diffusers.loaders import PeftAdapterMixin
|
| 23 |
+
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
| 24 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 25 |
+
from diffusers.models.attention import Attention, FeedForward
|
| 26 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 27 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 28 |
+
from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Import files from the local fodler
|
| 32 |
+
root_path = os.path.abspath('.')
|
| 33 |
+
sys.path.append(root_path)
|
| 34 |
+
from architecture.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
| 35 |
+
from architecture.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@maybe_allow_in_graph
|
| 42 |
+
class CogVideoXBlock(nn.Module):
|
| 43 |
+
r"""
|
| 44 |
+
Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
|
| 45 |
+
|
| 46 |
+
Parameters:
|
| 47 |
+
dim (`int`):
|
| 48 |
+
The number of channels in the input and output.
|
| 49 |
+
num_attention_heads (`int`):
|
| 50 |
+
The number of heads to use for multi-head attention.
|
| 51 |
+
attention_head_dim (`int`):
|
| 52 |
+
The number of channels in each head.
|
| 53 |
+
time_embed_dim (`int`):
|
| 54 |
+
The number of channels in timestep embedding.
|
| 55 |
+
dropout (`float`, defaults to `0.0`):
|
| 56 |
+
The dropout probability to use.
|
| 57 |
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
| 58 |
+
Activation function to be used in feed-forward.
|
| 59 |
+
attention_bias (`bool`, defaults to `False`):
|
| 60 |
+
Whether or not to use bias in attention projection layers.
|
| 61 |
+
qk_norm (`bool`, defaults to `True`):
|
| 62 |
+
Whether or not to use normalization after query and key projections in Attention.
|
| 63 |
+
norm_elementwise_affine (`bool`, defaults to `True`):
|
| 64 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
| 65 |
+
norm_eps (`float`, defaults to `1e-5`):
|
| 66 |
+
Epsilon value for normalization layers.
|
| 67 |
+
final_dropout (`bool` defaults to `False`):
|
| 68 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
| 69 |
+
ff_inner_dim (`int`, *optional*, defaults to `None`):
|
| 70 |
+
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
|
| 71 |
+
ff_bias (`bool`, defaults to `True`):
|
| 72 |
+
Whether or not to use bias in Feed-forward layer.
|
| 73 |
+
attention_out_bias (`bool`, defaults to `True`):
|
| 74 |
+
Whether or not to use bias in Attention output projection layer.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
dim: int,
|
| 80 |
+
num_attention_heads: int,
|
| 81 |
+
attention_head_dim: int,
|
| 82 |
+
time_embed_dim: int,
|
| 83 |
+
dropout: float = 0.0,
|
| 84 |
+
activation_fn: str = "gelu-approximate",
|
| 85 |
+
attention_bias: bool = False,
|
| 86 |
+
qk_norm: bool = True,
|
| 87 |
+
norm_elementwise_affine: bool = True,
|
| 88 |
+
norm_eps: float = 1e-5,
|
| 89 |
+
final_dropout: bool = True,
|
| 90 |
+
ff_inner_dim: Optional[int] = None,
|
| 91 |
+
ff_bias: bool = True,
|
| 92 |
+
attention_out_bias: bool = True,
|
| 93 |
+
):
|
| 94 |
+
super().__init__()
|
| 95 |
+
|
| 96 |
+
# 1. Self Attention
|
| 97 |
+
self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
| 98 |
+
|
| 99 |
+
self.attn1 = Attention(
|
| 100 |
+
query_dim=dim,
|
| 101 |
+
dim_head=attention_head_dim,
|
| 102 |
+
heads=num_attention_heads,
|
| 103 |
+
qk_norm="layer_norm" if qk_norm else None,
|
| 104 |
+
eps=1e-6,
|
| 105 |
+
bias=attention_bias,
|
| 106 |
+
out_bias=attention_out_bias,
|
| 107 |
+
processor=CogVideoXAttnProcessor2_0(),
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
# 2. Feed Forward
|
| 111 |
+
self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
|
| 112 |
+
|
| 113 |
+
self.ff = FeedForward(
|
| 114 |
+
dim,
|
| 115 |
+
dropout=dropout,
|
| 116 |
+
activation_fn=activation_fn,
|
| 117 |
+
final_dropout=final_dropout,
|
| 118 |
+
inner_dim=ff_inner_dim,
|
| 119 |
+
bias=ff_bias,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
def forward(
|
| 123 |
+
self,
|
| 124 |
+
hidden_states: torch.Tensor,
|
| 125 |
+
encoder_hidden_states: torch.Tensor,
|
| 126 |
+
temb: torch.Tensor,
|
| 127 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 128 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 129 |
+
) -> torch.Tensor:
|
| 130 |
+
text_seq_length = encoder_hidden_states.size(1)
|
| 131 |
+
attention_kwargs = attention_kwargs or {}
|
| 132 |
+
|
| 133 |
+
# norm & modulate
|
| 134 |
+
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
|
| 135 |
+
hidden_states, encoder_hidden_states, temb
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# attention
|
| 139 |
+
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
| 140 |
+
hidden_states = norm_hidden_states,
|
| 141 |
+
encoder_hidden_states = norm_encoder_hidden_states,
|
| 142 |
+
image_rotary_emb = image_rotary_emb,
|
| 143 |
+
**attention_kwargs,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
hidden_states = hidden_states + gate_msa * attn_hidden_states
|
| 147 |
+
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
|
| 148 |
+
|
| 149 |
+
# norm & modulate
|
| 150 |
+
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
|
| 151 |
+
hidden_states, encoder_hidden_states, temb
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# feed-forward
|
| 155 |
+
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
| 156 |
+
ff_output = self.ff(norm_hidden_states)
|
| 157 |
+
|
| 158 |
+
hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
|
| 159 |
+
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
|
| 160 |
+
|
| 161 |
+
return hidden_states, encoder_hidden_states
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
| 165 |
+
"""
|
| 166 |
+
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
| 167 |
+
|
| 168 |
+
Parameters:
|
| 169 |
+
num_attention_heads (`int`, defaults to `30`):
|
| 170 |
+
The number of heads to use for multi-head attention.
|
| 171 |
+
attention_head_dim (`int`, defaults to `64`):
|
| 172 |
+
The number of channels in each head.
|
| 173 |
+
in_channels (`int`, defaults to `16`):
|
| 174 |
+
The number of channels in the input.
|
| 175 |
+
out_channels (`int`, *optional*, defaults to `16`):
|
| 176 |
+
The number of channels in the output.
|
| 177 |
+
flip_sin_to_cos (`bool`, defaults to `True`):
|
| 178 |
+
Whether to flip the sin to cos in the time embedding.
|
| 179 |
+
time_embed_dim (`int`, defaults to `512`):
|
| 180 |
+
Output dimension of timestep embeddings.
|
| 181 |
+
ofs_embed_dim (`int`, defaults to `512`):
|
| 182 |
+
Output dimension of "ofs" embeddings used in CogVideoX-5b-I2B in version 1.5
|
| 183 |
+
text_embed_dim (`int`, defaults to `4096`):
|
| 184 |
+
Input dimension of text embeddings from the text encoder.
|
| 185 |
+
num_layers (`int`, defaults to `30`):
|
| 186 |
+
The number of layers of Transformer blocks to use.
|
| 187 |
+
dropout (`float`, defaults to `0.0`):
|
| 188 |
+
The dropout probability to use.
|
| 189 |
+
attention_bias (`bool`, defaults to `True`):
|
| 190 |
+
Whether to use bias in the attention projection layers.
|
| 191 |
+
sample_width (`int`, defaults to `90`):
|
| 192 |
+
The width of the input latents.
|
| 193 |
+
sample_height (`int`, defaults to `60`):
|
| 194 |
+
The height of the input latents.
|
| 195 |
+
sample_frames (`int`, defaults to `49`):
|
| 196 |
+
The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
|
| 197 |
+
instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
|
| 198 |
+
but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
|
| 199 |
+
K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
|
| 200 |
+
patch_size (`int`, defaults to `2`):
|
| 201 |
+
The size of the patches to use in the patch embedding layer.
|
| 202 |
+
temporal_compression_ratio (`int`, defaults to `4`):
|
| 203 |
+
The compression ratio across the temporal dimension. See documentation for `sample_frames`.
|
| 204 |
+
max_text_seq_length (`int`, defaults to `226`):
|
| 205 |
+
The maximum sequence length of the input text embeddings.
|
| 206 |
+
activation_fn (`str`, defaults to `"gelu-approximate"`):
|
| 207 |
+
Activation function to use in feed-forward.
|
| 208 |
+
timestep_activation_fn (`str`, defaults to `"silu"`):
|
| 209 |
+
Activation function to use when generating the timestep embeddings.
|
| 210 |
+
norm_elementwise_affine (`bool`, defaults to `True`):
|
| 211 |
+
Whether to use elementwise affine in normalization layers.
|
| 212 |
+
norm_eps (`float`, defaults to `1e-5`):
|
| 213 |
+
The epsilon value to use in normalization layers.
|
| 214 |
+
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
| 215 |
+
Scaling factor to apply in 3D positional embeddings across spatial dimensions.
|
| 216 |
+
temporal_interpolation_scale (`float`, defaults to `1.0`):
|
| 217 |
+
Scaling factor to apply in 3D positional embeddings across temporal dimensions.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
_supports_gradient_checkpointing = True
|
| 221 |
+
_no_split_modules = ["CogVideoXBlock", "CogVideoXPatchEmbed"]
|
| 222 |
+
|
| 223 |
+
@register_to_config
|
| 224 |
+
def __init__(
|
| 225 |
+
self,
|
| 226 |
+
num_attention_heads: int = 30,
|
| 227 |
+
attention_head_dim: int = 64,
|
| 228 |
+
in_channels: int = 16,
|
| 229 |
+
out_channels: Optional[int] = 16,
|
| 230 |
+
flip_sin_to_cos: bool = True,
|
| 231 |
+
freq_shift: int = 0,
|
| 232 |
+
time_embed_dim: int = 512,
|
| 233 |
+
ofs_embed_dim: Optional[int] = None,
|
| 234 |
+
text_embed_dim: int = 4096,
|
| 235 |
+
num_layers: int = 30,
|
| 236 |
+
dropout: float = 0.0,
|
| 237 |
+
attention_bias: bool = True,
|
| 238 |
+
sample_width: int = 90,
|
| 239 |
+
sample_height: int = 60,
|
| 240 |
+
sample_frames: int = 49,
|
| 241 |
+
patch_size: int = 2,
|
| 242 |
+
patch_size_t: Optional[int] = None,
|
| 243 |
+
temporal_compression_ratio: int = 4,
|
| 244 |
+
max_text_seq_length: int = 226,
|
| 245 |
+
activation_fn: str = "gelu-approximate",
|
| 246 |
+
timestep_activation_fn: str = "silu",
|
| 247 |
+
norm_elementwise_affine: bool = True,
|
| 248 |
+
norm_eps: float = 1e-5,
|
| 249 |
+
spatial_interpolation_scale: float = 1.875,
|
| 250 |
+
temporal_interpolation_scale: float = 1.0,
|
| 251 |
+
use_rotary_positional_embeddings: bool = False,
|
| 252 |
+
use_learned_positional_embeddings: bool = False,
|
| 253 |
+
patch_bias: bool = True,
|
| 254 |
+
extra_encoder_cond_channels: int = -1,
|
| 255 |
+
use_FrameIn: bool = False,
|
| 256 |
+
):
|
| 257 |
+
super().__init__()
|
| 258 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# breakpoint()
|
| 262 |
+
# if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
|
| 263 |
+
# raise ValueError(
|
| 264 |
+
# "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
|
| 265 |
+
# "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
|
| 266 |
+
# "issue at https://github.com/huggingface/diffusers/issues."
|
| 267 |
+
# )
|
| 268 |
+
|
| 269 |
+
# 1. Patch embedding
|
| 270 |
+
self.patch_embed = CogVideoXPatchEmbed(
|
| 271 |
+
patch_size = patch_size,
|
| 272 |
+
patch_size_t = patch_size_t,
|
| 273 |
+
in_channels = in_channels,
|
| 274 |
+
embed_dim = inner_dim,
|
| 275 |
+
text_embed_dim = text_embed_dim,
|
| 276 |
+
bias = patch_bias,
|
| 277 |
+
sample_width = sample_width,
|
| 278 |
+
sample_height = sample_height,
|
| 279 |
+
sample_frames = sample_frames,
|
| 280 |
+
temporal_compression_ratio = temporal_compression_ratio,
|
| 281 |
+
max_text_seq_length = max_text_seq_length,
|
| 282 |
+
spatial_interpolation_scale = spatial_interpolation_scale,
|
| 283 |
+
temporal_interpolation_scale = temporal_interpolation_scale,
|
| 284 |
+
use_positional_embeddings = not use_rotary_positional_embeddings, # HACK: use_positional_embeddings is the revert of use_rotary_positional_embeddings
|
| 285 |
+
use_learned_positional_embeddings = use_learned_positional_embeddings,
|
| 286 |
+
extra_encoder_cond_channels = extra_encoder_cond_channels,
|
| 287 |
+
use_FrameIn = use_FrameIn,
|
| 288 |
+
)
|
| 289 |
+
self.embedding_dropout = nn.Dropout(dropout)
|
| 290 |
+
|
| 291 |
+
# 2. Time embeddings and ofs embedding(Only CogVideoX1.5-5B I2V have)
|
| 292 |
+
|
| 293 |
+
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
| 294 |
+
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
| 295 |
+
|
| 296 |
+
self.ofs_proj = None
|
| 297 |
+
self.ofs_embedding = None
|
| 298 |
+
if ofs_embed_dim:
|
| 299 |
+
self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
|
| 300 |
+
self.ofs_embedding = TimestepEmbedding(
|
| 301 |
+
ofs_embed_dim, ofs_embed_dim, timestep_activation_fn
|
| 302 |
+
) # same as time embeddings, for ofs
|
| 303 |
+
|
| 304 |
+
# 3. Define spatio-temporal transformers blocks
|
| 305 |
+
self.transformer_blocks = nn.ModuleList(
|
| 306 |
+
[
|
| 307 |
+
CogVideoXBlock(
|
| 308 |
+
dim=inner_dim,
|
| 309 |
+
num_attention_heads=num_attention_heads,
|
| 310 |
+
attention_head_dim=attention_head_dim,
|
| 311 |
+
time_embed_dim=time_embed_dim,
|
| 312 |
+
dropout=dropout,
|
| 313 |
+
activation_fn=activation_fn,
|
| 314 |
+
attention_bias=attention_bias,
|
| 315 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
| 316 |
+
norm_eps=norm_eps,
|
| 317 |
+
)
|
| 318 |
+
for _ in range(num_layers)
|
| 319 |
+
]
|
| 320 |
+
)
|
| 321 |
+
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
|
| 322 |
+
|
| 323 |
+
# 4. Output blocks
|
| 324 |
+
self.norm_out = AdaLayerNorm(
|
| 325 |
+
embedding_dim=time_embed_dim,
|
| 326 |
+
output_dim=2 * inner_dim,
|
| 327 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
| 328 |
+
norm_eps=norm_eps,
|
| 329 |
+
chunk_dim=1,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
if patch_size_t is None:
|
| 333 |
+
# For CogVideox 1.0
|
| 334 |
+
output_dim = patch_size * patch_size * out_channels
|
| 335 |
+
else:
|
| 336 |
+
# For CogVideoX 1.5
|
| 337 |
+
output_dim = patch_size * patch_size * patch_size_t * out_channels
|
| 338 |
+
|
| 339 |
+
self.proj_out = nn.Linear(inner_dim, output_dim)
|
| 340 |
+
|
| 341 |
+
self.gradient_checkpointing = False
|
| 342 |
+
|
| 343 |
+
# def _set_gradient_checkpointing(self, module, value=False):
|
| 344 |
+
# self.gradient_checkpointing = value
|
| 345 |
+
|
| 346 |
+
@property
|
| 347 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 348 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 349 |
+
r"""
|
| 350 |
+
Returns:
|
| 351 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 352 |
+
indexed by its weight name.
|
| 353 |
+
"""
|
| 354 |
+
# set recursively
|
| 355 |
+
processors = {}
|
| 356 |
+
|
| 357 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 358 |
+
if hasattr(module, "get_processor"):
|
| 359 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 360 |
+
|
| 361 |
+
for sub_name, child in module.named_children():
|
| 362 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 363 |
+
|
| 364 |
+
return processors
|
| 365 |
+
|
| 366 |
+
for name, module in self.named_children():
|
| 367 |
+
fn_recursive_add_processors(name, module, processors)
|
| 368 |
+
|
| 369 |
+
return processors
|
| 370 |
+
|
| 371 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 372 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 373 |
+
r"""
|
| 374 |
+
Sets the attention processor to use to compute attention.
|
| 375 |
+
|
| 376 |
+
Parameters:
|
| 377 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 378 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 379 |
+
for **all** `Attention` layers.
|
| 380 |
+
|
| 381 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 382 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 383 |
+
|
| 384 |
+
"""
|
| 385 |
+
count = len(self.attn_processors.keys())
|
| 386 |
+
|
| 387 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 388 |
+
raise ValueError(
|
| 389 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 390 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 394 |
+
if hasattr(module, "set_processor"):
|
| 395 |
+
if not isinstance(processor, dict):
|
| 396 |
+
module.set_processor(processor)
|
| 397 |
+
else:
|
| 398 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 399 |
+
|
| 400 |
+
for sub_name, child in module.named_children():
|
| 401 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 402 |
+
|
| 403 |
+
for name, module in self.named_children():
|
| 404 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 405 |
+
|
| 406 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
|
| 407 |
+
def fuse_qkv_projections(self):
|
| 408 |
+
"""
|
| 409 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
| 410 |
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
| 411 |
+
|
| 412 |
+
<Tip warning={true}>
|
| 413 |
+
|
| 414 |
+
This API is 🧪 experimental.
|
| 415 |
+
|
| 416 |
+
</Tip>
|
| 417 |
+
"""
|
| 418 |
+
self.original_attn_processors = None
|
| 419 |
+
|
| 420 |
+
for _, attn_processor in self.attn_processors.items():
|
| 421 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
| 422 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
| 423 |
+
|
| 424 |
+
self.original_attn_processors = self.attn_processors
|
| 425 |
+
|
| 426 |
+
for module in self.modules():
|
| 427 |
+
if isinstance(module, Attention):
|
| 428 |
+
module.fuse_projections(fuse=True)
|
| 429 |
+
|
| 430 |
+
self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
|
| 431 |
+
|
| 432 |
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
| 433 |
+
def unfuse_qkv_projections(self):
|
| 434 |
+
"""Disables the fused QKV projection if enabled.
|
| 435 |
+
|
| 436 |
+
<Tip warning={true}>
|
| 437 |
+
|
| 438 |
+
This API is 🧪 experimental.
|
| 439 |
+
|
| 440 |
+
</Tip>
|
| 441 |
+
|
| 442 |
+
"""
|
| 443 |
+
if self.original_attn_processors is not None:
|
| 444 |
+
self.set_attn_processor(self.original_attn_processors)
|
| 445 |
+
|
| 446 |
+
def forward(
|
| 447 |
+
self,
|
| 448 |
+
hidden_states: torch.Tensor,
|
| 449 |
+
encoder_hidden_states: torch.Tensor,
|
| 450 |
+
timestep: Union[int, float, torch.LongTensor],
|
| 451 |
+
timestep_cond: Optional[torch.Tensor] = None,
|
| 452 |
+
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
|
| 453 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 454 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 455 |
+
return_dict: bool = True,
|
| 456 |
+
):
|
| 457 |
+
|
| 458 |
+
if attention_kwargs is not None:
|
| 459 |
+
attention_kwargs = attention_kwargs.copy()
|
| 460 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 461 |
+
else:
|
| 462 |
+
lora_scale = 1.0
|
| 463 |
+
|
| 464 |
+
if USE_PEFT_BACKEND:
|
| 465 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 466 |
+
scale_lora_layers(self, lora_scale)
|
| 467 |
+
else:
|
| 468 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 469 |
+
logger.warning(
|
| 470 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
batch_size, num_frames, channels, height, width = hidden_states.shape
|
| 475 |
+
|
| 476 |
+
# 1. Time embedding
|
| 477 |
+
timesteps = timestep
|
| 478 |
+
t_emb = self.time_proj(timesteps)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
# timesteps does not contain any weights and will always return f32 tensors
|
| 482 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 483 |
+
# there might be better ways to encapsulate this.
|
| 484 |
+
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
| 485 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 486 |
+
|
| 487 |
+
if self.ofs_embedding is not None:
|
| 488 |
+
ofs_emb = self.ofs_proj(ofs)
|
| 489 |
+
ofs_emb = ofs_emb.to(dtype = hidden_states.dtype)
|
| 490 |
+
ofs_emb = self.ofs_embedding(ofs_emb)
|
| 491 |
+
emb = emb + ofs_emb
|
| 492 |
+
|
| 493 |
+
# 2. Patch embedding
|
| 494 |
+
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) # Only use patch embedding at the very beginning
|
| 495 |
+
hidden_states = self.embedding_dropout(hidden_states)
|
| 496 |
+
|
| 497 |
+
# HACK: patch_embed embedding is split after Adding with Positional Embedding
|
| 498 |
+
text_seq_length = encoder_hidden_states.shape[1]
|
| 499 |
+
encoder_hidden_states = hidden_states[:, :text_seq_length] # Merged encoder hidden states is split again
|
| 500 |
+
hidden_states = hidden_states[:, text_seq_length:]
|
| 501 |
+
|
| 502 |
+
# 3. Transformer blocks
|
| 503 |
+
for i, block in enumerate(self.transformer_blocks):
|
| 504 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 505 |
+
|
| 506 |
+
def create_custom_forward(module):
|
| 507 |
+
def custom_forward(*inputs):
|
| 508 |
+
return module(*inputs)
|
| 509 |
+
|
| 510 |
+
return custom_forward
|
| 511 |
+
|
| 512 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 513 |
+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
| 514 |
+
create_custom_forward(block),
|
| 515 |
+
hidden_states,
|
| 516 |
+
encoder_hidden_states,
|
| 517 |
+
emb,
|
| 518 |
+
image_rotary_emb,
|
| 519 |
+
attention_kwargs,
|
| 520 |
+
**ckpt_kwargs,
|
| 521 |
+
)
|
| 522 |
+
else:
|
| 523 |
+
hidden_states, encoder_hidden_states = block(
|
| 524 |
+
hidden_states = hidden_states,
|
| 525 |
+
encoder_hidden_states = encoder_hidden_states,
|
| 526 |
+
temb = emb,
|
| 527 |
+
image_rotary_emb = image_rotary_emb,
|
| 528 |
+
attention_kwargs = attention_kwargs,
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
if not self.config.use_rotary_positional_embeddings:
|
| 532 |
+
# CogVideoX-2B
|
| 533 |
+
hidden_states = self.norm_final(hidden_states)
|
| 534 |
+
else:
|
| 535 |
+
# CogVideoX-5B
|
| 536 |
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
| 537 |
+
hidden_states = self.norm_final(hidden_states)
|
| 538 |
+
hidden_states = hidden_states[:, text_seq_length:]
|
| 539 |
+
|
| 540 |
+
# 4. Final block
|
| 541 |
+
hidden_states = self.norm_out(hidden_states, temb=emb)
|
| 542 |
+
hidden_states = self.proj_out(hidden_states)
|
| 543 |
+
|
| 544 |
+
# 5. Unpatchify
|
| 545 |
+
p = self.config.patch_size
|
| 546 |
+
p_t = self.config.patch_size_t
|
| 547 |
+
|
| 548 |
+
if p_t is None:
|
| 549 |
+
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
| 550 |
+
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
| 551 |
+
else:
|
| 552 |
+
output = hidden_states.reshape(
|
| 553 |
+
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
| 554 |
+
)
|
| 555 |
+
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
| 556 |
+
|
| 557 |
+
if USE_PEFT_BACKEND:
|
| 558 |
+
# remove `lora_scale` from each PEFT layer
|
| 559 |
+
unscale_lora_layers(self, lora_scale)
|
| 560 |
+
|
| 561 |
+
if not return_dict:
|
| 562 |
+
return (output,)
|
| 563 |
+
return Transformer2DModelOutput(sample=output)
|
architecture/embeddings.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
architecture/noise_sampler.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
"""Modified from https://github.com/THUDM/CogVideo/blob/3710a612d8760f5cdb1741befeebb65b9e0f2fe0/sat/sgm/modules/diffusionmodules/sigma_sampling.py
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
class DiscreteSampling:
|
| 7 |
+
|
| 8 |
+
def __init__(self, num_idx, uniform_sampling=False):
|
| 9 |
+
self.num_idx = num_idx
|
| 10 |
+
self.uniform_sampling = uniform_sampling
|
| 11 |
+
self.is_distributed = torch.distributed.is_available() and torch.distributed.is_initialized()
|
| 12 |
+
|
| 13 |
+
# print("self.is_distributed status is ", self.is_distributed)
|
| 14 |
+
if self.is_distributed and self.uniform_sampling:
|
| 15 |
+
world_size = torch.distributed.get_world_size()
|
| 16 |
+
self.rank = torch.distributed.get_rank()
|
| 17 |
+
|
| 18 |
+
i = 1
|
| 19 |
+
while True:
|
| 20 |
+
if world_size % i != 0 or num_idx % (world_size // i) != 0:
|
| 21 |
+
i += 1
|
| 22 |
+
else:
|
| 23 |
+
self.group_num = world_size // i
|
| 24 |
+
break
|
| 25 |
+
assert self.group_num > 0
|
| 26 |
+
assert world_size % self.group_num == 0
|
| 27 |
+
# the number of rank in one group
|
| 28 |
+
self.group_width = world_size // self.group_num
|
| 29 |
+
self.sigma_interval = self.num_idx // self.group_num
|
| 30 |
+
print('rank=%d world_size=%d group_num=%d group_width=%d sigma_interval=%s' % (
|
| 31 |
+
self.rank, world_size, self.group_num,
|
| 32 |
+
self.group_width, self.sigma_interval))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def __call__(self, n_samples, generator=None, device=None):
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if self.is_distributed and self.uniform_sampling:
|
| 39 |
+
group_index = self.rank // self.group_width
|
| 40 |
+
idx = torch.randint(
|
| 41 |
+
group_index * self.sigma_interval,
|
| 42 |
+
(group_index + 1) * self.sigma_interval,
|
| 43 |
+
(n_samples,),
|
| 44 |
+
generator=generator, device=device,
|
| 45 |
+
)
|
| 46 |
+
# print('proc[%d] idx=%s' % (self.rank, idx))
|
| 47 |
+
# print("Uniform sample range is ", group_index * self.sigma_interval, (group_index + 1) * self.sigma_interval)
|
| 48 |
+
|
| 49 |
+
else:
|
| 50 |
+
idx = torch.randint(
|
| 51 |
+
0, self.num_idx, (n_samples,),
|
| 52 |
+
generator=generator, device=device,
|
| 53 |
+
)
|
| 54 |
+
return idx
|
architecture/transformer_wan.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 23 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 24 |
+
from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
| 25 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 26 |
+
from diffusers.models.attention import FeedForward
|
| 27 |
+
from diffusers.models.attention_processor import Attention
|
| 28 |
+
from diffusers.models.cache_utils import CacheMixin
|
| 29 |
+
from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
|
| 30 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 31 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 32 |
+
from diffusers.models.normalization import FP32LayerNorm
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class WanAttnProcessor2_0:
|
| 39 |
+
def __init__(self):
|
| 40 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 41 |
+
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
| 42 |
+
|
| 43 |
+
def __call__(
|
| 44 |
+
self,
|
| 45 |
+
attn: Attention,
|
| 46 |
+
hidden_states: torch.Tensor,
|
| 47 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 48 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 49 |
+
rotary_emb: Optional[torch.Tensor] = None,
|
| 50 |
+
) -> torch.Tensor:
|
| 51 |
+
encoder_hidden_states_img = None
|
| 52 |
+
if attn.add_k_proj is not None:
|
| 53 |
+
# 512 is the context length of the text encoder, hardcoded for now
|
| 54 |
+
image_context_length = encoder_hidden_states.shape[1] - 512
|
| 55 |
+
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
|
| 56 |
+
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
|
| 57 |
+
if encoder_hidden_states is None:
|
| 58 |
+
encoder_hidden_states = hidden_states
|
| 59 |
+
|
| 60 |
+
query = attn.to_q(hidden_states)
|
| 61 |
+
key = attn.to_k(encoder_hidden_states)
|
| 62 |
+
value = attn.to_v(encoder_hidden_states)
|
| 63 |
+
|
| 64 |
+
if attn.norm_q is not None:
|
| 65 |
+
query = attn.norm_q(query)
|
| 66 |
+
if attn.norm_k is not None:
|
| 67 |
+
key = attn.norm_k(key)
|
| 68 |
+
|
| 69 |
+
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 70 |
+
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 71 |
+
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 72 |
+
|
| 73 |
+
if rotary_emb is not None:
|
| 74 |
+
|
| 75 |
+
def apply_rotary_emb(
|
| 76 |
+
hidden_states: torch.Tensor,
|
| 77 |
+
freqs_cos: torch.Tensor,
|
| 78 |
+
freqs_sin: torch.Tensor,
|
| 79 |
+
):
|
| 80 |
+
x = hidden_states.view(*hidden_states.shape[:-1], -1, 2)
|
| 81 |
+
x1, x2 = x[..., 0], x[..., 1]
|
| 82 |
+
cos = freqs_cos[..., 0::2]
|
| 83 |
+
sin = freqs_sin[..., 1::2]
|
| 84 |
+
out = torch.empty_like(hidden_states)
|
| 85 |
+
out[..., 0::2] = x1 * cos - x2 * sin
|
| 86 |
+
out[..., 1::2] = x1 * sin + x2 * cos
|
| 87 |
+
return out.type_as(hidden_states)
|
| 88 |
+
|
| 89 |
+
query = apply_rotary_emb(query, *rotary_emb)
|
| 90 |
+
key = apply_rotary_emb(key, *rotary_emb)
|
| 91 |
+
|
| 92 |
+
# I2V task
|
| 93 |
+
hidden_states_img = None
|
| 94 |
+
if encoder_hidden_states_img is not None:
|
| 95 |
+
key_img = attn.add_k_proj(encoder_hidden_states_img)
|
| 96 |
+
key_img = attn.norm_added_k(key_img)
|
| 97 |
+
value_img = attn.add_v_proj(encoder_hidden_states_img)
|
| 98 |
+
|
| 99 |
+
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 100 |
+
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
| 101 |
+
|
| 102 |
+
hidden_states_img = F.scaled_dot_product_attention(
|
| 103 |
+
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
|
| 104 |
+
)
|
| 105 |
+
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
|
| 106 |
+
hidden_states_img = hidden_states_img.type_as(query)
|
| 107 |
+
|
| 108 |
+
hidden_states = F.scaled_dot_product_attention(
|
| 109 |
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 110 |
+
)
|
| 111 |
+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
| 112 |
+
hidden_states = hidden_states.type_as(query)
|
| 113 |
+
|
| 114 |
+
if hidden_states_img is not None:
|
| 115 |
+
hidden_states = hidden_states + hidden_states_img
|
| 116 |
+
|
| 117 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 118 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 119 |
+
return hidden_states
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class WanImageEmbedding(torch.nn.Module):
|
| 123 |
+
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
|
| 124 |
+
super().__init__()
|
| 125 |
+
|
| 126 |
+
self.norm1 = FP32LayerNorm(in_features)
|
| 127 |
+
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
|
| 128 |
+
self.norm2 = FP32LayerNorm(out_features)
|
| 129 |
+
if pos_embed_seq_len is not None:
|
| 130 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
|
| 131 |
+
else:
|
| 132 |
+
self.pos_embed = None
|
| 133 |
+
|
| 134 |
+
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
|
| 135 |
+
if self.pos_embed is not None:
|
| 136 |
+
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
|
| 137 |
+
encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
|
| 138 |
+
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
|
| 139 |
+
|
| 140 |
+
hidden_states = self.norm1(encoder_hidden_states_image)
|
| 141 |
+
hidden_states = self.ff(hidden_states)
|
| 142 |
+
hidden_states = self.norm2(hidden_states)
|
| 143 |
+
return hidden_states
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
class WanTimeTextImageEmbedding(nn.Module):
|
| 147 |
+
def __init__(
|
| 148 |
+
self,
|
| 149 |
+
dim: int,
|
| 150 |
+
time_freq_dim: int,
|
| 151 |
+
time_proj_dim: int,
|
| 152 |
+
text_embed_dim: int,
|
| 153 |
+
image_embed_dim: Optional[int] = None,
|
| 154 |
+
pos_embed_seq_len: Optional[int] = None,
|
| 155 |
+
):
|
| 156 |
+
super().__init__()
|
| 157 |
+
|
| 158 |
+
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
| 159 |
+
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
|
| 160 |
+
self.act_fn = nn.SiLU()
|
| 161 |
+
self.time_proj = nn.Linear(dim, time_proj_dim)
|
| 162 |
+
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
|
| 163 |
+
|
| 164 |
+
self.image_embedder = None
|
| 165 |
+
if image_embed_dim is not None:
|
| 166 |
+
self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
|
| 167 |
+
|
| 168 |
+
def forward(
|
| 169 |
+
self,
|
| 170 |
+
timestep: torch.Tensor,
|
| 171 |
+
encoder_hidden_states: torch.Tensor,
|
| 172 |
+
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
| 173 |
+
timestep_seq_len: Optional[int] = None,
|
| 174 |
+
):
|
| 175 |
+
timestep = self.timesteps_proj(timestep)
|
| 176 |
+
if timestep_seq_len is not None:
|
| 177 |
+
timestep = timestep.unflatten(0, (1, timestep_seq_len))
|
| 178 |
+
|
| 179 |
+
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
| 180 |
+
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
| 181 |
+
timestep = timestep.to(time_embedder_dtype)
|
| 182 |
+
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
| 183 |
+
timestep_proj = self.time_proj(self.act_fn(temb))
|
| 184 |
+
|
| 185 |
+
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
| 186 |
+
if encoder_hidden_states_image is not None:
|
| 187 |
+
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
|
| 188 |
+
|
| 189 |
+
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class WanRotaryPosEmbed(nn.Module):
|
| 193 |
+
def __init__(
|
| 194 |
+
self,
|
| 195 |
+
attention_head_dim: int,
|
| 196 |
+
patch_size: Tuple[int, int, int],
|
| 197 |
+
max_seq_len: int,
|
| 198 |
+
theta: float = 10000.0,
|
| 199 |
+
):
|
| 200 |
+
super().__init__()
|
| 201 |
+
|
| 202 |
+
self.attention_head_dim = attention_head_dim
|
| 203 |
+
self.patch_size = patch_size
|
| 204 |
+
self.max_seq_len = max_seq_len
|
| 205 |
+
|
| 206 |
+
h_dim = w_dim = 2 * (attention_head_dim // 6)
|
| 207 |
+
t_dim = attention_head_dim - h_dim - w_dim
|
| 208 |
+
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
|
| 209 |
+
|
| 210 |
+
freqs_cos = []
|
| 211 |
+
freqs_sin = []
|
| 212 |
+
|
| 213 |
+
for dim in [t_dim, h_dim, w_dim]:
|
| 214 |
+
freq_cos, freq_sin = get_1d_rotary_pos_embed(
|
| 215 |
+
dim,
|
| 216 |
+
max_seq_len,
|
| 217 |
+
theta,
|
| 218 |
+
use_real=True,
|
| 219 |
+
repeat_interleave_real=True,
|
| 220 |
+
freqs_dtype=freqs_dtype,
|
| 221 |
+
)
|
| 222 |
+
freqs_cos.append(freq_cos)
|
| 223 |
+
freqs_sin.append(freq_sin)
|
| 224 |
+
|
| 225 |
+
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
|
| 226 |
+
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
|
| 227 |
+
|
| 228 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 229 |
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
| 230 |
+
p_t, p_h, p_w = self.patch_size
|
| 231 |
+
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
|
| 232 |
+
|
| 233 |
+
split_sizes = [
|
| 234 |
+
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
|
| 235 |
+
self.attention_head_dim // 3,
|
| 236 |
+
self.attention_head_dim // 3,
|
| 237 |
+
]
|
| 238 |
+
|
| 239 |
+
freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
|
| 240 |
+
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
|
| 241 |
+
|
| 242 |
+
freqs_cos_f = freqs_cos[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
| 243 |
+
freqs_cos_h = freqs_cos[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
| 244 |
+
freqs_cos_w = freqs_cos[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
| 245 |
+
|
| 246 |
+
freqs_sin_f = freqs_sin[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
|
| 247 |
+
freqs_sin_h = freqs_sin[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
|
| 248 |
+
freqs_sin_w = freqs_sin[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
|
| 249 |
+
|
| 250 |
+
freqs_cos = torch.cat([freqs_cos_f, freqs_cos_h, freqs_cos_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
|
| 251 |
+
freqs_sin = torch.cat([freqs_sin_f, freqs_sin_h, freqs_sin_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
|
| 252 |
+
|
| 253 |
+
return freqs_cos, freqs_sin
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
@maybe_allow_in_graph
|
| 257 |
+
class WanTransformerBlock(nn.Module):
|
| 258 |
+
def __init__(
|
| 259 |
+
self,
|
| 260 |
+
dim: int,
|
| 261 |
+
ffn_dim: int,
|
| 262 |
+
num_heads: int,
|
| 263 |
+
qk_norm: str = "rms_norm_across_heads",
|
| 264 |
+
cross_attn_norm: bool = False,
|
| 265 |
+
eps: float = 1e-6,
|
| 266 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 267 |
+
):
|
| 268 |
+
super().__init__()
|
| 269 |
+
|
| 270 |
+
# 1. Self-attention
|
| 271 |
+
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
| 272 |
+
self.attn1 = Attention(
|
| 273 |
+
query_dim=dim,
|
| 274 |
+
heads=num_heads,
|
| 275 |
+
kv_heads=num_heads,
|
| 276 |
+
dim_head=dim // num_heads,
|
| 277 |
+
qk_norm=qk_norm,
|
| 278 |
+
eps=eps,
|
| 279 |
+
bias=True,
|
| 280 |
+
cross_attention_dim=None,
|
| 281 |
+
out_bias=True,
|
| 282 |
+
processor=WanAttnProcessor2_0(),
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# 2. Cross-attention
|
| 286 |
+
self.attn2 = Attention(
|
| 287 |
+
query_dim=dim,
|
| 288 |
+
heads=num_heads,
|
| 289 |
+
kv_heads=num_heads,
|
| 290 |
+
dim_head=dim // num_heads,
|
| 291 |
+
qk_norm=qk_norm,
|
| 292 |
+
eps=eps,
|
| 293 |
+
bias=True,
|
| 294 |
+
cross_attention_dim=None,
|
| 295 |
+
out_bias=True,
|
| 296 |
+
added_kv_proj_dim=added_kv_proj_dim,
|
| 297 |
+
added_proj_bias=True,
|
| 298 |
+
processor=WanAttnProcessor2_0(),
|
| 299 |
+
)
|
| 300 |
+
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
| 301 |
+
|
| 302 |
+
# 3. Feed-forward
|
| 303 |
+
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
|
| 304 |
+
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
| 305 |
+
|
| 306 |
+
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
| 307 |
+
|
| 308 |
+
def forward(
|
| 309 |
+
self,
|
| 310 |
+
hidden_states: torch.Tensor,
|
| 311 |
+
encoder_hidden_states: torch.Tensor,
|
| 312 |
+
temb: torch.Tensor,
|
| 313 |
+
rotary_emb: torch.Tensor,
|
| 314 |
+
) -> torch.Tensor:
|
| 315 |
+
if temb.ndim == 4:
|
| 316 |
+
# temb: batch_size, seq_len, 6, inner_dim (wan2.2 ti2v)
|
| 317 |
+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
| 318 |
+
self.scale_shift_table.unsqueeze(0) + temb.float()
|
| 319 |
+
).chunk(6, dim=2)
|
| 320 |
+
# batch_size, seq_len, 1, inner_dim
|
| 321 |
+
shift_msa = shift_msa.squeeze(2)
|
| 322 |
+
scale_msa = scale_msa.squeeze(2)
|
| 323 |
+
gate_msa = gate_msa.squeeze(2)
|
| 324 |
+
c_shift_msa = c_shift_msa.squeeze(2)
|
| 325 |
+
c_scale_msa = c_scale_msa.squeeze(2)
|
| 326 |
+
c_gate_msa = c_gate_msa.squeeze(2)
|
| 327 |
+
else:
|
| 328 |
+
# temb: batch_size, 6, inner_dim (wan2.1/wan2.2 14B)
|
| 329 |
+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
| 330 |
+
self.scale_shift_table + temb.float()
|
| 331 |
+
).chunk(6, dim=1)
|
| 332 |
+
|
| 333 |
+
# 1. Self-attention
|
| 334 |
+
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
|
| 335 |
+
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
|
| 336 |
+
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
|
| 337 |
+
|
| 338 |
+
# 2. Cross-attention
|
| 339 |
+
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
|
| 340 |
+
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
| 341 |
+
hidden_states = hidden_states + attn_output
|
| 342 |
+
|
| 343 |
+
# 3. Feed-forward
|
| 344 |
+
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
|
| 345 |
+
hidden_states
|
| 346 |
+
)
|
| 347 |
+
ff_output = self.ffn(norm_hidden_states)
|
| 348 |
+
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
|
| 349 |
+
|
| 350 |
+
return hidden_states
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
| 354 |
+
r"""
|
| 355 |
+
A Transformer model for video-like data used in the Wan model.
|
| 356 |
+
|
| 357 |
+
Args:
|
| 358 |
+
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
|
| 359 |
+
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
|
| 360 |
+
num_attention_heads (`int`, defaults to `40`):
|
| 361 |
+
Fixed length for text embeddings.
|
| 362 |
+
attention_head_dim (`int`, defaults to `128`):
|
| 363 |
+
The number of channels in each head.
|
| 364 |
+
in_channels (`int`, defaults to `16`):
|
| 365 |
+
The number of channels in the input.
|
| 366 |
+
out_channels (`int`, defaults to `16`):
|
| 367 |
+
The number of channels in the output.
|
| 368 |
+
text_dim (`int`, defaults to `512`):
|
| 369 |
+
Input dimension for text embeddings.
|
| 370 |
+
freq_dim (`int`, defaults to `256`):
|
| 371 |
+
Dimension for sinusoidal time embeddings.
|
| 372 |
+
ffn_dim (`int`, defaults to `13824`):
|
| 373 |
+
Intermediate dimension in feed-forward network.
|
| 374 |
+
num_layers (`int`, defaults to `40`):
|
| 375 |
+
The number of layers of transformer blocks to use.
|
| 376 |
+
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
|
| 377 |
+
Window size for local attention (-1 indicates global attention).
|
| 378 |
+
cross_attn_norm (`bool`, defaults to `True`):
|
| 379 |
+
Enable cross-attention normalization.
|
| 380 |
+
qk_norm (`bool`, defaults to `True`):
|
| 381 |
+
Enable query/key normalization.
|
| 382 |
+
eps (`float`, defaults to `1e-6`):
|
| 383 |
+
Epsilon value for normalization layers.
|
| 384 |
+
add_img_emb (`bool`, defaults to `False`):
|
| 385 |
+
Whether to use img_emb.
|
| 386 |
+
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
| 387 |
+
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
| 388 |
+
"""
|
| 389 |
+
|
| 390 |
+
_supports_gradient_checkpointing = True
|
| 391 |
+
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
|
| 392 |
+
_no_split_modules = ["WanTransformerBlock"]
|
| 393 |
+
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
| 394 |
+
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
| 395 |
+
_repeated_blocks = ["WanTransformerBlock"]
|
| 396 |
+
|
| 397 |
+
@register_to_config
|
| 398 |
+
def __init__(
|
| 399 |
+
self,
|
| 400 |
+
patch_size: Tuple[int] = (1, 2, 2),
|
| 401 |
+
num_attention_heads: int = 40,
|
| 402 |
+
attention_head_dim: int = 128,
|
| 403 |
+
in_channels: int = 16,
|
| 404 |
+
out_channels: int = 16,
|
| 405 |
+
text_dim: int = 4096,
|
| 406 |
+
freq_dim: int = 256,
|
| 407 |
+
ffn_dim: int = 13824,
|
| 408 |
+
num_layers: int = 40,
|
| 409 |
+
cross_attn_norm: bool = True,
|
| 410 |
+
qk_norm: Optional[str] = "rms_norm_across_heads",
|
| 411 |
+
eps: float = 1e-6,
|
| 412 |
+
image_dim: Optional[int] = None,
|
| 413 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 414 |
+
rope_max_seq_len: int = 1024,
|
| 415 |
+
pos_embed_seq_len: Optional[int] = None,
|
| 416 |
+
) -> None:
|
| 417 |
+
super().__init__()
|
| 418 |
+
|
| 419 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 420 |
+
out_channels = out_channels or in_channels
|
| 421 |
+
|
| 422 |
+
# 1. Patch & position embedding
|
| 423 |
+
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
|
| 424 |
+
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
| 425 |
+
|
| 426 |
+
# 2. Condition embeddings
|
| 427 |
+
# image_embedding_dim=1280 for I2V model
|
| 428 |
+
self.condition_embedder = WanTimeTextImageEmbedding(
|
| 429 |
+
dim=inner_dim,
|
| 430 |
+
time_freq_dim=freq_dim,
|
| 431 |
+
time_proj_dim=inner_dim * 6,
|
| 432 |
+
text_embed_dim=text_dim,
|
| 433 |
+
image_embed_dim=image_dim,
|
| 434 |
+
pos_embed_seq_len=pos_embed_seq_len,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# 3. Transformer blocks
|
| 438 |
+
self.blocks = nn.ModuleList(
|
| 439 |
+
[
|
| 440 |
+
WanTransformerBlock(
|
| 441 |
+
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
|
| 442 |
+
)
|
| 443 |
+
for _ in range(num_layers)
|
| 444 |
+
]
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# 4. Output norm & projection
|
| 448 |
+
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
|
| 449 |
+
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
|
| 450 |
+
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
|
| 451 |
+
|
| 452 |
+
self.gradient_checkpointing = False
|
| 453 |
+
|
| 454 |
+
def forward(
|
| 455 |
+
self,
|
| 456 |
+
hidden_states: torch.Tensor,
|
| 457 |
+
timestep: torch.LongTensor,
|
| 458 |
+
encoder_hidden_states: torch.Tensor,
|
| 459 |
+
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
| 460 |
+
return_dict: bool = True,
|
| 461 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 462 |
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 463 |
+
if attention_kwargs is not None:
|
| 464 |
+
attention_kwargs = attention_kwargs.copy()
|
| 465 |
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
| 466 |
+
else:
|
| 467 |
+
lora_scale = 1.0
|
| 468 |
+
|
| 469 |
+
if USE_PEFT_BACKEND:
|
| 470 |
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
| 471 |
+
scale_lora_layers(self, lora_scale)
|
| 472 |
+
else:
|
| 473 |
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
| 474 |
+
logger.warning(
|
| 475 |
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
| 479 |
+
p_t, p_h, p_w = self.config.patch_size
|
| 480 |
+
post_patch_num_frames = num_frames // p_t
|
| 481 |
+
post_patch_height = height // p_h
|
| 482 |
+
post_patch_width = width // p_w
|
| 483 |
+
|
| 484 |
+
rotary_emb = self.rope(hidden_states)
|
| 485 |
+
|
| 486 |
+
hidden_states = self.patch_embedding(hidden_states)
|
| 487 |
+
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
| 488 |
+
|
| 489 |
+
# timestep shape: batch_size, or batch_size, seq_len (wan 2.2 ti2v)
|
| 490 |
+
if timestep.ndim == 2:
|
| 491 |
+
ts_seq_len = timestep.shape[1]
|
| 492 |
+
timestep = timestep.flatten() # batch_size * seq_len
|
| 493 |
+
else:
|
| 494 |
+
ts_seq_len = None
|
| 495 |
+
|
| 496 |
+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
|
| 497 |
+
timestep, encoder_hidden_states, encoder_hidden_states_image, timestep_seq_len=ts_seq_len
|
| 498 |
+
)
|
| 499 |
+
if ts_seq_len is not None:
|
| 500 |
+
# batch_size, seq_len, 6, inner_dim
|
| 501 |
+
timestep_proj = timestep_proj.unflatten(2, (6, -1))
|
| 502 |
+
else:
|
| 503 |
+
# batch_size, 6, inner_dim
|
| 504 |
+
timestep_proj = timestep_proj.unflatten(1, (6, -1))
|
| 505 |
+
|
| 506 |
+
if encoder_hidden_states_image is not None:
|
| 507 |
+
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
| 508 |
+
|
| 509 |
+
# 4. Transformer blocks
|
| 510 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 511 |
+
for block in self.blocks:
|
| 512 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 513 |
+
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
|
| 514 |
+
)
|
| 515 |
+
else:
|
| 516 |
+
for block in self.blocks:
|
| 517 |
+
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
|
| 518 |
+
|
| 519 |
+
# 5. Output norm, projection & unpatchify
|
| 520 |
+
if temb.ndim == 3:
|
| 521 |
+
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
|
| 522 |
+
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
|
| 523 |
+
shift = shift.squeeze(2)
|
| 524 |
+
scale = scale.squeeze(2)
|
| 525 |
+
else:
|
| 526 |
+
# batch_size, inner_dim
|
| 527 |
+
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
| 528 |
+
|
| 529 |
+
# Move the shift and scale tensors to the same device as hidden_states.
|
| 530 |
+
# When using multi-GPU inference via accelerate these will be on the
|
| 531 |
+
# first device rather than the last device, which hidden_states ends up
|
| 532 |
+
# on.
|
| 533 |
+
shift = shift.to(hidden_states.device)
|
| 534 |
+
scale = scale.to(hidden_states.device)
|
| 535 |
+
|
| 536 |
+
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
|
| 537 |
+
hidden_states = self.proj_out(hidden_states)
|
| 538 |
+
|
| 539 |
+
hidden_states = hidden_states.reshape(
|
| 540 |
+
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
| 541 |
+
)
|
| 542 |
+
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
| 543 |
+
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
| 544 |
+
|
| 545 |
+
if USE_PEFT_BACKEND:
|
| 546 |
+
# remove `lora_scale` from each PEFT layer
|
| 547 |
+
unscale_lora_layers(self, lora_scale)
|
| 548 |
+
|
| 549 |
+
if not return_dict:
|
| 550 |
+
return (output,)
|
| 551 |
+
|
| 552 |
+
return Transformer2DModelOutput(sample=output)
|
config/accelerate_config_4GPU.json
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"compute_environment": "LOCAL_MACHINE",
|
| 3 |
+
"debug": false,
|
| 4 |
+
"distributed_type": "MULTI_GPU",
|
| 5 |
+
"downcast_bf16": "no",
|
| 6 |
+
"gpu_ids": "all",
|
| 7 |
+
"machine_rank": 0,
|
| 8 |
+
"main_training_function": "main",
|
| 9 |
+
"mixed_precision": "bf16",
|
| 10 |
+
"num_machines": 1,
|
| 11 |
+
"num_processes": 4,
|
| 12 |
+
"rdzv_backend": "static",
|
| 13 |
+
"same_network": true,
|
| 14 |
+
"tpu_env": [],
|
| 15 |
+
"tpu_use_cluster": false,
|
| 16 |
+
"tpu_use_sudo": false,
|
| 17 |
+
"use_cpu": false
|
| 18 |
+
}
|
config/train_cogvideox_motion.yaml
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
experiment_name: CogVideoX_5B_Motion_480P # Store Folder Name
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Model Setting
|
| 6 |
+
base_model_path: zai-org/CogVideoX-5b-I2V
|
| 7 |
+
pretrained_transformer_path: # No need to set; if you set, this will load transformer model with non-default Wan transformer
|
| 8 |
+
enable_slicing: True
|
| 9 |
+
enable_tiling: True
|
| 10 |
+
use_learned_positional_embeddings: True
|
| 11 |
+
use_rotary_positional_embeddings: True
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# Dataset Setting
|
| 16 |
+
download_folder_path: FrameINO_data/ # Set the downloaded folder path, all the other csv will be read automatically
|
| 17 |
+
train_csv_relative_path: dataset_csv_files/train_sample_short_dataset # No need to change, Fixed
|
| 18 |
+
train_video_relative_path: video_dataset/train_sample_dataset # No need to change, Fixed
|
| 19 |
+
validation_csv_relative_path: dataset_csv_files/val_sample_short_dataset # No need to change, Fixed
|
| 20 |
+
validation_video_relative_path: video_dataset/val_sample_dataset # No need to change, Fixed
|
| 21 |
+
dataloader_num_workers: 4 # This should be per GPU In Debug, we set to 1
|
| 22 |
+
# height_range: [480, 480] # Height Range; By slightly modify the dataloader code and use this setting, we can use variable resolution training
|
| 23 |
+
target_height: 480
|
| 24 |
+
target_width: 720
|
| 25 |
+
sample_accelerate_factor: 2 # Imitate 12FPS we have set before.
|
| 26 |
+
train_frame_num_range: [49, 49] # Number of frames for the trianing, required to be 4N+1
|
| 27 |
+
# min_train_frame_num: 49 # If it is less than this number, the dataloader will raise Exception and skip to the next one valid!
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# Motion Setting
|
| 31 |
+
dot_radius: 6 # This is set with respect to 384 height pixel, will be adjust based on the height change
|
| 32 |
+
point_keep_ratio: 0.4 # The ratio of points left; Likelyhood by random.choices for each tracking point, so it can be quite versatile; 0.33 is also recommended
|
| 33 |
+
faster_motion_prob: 0.0 # Whether we support faster (~8FPS), 0.0 - 0.1 is also recomended (0.0 by default).
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Denoise + Text Setting
|
| 37 |
+
noised_image_dropout: 0.05 # No First Frame Setting, becomes T2V
|
| 38 |
+
empty_text_prompt: False # FOR TI2V, we needs to use text prompt
|
| 39 |
+
text_mask_ratio: 0.05 # Follow InstructPix2Pix
|
| 40 |
+
max_text_seq_length: 226
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# Training Setting
|
| 44 |
+
resume_from_checkpoint: False # latest / False; latest will automatically fetch the newest checkpoint
|
| 45 |
+
max_train_steps: 1002 # Based on the needs; This is just a demo dataset, so training low is not needed
|
| 46 |
+
train_batch_size: 1 # batch size per GPU
|
| 47 |
+
gradient_accumulation_steps: 2 # Equivalent to multi batch size; Total GPU
|
| 48 |
+
checkpointing_steps: 2000 # Check point frequeuncy, don't recommend to be too frequent
|
| 49 |
+
checkpoints_total_limit: 8 # Transformer are too large, this size is too big (~32 GB per checkpoint)
|
| 50 |
+
mixed_precision: bf16 # CogvideoX official code usaully use bf16
|
| 51 |
+
gradient_checkpointing: True # This will save the memory but slower; Even if I have 80GB memory, this is still needed to open; else, OOM
|
| 52 |
+
seed: # If we set seed here, the reading of the data in each resume will be the same as the first time, which cannot train full dataset in resume mode
|
| 53 |
+
output_folder: checkpoints/
|
| 54 |
+
logging_name: logging
|
| 55 |
+
nccl_timeout: 1800
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Validation Setting
|
| 59 |
+
validation_step: 2000 # Don't set too frequent, which will be very resource consuming
|
| 60 |
+
first_iter_validation: True # Whether we do the first iter validation
|
| 61 |
+
num_inference_steps: 50
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# Learning Rate and Optimizer
|
| 65 |
+
optimizer: adamw # Choose between ["adam", "adamw", "prodigy"]
|
| 66 |
+
learning_rate: 2e-5 # 1e-4 might be too big
|
| 67 |
+
scale_lr: False
|
| 68 |
+
lr_scheduler: constant_with_warmup # Most cases should be constant
|
| 69 |
+
adam_beta1: 0.9
|
| 70 |
+
adam_beta2: 0.95 # In the past, this used to be 0.999; smaller than usual
|
| 71 |
+
adam_beta3: 0.98
|
| 72 |
+
lr_power: 1.0
|
| 73 |
+
lr_num_cycles: 1.0
|
| 74 |
+
max_grad_norm: 1.0
|
| 75 |
+
prodigy_beta3: # Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2
|
| 76 |
+
adam_weight_decay: 1e-04
|
| 77 |
+
adam_epsilon: 1e-08
|
| 78 |
+
lr_warmup_steps: 400
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# Other Setting
|
| 83 |
+
report_to: tensorboard
|
| 84 |
+
allow_tf32: True
|
| 85 |
+
revision:
|
| 86 |
+
variant:
|
| 87 |
+
cache_dir:
|
| 88 |
+
tracker_name:
|
config/train_cogvideox_motion_FrameINO.yaml
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
experiment_name: CogVideoX_5B_Motion_FINO_480P
|
| 3 |
+
|
| 4 |
+
# Model Setting
|
| 5 |
+
base_model_path: zai-org/CogVideoX-5b-I2V
|
| 6 |
+
pretrained_transformer_path: uva-cv-lab/FrameINO_CogVideoX_Stage1_Motion_v1.0 # Use the stage1 weight here; if you use your trained weight, it should go to the transformer folder (TODO: needs to check this)
|
| 7 |
+
enable_slicing: True
|
| 8 |
+
enable_tiling: True
|
| 9 |
+
use_learned_positional_embeddings: True
|
| 10 |
+
use_rotary_positional_embeddings: True
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Dataset Setting
|
| 15 |
+
download_folder_path: FrameINO_data/ # Set the downloaded folder path, all the other csv will be read automatically
|
| 16 |
+
train_csv_relative_path: dataset_csv_files/train_sample_short_dataset # No need to change, Fixed
|
| 17 |
+
train_video_relative_path: video_dataset/train_sample_dataset # No need to change, Fixed
|
| 18 |
+
train_ID_relative_path: video_dataset/train_ID_FrameIn # No need to change, Fixed
|
| 19 |
+
validation_csv_relative_path: dataset_csv_files/val_sample_short_dataset # No need to change, Fixed
|
| 20 |
+
validation_video_relative_path: video_dataset/val_sample_dataset # No need to change, Fixed
|
| 21 |
+
validation_ID_relative_path: video_dataset/val_ID_FrameIn # No need to change, Fixed
|
| 22 |
+
dataloader_num_workers: 4 # This should be per GPU
|
| 23 |
+
# height_range: [480, 704] # Height Range; By slightly modify the dataloader code and use this setting, we can use variable resolution training
|
| 24 |
+
target_height: 480
|
| 25 |
+
target_width: 720
|
| 26 |
+
sample_accelerate_factor: 2 # Imitate 12FPS we have set before.
|
| 27 |
+
train_frame_num_range: [49, 49] # Number of frames for the trianing, required to be 4N+1
|
| 28 |
+
min_train_frame_num: 49 # If it is less than this number, the dataloader will raise Exception and skip to the next one valid! We recommand CogVideoX to use exactly 49 frames.
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# Motion Setting
|
| 32 |
+
dot_radius: 6 # This is set with respect to 384 height pixel, will be adjust based on the height change
|
| 33 |
+
point_keep_ratio_regular: 0.33 # Less points than motion control; The Ratio of points left for points inside the region box; For Non-main Object Motion
|
| 34 |
+
faster_motion_prob: 0.0 # Whether we support faster (~8FPS), 0.0 - 0.1 is also recomended (0.0 by default).
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# Frame In and Out Setting
|
| 38 |
+
drop_FrameIn_prob: 0.15 # This is the cases where we only has FrameOut occur, FrameIn will be whole whilte place holder (Recommend: 0.15)
|
| 39 |
+
point_keep_ratio_ID: 0.33 # The Ratio of points left for new ID introduced
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Denoise + Text Setting
|
| 43 |
+
noised_image_dropout: 0.05 # No First Frame Setting, becomes T2V
|
| 44 |
+
empty_text_prompt: False # FOR TI2V, we needs to use text prompt
|
| 45 |
+
text_mask_ratio: 0.05 # Follow InstructPix2Pix
|
| 46 |
+
max_text_seq_length: 226
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Training Setting
|
| 50 |
+
resume_from_checkpoint: False # latest / False; latest will automatically fetch the newest checkpoint
|
| 51 |
+
max_train_steps: 1002 # Based on the needs; This is just a demo dataset, so training low is not needed
|
| 52 |
+
train_batch_size: 1 # batch size per GPU
|
| 53 |
+
gradient_accumulation_steps: 2 # This should be set to 1 usually.
|
| 54 |
+
checkpointing_steps: 2000 # Check point frequeuncy, don't recommend to be too frequent
|
| 55 |
+
checkpoints_total_limit: 8 # Transformer are too large, this size is too big (~32 GB per checkpoint)
|
| 56 |
+
mixed_precision: bf16 # CogvideoX official code usaully use bf16
|
| 57 |
+
gradient_checkpointing: True # This will save the memory but slower; Even if I have 80GB memory, this is still needed to open; else, OOM
|
| 58 |
+
seed: # 如果这里set seed了;你每次resume都跟resume前的data 读取顺序完全一致;如果连一个epoch都没train,那就每次同样数据循环
|
| 59 |
+
output_folder: checkpoints/
|
| 60 |
+
logging_name: logging
|
| 61 |
+
nccl_timeout: 1800
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# Validation Setting
|
| 65 |
+
validation_step: 2000 # Don't set too frequent, which will be very resource consuming
|
| 66 |
+
first_iter_validation: True # Whether we do the first iter validation
|
| 67 |
+
num_inference_steps: 50
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# Learning Rate and Optimizer
|
| 72 |
+
optimizer: adamw # Choose between ["adam", "adamw", "prodigy"]
|
| 73 |
+
learning_rate: 2e-5 # 1e-4 might be too big
|
| 74 |
+
scale_lr: False
|
| 75 |
+
lr_scheduler: constant_with_warmup # Most cases should be constant
|
| 76 |
+
adam_beta1: 0.9
|
| 77 |
+
adam_beta2: 0.95 # In the past, this used to be 0.999; smaller than usual
|
| 78 |
+
adam_beta3: 0.98
|
| 79 |
+
lr_power: 1.0
|
| 80 |
+
lr_num_cycles: 1.0
|
| 81 |
+
max_grad_norm: 1.0
|
| 82 |
+
prodigy_beta3: # Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2
|
| 83 |
+
# use_8bit_adam: False # This saves a lot of GPU memory, but slightly slower
|
| 84 |
+
adam_weight_decay: 1e-04
|
| 85 |
+
adam_epsilon: 1e-08
|
| 86 |
+
lr_warmup_steps: 400
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# Other Setting
|
| 91 |
+
report_to: tensorboard
|
| 92 |
+
allow_tf32: True
|
| 93 |
+
revision:
|
| 94 |
+
variant:
|
| 95 |
+
cache_dir:
|
| 96 |
+
tracker_name:
|
config/train_wan_motion.yaml
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
experiment_name: Wan_5B_Motion_704P
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# Model Setting
|
| 6 |
+
base_model_path: Wan-AI/Wan2.2-TI2V-5B-Diffusers
|
| 7 |
+
pretrained_transformer_path: # No need to set; if you set, this will load transformer model with non-default Wan transformer
|
| 8 |
+
enable_slicing: True
|
| 9 |
+
enable_tiling: True
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# Dataset Setting
|
| 14 |
+
download_folder_path: FrameINO_data/ # Set the downloaded folder path, all the other csv will be read automatically
|
| 15 |
+
train_csv_relative_path: dataset_csv_files/train_sample_short_dataset # No need to change, Fixed
|
| 16 |
+
train_video_relative_path: video_dataset/train_sample_dataset # No need to change, Fixed
|
| 17 |
+
validation_csv_relative_path: dataset_csv_files/val_sample_short_dataset # No need to change, Fixed
|
| 18 |
+
validation_video_relative_path: video_dataset/val_sample_dataset # No need to change, Fixed
|
| 19 |
+
dataloader_num_workers: 4 # This should be per GPU; In Debug, we set to 1
|
| 20 |
+
# height_range: [480, 704] # Height Range; By slightly modify the dataloader code and use this setting, we can use variable resolution training
|
| 21 |
+
target_height: 704
|
| 22 |
+
target_width: 1280
|
| 23 |
+
sample_accelerate_factor: 2 # Imitate 12FPS we have set before.
|
| 24 |
+
train_frame_num_range: [81, 81] # Number of frames for the trianing, required to be 4N+1; If the total number of files is less than the min range, just use the minimum available; Now, set to 81 Frames
|
| 25 |
+
# min_train_frame_num: 49 # If it is less than this number, the dataloader will raise Exception and skip to the next one valid!
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# Motion Setting
|
| 29 |
+
dot_radius: 7 # Due to the VAE of Wan, this is slightly larger than CogVideoX; this is set with respect to 384 height pixel, will be adjust based on the height change
|
| 30 |
+
point_keep_ratio: 0.4 # The ratio of points left; Likelyhood by random.choices for each tracking point, so it can be quite versatile; 0.33 is also recommended
|
| 31 |
+
faster_motion_prob: 0.0 # Whether we support faster (~8FPS), 0.0 - 0.1 is also recomended (0.0 by default).
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# Denoise (For Flow Matchin-based)
|
| 35 |
+
noised_image_dropout: 0.0 # No First Frame Setting, becomes T2V; not used for Wan
|
| 36 |
+
train_sampling_steps: 1000
|
| 37 |
+
noise_scheduler_kwargs:
|
| 38 |
+
num_train_timesteps: 1000 # 1000 is the default value
|
| 39 |
+
shift: 5.0
|
| 40 |
+
use_dynamic_shifting: false # false is the default value
|
| 41 |
+
base_shift: 0.5 # 0.5 is the default value
|
| 42 |
+
max_shift: 1.15 # 1.15 is the default value
|
| 43 |
+
base_image_seq_len: 256 # 256 is the default value
|
| 44 |
+
max_image_seq_len: 4096 # 4096 is the default value
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# Text Setting
|
| 48 |
+
text_mask_ratio: 0.0 # Follow InstructPix2Pix
|
| 49 |
+
empty_text_prompt: False # FOR TI2V, we start using text prompt
|
| 50 |
+
max_text_seq_length: 512 # For the Wan
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# Training Setting
|
| 55 |
+
resume_from_checkpoint: False # latest / False; latest will automatically fetch the newest checkpoint
|
| 56 |
+
max_train_steps: 1002 # Based on the needs; This is just a demo dataset, so training low is not needed
|
| 57 |
+
train_batch_size: 1 # batch size per GPU
|
| 58 |
+
gradient_accumulation_steps: 2 # Equivalent to multi batch size; Total GPU
|
| 59 |
+
checkpointing_steps: 2000 # Check point frequeuncy, don't recommend to be too frequent
|
| 60 |
+
checkpoints_total_limit: 8 # Transformer are too large, this size is too big (~32 GB per checkpoint)
|
| 61 |
+
mixed_precision: bf16 # CogvideoX official code usaully use bf16
|
| 62 |
+
gradient_checkpointing: True # This will save the memory but slower; Even if I have 80GB memory, this is still needed to open; else, OOM
|
| 63 |
+
seed: # If we set seed here, the reading of the data in each resume will be the same as the first time, which cannot train full dataset in resume mode
|
| 64 |
+
output_folder: checkpoints/
|
| 65 |
+
logging_name: logging
|
| 66 |
+
nccl_timeout: 1800
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# Validation Setting
|
| 71 |
+
validation_step: 2000 # Don't set too frequent, which will be very resource consuming
|
| 72 |
+
first_iter_validation: True # Whether we do the first iter validation
|
| 73 |
+
num_inference_steps: 38
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# Learning Rate and Optimizer
|
| 78 |
+
optimizer: adamw # Choose between ["adam", "adamw", "prodigy"]
|
| 79 |
+
learning_rate: 3e-5 # 1e-4 might be too big
|
| 80 |
+
scale_lr: False
|
| 81 |
+
lr_scheduler: constant_with_warmup # Most cases should be constant
|
| 82 |
+
adam_beta1: 0.9 # This Setting is different from CogVideoX, we follow VideoFun
|
| 83 |
+
adam_beta2: 0.999
|
| 84 |
+
# adam_beta3: 0.98
|
| 85 |
+
lr_power: 1.0
|
| 86 |
+
lr_num_cycles: 1.0
|
| 87 |
+
initial_grad_norm_ratio: 5
|
| 88 |
+
abnormal_norm_clip_start: 1000 # Follow VideoFun
|
| 89 |
+
max_grad_norm: 0.05 # Follow VideoFun
|
| 90 |
+
prodigy_beta3: # Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2
|
| 91 |
+
# use_8bit_adam: False # This saves a lot of GPU memory, but slightly slower; Recommend to open
|
| 92 |
+
adam_weight_decay: 1e-4
|
| 93 |
+
adam_epsilon: 1e-10
|
| 94 |
+
lr_warmup_steps: 100
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# Other Setting
|
| 99 |
+
report_to: tensorboard
|
| 100 |
+
allow_tf32: True
|
| 101 |
+
revision:
|
| 102 |
+
variant:
|
| 103 |
+
cache_dir:
|
| 104 |
+
tracker_name:
|
config/train_wan_motion_FrameINO.yaml
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
experiment_name: Wan_5B_Motion_FINO_704P
|
| 3 |
+
|
| 4 |
+
# Model Setting
|
| 5 |
+
base_model_path: Wan-AI/Wan2.2-TI2V-5B-Diffusers
|
| 6 |
+
pretrained_transformer_path: uva-cv-lab/FrameINO_Wan2.2_5B_Stage1_Motion_v1.5 # Use the one trained with the motion
|
| 7 |
+
enable_slicing: True
|
| 8 |
+
enable_tiling: True
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Dataset Setting
|
| 13 |
+
download_folder_path: FrameINO_data/ # Set the downloaded folder path, all the other csv will be read automatically
|
| 14 |
+
train_csv_relative_path: dataset_csv_files/train_sample_short_dataset # No need to change, Fixed
|
| 15 |
+
train_video_relative_path: video_dataset/train_sample_dataset # No need to change, Fixed
|
| 16 |
+
train_ID_relative_path: video_dataset/train_ID_FrameIn # No need to change, Fixed
|
| 17 |
+
validation_csv_relative_path: dataset_csv_files/val_sample_short_dataset # No need to change, Fixed
|
| 18 |
+
validation_video_relative_path: video_dataset/val_sample_dataset # No need to change, Fixed
|
| 19 |
+
validation_ID_relative_path: video_dataset/val_ID_FrameIn # No need to change, Fixed
|
| 20 |
+
dataloader_num_workers: 4 # This should be per GPU In Debug, we set to 1
|
| 21 |
+
# height_range: [480, 704] # Height Range; By slightly modify the dataloader code and use this setting, we can use variable resolution training
|
| 22 |
+
target_height: 704 # Recommend 704 x 1280 for the Wan2.2
|
| 23 |
+
target_width: 1280
|
| 24 |
+
sample_accelerate_factor: 2 # Imitate 12FPS we have set before.
|
| 25 |
+
train_frame_num_range: [81, 81] # Number of frames for the trianing, required to be 4N+1
|
| 26 |
+
min_train_frame_num: 49 # If it is less than this number, the dataloader will raise Exception and skip to the next one valid!
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Motion Setting
|
| 30 |
+
dot_radius: 7 # Due to VAE of Wan, this is slightly larger than CogVideoX; this is set with respect to 384 height pixel, will be adjust based on the height change
|
| 31 |
+
point_keep_ratio_regular: 0.33 # Less points than motion control; The Ratio of points left for points inside the region box; For Non-main Object Motion
|
| 32 |
+
faster_motion_prob: 0.0 # Whether we support faster (~8FPS), 0.0 - 0.1 is also recomended (0.0 by default).
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# Frame In and Out Setting
|
| 36 |
+
drop_FrameIn_prob: 0.15 # This is the cases where we only has FrameOut occur; ID tokens will be filled with whole whilte place holder (Recommend value: 0.15)
|
| 37 |
+
point_keep_ratio_ID: 0.33 # The Ratio of points left for new ID introduced; For Main ID Object Motion
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Denoise
|
| 41 |
+
noised_image_dropout: 0.0 # No First Frame Setting, becomes T2V; not used for Wan
|
| 42 |
+
train_sampling_steps: 1000
|
| 43 |
+
noise_scheduler_kwargs:
|
| 44 |
+
num_train_timesteps: 1000 # 1000 is the default value
|
| 45 |
+
shift: 5.0
|
| 46 |
+
use_dynamic_shifting: false # false is the default value
|
| 47 |
+
base_shift: 0.5 # 0.5 is the default value
|
| 48 |
+
max_shift: 1.15 # 1.15 is the default value
|
| 49 |
+
base_image_seq_len: 256 # 256 is the default value
|
| 50 |
+
max_image_seq_len: 4096 # 4096 is the default value
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# Text Setting
|
| 54 |
+
text_mask_ratio: 0.0 # Follow InstructPix2Pix, Currently, we set to 0; At most 0.05 is recommeneded
|
| 55 |
+
empty_text_prompt: False # FOR TI2V, we needs to use text prompt
|
| 56 |
+
max_text_seq_length: 512 # For the Wan
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# Training setting
|
| 61 |
+
resume_from_checkpoint: False # latest / False; latest will automatically fetch the newest checkpoint
|
| 62 |
+
max_train_steps: 1002 # Based on the needs; This is just a demo dataset, so training low is not needed
|
| 63 |
+
train_batch_size: 1 # batch size per GPU
|
| 64 |
+
gradient_accumulation_steps: 2 # This should be set to 1 usually.
|
| 65 |
+
checkpointing_steps: 2000 # Check point frequeuncy, don't recommend to be too frequent
|
| 66 |
+
checkpoints_total_limit: 8 # Transformer are too large, this size is too big (~32 GB per checkpoint)
|
| 67 |
+
mixed_precision: bf16 # CogvideoX official code usaully use bf16
|
| 68 |
+
gradient_checkpointing: True # This will save the memory but slower; Even if I have 80GB memory, this is still needed to open; else, OOM
|
| 69 |
+
seed: # 如果这里set seed了;你每次resume都跟resume前的data 读取顺序完全一致;如果连一个epoch都没train,那就每次同样数据循环
|
| 70 |
+
output_folder: checkpoints/
|
| 71 |
+
logging_name: logging
|
| 72 |
+
nccl_timeout: 1800
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Validation Setting
|
| 77 |
+
validation_step: 2000 # Don't set too frequent, which will be very resource consuming
|
| 78 |
+
first_iter_validation: True # Whether we do the first iter validation
|
| 79 |
+
num_inference_steps: 38
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# Learning Rate and Optimizer
|
| 84 |
+
optimizer: adamw # Choose between ["adam", "adamw", "prodigy"]
|
| 85 |
+
learning_rate: 3e-5 # 1e-4 might be too big
|
| 86 |
+
scale_lr: False
|
| 87 |
+
lr_scheduler: constant_with_warmup # Most cases should be constant
|
| 88 |
+
adam_beta1: 0.9 # This Setting is different from CogVideoX, we follow VideoFun
|
| 89 |
+
adam_beta2: 0.999
|
| 90 |
+
# adam_beta3: 0.98
|
| 91 |
+
lr_power: 1.0
|
| 92 |
+
lr_num_cycles: 1.0
|
| 93 |
+
initial_grad_norm_ratio: 5
|
| 94 |
+
abnormal_norm_clip_start: 1000 # Follow VideoFun
|
| 95 |
+
max_grad_norm: 0.05 # Follow VideoFun
|
| 96 |
+
prodigy_beta3: # Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2
|
| 97 |
+
# use_8bit_adam: False # This saves a lot of GPU memory, but slightly slower
|
| 98 |
+
adam_weight_decay: 1e-4
|
| 99 |
+
adam_epsilon: 1e-10
|
| 100 |
+
lr_warmup_steps: 100
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# Other Setting
|
| 105 |
+
report_to: tensorboard
|
| 106 |
+
allow_tf32: True
|
| 107 |
+
revision:
|
| 108 |
+
variant:
|
| 109 |
+
cache_dir:
|
| 110 |
+
tracker_name:
|
data_loader/sampler.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Last modified: 2024-04-18
|
| 2 |
+
#
|
| 3 |
+
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
# --------------------------------------------------------------------------
|
| 17 |
+
# If you find this code useful, we kindly ask you to cite our paper in your work.
|
| 18 |
+
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
|
| 19 |
+
# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
|
| 20 |
+
# More information about the method can be found at https://marigoldmonodepth.github.io
|
| 21 |
+
# --------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from torch.utils.data import (
|
| 25 |
+
BatchSampler,
|
| 26 |
+
RandomSampler,
|
| 27 |
+
SequentialSampler,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class MixedBatchSampler(BatchSampler):
|
| 32 |
+
"""Sample one batch from a selected dataset with given probability.
|
| 33 |
+
Compatible with datasets at different resolution
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self, src_dataset_ls, batch_size, drop_last, shuffle, prob=None, generator=None
|
| 38 |
+
):
|
| 39 |
+
self.base_sampler = None
|
| 40 |
+
self.batch_size = batch_size
|
| 41 |
+
self.shuffle = shuffle
|
| 42 |
+
self.drop_last = drop_last
|
| 43 |
+
self.generator = generator
|
| 44 |
+
|
| 45 |
+
self.src_dataset_ls = src_dataset_ls
|
| 46 |
+
self.n_dataset = len(self.src_dataset_ls)
|
| 47 |
+
|
| 48 |
+
# Dataset length
|
| 49 |
+
self.dataset_length = [len(ds) for ds in self.src_dataset_ls]
|
| 50 |
+
self.cum_dataset_length = [
|
| 51 |
+
sum(self.dataset_length[:i]) for i in range(self.n_dataset)
|
| 52 |
+
] # cumulative dataset length
|
| 53 |
+
|
| 54 |
+
# BatchSamplers for each source dataset
|
| 55 |
+
if self.shuffle:
|
| 56 |
+
self.src_batch_samplers = [
|
| 57 |
+
BatchSampler(
|
| 58 |
+
sampler=RandomSampler(
|
| 59 |
+
ds, replacement=False, generator=self.generator
|
| 60 |
+
),
|
| 61 |
+
batch_size=self.batch_size,
|
| 62 |
+
drop_last=self.drop_last,
|
| 63 |
+
)
|
| 64 |
+
for ds in self.src_dataset_ls
|
| 65 |
+
]
|
| 66 |
+
else:
|
| 67 |
+
self.src_batch_samplers = [
|
| 68 |
+
BatchSampler(
|
| 69 |
+
sampler=SequentialSampler(ds),
|
| 70 |
+
batch_size=self.batch_size,
|
| 71 |
+
drop_last=self.drop_last,
|
| 72 |
+
)
|
| 73 |
+
for ds in self.src_dataset_ls
|
| 74 |
+
]
|
| 75 |
+
self.raw_batches = [
|
| 76 |
+
list(bs) for bs in self.src_batch_samplers
|
| 77 |
+
] # index in original dataset
|
| 78 |
+
self.n_batches = [len(b) for b in self.raw_batches]
|
| 79 |
+
self.n_total_batch = sum(self.n_batches)
|
| 80 |
+
# sampling probability
|
| 81 |
+
if prob is None:
|
| 82 |
+
# if not given, decide by dataset length
|
| 83 |
+
self.prob = torch.tensor(self.n_batches) / self.n_total_batch
|
| 84 |
+
else:
|
| 85 |
+
self.prob = torch.as_tensor(prob)
|
| 86 |
+
|
| 87 |
+
def __iter__(self):
|
| 88 |
+
"""_summary_
|
| 89 |
+
|
| 90 |
+
Yields:
|
| 91 |
+
list(int): a batch of indics, corresponding to ConcatDataset of src_dataset_ls
|
| 92 |
+
"""
|
| 93 |
+
for _ in range(self.n_total_batch):
|
| 94 |
+
idx_ds = torch.multinomial(
|
| 95 |
+
self.prob, 1, replacement=True, generator=self.generator
|
| 96 |
+
).item()
|
| 97 |
+
# if batch list is empty, generate new list
|
| 98 |
+
if 0 == len(self.raw_batches[idx_ds]):
|
| 99 |
+
self.raw_batches[idx_ds] = list(self.src_batch_samplers[idx_ds])
|
| 100 |
+
# get a batch from list
|
| 101 |
+
batch_raw = self.raw_batches[idx_ds].pop()
|
| 102 |
+
# shift by cumulative dataset length
|
| 103 |
+
shift = self.cum_dataset_length[idx_ds]
|
| 104 |
+
batch = [n + shift for n in batch_raw]
|
| 105 |
+
|
| 106 |
+
yield batch
|
| 107 |
+
|
| 108 |
+
def __len__(self):
|
| 109 |
+
return self.n_total_batch
|
| 110 |
+
|
data_loader/video_dataset_motion.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys, shutil
|
| 2 |
+
from typing import List, Optional, Tuple, Union
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import csv
|
| 5 |
+
import random
|
| 6 |
+
import math
|
| 7 |
+
import numpy as np
|
| 8 |
+
import ffmpeg
|
| 9 |
+
import json
|
| 10 |
+
import imageio
|
| 11 |
+
import collections
|
| 12 |
+
import cv2
|
| 13 |
+
import pdb
|
| 14 |
+
csv.field_size_limit(sys.maxsize) # Default setting is 131072, 100x expand should be enough
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch.utils.data import Dataset
|
| 18 |
+
from torchvision import transforms
|
| 19 |
+
|
| 20 |
+
# Import files from the local folder
|
| 21 |
+
root_path = os.path.abspath('.')
|
| 22 |
+
sys.path.append(root_path)
|
| 23 |
+
from utils.optical_flow_utils import flow_to_image, filter_uv, bivariate_Gaussian
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# Init paramter and global shared setting
|
| 27 |
+
|
| 28 |
+
# Blurring Kernel
|
| 29 |
+
blur_kernel = bivariate_Gaussian(45, 3, 3, 0, grid = None, isotropic = True)
|
| 30 |
+
|
| 31 |
+
# Color
|
| 32 |
+
all_color_codes = [(255, 0, 0), (255, 255, 0), (0, 255, 0), (0, 255, 255),
|
| 33 |
+
(255, 0, 255), (0, 0, 255), (128, 128, 128), (64, 224, 208),
|
| 34 |
+
(233, 150, 122)]
|
| 35 |
+
for _ in range(100): # Should not be over 100 colors
|
| 36 |
+
all_color_codes.append((random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)))
|
| 37 |
+
|
| 38 |
+
# Data Transforms
|
| 39 |
+
train_transforms = transforms.Compose(
|
| 40 |
+
[
|
| 41 |
+
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
|
| 42 |
+
]
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class VideoDataset_Motion(Dataset):
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
config,
|
| 52 |
+
download_folder_path,
|
| 53 |
+
csv_relative_path,
|
| 54 |
+
video_relative_path,
|
| 55 |
+
is_diy_test = False,
|
| 56 |
+
) -> None:
|
| 57 |
+
super().__init__()
|
| 58 |
+
|
| 59 |
+
# Gen Size Settings
|
| 60 |
+
# self.height_range = config["height_range"]
|
| 61 |
+
# self.max_aspect_ratio = config["max_aspect_ratio"]
|
| 62 |
+
self.target_height = config["target_height"]
|
| 63 |
+
self.target_width = config["target_width"]
|
| 64 |
+
self.sample_accelerate_factor = config["sample_accelerate_factor"]
|
| 65 |
+
self.train_frame_num_range = config["train_frame_num_range"]
|
| 66 |
+
|
| 67 |
+
# Condition Settings (Text, Motion, etc.)
|
| 68 |
+
self.empty_text_prompt = config["empty_text_prompt"]
|
| 69 |
+
self.dot_radius = int(config["dot_radius"])
|
| 70 |
+
self.point_keep_ratio = config["point_keep_ratio"] # Point selection mechanism
|
| 71 |
+
self.faster_motion_prob = config["faster_motion_prob"]
|
| 72 |
+
|
| 73 |
+
# Other Settings
|
| 74 |
+
self.download_folder_path = download_folder_path
|
| 75 |
+
self.is_diy_test = is_diy_test
|
| 76 |
+
self.config = config
|
| 77 |
+
self.video_folder_path = os.path.join(download_folder_path, video_relative_path)
|
| 78 |
+
csv_folder_path = os.path.join(download_folder_path, csv_relative_path)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# Sanity Check
|
| 82 |
+
assert(os.path.exists(csv_folder_path))
|
| 83 |
+
assert(self.point_keep_ratio <= 1.0)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# Read the CSV files
|
| 88 |
+
info_lists = []
|
| 89 |
+
for csv_file_name in os.listdir(csv_folder_path): # Read all csv files
|
| 90 |
+
csv_file_path = os.path.join(csv_folder_path, csv_file_name)
|
| 91 |
+
|
| 92 |
+
with open(csv_file_path) as file_obj:
|
| 93 |
+
reader_obj = csv.reader(file_obj)
|
| 94 |
+
|
| 95 |
+
# Iterate over each row in the csv
|
| 96 |
+
for idx, row in enumerate(reader_obj):
|
| 97 |
+
if idx == 0:
|
| 98 |
+
elements = dict()
|
| 99 |
+
for element_idx, key in enumerate(row):
|
| 100 |
+
elements[key] = element_idx
|
| 101 |
+
continue
|
| 102 |
+
|
| 103 |
+
# Read the important information
|
| 104 |
+
info_lists.append(row)
|
| 105 |
+
|
| 106 |
+
# Organize
|
| 107 |
+
self.info_lists = info_lists
|
| 108 |
+
self.element_idx_dict = elements
|
| 109 |
+
|
| 110 |
+
# Log
|
| 111 |
+
print("The number of videos for ", csv_folder_path, " is ", len(self.info_lists))
|
| 112 |
+
# print("The memory cost is ", sys.getsizeof(self.info_lists))
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def __len__(self):
|
| 116 |
+
return len(self.info_lists)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
@staticmethod
|
| 120 |
+
def prepare_traj_tensor(full_pred_tracks, original_height, original_width, selected_frames,
|
| 121 |
+
dot_radius, target_width, target_height, idx = 0, first_frame_img = None):
|
| 122 |
+
|
| 123 |
+
# Prepare the color
|
| 124 |
+
target_color_codes = all_color_codes[:len(full_pred_tracks[0])] # This means how many objects in total we have
|
| 125 |
+
|
| 126 |
+
# Prepare the traj image
|
| 127 |
+
traj_img_lists = []
|
| 128 |
+
|
| 129 |
+
# Set a new dot radius based on the resolution fluctuating
|
| 130 |
+
dot_radius_resize = int( dot_radius * original_height / 384 ) # This is set with respect to default 384 height, will be adjust based on the height change
|
| 131 |
+
|
| 132 |
+
# Prepare base draw image if there is
|
| 133 |
+
if first_frame_img is not None:
|
| 134 |
+
img_with_traj = first_frame_img.copy()
|
| 135 |
+
|
| 136 |
+
# Iterate all temporal sequence
|
| 137 |
+
merge_frames = []
|
| 138 |
+
for temporal_idx, points_per_frame in enumerate(full_pred_tracks): # Iterate all downsampled frames, should be 13
|
| 139 |
+
|
| 140 |
+
# Init the base img for the traj figures
|
| 141 |
+
base_img = np.zeros((original_height, original_width, 3)).astype(np.float32) # Use the original image size
|
| 142 |
+
base_img.fill(255) # Whole white frames
|
| 143 |
+
|
| 144 |
+
# Iterate all points in each object
|
| 145 |
+
for obj_idx, points_per_obj in enumerate(points_per_frame):
|
| 146 |
+
|
| 147 |
+
# Basic setting
|
| 148 |
+
color_code = target_color_codes[obj_idx] # Color across frames should be consistent
|
| 149 |
+
|
| 150 |
+
# Process all points in this current object
|
| 151 |
+
for (horizontal, vertical) in points_per_obj:
|
| 152 |
+
if horizontal < 0 or horizontal >= original_width or vertical < 0 or vertical >= original_height:
|
| 153 |
+
continue # If the point is already out of the range, Don't draw
|
| 154 |
+
|
| 155 |
+
# Draw square around the target position
|
| 156 |
+
vertical_start = min(original_height, max(0, vertical - dot_radius_resize))
|
| 157 |
+
vertical_end = min(original_height, max(0, vertical + dot_radius_resize)) # Diameter, used to be 10, but want smaller if there are too many points now
|
| 158 |
+
horizontal_start = min(original_width, max(0, horizontal - dot_radius_resize))
|
| 159 |
+
horizontal_end = min(original_width, max(0, horizontal + dot_radius_resize))
|
| 160 |
+
|
| 161 |
+
# Paint
|
| 162 |
+
base_img[vertical_start:vertical_end, horizontal_start:horizontal_end, :] = color_code
|
| 163 |
+
|
| 164 |
+
# Draw the visual of traj if needed
|
| 165 |
+
if first_frame_img is not None:
|
| 166 |
+
img_with_traj[vertical_start:vertical_end, horizontal_start:horizontal_end, :] = color_code
|
| 167 |
+
|
| 168 |
+
# Resize frames Don't use negative and don't resize in [0,1]
|
| 169 |
+
base_img = cv2.resize(base_img, (target_width, target_height), interpolation = cv2.INTER_CUBIC)
|
| 170 |
+
|
| 171 |
+
# Dilate (Default to be True)
|
| 172 |
+
base_img = cv2.filter2D(base_img, -1, blur_kernel).astype(np.uint8)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# Append selected_frames and the color together for visualization
|
| 176 |
+
if len(selected_frames) != 0:
|
| 177 |
+
merge_frame = selected_frames[temporal_idx].copy()
|
| 178 |
+
merge_frame[base_img < 250] = base_img[base_img < 250]
|
| 179 |
+
merge_frames.append(merge_frame)
|
| 180 |
+
# cv2.imwrite("Video"+str(idx) + "_traj" + str(temporal_idx).zfill(2) + ".png", cv2.cvtColor(merge_frame, cv2.COLOR_RGB2BGR)) # Comment Out Later
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
# Append to the temporal index
|
| 184 |
+
traj_img_lists.append(base_img)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# Convert to tensor
|
| 188 |
+
traj_imgs_np = np.array(traj_img_lists)
|
| 189 |
+
traj_tensor = torch.tensor(traj_imgs_np)
|
| 190 |
+
|
| 191 |
+
# Transform
|
| 192 |
+
traj_tensor = traj_tensor.float()
|
| 193 |
+
traj_tensor = torch.stack([train_transforms(traj_frame) for traj_frame in traj_tensor], dim=0)
|
| 194 |
+
traj_tensor = traj_tensor.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
# Write to video (Comment Out Later)
|
| 198 |
+
# imageio.mimsave("merge_cond" + str(idx) + ".mp4", merge_frames, fps=12)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# Return
|
| 202 |
+
merge_frames = np.array(merge_frames)
|
| 203 |
+
if first_frame_img is not None:
|
| 204 |
+
return traj_tensor, traj_imgs_np, merge_frames, img_with_traj
|
| 205 |
+
else:
|
| 206 |
+
return traj_tensor, traj_imgs_np, merge_frames # Need to return traj_imgs_np for other purpose
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def __getitem__(self, idx):
|
| 211 |
+
|
| 212 |
+
while True: # Iterate until there is a valid video read
|
| 213 |
+
|
| 214 |
+
# try:
|
| 215 |
+
|
| 216 |
+
# Fetch the information
|
| 217 |
+
info = self.info_lists[idx]
|
| 218 |
+
video_path = os.path.join(self.video_folder_path, info[self.element_idx_dict["video_path"]])
|
| 219 |
+
original_height = int(info[self.element_idx_dict["height"]])
|
| 220 |
+
original_width = int(info[self.element_idx_dict["width"]])
|
| 221 |
+
# num_frames = int(info[self.element_idx_dict["num_frames"]]) # Deprecated, this is about the whole frame duration, not just one
|
| 222 |
+
|
| 223 |
+
valid_duration = json.loads(info[self.element_idx_dict["valid_duration"]])
|
| 224 |
+
All_Frame_Panoptic_Segmentation = json.loads(info[self.element_idx_dict["Panoptic_Segmentation"]])
|
| 225 |
+
text_prompt_all = json.loads(info[self.element_idx_dict["Structured_Text_Prompt"]])
|
| 226 |
+
Track_Traj_all = json.loads(info[self.element_idx_dict["Track_Traj"]]) # NOTE: Same as above, only consider the first panoptic segmented frame
|
| 227 |
+
Obj_Info_all = json.loads(info[self.element_idx_dict["Obj_Info"]])
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# Sanity check
|
| 231 |
+
if not os.path.exists(video_path):
|
| 232 |
+
raise Exception("This video path", video_path, "doesn't exists!")
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
########################################## Mangage Resolution and selected Clip Setting ##########################################
|
| 236 |
+
|
| 237 |
+
# Option1: Variable Resolution Gen
|
| 238 |
+
# # Check the resolution size
|
| 239 |
+
# aspect_ratio = min(self.max_aspect_ratio, original_width / original_height)
|
| 240 |
+
# target_height_raw = min(original_height, random.randint(*self.height_range))
|
| 241 |
+
# target_width_raw = min(original_width, int(target_height_raw * aspect_ratio))
|
| 242 |
+
# # Must be the multiplier of 32
|
| 243 |
+
# target_height = (target_height_raw // 32) * 32
|
| 244 |
+
# target_width = (target_width_raw // 32) * 32
|
| 245 |
+
# print("New Height and Width are ", target_height, target_width)
|
| 246 |
+
|
| 247 |
+
# Option2: Fixed Resolution Gen (Assume that the provided is 32x valid)
|
| 248 |
+
target_width = self.target_width
|
| 249 |
+
target_height = self.target_height
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
# Only choose the first clip
|
| 253 |
+
Obj_Info = Obj_Info_all[0] # For the Motion Training, we have enough dataset, so we just choose the first panoptic section
|
| 254 |
+
Track_Traj = Track_Traj_all[0]
|
| 255 |
+
text_prompt = text_prompt_all[0]
|
| 256 |
+
resolution = str(target_width) + "x" + str(target_height) # Used for ffmpeg load
|
| 257 |
+
frame_start_idx = Obj_Info[0][1] # NOTE: If there is multiple objects Obj_Info[X][1] should be the same
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
##############################################################################################################################
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
############################################## Read the video by ffmpeg #################################################
|
| 265 |
+
|
| 266 |
+
# Read the video by ffmpeg in the needed decode fps and resolution
|
| 267 |
+
video_stream, err = ffmpeg.input(
|
| 268 |
+
video_path
|
| 269 |
+
).output(
|
| 270 |
+
"pipe:", format = "rawvideo", pix_fmt = "rgb24", s = resolution, vsync = 'passthrough',
|
| 271 |
+
).run(
|
| 272 |
+
capture_stdout = True, capture_stderr = True # If there is bug, command capture_stderr
|
| 273 |
+
) # The resize is already included
|
| 274 |
+
video_np_full = np.frombuffer(video_stream, np.uint8).reshape(-1, target_height, target_width, 3)
|
| 275 |
+
|
| 276 |
+
# Fetch the valid duration
|
| 277 |
+
video_np = video_np_full[valid_duration[0] : valid_duration[1]]
|
| 278 |
+
valid_num_frames = len(video_np) # Update the number of frames
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
# Decide the accelerate factor
|
| 282 |
+
train_frame_num_raw = random.randint(*self.train_frame_num_range)
|
| 283 |
+
if frame_start_idx + 3 * train_frame_num_raw < valid_num_frames and random.random() < self.faster_motion_prob: # Should be (1) have enough frames and (2) in 10% probability
|
| 284 |
+
sample_accelerate_factor = self.sample_accelerate_factor + 1 # Hard Code
|
| 285 |
+
else:
|
| 286 |
+
sample_accelerate_factor = self.sample_accelerate_factor
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
# Check the number of frames needed this time
|
| 290 |
+
frame_end_idx = min(valid_num_frames, frame_start_idx + sample_accelerate_factor * train_frame_num_raw)
|
| 291 |
+
frame_end_idx = frame_start_idx + 4 * math.floor(( (frame_end_idx-frame_start_idx) - 1) / 4) + 1 # Rounded to the closest 4N + 1 size
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# Select Frames and Convert to Tensor
|
| 295 |
+
selected_frames = video_np[ frame_start_idx : frame_end_idx : sample_accelerate_factor] # NOTE: start from the first frame
|
| 296 |
+
video_tensor = torch.tensor(selected_frames) # Convert to tensor
|
| 297 |
+
first_frame_np = selected_frames[0] # Needs to return for Validation
|
| 298 |
+
train_frame_num = len(video_tensor) # Read the actual number of frames from the video (Must be 4N+1)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
# Data transforms and shape organize
|
| 302 |
+
video_tensor = video_tensor.float()
|
| 303 |
+
video_tensor = torch.stack([train_transforms(frame) for frame in video_tensor], dim=0)
|
| 304 |
+
video_tensor = video_tensor.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
#############################################################################################################################
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
######################################### Define the text prompt #######################################################
|
| 312 |
+
|
| 313 |
+
# NOTE: text prompt is fetched above; here, we just decide if we you empty string
|
| 314 |
+
if self.empty_text_prompt or random.random() < self.config["text_mask_ratio"]:
|
| 315 |
+
text_prompt = ""
|
| 316 |
+
# print("Text Prompt for Video", idx, " is ", text_prompt)
|
| 317 |
+
|
| 318 |
+
########################################################################################################################
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
###################### Prepare the Tracking points for each object (each object has different color) #################################
|
| 323 |
+
|
| 324 |
+
# Iterate all the segmentation info
|
| 325 |
+
full_pred_tracks = [[] for _ in range(train_frame_num)] # The dim should be: (temporal, object, points, xy) The fps should be fixed to 12 fps, which is the same as training decode fps
|
| 326 |
+
for track_obj_idx in range(len(Obj_Info)):
|
| 327 |
+
|
| 328 |
+
# Read the basic info
|
| 329 |
+
text_name, frame_idx_raw = Obj_Info[track_obj_idx] # This is expected to be all the same in the video
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
# Sanity Check: make sure that the number of frames is consistent
|
| 333 |
+
if track_obj_idx > 0:
|
| 334 |
+
if frame_idx_raw != previous_frame_idx_raw:
|
| 335 |
+
raise Exception("The panoptic_frame_idx cannot pass the sanity check")
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
# Prepare the tracjectory
|
| 339 |
+
pred_tracks_full = Track_Traj[track_obj_idx]
|
| 340 |
+
pred_tracks = pred_tracks_full[ frame_start_idx : frame_end_idx : sample_accelerate_factor]
|
| 341 |
+
if len(pred_tracks) != train_frame_num:
|
| 342 |
+
raise Exception("The length of tracking images does not match the video GT.")
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
# Randomly select the points based on the prob given, here, the number of points is different for each objeects
|
| 346 |
+
kept_point_status = random.choices([True, False], weights = [self.point_keep_ratio, 1 - self.point_keep_ratio], k = len(pred_tracks[0]))
|
| 347 |
+
if len(kept_point_status) != len(pred_tracks[-1]):
|
| 348 |
+
raise Exception("The number of points filterred is not match with the dataset")
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
# Iterate and add all temporally
|
| 352 |
+
for temporal_idx, pred_track in enumerate(pred_tracks):
|
| 353 |
+
|
| 354 |
+
# Iterate all point one by one
|
| 355 |
+
left_points = []
|
| 356 |
+
for point_idx in range(len(pred_track)):
|
| 357 |
+
if kept_point_status[point_idx]:
|
| 358 |
+
left_points.append(pred_track[point_idx])
|
| 359 |
+
# Append the left points to the list
|
| 360 |
+
full_pred_tracks[temporal_idx].append(left_points) # pred_tracks will be 49 frames, and each one represent all tracking points for single objects; only one object here
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
# Other update
|
| 364 |
+
previous_frame_idx_raw = frame_idx_raw
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
# Draw the dilated traj points
|
| 368 |
+
traj_tensor, traj_imgs_np, merge_frames = self.prepare_traj_tensor(full_pred_tracks, original_height, original_width, selected_frames,
|
| 369 |
+
self.dot_radius, target_width, target_height, idx)
|
| 370 |
+
|
| 371 |
+
# Sanity Check to make sure that the traj tensor and ground truth has the same number of frames
|
| 372 |
+
if len(traj_tensor) != len(video_tensor): # If this two cannot match, the torch.cat on latents will fail
|
| 373 |
+
raise Exception("Traj length and Video length does not matched!")
|
| 374 |
+
|
| 375 |
+
#########################################################################################################################################
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
# except Exception as e: # Note: You can uncomment this part to jump failure cases in mass training.
|
| 379 |
+
# print("The exception is ", e)
|
| 380 |
+
# old_idx = idx
|
| 381 |
+
# idx = (idx + 1) % len(self.info_lists)
|
| 382 |
+
# print("We cannot process the video", old_idx, " and we choose a new idx of ", idx)
|
| 383 |
+
# continue # For any error occurs, we run it again with new idx proposed (a random int less than current value)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
# If everything is ok, we should break at the end
|
| 387 |
+
break
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
# Return the information
|
| 391 |
+
return {
|
| 392 |
+
"video_tensor": video_tensor,
|
| 393 |
+
"traj_tensor": traj_tensor,
|
| 394 |
+
"text_prompt": text_prompt,
|
| 395 |
+
|
| 396 |
+
# The rest are auxiliary data for the validation/testing purposes
|
| 397 |
+
"video_gt_np": selected_frames,
|
| 398 |
+
"first_frame_np": first_frame_np,
|
| 399 |
+
"traj_imgs_np": traj_imgs_np,
|
| 400 |
+
"merge_frames": merge_frames,
|
| 401 |
+
"gt_video_path": video_path,
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
|
data_loader/video_dataset_motion_FrameINO.py
ADDED
|
@@ -0,0 +1,578 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys, shutil
|
| 2 |
+
from typing import List, Optional, Tuple, Union
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import csv
|
| 5 |
+
import random
|
| 6 |
+
import numpy as np
|
| 7 |
+
import ffmpeg
|
| 8 |
+
import json
|
| 9 |
+
import imageio
|
| 10 |
+
import collections
|
| 11 |
+
import cv2
|
| 12 |
+
import pdb
|
| 13 |
+
import math
|
| 14 |
+
import PIL.Image as Image
|
| 15 |
+
csv.field_size_limit(sys.maxsize) # Default setting is 131072, 100x expand should be enough
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch.utils.data import Dataset
|
| 19 |
+
from torchvision import transforms
|
| 20 |
+
|
| 21 |
+
# Import files from the local folder
|
| 22 |
+
root_path = os.path.abspath('.')
|
| 23 |
+
sys.path.append(root_path)
|
| 24 |
+
from utils.optical_flow_utils import flow_to_image, filter_uv, bivariate_Gaussian
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Init paramter and global shared setting
|
| 28 |
+
|
| 29 |
+
# Blurring Kernel
|
| 30 |
+
blur_kernel = bivariate_Gaussian(45, 3, 3, 0, grid = None, isotropic = True)
|
| 31 |
+
|
| 32 |
+
# Color
|
| 33 |
+
all_color_codes = [(255, 0, 0), (255, 255, 0), (0, 255, 0), (0, 255, 255),
|
| 34 |
+
(255, 0, 255), (0, 0, 255), (128, 128, 128), (64, 224, 208),
|
| 35 |
+
(233, 150, 122)]
|
| 36 |
+
for _ in range(100): # Should not be over 100 colors
|
| 37 |
+
all_color_codes.append((random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)))
|
| 38 |
+
|
| 39 |
+
# Data Transforms
|
| 40 |
+
train_transforms = transforms.Compose(
|
| 41 |
+
[
|
| 42 |
+
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
|
| 43 |
+
]
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class VideoDataset_Motion_FrameINO(Dataset):
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
config,
|
| 51 |
+
download_folder_path,
|
| 52 |
+
csv_relative_path,
|
| 53 |
+
video_relative_path,
|
| 54 |
+
ID_relative_path,
|
| 55 |
+
FrameOut_only = False,
|
| 56 |
+
one_point_one_obj = False,
|
| 57 |
+
strict_validation_match = False,
|
| 58 |
+
) -> None:
|
| 59 |
+
super().__init__()
|
| 60 |
+
|
| 61 |
+
# Gen Size Settings
|
| 62 |
+
# self.height_range = config["height_range"]
|
| 63 |
+
# self.max_aspect_ratio = config["max_aspect_ratio"]
|
| 64 |
+
self.target_height = config["target_height"]
|
| 65 |
+
self.target_width = config["target_width"]
|
| 66 |
+
self.sample_accelerate_factor = config["sample_accelerate_factor"]
|
| 67 |
+
self.train_frame_num_range = config["train_frame_num_range"]
|
| 68 |
+
self.min_train_frame_num = config["min_train_frame_num"]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# Condition Settings (Text, Motion, etc.)
|
| 72 |
+
self.empty_text_prompt = config["empty_text_prompt"]
|
| 73 |
+
self.dot_radius = int(config["dot_radius"])
|
| 74 |
+
self.point_keep_ratio_ID = config["point_keep_ratio_ID"]
|
| 75 |
+
self.point_keep_ratio_regular = config["point_keep_ratio_regular"]
|
| 76 |
+
self.faster_motion_prob = config["faster_motion_prob"]
|
| 77 |
+
|
| 78 |
+
# Other Settings
|
| 79 |
+
self.FrameOut_only = FrameOut_only
|
| 80 |
+
self.one_point_one_obj = one_point_one_obj # Currently, this only open when FrameOut_only = True
|
| 81 |
+
self.strict_validation_match = strict_validation_match
|
| 82 |
+
self.config = config
|
| 83 |
+
self.video_folder_path = os.path.join(download_folder_path, video_relative_path)
|
| 84 |
+
self.ID_folder_path = os.path.join(download_folder_path, ID_relative_path)
|
| 85 |
+
csv_folder_path = os.path.join(download_folder_path, csv_relative_path)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# Sanity Check
|
| 89 |
+
assert(os.path.exists(csv_folder_path))
|
| 90 |
+
assert(self.point_keep_ratio_ID <= 1.0)
|
| 91 |
+
assert(self.point_keep_ratio_regular <= 1.0)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# Read the CSV files
|
| 95 |
+
info_lists = []
|
| 96 |
+
for csv_file_name in os.listdir(csv_folder_path): # Read all csv files
|
| 97 |
+
csv_file_path = os.path.join(csv_folder_path, csv_file_name)
|
| 98 |
+
|
| 99 |
+
with open(csv_file_path) as file_obj:
|
| 100 |
+
reader_obj = csv.reader(file_obj)
|
| 101 |
+
|
| 102 |
+
# Iterate over each row in the csv
|
| 103 |
+
for idx, row in enumerate(reader_obj):
|
| 104 |
+
if idx == 0:
|
| 105 |
+
elements = dict()
|
| 106 |
+
for element_idx, key in enumerate(row):
|
| 107 |
+
elements[key] = element_idx
|
| 108 |
+
continue
|
| 109 |
+
|
| 110 |
+
# Read the important information
|
| 111 |
+
info_lists.append(row)
|
| 112 |
+
|
| 113 |
+
# Organize
|
| 114 |
+
self.info_lists = info_lists
|
| 115 |
+
self.element_idx_dict = elements
|
| 116 |
+
|
| 117 |
+
# Log
|
| 118 |
+
print("The number of videos for ", csv_folder_path, " is ", len(self.info_lists))
|
| 119 |
+
# print("The memory cost is ", sys.getsizeof(self.info_lists))
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def __len__(self):
|
| 123 |
+
return len(self.info_lists)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@staticmethod
|
| 127 |
+
def prepare_traj_tensor(full_pred_tracks, original_height, original_width, selected_frames,
|
| 128 |
+
dot_radius, target_width, target_height, region_box, idx = 0, first_frame_img = None):
|
| 129 |
+
|
| 130 |
+
# Prepare the color and other stuff
|
| 131 |
+
target_color_codes = all_color_codes[:len(full_pred_tracks[0])] # This means how many objects in total we have
|
| 132 |
+
(top_left_x, top_left_y), (bottom_right_x, bottom_right_y) = region_box
|
| 133 |
+
|
| 134 |
+
# Prepare the traj image
|
| 135 |
+
traj_img_lists = []
|
| 136 |
+
|
| 137 |
+
# Set a new dot radius based on the resolution fluctuating
|
| 138 |
+
dot_radius_resize = int( dot_radius * original_height / 384 ) # This is set with respect to default 384 height, will be adjust based on the height change
|
| 139 |
+
|
| 140 |
+
# Prepare base draw image if there is
|
| 141 |
+
if first_frame_img is not None:
|
| 142 |
+
img_with_traj = first_frame_img.copy()
|
| 143 |
+
|
| 144 |
+
# Iterate all object instance
|
| 145 |
+
merge_frames = []
|
| 146 |
+
for temporal_idx, obj_points in enumerate(full_pred_tracks): # Iterate all downsampled frames, should be 13
|
| 147 |
+
|
| 148 |
+
# Init the base img for the traj figures
|
| 149 |
+
base_img = np.zeros((original_height, original_width, 3)).astype(np.float32) # Use the original image size
|
| 150 |
+
base_img.fill(255) # Whole white frames
|
| 151 |
+
|
| 152 |
+
# Iterate for the per object
|
| 153 |
+
for obj_idx, points in enumerate(obj_points):
|
| 154 |
+
|
| 155 |
+
# Basic setting
|
| 156 |
+
color_code = target_color_codes[obj_idx] # Color across frames should be consistent
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# Process all points in this current object
|
| 160 |
+
for (horizontal, vertical) in points:
|
| 161 |
+
if horizontal < 0 or horizontal >= original_width or vertical < 0 or vertical >= original_height:
|
| 162 |
+
continue # If the point is already out of the range, Don't draw
|
| 163 |
+
|
| 164 |
+
# Draw square around the target position
|
| 165 |
+
vertical_start = min(original_height, max(0, vertical - dot_radius_resize))
|
| 166 |
+
vertical_end = min(original_height, max(0, vertical + dot_radius_resize)) # Diameter, used to be 10, but want smaller if there are too many points now
|
| 167 |
+
horizontal_start = min(original_width, max(0, horizontal - dot_radius_resize))
|
| 168 |
+
horizontal_end = min(original_width, max(0, horizontal + dot_radius_resize))
|
| 169 |
+
|
| 170 |
+
# Paint
|
| 171 |
+
base_img[vertical_start:vertical_end, horizontal_start:horizontal_end, :] = color_code
|
| 172 |
+
|
| 173 |
+
# Draw the visual of traj if needed
|
| 174 |
+
if first_frame_img is not None:
|
| 175 |
+
img_with_traj[vertical_start:vertical_end, horizontal_start:horizontal_end, :] = color_code
|
| 176 |
+
|
| 177 |
+
# Resize frames Don't use negative and don't resize in [0,1]
|
| 178 |
+
base_img = cv2.resize(base_img, (target_width, target_height), interpolation = cv2.INTER_CUBIC)
|
| 179 |
+
|
| 180 |
+
# Dilate (Default to be True)
|
| 181 |
+
base_img = cv2.filter2D(base_img, -1, blur_kernel).astype(np.uint8)
|
| 182 |
+
|
| 183 |
+
# Append selected_frames and the color together for visualization
|
| 184 |
+
merge_frame = selected_frames[temporal_idx].copy()
|
| 185 |
+
merge_frame = cv2.rectangle(merge_frame, (top_left_x, top_left_y), (bottom_right_x, bottom_right_y), (255, 0, 0), 5) # Draw the Region Box Area
|
| 186 |
+
merge_frame[base_img < 250] = base_img[base_img < 250]
|
| 187 |
+
merge_frames.append(merge_frame)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# Append to the temporal index
|
| 191 |
+
traj_img_lists.append(base_img)
|
| 192 |
+
|
| 193 |
+
# Convert to tensor
|
| 194 |
+
traj_imgs_np = np.array(traj_img_lists)
|
| 195 |
+
traj_tensor = torch.tensor(traj_imgs_np)
|
| 196 |
+
|
| 197 |
+
# Transform
|
| 198 |
+
traj_tensor = traj_tensor.float()
|
| 199 |
+
traj_tensor = torch.stack([train_transforms(traj_frame) for traj_frame in traj_tensor], dim=0)
|
| 200 |
+
traj_tensor = traj_tensor.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
# Write to video (For Debug Purpose)
|
| 204 |
+
# imageio.mimsave("merge_cond" + str(idx) + ".mp4", merge_frames, fps=12)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
# Return
|
| 209 |
+
merge_frames = np.array(merge_frames)
|
| 210 |
+
if first_frame_img is not None:
|
| 211 |
+
return traj_tensor, traj_imgs_np, merge_frames, img_with_traj
|
| 212 |
+
else:
|
| 213 |
+
return traj_tensor, traj_imgs_np, merge_frames # Need to return traj_imgs_np for other purpose
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def __getitem__(self, idx):
|
| 218 |
+
|
| 219 |
+
while True: # Iterate until there is a valid video read
|
| 220 |
+
|
| 221 |
+
# try:
|
| 222 |
+
|
| 223 |
+
# Fetch the information
|
| 224 |
+
info = self.info_lists[idx]
|
| 225 |
+
video_path = os.path.join(self.video_folder_path, info[self.element_idx_dict["video_path"]])
|
| 226 |
+
original_height = int(info[self.element_idx_dict["height"]])
|
| 227 |
+
original_width = int(info[self.element_idx_dict["width"]])
|
| 228 |
+
# num_frames = int(info[self.element_idx_dict["num_frames"]]) # Deprecated, this is about the whole frame duration, not just one
|
| 229 |
+
|
| 230 |
+
valid_duration = json.loads(info[self.element_idx_dict["valid_duration"]])
|
| 231 |
+
All_Frame_Panoptic_Segmentation = json.loads(info[self.element_idx_dict["Panoptic_Segmentation"]])
|
| 232 |
+
text_prompt_all = json.loads(info[self.element_idx_dict["Structured_Text_Prompt"]])
|
| 233 |
+
Track_Traj_all = json.loads(info[self.element_idx_dict["Track_Traj"]])
|
| 234 |
+
Obj_Info_all = json.loads(info[self.element_idx_dict["Obj_Info"]])
|
| 235 |
+
ID_info_all = json.loads(info[self.element_idx_dict["ID_info"]]) # New elements compared to motion data loader
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# Sanity check
|
| 239 |
+
if not os.path.exists(video_path):
|
| 240 |
+
raise Exception("This video path", video_path, "doesn't exists!")
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
########################################## Mangage Resolution and selected Clip Setting ##########################################
|
| 244 |
+
|
| 245 |
+
# Option1: Variable Resolution Gen
|
| 246 |
+
# # Check the resolution size
|
| 247 |
+
# aspect_ratio = min(self.max_aspect_ratio, original_width / original_height)
|
| 248 |
+
# target_height_raw = min(original_height, random.randint(*self.height_range))
|
| 249 |
+
# target_width_raw = min(original_width, int(target_height_raw * aspect_ratio))
|
| 250 |
+
# # Must be the multiplier of 32
|
| 251 |
+
# target_height = (target_height_raw // 32) * 32
|
| 252 |
+
# target_width = (target_width_raw // 32) * 32
|
| 253 |
+
# print("New Height and Width are ", target_height, target_width)
|
| 254 |
+
|
| 255 |
+
# Option2: Fixed Resolution Gen (Assume that the provided is 32x valid)
|
| 256 |
+
target_width = self.target_width
|
| 257 |
+
target_height = self.target_height
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# NOTE: Here, we only choose the first Panoptic choice, to avoid multiple panoptic choices.
|
| 261 |
+
Obj_Info = Obj_Info_all[0] # For panoptic Segmentation
|
| 262 |
+
Track_Traj = Track_Traj_all[0]
|
| 263 |
+
text_prompt = text_prompt_all[0]
|
| 264 |
+
ID_info = ID_info_all[0] # For Frame In ID information, Just one Panoptic Frame
|
| 265 |
+
resolution = str(target_width) + "x" + str(target_height)
|
| 266 |
+
frame_start_idx = Obj_Info[0][1] # NOTE: If there is multiple objects Obj_Info[X][1] should be the same
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
##############################################################################################################################
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
#################################################### Fetch FrameIn ID information ###############################################################
|
| 274 |
+
|
| 275 |
+
# FrameIn drop
|
| 276 |
+
if self.FrameOut_only or random.random() < self.config["drop_FrameIn_prob"]:
|
| 277 |
+
drop_FrameIn = True
|
| 278 |
+
else:
|
| 279 |
+
drop_FrameIn = False
|
| 280 |
+
|
| 281 |
+
# Not all objects is ideal FrameIn, we need to select
|
| 282 |
+
if not self.strict_validation_match:
|
| 283 |
+
effective_ID_idxs = []
|
| 284 |
+
for ID_idx, ID_Info_obj in enumerate(ID_info):
|
| 285 |
+
if ID_Info_obj != []:
|
| 286 |
+
effective_ID_idxs.append(ID_idx)
|
| 287 |
+
main_target_ID_idx = random.choice(effective_ID_idxs) # NOTE: I think we should only has one object to be processed for now
|
| 288 |
+
else:
|
| 289 |
+
main_target_ID_idx = 0 # Always choose the first one
|
| 290 |
+
|
| 291 |
+
# Fetch the FrameIn ID info
|
| 292 |
+
segmentation_info, useful_region_box = ID_info[main_target_ID_idx] # There might be multiple objects ideal, but we just randomly choose one
|
| 293 |
+
if not self.FrameOut_only:
|
| 294 |
+
_, first_frame_reference_path, _ = segmentation_info # bbox_info, first_frame_reference_path, store_img_path_lists
|
| 295 |
+
first_frame_reference_path = os.path.join(self.ID_folder_path, first_frame_reference_path)
|
| 296 |
+
if not os.path.exists(first_frame_reference_path):
|
| 297 |
+
raise Exception("Cannot find ID path", first_frame_reference_path)
|
| 298 |
+
##################################################################################################################################################
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
################ Randomly choose one mask inside the multiple choice available (Resolution is respect to the origional resolution) #################
|
| 303 |
+
|
| 304 |
+
# Choose one region box
|
| 305 |
+
useful_region_box.sort(key=lambda x: x[0]) # Sort based on the BBox size
|
| 306 |
+
if not self.strict_validation_match:
|
| 307 |
+
mask_region = random.choice(useful_region_box[-5:])[1:] # Choose among the largest 5 BBox available
|
| 308 |
+
else:
|
| 309 |
+
mask_region = useful_region_box[-1][1:] # Choose the last one
|
| 310 |
+
|
| 311 |
+
# Fetch
|
| 312 |
+
(top_left_x_raw, top_left_y_raw), (bottom_right_x_raw, bottom_right_y_raw) = mask_region # As Original Resolution
|
| 313 |
+
|
| 314 |
+
# Resize the mask based on the CURRENT Target resolution (现在的384x480的resolution了)
|
| 315 |
+
top_left_x = int(top_left_x_raw * target_width / original_width)
|
| 316 |
+
top_left_y = int(top_left_y_raw * target_height / original_height)
|
| 317 |
+
bottom_right_x = int(bottom_right_x_raw * target_width / original_width)
|
| 318 |
+
bottom_right_y = int(bottom_right_y_raw * target_height / original_height)
|
| 319 |
+
resized_mask_region_box = (top_left_x, top_left_y), (bottom_right_x, bottom_right_y)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
###################################################################################################################################################
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
################################################ Read the video by ffmpeg #########################################################################
|
| 327 |
+
|
| 328 |
+
# Read the video by ffmpeg in the needed decode fps and resolution
|
| 329 |
+
video_stream, err = ffmpeg.input(
|
| 330 |
+
video_path
|
| 331 |
+
).output(
|
| 332 |
+
"pipe:", format = "rawvideo", pix_fmt = "rgb24", s = resolution, vsync = 'passthrough',
|
| 333 |
+
).run(
|
| 334 |
+
capture_stdout = True, capture_stderr = True # If there is bug, command capture_stderr
|
| 335 |
+
) # The resize is already included
|
| 336 |
+
video_np_full = np.frombuffer(video_stream, np.uint8).reshape(-1, target_height, target_width, 3)
|
| 337 |
+
|
| 338 |
+
# Fetch the valid duration
|
| 339 |
+
video_np = video_np_full[valid_duration[0] : valid_duration[1]]
|
| 340 |
+
valid_num_frames = len(video_np) # Update the number of frames
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
# Decide the accelerate factor
|
| 344 |
+
train_frame_num_raw = random.randint(*self.train_frame_num_range)
|
| 345 |
+
if frame_start_idx + 3 * train_frame_num_raw < valid_num_frames and random.random() < self.faster_motion_prob: # Should be (1) have enough frames and (2) in 10% probability
|
| 346 |
+
sample_accelerate_factor = self.sample_accelerate_factor + 1 # Hard Code
|
| 347 |
+
else:
|
| 348 |
+
sample_accelerate_factor = self.sample_accelerate_factor
|
| 349 |
+
|
| 350 |
+
# Check the number of frames needed this time
|
| 351 |
+
frame_end_idx = min(valid_num_frames, frame_start_idx + sample_accelerate_factor * train_frame_num_raw)
|
| 352 |
+
frame_end_idx = frame_start_idx + 4 * math.floor(( (frame_end_idx-frame_start_idx) - 1) / 4) + 1 # Rounded to the closest 4N + 1 size
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
# Select Frames based on the start and end idx; then, Convert to Tensor
|
| 356 |
+
selected_frames = video_np[ frame_start_idx : frame_end_idx : sample_accelerate_factor] # NOTE: start from the first frame
|
| 357 |
+
if len(selected_frames) < self.min_train_frame_num:
|
| 358 |
+
print(len(selected_frames), len(video_np), frame_start_idx, frame_end_idx, sample_accelerate_factor)
|
| 359 |
+
raise Exception(f"selected_frames is less than {self.min_train_frame_num} frames preset! We jump to the next valid one!") # 我这里让Number of Frames Exactly = 49
|
| 360 |
+
video_tensor = torch.tensor(selected_frames) # Convert to tensor
|
| 361 |
+
train_frame_num = len(video_tensor) # Read the actual number of frames from the video (Must be 4N+1)
|
| 362 |
+
# print("Number of frames is", train_frame_num)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
# Data transforms and shape organize
|
| 366 |
+
video_tensor = video_tensor.float()
|
| 367 |
+
video_tensor = torch.stack([train_transforms(frame) for frame in video_tensor], dim=0)
|
| 368 |
+
video_tensor = video_tensor.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
# Crop the tensor with all Non-interest region becomes blank(black-0 value); The region is target resolution in training with VAE step size adjustment
|
| 372 |
+
video_np_masked = np.zeros(selected_frames.shape, dtype = np.uint8)
|
| 373 |
+
video_np_masked[:, top_left_y:bottom_right_y, top_left_x:bottom_right_x, :] = selected_frames[:, top_left_y:bottom_right_y, top_left_x:bottom_right_x, :]
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
# Decide the first frame with the masked one instead of the full one.
|
| 377 |
+
first_frame_np = video_np_masked[0] # Needs to return for Validation
|
| 378 |
+
# cv2.imwrite("first_frame"+str(idx)+".png", cv2.cvtColor(first_frame_np, cv2.COLOR_BGR2RGB)) # Comment Out Later
|
| 379 |
+
|
| 380 |
+
# Convert to Tensor and then Transforms
|
| 381 |
+
first_frame_tensor = torch.tensor(first_frame_np)
|
| 382 |
+
first_frame_tensor = train_transforms(first_frame_tensor).permute(2, 0, 1).contiguous()
|
| 383 |
+
|
| 384 |
+
#########################################################################################################################################
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
############################################# Define the text prompt #######################################################
|
| 389 |
+
|
| 390 |
+
# NOTE: text prompt 上面已经extract好了,这里就是看到底要不要设置为empty的case
|
| 391 |
+
if self.empty_text_prompt or random.random() < self.config["text_mask_ratio"]:
|
| 392 |
+
text_prompt = ""
|
| 393 |
+
# print("Text Prompt for Video", idx, " is ", text_prompt) # Comment Out Later
|
| 394 |
+
|
| 395 |
+
#############################################################################################################################
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
########################### Prepare the Tracking points for each object (each object has different color) #################################
|
| 400 |
+
|
| 401 |
+
# Iterate all the Segmentation Info
|
| 402 |
+
full_pred_tracks = [[] for _ in range(train_frame_num)] # The dim should be: (temporal, object, points, xy) The fps should be fixed to 12 fps, which is the same as training decode fps
|
| 403 |
+
for track_obj_idx in range(len(Obj_Info)):
|
| 404 |
+
|
| 405 |
+
# Read the basic info
|
| 406 |
+
text_name, frame_idx_raw = Obj_Info[track_obj_idx] # This is expected to be all the same in the video
|
| 407 |
+
|
| 408 |
+
# Sanity Check: make sure that the number of frames is consistent
|
| 409 |
+
if track_obj_idx > 0:
|
| 410 |
+
if frame_idx_raw != previous_frame_idx_raw:
|
| 411 |
+
raise Exception("The panoptic_frame_idx cannot pass the sanity check")
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
# Prepare the tracjectory
|
| 415 |
+
pred_tracks_full = Track_Traj[track_obj_idx]
|
| 416 |
+
pred_tracks = pred_tracks_full[ frame_start_idx : frame_end_idx : sample_accelerate_factor]
|
| 417 |
+
if len(pred_tracks) != train_frame_num:
|
| 418 |
+
raise Exception("The length of tracking images does not match the video GT.")
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
# Here is FrameINO special Setting on Kept Point Setting: For Non-main obj idx, we must ensure all points inside the region box; If it is main obj, the ID must be outside the region box
|
| 422 |
+
if track_obj_idx != main_target_ID_idx or self.FrameOut_only: # Non-main obj (Usually, for Frame Out cases)
|
| 423 |
+
|
| 424 |
+
# Randomly select the points based on the prob given, here, the number of points is different for each objeects
|
| 425 |
+
kept_point_status = random.choices([True, False], weights = [self.point_keep_ratio_regular, 1 - self.point_keep_ratio_regular], k = len(pred_tracks[0]))
|
| 426 |
+
|
| 427 |
+
# Check if point of the object is within the first frame; No need to check for following frames (allowed to have FrameOut effect)
|
| 428 |
+
first_frame_points = pred_tracks[0]
|
| 429 |
+
for point_idx in range(len(first_frame_points)):
|
| 430 |
+
(horizontal, vertical) = first_frame_points[point_idx]
|
| 431 |
+
if horizontal < top_left_x_raw or horizontal >= bottom_right_x_raw or vertical < top_left_y_raw or vertical >= bottom_right_y_raw: # Whether Outside the BBox region
|
| 432 |
+
kept_point_status[point_idx] = False
|
| 433 |
+
|
| 434 |
+
else: # For main object
|
| 435 |
+
|
| 436 |
+
# Randomly select the points based on the prob given, here, the number of points is different for each objeects
|
| 437 |
+
if drop_FrameIn:
|
| 438 |
+
# No motion provided on ID for Drop FrameIn cases
|
| 439 |
+
kept_point_status = random.choices([False], k = len(pred_tracks[0]))
|
| 440 |
+
|
| 441 |
+
else: # Regular FrameIn case
|
| 442 |
+
kept_point_status = random.choices([True, False], weights = [self.point_keep_ratio_ID, 1 - self.point_keep_ratio_ID], k = len(pred_tracks[0]))
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
# Sanity Check
|
| 446 |
+
if len(kept_point_status) != len(pred_tracks[-1]):
|
| 447 |
+
raise Exception("The number of points filterred does not match with the dataset")
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
# Iterate and add all temporally
|
| 451 |
+
for temporal_idx, pred_track in enumerate(pred_tracks): # The length = number of frames
|
| 452 |
+
|
| 453 |
+
# Iterate all point one by one
|
| 454 |
+
left_points = []
|
| 455 |
+
for point_idx in range(len(pred_track)):
|
| 456 |
+
# Select kept points
|
| 457 |
+
if kept_point_status[point_idx]:
|
| 458 |
+
left_points.append(pred_track[point_idx])
|
| 459 |
+
|
| 460 |
+
# Append the left points to the list
|
| 461 |
+
full_pred_tracks[temporal_idx].append(left_points) # pred_tracks will be 49 frames, and each one represent all tracking points for single objects; only one object here
|
| 462 |
+
|
| 463 |
+
# Other update
|
| 464 |
+
previous_frame_idx_raw = frame_idx_raw
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
# Fetch One Point
|
| 468 |
+
if self.one_point_one_obj:
|
| 469 |
+
one_track_point = []
|
| 470 |
+
for full_pred_track_per_frame in full_pred_tracks:
|
| 471 |
+
one_track_point.append( [[full_pred_track_per_frame[0][0]]])
|
| 472 |
+
|
| 473 |
+
#######################################################################################################################################
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
############################### Process the Video Tensor (based on info fetched from traj) ############################################
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
if drop_FrameIn:
|
| 481 |
+
|
| 482 |
+
ID_img = np.uint8(np.zeros((target_height, target_width, 3))) # Whole Black (0-value) pixel placeholder
|
| 483 |
+
|
| 484 |
+
else:
|
| 485 |
+
|
| 486 |
+
# Fetch the reference and resize
|
| 487 |
+
ID_img = np.asarray(Image.open(first_frame_reference_path))
|
| 488 |
+
|
| 489 |
+
# Resize to the same size as the video
|
| 490 |
+
ref_h, ref_w = ID_img.shape[:2]
|
| 491 |
+
scale_h = target_height / max(ref_h, ref_w)
|
| 492 |
+
scale_w = target_width / max(ref_h, ref_w)
|
| 493 |
+
new_h, new_w = int(ref_h * scale_h), int(ref_w * scale_w)
|
| 494 |
+
ID_img = cv2.resize(ID_img, (new_w, new_h), interpolation = cv2.INTER_AREA)
|
| 495 |
+
|
| 496 |
+
# Calculate padding amounts on all direction
|
| 497 |
+
pad_height1 = (target_height - ID_img.shape[0]) // 2
|
| 498 |
+
pad_height2 = target_height - ID_img.shape[0] - pad_height1
|
| 499 |
+
pad_width1 = (target_width - ID_img.shape[1]) // 2
|
| 500 |
+
pad_width2 = target_width - ID_img.shape[1] - pad_width1
|
| 501 |
+
|
| 502 |
+
# Apply padding to same resolution as the training farmes
|
| 503 |
+
ID_img = np.pad(
|
| 504 |
+
ID_img,
|
| 505 |
+
((pad_height1, pad_height2), (pad_width1, pad_width2), (0, 0)),
|
| 506 |
+
mode = 'constant',
|
| 507 |
+
constant_values = 0
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
# Visualize; Comment Out Later
|
| 511 |
+
# cv2.imwrite("ID_img_padded"+str(idx)+".png", cv2.cvtColor(ID_img, cv2.COLOR_BGR2RGB))
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
# Convert to tensor (Same as others)
|
| 515 |
+
ID_tensor = torch.tensor(ID_img)
|
| 516 |
+
ID_tensor = train_transforms(ID_tensor).permute(2, 0, 1).contiguous()
|
| 517 |
+
|
| 518 |
+
#######################################################################################################################################
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
############################################## Draw the Traj Points and Transform to Tensor #############################################
|
| 523 |
+
|
| 524 |
+
# Draw the dilated points
|
| 525 |
+
if self.one_point_one_obj:
|
| 526 |
+
target_pred_tracks = one_track_point # For this case, we only has one point per one object
|
| 527 |
+
else:
|
| 528 |
+
target_pred_tracks = full_pred_tracks
|
| 529 |
+
|
| 530 |
+
traj_tensor, traj_imgs_np, merge_frames = self.prepare_traj_tensor(target_pred_tracks, original_height, original_width, selected_frames,
|
| 531 |
+
self.dot_radius, target_width, target_height, resized_mask_region_box, idx)
|
| 532 |
+
|
| 533 |
+
# Sanity Check to make sure that the traj tensor and ground truth has the same number of frames
|
| 534 |
+
if len(traj_tensor) != len(video_tensor): # If this two cannot match, the torch.cat on latents will fail
|
| 535 |
+
raise Exception("Traj length and Video length does not matched!")
|
| 536 |
+
|
| 537 |
+
#########################################################################################################################################
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
# Write some processed meta data
|
| 541 |
+
processed_meta_data = {
|
| 542 |
+
"full_pred_tracks": full_pred_tracks,
|
| 543 |
+
"original_width": original_width,
|
| 544 |
+
"original_height": original_height,
|
| 545 |
+
"mask_region": mask_region,
|
| 546 |
+
"resized_mask_region_box": resized_mask_region_box,
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
# except Exception as e: # Note: You can uncomment this part to jump failure cases in mass training.
|
| 550 |
+
# print("The exception is ", e)
|
| 551 |
+
# old_idx = idx
|
| 552 |
+
# idx = (idx + 1) % len(self.info_lists)
|
| 553 |
+
# print("We cannot process the video", old_idx, " and we choose a new idx of ", idx)
|
| 554 |
+
# continue # For any error occurs, we run it again with new idx proposed (a random int less than current value)
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
# If everything is ok, we should break at the end
|
| 558 |
+
break
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
# Return the information
|
| 562 |
+
return {
|
| 563 |
+
"video_tensor": video_tensor,
|
| 564 |
+
"traj_tensor": traj_tensor,
|
| 565 |
+
"first_frame_tensor": first_frame_tensor,
|
| 566 |
+
"ID_tensor": ID_tensor,
|
| 567 |
+
"text_prompt": text_prompt,
|
| 568 |
+
|
| 569 |
+
# The rest are auxiliary data for the validation/testing purposes
|
| 570 |
+
"video_gt_np": selected_frames,
|
| 571 |
+
"first_frame_np": first_frame_np,
|
| 572 |
+
"ID_np": ID_img,
|
| 573 |
+
"processed_meta_data": processed_meta_data,
|
| 574 |
+
"traj_imgs_np": traj_imgs_np,
|
| 575 |
+
"merge_frames" : merge_frames,
|
| 576 |
+
"gt_video_path": video_path,
|
| 577 |
+
}
|
| 578 |
+
|
data_loader/video_dataset_motion_FrameINO_old.py
ADDED
|
@@ -0,0 +1,538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, sys, shutil
|
| 2 |
+
from typing import List, Optional, Tuple, Union
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import csv
|
| 5 |
+
import random
|
| 6 |
+
import numpy as np
|
| 7 |
+
import ffmpeg
|
| 8 |
+
import json
|
| 9 |
+
import imageio
|
| 10 |
+
import collections
|
| 11 |
+
import cv2
|
| 12 |
+
import pdb
|
| 13 |
+
import math
|
| 14 |
+
import PIL.Image as Image
|
| 15 |
+
csv.field_size_limit(13107200) # Default setting is 131072, 100x expand should be enough
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
from torch.utils.data import Dataset
|
| 19 |
+
from torchvision import transforms
|
| 20 |
+
|
| 21 |
+
# Import files from the local folder
|
| 22 |
+
root_path = os.path.abspath('.')
|
| 23 |
+
sys.path.append(root_path)
|
| 24 |
+
from utils.optical_flow_utils import flow_to_image, filter_uv, bivariate_Gaussian
|
| 25 |
+
|
| 26 |
+
# Init paramter and global shared setting
|
| 27 |
+
|
| 28 |
+
# Blurring Kernel
|
| 29 |
+
blur_kernel = bivariate_Gaussian(45, 3, 3, 0, grid = None, isotropic = True)
|
| 30 |
+
|
| 31 |
+
# Color
|
| 32 |
+
all_color_codes = [(255, 0, 0), (255, 255, 0), (0, 255, 0), (0, 255, 255),
|
| 33 |
+
(255, 0, 255), (0, 0, 255), (128, 128, 128), (64, 224, 208),
|
| 34 |
+
(233, 150, 122)]
|
| 35 |
+
for _ in range(100): # Should not be over 100 colors
|
| 36 |
+
all_color_codes.append((random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)))
|
| 37 |
+
|
| 38 |
+
# Data Transforms
|
| 39 |
+
train_transforms = transforms.Compose(
|
| 40 |
+
[
|
| 41 |
+
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
|
| 42 |
+
]
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class VideoDataset_Motion_FrameINO(Dataset):
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
config,
|
| 50 |
+
csv_folder_path,
|
| 51 |
+
FrameOut_only = False,
|
| 52 |
+
one_point_one_obj = False,
|
| 53 |
+
strict_validation_match = False,
|
| 54 |
+
) -> None:
|
| 55 |
+
super().__init__()
|
| 56 |
+
|
| 57 |
+
# Fetch the Fundamental Setting
|
| 58 |
+
self.dataset_folder_path = config["dataset_folder_path"]
|
| 59 |
+
if not FrameOut_only: # Frame In mode
|
| 60 |
+
self.ID_folder_path = config["ID_folder_path"]
|
| 61 |
+
self.target_height = config["height"]
|
| 62 |
+
self.target_width = config["width"]
|
| 63 |
+
# self.ref_cond_size = config["ref_cond_size"]
|
| 64 |
+
self.preset_decode_fps = config["preset_decode_fps"] # Set to be 16
|
| 65 |
+
self.train_frame_num = config["train_frame_num"]
|
| 66 |
+
self.empty_text_prompt = config["empty_text_prompt"]
|
| 67 |
+
self.start_skip = config["start_skip"]
|
| 68 |
+
self.end_skip = config["end_skip"]
|
| 69 |
+
self.dot_radius = int(config["dot_radius"]) # Set to be 6
|
| 70 |
+
self.point_keep_ratio_ID = config["point_keep_ratio_ID"]
|
| 71 |
+
self.point_keep_ratio_regular = config["point_keep_ratio_regular"]
|
| 72 |
+
self.faster_motion_prob = config["faster_motion_prob"]
|
| 73 |
+
self.FrameOut_only = FrameOut_only
|
| 74 |
+
self.one_point_one_obj = one_point_one_obj # Currently, this only open when FrameOut_only = True
|
| 75 |
+
self.strict_validation_match = strict_validation_match
|
| 76 |
+
self.config = config
|
| 77 |
+
|
| 78 |
+
# Sanity Check
|
| 79 |
+
assert(self.point_keep_ratio_ID <= 1.0)
|
| 80 |
+
assert(self.point_keep_ratio_regular <= 1.0)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# Read the CSV files
|
| 84 |
+
info_lists = []
|
| 85 |
+
for csv_file_name in os.listdir(csv_folder_path): # Read all csv files
|
| 86 |
+
csv_file_path = os.path.join(csv_folder_path, csv_file_name)
|
| 87 |
+
with open(csv_file_path) as file_obj:
|
| 88 |
+
reader_obj = csv.reader(file_obj)
|
| 89 |
+
|
| 90 |
+
# Iterate over each row in the csv
|
| 91 |
+
for idx, row in enumerate(reader_obj):
|
| 92 |
+
if idx == 0:
|
| 93 |
+
elements = dict()
|
| 94 |
+
for element_idx, key in enumerate(row):
|
| 95 |
+
elements[key] = element_idx
|
| 96 |
+
continue
|
| 97 |
+
|
| 98 |
+
# Read the important information
|
| 99 |
+
info_lists.append(row)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# Organize
|
| 103 |
+
self.info_lists = info_lists
|
| 104 |
+
self.element_idx_dict = elements
|
| 105 |
+
|
| 106 |
+
# Log
|
| 107 |
+
print("The number of videos for ", csv_folder_path, " is ", len(self.info_lists))
|
| 108 |
+
# print("The memory cost is ", sys.getsizeof(self.info_lists))
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def __len__(self):
|
| 112 |
+
return len(self.info_lists)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
@staticmethod
|
| 116 |
+
def prepare_traj_tensor(full_pred_tracks, original_height, original_width, selected_frames,
|
| 117 |
+
dot_radius, target_width, target_height, region_box, idx = 0, first_frame_img = None):
|
| 118 |
+
|
| 119 |
+
# Prepare the color and other stuff
|
| 120 |
+
target_color_codes = all_color_codes[:len(full_pred_tracks[0])] # This means how many objects in total we have
|
| 121 |
+
(top_left_x, top_left_y), (bottom_right_x, bottom_right_y) = region_box
|
| 122 |
+
|
| 123 |
+
# Prepare the traj image
|
| 124 |
+
traj_img_lists = []
|
| 125 |
+
|
| 126 |
+
# Set a new dot radius based on the resolution fluctuating
|
| 127 |
+
dot_radius_resize = int( dot_radius * original_height / 384 ) # This is set with respect to default 384 height, will be adjust based on the height change
|
| 128 |
+
|
| 129 |
+
# Prepare base draw image if there is
|
| 130 |
+
if first_frame_img is not None:
|
| 131 |
+
img_with_traj = first_frame_img.copy()
|
| 132 |
+
|
| 133 |
+
# Iterate all object instance
|
| 134 |
+
merge_frames = []
|
| 135 |
+
for temporal_idx, obj_points in enumerate(full_pred_tracks): # Iterate all downsampled frames, should be 13
|
| 136 |
+
|
| 137 |
+
# Init the base img for the traj figures
|
| 138 |
+
base_img = np.zeros((original_height, original_width, 3)).astype(np.float32) # Use the original image size
|
| 139 |
+
base_img.fill(255) # Whole white frames
|
| 140 |
+
|
| 141 |
+
# Iterate for the per object
|
| 142 |
+
for obj_idx, points in enumerate(obj_points):
|
| 143 |
+
|
| 144 |
+
# Basic setting
|
| 145 |
+
color_code = target_color_codes[obj_idx] # Color across frames should be consistent
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# Process all points in this current object
|
| 149 |
+
for (horizontal, vertical) in points:
|
| 150 |
+
if horizontal < 0 or horizontal >= original_width or vertical < 0 or vertical >= original_height:
|
| 151 |
+
continue # If the point is already out of the range, Don't draw
|
| 152 |
+
|
| 153 |
+
# Draw square around the target position
|
| 154 |
+
vertical_start = min(original_height, max(0, vertical - dot_radius_resize))
|
| 155 |
+
vertical_end = min(original_height, max(0, vertical + dot_radius_resize)) # Diameter, used to be 10, but want smaller if there are too many points now
|
| 156 |
+
horizontal_start = min(original_width, max(0, horizontal - dot_radius_resize))
|
| 157 |
+
horizontal_end = min(original_width, max(0, horizontal + dot_radius_resize))
|
| 158 |
+
|
| 159 |
+
# Paint
|
| 160 |
+
base_img[vertical_start:vertical_end, horizontal_start:horizontal_end, :] = color_code
|
| 161 |
+
|
| 162 |
+
# Draw the visual of traj if needed
|
| 163 |
+
if first_frame_img is not None:
|
| 164 |
+
img_with_traj[vertical_start:vertical_end, horizontal_start:horizontal_end, :] = color_code
|
| 165 |
+
|
| 166 |
+
# Resize frames Don't use negative and don't resize in [0,1]
|
| 167 |
+
base_img = cv2.resize(base_img, (target_width, target_height), interpolation = cv2.INTER_CUBIC)
|
| 168 |
+
|
| 169 |
+
# Dilate (Default to be True)
|
| 170 |
+
base_img = cv2.filter2D(base_img, -1, blur_kernel).astype(np.uint8)
|
| 171 |
+
|
| 172 |
+
# Append selected_frames and the color together for visualization
|
| 173 |
+
merge_frame = selected_frames[temporal_idx].copy()
|
| 174 |
+
merge_frame = cv2.rectangle(merge_frame, (top_left_x, top_left_y), (bottom_right_x, bottom_right_y), (255, 0, 0), 5) # Draw the Region Box Area
|
| 175 |
+
merge_frame[base_img < 250] = base_img[base_img < 250]
|
| 176 |
+
merge_frames.append(merge_frame)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
# Append to the temporal index
|
| 180 |
+
traj_img_lists.append(base_img)
|
| 181 |
+
|
| 182 |
+
# Convert to tensor
|
| 183 |
+
traj_imgs_np = np.array(traj_img_lists)
|
| 184 |
+
traj_tensor = torch.tensor(traj_imgs_np)
|
| 185 |
+
|
| 186 |
+
# Transform
|
| 187 |
+
traj_tensor = traj_tensor.float()
|
| 188 |
+
traj_tensor = torch.stack([train_transforms(traj_frame) for traj_frame in traj_tensor], dim=0)
|
| 189 |
+
traj_tensor = traj_tensor.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# Write to video (For Debug Purpose)
|
| 193 |
+
# imageio.mimsave("merge_cond" + str(idx) + ".mp4", merge_frames, fps=12)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# Return
|
| 197 |
+
merge_frames = np.array(merge_frames)
|
| 198 |
+
if first_frame_img is not None:
|
| 199 |
+
return traj_tensor, traj_imgs_np, merge_frames, img_with_traj
|
| 200 |
+
else:
|
| 201 |
+
return traj_tensor, traj_imgs_np, merge_frames # Need to return traj_imgs_np for other purpose
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def __getitem__(self, idx):
|
| 207 |
+
|
| 208 |
+
while True: # Iterate until there is a valid video read
|
| 209 |
+
|
| 210 |
+
try:
|
| 211 |
+
|
| 212 |
+
# Fetch the information
|
| 213 |
+
info = self.info_lists[idx]
|
| 214 |
+
video_path = os.path.join(self.dataset_folder_path, info[self.element_idx_dict["video_path"]])
|
| 215 |
+
original_height = int(info[self.element_idx_dict["height"]])
|
| 216 |
+
original_width = int(info[self.element_idx_dict["width"]])
|
| 217 |
+
num_frames = int(info[self.element_idx_dict["num_frames"]])
|
| 218 |
+
fps = float(info[self.element_idx_dict["fps"]])
|
| 219 |
+
|
| 220 |
+
# Fetch all panoptic frames
|
| 221 |
+
FrameIN_info_all = json.loads(info[self.element_idx_dict["FrameIN_info"]])
|
| 222 |
+
Track_Traj_all = json.loads(info[self.element_idx_dict["Track_Traj"]])
|
| 223 |
+
text_prompt_all = json.loads(info[self.element_idx_dict["Improved_Text_Prompt"]])
|
| 224 |
+
ID_info_all = json.loads(info[self.element_idx_dict["ID_info"]])
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# Randomly Choose one available
|
| 228 |
+
panoptic_idx = random.choice(range(len(FrameIN_info_all)))
|
| 229 |
+
FrameIN_info = FrameIN_info_all[panoptic_idx]
|
| 230 |
+
Track_Traj = Track_Traj_all[panoptic_idx]
|
| 231 |
+
text_prompt = text_prompt_all[panoptic_idx]
|
| 232 |
+
ID_info_panoptic = ID_info_all[panoptic_idx]
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# Organize
|
| 236 |
+
resolution = str(self.target_width) + "x" + str(self.target_height)
|
| 237 |
+
fps_scale = self.preset_decode_fps / fps
|
| 238 |
+
downsample_num_frames = int(num_frames * fps_scale)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# FrameIn drop
|
| 242 |
+
if self.FrameOut_only or random.random() < self.config["drop_FrameIn_prob"]:
|
| 243 |
+
drop_FrameIn = True
|
| 244 |
+
else:
|
| 245 |
+
drop_FrameIn = False
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# Sanity check
|
| 250 |
+
if not os.path.exists(video_path):
|
| 251 |
+
raise Exception("This video path ", video_path, " doesn't exists!")
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
# Not all objects is ideal FrameIn, we need to select
|
| 255 |
+
if not self.strict_validation_match:
|
| 256 |
+
effective_obj_idxs = []
|
| 257 |
+
for obj_idx, obj_info in enumerate(ID_info_panoptic):
|
| 258 |
+
if obj_info != []:
|
| 259 |
+
effective_obj_idxs.append(obj_idx)
|
| 260 |
+
main_target_obj_idx = random.choice(effective_obj_idxs) # NOTE: I think we should only has one object to be processed for now
|
| 261 |
+
else:
|
| 262 |
+
main_target_obj_idx = 0 # Always choose the first one
|
| 263 |
+
|
| 264 |
+
#################################################### Fetch FrameIn ID information ###############################################################
|
| 265 |
+
|
| 266 |
+
# Fetch the FrameIn ID info
|
| 267 |
+
segmentation_info, useful_region_box = ID_info_panoptic[main_target_obj_idx] # There might be multiple objects ideal, but we just randomly choose one
|
| 268 |
+
if not self.FrameOut_only:
|
| 269 |
+
_, first_frame_reference_path, _ = segmentation_info # bbox_info, first_frame_reference_path, store_img_path_lists
|
| 270 |
+
first_frame_reference_path = os.path.join(self.ID_folder_path, first_frame_reference_path)
|
| 271 |
+
|
| 272 |
+
##################################################################################################################################################
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
############ Randomly choose one mask inside the multiple choice available (Resolution is respect to the origional resolution) ############
|
| 277 |
+
useful_region_box.sort(key=lambda x: x[0])
|
| 278 |
+
|
| 279 |
+
# Choose one region box
|
| 280 |
+
if not self.strict_validation_match:
|
| 281 |
+
mask_region = random.choice(useful_region_box[-5:])[1:] # Choose in the largest 5 available
|
| 282 |
+
else:
|
| 283 |
+
mask_region = useful_region_box[-1][1:] # Choose the last one
|
| 284 |
+
|
| 285 |
+
# Fetch
|
| 286 |
+
(top_left_x_raw, top_left_y_raw), (bottom_right_x_raw, bottom_right_y_raw) = mask_region # As Original Resolution
|
| 287 |
+
|
| 288 |
+
# Resize the mask based on the CURRENT Target resolution (现在的384x480的resolution了)
|
| 289 |
+
top_left_x = int(top_left_x_raw * self.target_width / original_width)
|
| 290 |
+
top_left_y = int(top_left_y_raw * self.target_height / original_height)
|
| 291 |
+
bottom_right_x = int(bottom_right_x_raw * self.target_width / original_width)
|
| 292 |
+
bottom_right_y = int(bottom_right_y_raw * self.target_height / original_height)
|
| 293 |
+
resized_mask_region_box = (top_left_x, top_left_y), (bottom_right_x, bottom_right_y)
|
| 294 |
+
|
| 295 |
+
###########################################################################################################################################
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
############################################## Read the video by ffmpeg #############################################################
|
| 300 |
+
|
| 301 |
+
# Read the video by ffmpeg in the needed decode fps and resolution
|
| 302 |
+
video_stream, err = ffmpeg.input(
|
| 303 |
+
video_path
|
| 304 |
+
).filter(
|
| 305 |
+
'fps', fps = self.preset_decode_fps, round = 'up'
|
| 306 |
+
).output(
|
| 307 |
+
"pipe:", format = "rawvideo", pix_fmt = "rgb24", s = resolution
|
| 308 |
+
).run(
|
| 309 |
+
capture_stdout = True, capture_stderr = True
|
| 310 |
+
) # The resize is already included
|
| 311 |
+
video_np_raw = np.frombuffer(video_stream, np.uint8).reshape(-1, self.target_height, self.target_width, 3)
|
| 312 |
+
|
| 313 |
+
# Sanity Check
|
| 314 |
+
if len(video_np_raw) - self.start_skip - self.end_skip < self.train_frame_num:
|
| 315 |
+
raise Exception("The number of frames from the video is not enough")
|
| 316 |
+
|
| 317 |
+
# Crop the tensor with all Non-interest region becomes blank(black-0 value); The region is target resolution in training with VAE step size adjustment
|
| 318 |
+
video_np_masked = np.zeros(video_np_raw.shape, dtype = np.uint8)
|
| 319 |
+
video_np_masked[:, top_left_y:bottom_right_y, top_left_x:bottom_right_x, :] = video_np_raw[:, top_left_y:bottom_right_y, top_left_x:bottom_right_x, :]
|
| 320 |
+
|
| 321 |
+
#########################################################################################################################################
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
######################################### Define the text prompt #######################################################
|
| 326 |
+
|
| 327 |
+
# Whether empty text prompt; Text Prompt already exists above
|
| 328 |
+
if self.empty_text_prompt or random.random() < self.config["text_mask_ratio"]:
|
| 329 |
+
text_prompt = ""
|
| 330 |
+
|
| 331 |
+
########################################################################################################################
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
###################### Prepare the Tracking points for each object (each object has different color) #################################
|
| 336 |
+
|
| 337 |
+
# Make sure that the frame from the FrameIN_info has enough number of frames
|
| 338 |
+
_, original_start_frame_idx, fps_scale = FrameIN_info[main_target_obj_idx] # This is expected to be all the same in the video
|
| 339 |
+
downsample_start_frame_idx = max(0, int(original_start_frame_idx * fps_scale))
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
# Check the max number of frames available (NOTE: Recommended to use Full Text Prompt Version)
|
| 343 |
+
max_step_num = (downsample_num_frames - downsample_start_frame_idx) // self.train_frame_num
|
| 344 |
+
if max_step_num == 0:
|
| 345 |
+
print("This video is ", video_path)
|
| 346 |
+
raise Exception("The video is too short!")
|
| 347 |
+
elif max_step_num >= 2 and random.random() < self.faster_motion_prob:
|
| 348 |
+
iter_gap = 2 # Maximum Setting now is 2x; else, the VAE might not works well
|
| 349 |
+
else:
|
| 350 |
+
iter_gap = 1
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
# Iterate all the Segmentation Info
|
| 354 |
+
full_pred_tracks = [[] for _ in range(self.train_frame_num)] # The dim should be: (temporal, object, points, xy) The fps should be fixed to 12 fps, which is the same as training decode fps
|
| 355 |
+
|
| 356 |
+
# Iterate all objects but not the main objects
|
| 357 |
+
for obj_idx in range(len(ID_info_panoptic)):
|
| 358 |
+
|
| 359 |
+
# Prepare the tracjectory
|
| 360 |
+
pred_tracks = Track_Traj[obj_idx]
|
| 361 |
+
pred_tracks = pred_tracks[downsample_start_frame_idx : downsample_start_frame_idx + iter_gap * self.train_frame_num : iter_gap]
|
| 362 |
+
if len(pred_tracks) != self.train_frame_num:
|
| 363 |
+
raise Exception("The len of pre_track does not match")
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
# For Non-main obj idx, we must ensure all points inside the region box; If it is main obj, the ID must be outside the region box
|
| 367 |
+
if obj_idx != main_target_obj_idx or self.FrameOut_only:
|
| 368 |
+
|
| 369 |
+
# Randomly select the points based on the prob given, here, the number of points is different for each objeects
|
| 370 |
+
kept_point_status = random.choices([True, False], weights = [self.point_keep_ratio_regular, 1 - self.point_keep_ratio_regular], k = len(pred_tracks[0]))
|
| 371 |
+
|
| 372 |
+
# Check witht the first frame, No need to check for following frames (allowed to have FrameOut effect)
|
| 373 |
+
first_frame_points = pred_tracks[0]
|
| 374 |
+
for point_idx in range(len(first_frame_points)):
|
| 375 |
+
(horizontal, vertical) = first_frame_points[point_idx]
|
| 376 |
+
if horizontal < top_left_x_raw or horizontal >= bottom_right_x_raw or vertical < top_left_y_raw or vertical >= bottom_right_y_raw:
|
| 377 |
+
kept_point_status[point_idx] = False
|
| 378 |
+
|
| 379 |
+
else: # For main object
|
| 380 |
+
|
| 381 |
+
# Randomly select the points based on the prob given, here, the number of points is different for each objeects
|
| 382 |
+
if drop_FrameIn:
|
| 383 |
+
# No motion provided on ID for Drop FrameIn cases
|
| 384 |
+
kept_point_status = random.choices([False], k = len(pred_tracks[0]))
|
| 385 |
+
|
| 386 |
+
else: # Regular FrameIn case
|
| 387 |
+
kept_point_status = random.choices([True, False], weights = [self.point_keep_ratio_ID, 1 - self.point_keep_ratio_ID], k = len(pred_tracks[0]))
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
# Sanity Check
|
| 391 |
+
if len(kept_point_status) != len(pred_tracks[-1]):
|
| 392 |
+
raise Exception("The number of points filterred is not match with the dataset")
|
| 393 |
+
|
| 394 |
+
# Iterate and add all temporally
|
| 395 |
+
for temporal_idx, pred_track in enumerate(pred_tracks):
|
| 396 |
+
|
| 397 |
+
# Iterate all point one by one
|
| 398 |
+
left_points = []
|
| 399 |
+
for point_idx in range(len(pred_track)):
|
| 400 |
+
# Select kept points
|
| 401 |
+
if kept_point_status[point_idx]:
|
| 402 |
+
left_points.append(pred_track[point_idx])
|
| 403 |
+
|
| 404 |
+
# Append the left points to the list
|
| 405 |
+
full_pred_tracks[temporal_idx].append(left_points) # pred_tracks will be 49 frames, and each one represent all tracking points for single objects; only one object here
|
| 406 |
+
|
| 407 |
+
# Fetch One Point
|
| 408 |
+
if self.one_point_one_obj:
|
| 409 |
+
one_track_point = []
|
| 410 |
+
for full_pred_track_per_frame in full_pred_tracks:
|
| 411 |
+
one_track_point.append( [[full_pred_track_per_frame[0][0]]])
|
| 412 |
+
|
| 413 |
+
#######################################################################################################################################
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
############################### Process the Video Tensor (based on info fetched from traj) ############################################
|
| 418 |
+
|
| 419 |
+
# Select Frames based on the panoptic range (No Mask here)
|
| 420 |
+
selected_frames = video_np_raw[downsample_start_frame_idx : downsample_start_frame_idx + iter_gap * self.train_frame_num : iter_gap]
|
| 421 |
+
|
| 422 |
+
# Prepare the Video Tensor; NOTE: in this branch, video tensor is full image without mask
|
| 423 |
+
video_tensor = torch.tensor(selected_frames) # Convert to tensor
|
| 424 |
+
if len(video_tensor) != self.train_frame_num:
|
| 425 |
+
raise Exception("The len of train frames does not match")
|
| 426 |
+
|
| 427 |
+
# Training transforms for the Video and condition
|
| 428 |
+
video_tensor = video_tensor.float()
|
| 429 |
+
video_tensor = torch.stack([train_transforms(frame) for frame in video_tensor], dim=0)
|
| 430 |
+
video_tensor = video_tensor.permute(0, 3, 1, 2).contiguous() # [F, C, H, W]
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
if drop_FrameIn:
|
| 435 |
+
main_reference_img = np.uint8(np.zeros((self.target_height, self.target_width, 3))) # Whole Black (0-value) pixel placeholder
|
| 436 |
+
|
| 437 |
+
else:
|
| 438 |
+
|
| 439 |
+
# Fetch the reference and resize
|
| 440 |
+
main_reference_img = np.asarray(Image.open(first_frame_reference_path))
|
| 441 |
+
|
| 442 |
+
# Resize to the same size as the video
|
| 443 |
+
ref_h, ref_w = main_reference_img.shape[:2]
|
| 444 |
+
scale_h = self.target_height / max(ref_h, ref_w)
|
| 445 |
+
scale_w = self.target_width / max(ref_h, ref_w)
|
| 446 |
+
new_h, new_w = int(ref_h * scale_h), int(ref_w * scale_w)
|
| 447 |
+
main_reference_img = cv2.resize(main_reference_img, (new_w, new_h), interpolation = cv2.INTER_AREA)
|
| 448 |
+
|
| 449 |
+
# Calculate padding amounts on all direction
|
| 450 |
+
pad_height1 = (self.target_height - main_reference_img.shape[0]) // 2
|
| 451 |
+
pad_height2 = self.target_height - main_reference_img.shape[0] - pad_height1
|
| 452 |
+
pad_width1 = (self.target_width - main_reference_img.shape[1]) // 2
|
| 453 |
+
pad_width2 = self.target_width - main_reference_img.shape[1] - pad_width1
|
| 454 |
+
|
| 455 |
+
# Apply padding to same resolution as the training farmes
|
| 456 |
+
main_reference_img = np.pad(
|
| 457 |
+
main_reference_img,
|
| 458 |
+
((pad_height1, pad_height2), (pad_width1, pad_width2), (0, 0)),
|
| 459 |
+
mode = 'constant',
|
| 460 |
+
constant_values = 0
|
| 461 |
+
)
|
| 462 |
+
# cv2.imwrite("main_reference_img_padded"+str(idx)+".png", cv2.cvtColor(main_reference_img, cv2.COLOR_BGR2RGB))
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
# Convert to tensor
|
| 466 |
+
main_reference_tensor = torch.tensor(main_reference_img)
|
| 467 |
+
main_reference_tensor = train_transforms(main_reference_tensor).permute(2, 0, 1).contiguous()
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
# Fetch the first frame and then do ID merge for this branch of training
|
| 471 |
+
first_frame_np = video_np_masked[downsample_start_frame_idx] # Needs to return for Validation
|
| 472 |
+
# cv2.imwrite("first_frame"+str(idx)+".png", cv2.cvtColor(first_frame_np, cv2.COLOR_BGR2RGB))
|
| 473 |
+
|
| 474 |
+
# Convert to Tensor and then Transforms
|
| 475 |
+
first_frame_tensor = torch.tensor(first_frame_np)
|
| 476 |
+
first_frame_tensor = train_transforms(first_frame_tensor).permute(2, 0, 1).contiguous()
|
| 477 |
+
|
| 478 |
+
#######################################################################################################################################
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
############################################## Draw the Traj Points and Transform to Tensor #############################################
|
| 483 |
+
|
| 484 |
+
# Draw the dilated points
|
| 485 |
+
if self.one_point_one_obj:
|
| 486 |
+
target_pred_tracks = one_track_point # For this case, we only has one point per one object
|
| 487 |
+
else:
|
| 488 |
+
target_pred_tracks = full_pred_tracks
|
| 489 |
+
|
| 490 |
+
traj_tensor, traj_imgs_np, merge_frames = self.prepare_traj_tensor(target_pred_tracks, original_height, original_width, selected_frames,
|
| 491 |
+
self.dot_radius, self.target_width, self.target_height, resized_mask_region_box, idx)
|
| 492 |
+
|
| 493 |
+
#########################################################################################################################################
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
# Write some processed meta data
|
| 497 |
+
processed_meta_data = {
|
| 498 |
+
"full_pred_tracks": full_pred_tracks,
|
| 499 |
+
"original_width": original_width,
|
| 500 |
+
"original_height": original_height,
|
| 501 |
+
"mask_region": mask_region,
|
| 502 |
+
"resized_mask_region_box": resized_mask_region_box,
|
| 503 |
+
}
|
| 504 |
+
|
| 505 |
+
except Exception as e:
|
| 506 |
+
print("The exception is ", e)
|
| 507 |
+
old_idx = idx
|
| 508 |
+
idx = random.randint(0, len(self.info_lists))
|
| 509 |
+
print("We cannot process the video", old_idx, " and we choose a new idx of ", idx)
|
| 510 |
+
continue # For any error occurs, we run it again with new idx proposed (a random int less than current value)
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
# If everything is ok, we should break at the end
|
| 514 |
+
break
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
# Return the information
|
| 518 |
+
return {
|
| 519 |
+
"video_tensor": video_tensor,
|
| 520 |
+
"traj_tensor": traj_tensor,
|
| 521 |
+
"first_frame_tensor": first_frame_tensor,
|
| 522 |
+
"main_reference_tensor": main_reference_tensor,
|
| 523 |
+
"text_prompt": text_prompt,
|
| 524 |
+
|
| 525 |
+
# The rest are auxiliary data for the validation/testing purposes
|
| 526 |
+
"video_gt_np": selected_frames,
|
| 527 |
+
"first_frame_np": first_frame_np,
|
| 528 |
+
"main_reference_np": main_reference_img,
|
| 529 |
+
"processed_meta_data": processed_meta_data,
|
| 530 |
+
"traj_imgs_np": traj_imgs_np,
|
| 531 |
+
"merge_frames" : merge_frames,
|
| 532 |
+
"gt_video_path": video_path,
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
|
pipelines/pipeline_cogvideox_i2v_motion.py
ADDED
|
@@ -0,0 +1,931 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os, sys, shutil
|
| 17 |
+
import inspect
|
| 18 |
+
import math
|
| 19 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import PIL
|
| 22 |
+
import torch
|
| 23 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
| 24 |
+
|
| 25 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 26 |
+
from diffusers.image_processor import PipelineImageInput
|
| 27 |
+
from diffusers.loaders import CogVideoXLoraLoaderMixin
|
| 28 |
+
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
| 29 |
+
# from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
| 30 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 31 |
+
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
| 32 |
+
from diffusers.utils import (
|
| 33 |
+
is_torch_xla_available,
|
| 34 |
+
logging,
|
| 35 |
+
replace_example_docstring,
|
| 36 |
+
)
|
| 37 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 38 |
+
from diffusers.video_processor import VideoProcessor
|
| 39 |
+
from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# Import files from the local folder
|
| 43 |
+
root_path = os.path.abspath('.')
|
| 44 |
+
sys.path.append(root_path)
|
| 45 |
+
from architecture.embeddings import get_3d_rotary_pos_embed
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
if is_torch_xla_available():
|
| 49 |
+
import torch_xla.core.xla_model as xm
|
| 50 |
+
|
| 51 |
+
XLA_AVAILABLE = True
|
| 52 |
+
else:
|
| 53 |
+
XLA_AVAILABLE = False
|
| 54 |
+
|
| 55 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
EXAMPLE_DOC_STRING = """
|
| 59 |
+
Examples:
|
| 60 |
+
```py
|
| 61 |
+
>>> import torch
|
| 62 |
+
>>> from diffusers import CogVideoXImageToVideoPipeline
|
| 63 |
+
>>> from diffusers.utils import export_to_video, load_image
|
| 64 |
+
|
| 65 |
+
>>> pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
|
| 66 |
+
>>> pipe.to("cuda")
|
| 67 |
+
|
| 68 |
+
>>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
|
| 69 |
+
>>> image = load_image(
|
| 70 |
+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
|
| 71 |
+
... )
|
| 72 |
+
>>> video = pipe(image, prompt, use_dynamic_cfg=True)
|
| 73 |
+
>>> export_to_video(video.frames[0], "output.mp4", fps=8)
|
| 74 |
+
```
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
| 79 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
| 80 |
+
|
| 81 |
+
tw = tgt_width
|
| 82 |
+
th = tgt_height
|
| 83 |
+
h, w = src
|
| 84 |
+
r = h / w
|
| 85 |
+
if r > (th / tw): # NOTE: 这里应该是把aspect ratio align到target的程度 (类似于之前看的Reference Resize方法)
|
| 86 |
+
resize_height = th
|
| 87 |
+
resize_width = int(round(th / h * w)) # NOTE: 这个一个branch,这里会有多余位点
|
| 88 |
+
else:
|
| 89 |
+
resize_width = tw
|
| 90 |
+
resize_height = int(round(tw / w * h))
|
| 91 |
+
|
| 92 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
| 93 |
+
crop_left = int(round((tw - resize_width) / 2.0)) # NOTE: 这个取了中间值
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 100 |
+
def retrieve_timesteps(
|
| 101 |
+
scheduler,
|
| 102 |
+
num_inference_steps: Optional[int] = None,
|
| 103 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 104 |
+
timesteps: Optional[List[int]] = None,
|
| 105 |
+
sigmas: Optional[List[float]] = None,
|
| 106 |
+
**kwargs,
|
| 107 |
+
):
|
| 108 |
+
r"""
|
| 109 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 110 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
scheduler (`SchedulerMixin`):
|
| 114 |
+
The scheduler to get timesteps from.
|
| 115 |
+
num_inference_steps (`int`):
|
| 116 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 117 |
+
must be `None`.
|
| 118 |
+
device (`str` or `torch.device`, *optional*):
|
| 119 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 120 |
+
timesteps (`List[int]`, *optional*):
|
| 121 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 122 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 123 |
+
sigmas (`List[float]`, *optional*):
|
| 124 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 125 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 129 |
+
second element is the number of inference steps.
|
| 130 |
+
"""
|
| 131 |
+
if timesteps is not None and sigmas is not None:
|
| 132 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 133 |
+
if timesteps is not None:
|
| 134 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 135 |
+
if not accepts_timesteps:
|
| 136 |
+
raise ValueError(
|
| 137 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 138 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 139 |
+
)
|
| 140 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 141 |
+
timesteps = scheduler.timesteps
|
| 142 |
+
num_inference_steps = len(timesteps)
|
| 143 |
+
elif sigmas is not None:
|
| 144 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 145 |
+
if not accept_sigmas:
|
| 146 |
+
raise ValueError(
|
| 147 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 148 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 149 |
+
)
|
| 150 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 151 |
+
timesteps = scheduler.timesteps
|
| 152 |
+
num_inference_steps = len(timesteps)
|
| 153 |
+
else:
|
| 154 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 155 |
+
timesteps = scheduler.timesteps
|
| 156 |
+
return timesteps, num_inference_steps
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 160 |
+
def retrieve_latents(
|
| 161 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 162 |
+
):
|
| 163 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 164 |
+
return encoder_output.latent_dist.sample(generator)
|
| 165 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 166 |
+
return encoder_output.latent_dist.mode()
|
| 167 |
+
elif hasattr(encoder_output, "latents"):
|
| 168 |
+
return encoder_output.latents
|
| 169 |
+
else:
|
| 170 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
| 174 |
+
r"""
|
| 175 |
+
Pipeline for image-to-video generation using CogVideoX.
|
| 176 |
+
|
| 177 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 178 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
vae ([`AutoencoderKL`]):
|
| 182 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 183 |
+
text_encoder ([`T5EncoderModel`]):
|
| 184 |
+
Frozen text-encoder. CogVideoX uses
|
| 185 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
| 186 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
| 187 |
+
tokenizer (`T5Tokenizer`):
|
| 188 |
+
Tokenizer of class
|
| 189 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 190 |
+
transformer ([`CogVideoXTransformer3DModel`]):
|
| 191 |
+
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
| 192 |
+
scheduler ([`SchedulerMixin`]):
|
| 193 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
_optional_components = []
|
| 197 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 198 |
+
|
| 199 |
+
_callback_tensor_inputs = [
|
| 200 |
+
"latents",
|
| 201 |
+
"prompt_embeds",
|
| 202 |
+
"negative_prompt_embeds",
|
| 203 |
+
]
|
| 204 |
+
|
| 205 |
+
def __init__(
|
| 206 |
+
self,
|
| 207 |
+
tokenizer: T5Tokenizer,
|
| 208 |
+
text_encoder: T5EncoderModel,
|
| 209 |
+
vae: AutoencoderKLCogVideoX,
|
| 210 |
+
transformer: CogVideoXTransformer3DModel,
|
| 211 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
| 212 |
+
):
|
| 213 |
+
super().__init__()
|
| 214 |
+
|
| 215 |
+
self.register_modules(
|
| 216 |
+
tokenizer=tokenizer,
|
| 217 |
+
text_encoder=text_encoder,
|
| 218 |
+
vae=vae,
|
| 219 |
+
transformer=transformer,
|
| 220 |
+
scheduler=scheduler,
|
| 221 |
+
)
|
| 222 |
+
self.vae_scale_factor_spatial = (
|
| 223 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 224 |
+
)
|
| 225 |
+
self.vae_scale_factor_temporal = (
|
| 226 |
+
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
|
| 227 |
+
)
|
| 228 |
+
self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
|
| 229 |
+
|
| 230 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 231 |
+
|
| 232 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
|
| 233 |
+
def _get_t5_prompt_embeds(
|
| 234 |
+
self,
|
| 235 |
+
prompt: Union[str, List[str]] = None,
|
| 236 |
+
num_videos_per_prompt: int = 1,
|
| 237 |
+
max_sequence_length: int = 226,
|
| 238 |
+
device: Optional[torch.device] = None,
|
| 239 |
+
dtype: Optional[torch.dtype] = None,
|
| 240 |
+
):
|
| 241 |
+
device = device or self._execution_device
|
| 242 |
+
dtype = dtype or self.text_encoder.dtype
|
| 243 |
+
|
| 244 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 245 |
+
batch_size = len(prompt)
|
| 246 |
+
|
| 247 |
+
text_inputs = self.tokenizer(
|
| 248 |
+
prompt,
|
| 249 |
+
padding="max_length",
|
| 250 |
+
max_length=max_sequence_length,
|
| 251 |
+
truncation=True,
|
| 252 |
+
add_special_tokens=True,
|
| 253 |
+
return_tensors="pt",
|
| 254 |
+
)
|
| 255 |
+
text_input_ids = text_inputs.input_ids
|
| 256 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 257 |
+
|
| 258 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 259 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 260 |
+
logger.warning(
|
| 261 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 262 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
| 266 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 267 |
+
|
| 268 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 269 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 270 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 271 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 272 |
+
|
| 273 |
+
return prompt_embeds
|
| 274 |
+
|
| 275 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
|
| 276 |
+
def encode_prompt(
|
| 277 |
+
self,
|
| 278 |
+
prompt: Union[str, List[str]],
|
| 279 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 280 |
+
do_classifier_free_guidance: bool = True,
|
| 281 |
+
num_videos_per_prompt: int = 1,
|
| 282 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 283 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 284 |
+
max_sequence_length: int = 226,
|
| 285 |
+
device: Optional[torch.device] = None,
|
| 286 |
+
dtype: Optional[torch.dtype] = None,
|
| 287 |
+
):
|
| 288 |
+
r"""
|
| 289 |
+
Encodes the prompt into text encoder hidden states.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 293 |
+
prompt to be encoded
|
| 294 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 295 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 296 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 297 |
+
less than `1`).
|
| 298 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 299 |
+
Whether to use classifier free guidance or not.
|
| 300 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 301 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 302 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 303 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 304 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 305 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 306 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 307 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 308 |
+
argument.
|
| 309 |
+
device: (`torch.device`, *optional*):
|
| 310 |
+
torch device
|
| 311 |
+
dtype: (`torch.dtype`, *optional*):
|
| 312 |
+
torch dtype
|
| 313 |
+
"""
|
| 314 |
+
device = device or self._execution_device
|
| 315 |
+
|
| 316 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 317 |
+
if prompt is not None:
|
| 318 |
+
batch_size = len(prompt)
|
| 319 |
+
else:
|
| 320 |
+
batch_size = prompt_embeds.shape[0]
|
| 321 |
+
|
| 322 |
+
if prompt_embeds is None:
|
| 323 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 324 |
+
prompt=prompt,
|
| 325 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 326 |
+
max_sequence_length=max_sequence_length,
|
| 327 |
+
device=device,
|
| 328 |
+
dtype=dtype,
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 332 |
+
negative_prompt = negative_prompt or ""
|
| 333 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 334 |
+
|
| 335 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 336 |
+
raise TypeError(
|
| 337 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 338 |
+
f" {type(prompt)}."
|
| 339 |
+
)
|
| 340 |
+
elif batch_size != len(negative_prompt):
|
| 341 |
+
raise ValueError(
|
| 342 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 343 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 344 |
+
" the batch size of `prompt`."
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 348 |
+
prompt=negative_prompt,
|
| 349 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 350 |
+
max_sequence_length=max_sequence_length,
|
| 351 |
+
device=device,
|
| 352 |
+
dtype=dtype,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
return prompt_embeds, negative_prompt_embeds
|
| 356 |
+
|
| 357 |
+
def prepare_latents(
|
| 358 |
+
self,
|
| 359 |
+
image: torch.Tensor,
|
| 360 |
+
batch_size: int = 1,
|
| 361 |
+
num_channels_latents: int = 16,
|
| 362 |
+
num_frames: int = 13,
|
| 363 |
+
height: int = 60,
|
| 364 |
+
width: int = 90,
|
| 365 |
+
dtype: Optional[torch.dtype] = None,
|
| 366 |
+
device: Optional[torch.device] = None,
|
| 367 |
+
generator: Optional[torch.Generator] = None,
|
| 368 |
+
latents: Optional[torch.Tensor] = None,
|
| 369 |
+
):
|
| 370 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 371 |
+
raise ValueError(
|
| 372 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 373 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 377 |
+
shape = (
|
| 378 |
+
batch_size,
|
| 379 |
+
num_frames,
|
| 380 |
+
num_channels_latents,
|
| 381 |
+
height // self.vae_scale_factor_spatial,
|
| 382 |
+
width // self.vae_scale_factor_spatial,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
# For CogVideoX1.5, the latent should add 1 for padding (Not use)
|
| 386 |
+
if self.transformer.config.patch_size_t is not None:
|
| 387 |
+
shape = shape[:1] + (shape[1] + shape[1] % self.transformer.config.patch_size_t,) + shape[2:]
|
| 388 |
+
|
| 389 |
+
image = image.unsqueeze(2) # [B, C, F, H, W]
|
| 390 |
+
|
| 391 |
+
if isinstance(generator, list):
|
| 392 |
+
image_latents = [
|
| 393 |
+
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
|
| 394 |
+
]
|
| 395 |
+
else:
|
| 396 |
+
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
|
| 397 |
+
|
| 398 |
+
image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
| 399 |
+
|
| 400 |
+
if not self.vae.config.invert_scale_latents:
|
| 401 |
+
image_latents = self.vae_scaling_factor_image * image_latents
|
| 402 |
+
else:
|
| 403 |
+
# This is awkward but required because the CogVideoX team forgot to multiply the
|
| 404 |
+
# scaling factor during training :)
|
| 405 |
+
image_latents = 1 / self.vae_scaling_factor_image * image_latents
|
| 406 |
+
|
| 407 |
+
padding_shape = (
|
| 408 |
+
batch_size,
|
| 409 |
+
num_frames - 1,
|
| 410 |
+
num_channels_latents,
|
| 411 |
+
height // self.vae_scale_factor_spatial,
|
| 412 |
+
width // self.vae_scale_factor_spatial,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
|
| 416 |
+
image_latents = torch.cat([image_latents, latent_padding], dim=1)
|
| 417 |
+
|
| 418 |
+
# Select the first frame along the second dimension
|
| 419 |
+
if self.transformer.config.patch_size_t is not None:
|
| 420 |
+
first_frame = image_latents[:, : image_latents.size(1) % self.transformer.config.patch_size_t, ...]
|
| 421 |
+
image_latents = torch.cat([first_frame, image_latents], dim=1)
|
| 422 |
+
|
| 423 |
+
if latents is None:
|
| 424 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 425 |
+
else:
|
| 426 |
+
latents = latents.to(device)
|
| 427 |
+
|
| 428 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 429 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 430 |
+
return latents, image_latents
|
| 431 |
+
|
| 432 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
|
| 433 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 434 |
+
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
| 435 |
+
latents = 1 / self.vae_scaling_factor_image * latents
|
| 436 |
+
|
| 437 |
+
frames = self.vae.decode(latents).sample
|
| 438 |
+
return frames
|
| 439 |
+
|
| 440 |
+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
|
| 441 |
+
def get_timesteps(self, num_inference_steps, timesteps, strength, device):
|
| 442 |
+
# get the original timestep using init_timestep
|
| 443 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 444 |
+
|
| 445 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
| 446 |
+
timesteps = timesteps[t_start * self.scheduler.order :]
|
| 447 |
+
|
| 448 |
+
return timesteps, num_inference_steps - t_start
|
| 449 |
+
|
| 450 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 451 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 452 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 453 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 454 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 455 |
+
# and should be between [0, 1]
|
| 456 |
+
|
| 457 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 458 |
+
extra_step_kwargs = {}
|
| 459 |
+
if accepts_eta:
|
| 460 |
+
extra_step_kwargs["eta"] = eta
|
| 461 |
+
|
| 462 |
+
# check if the scheduler accepts generator
|
| 463 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 464 |
+
if accepts_generator:
|
| 465 |
+
extra_step_kwargs["generator"] = generator
|
| 466 |
+
return extra_step_kwargs
|
| 467 |
+
|
| 468 |
+
def check_inputs(
|
| 469 |
+
self,
|
| 470 |
+
image,
|
| 471 |
+
prompt,
|
| 472 |
+
height,
|
| 473 |
+
width,
|
| 474 |
+
negative_prompt,
|
| 475 |
+
callback_on_step_end_tensor_inputs,
|
| 476 |
+
latents=None,
|
| 477 |
+
prompt_embeds=None,
|
| 478 |
+
negative_prompt_embeds=None,
|
| 479 |
+
):
|
| 480 |
+
if (
|
| 481 |
+
not isinstance(image, torch.Tensor)
|
| 482 |
+
and not isinstance(image, PIL.Image.Image)
|
| 483 |
+
and not isinstance(image, list)
|
| 484 |
+
):
|
| 485 |
+
raise ValueError(
|
| 486 |
+
"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
| 487 |
+
f" {type(image)}"
|
| 488 |
+
)
|
| 489 |
+
|
| 490 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 491 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 492 |
+
|
| 493 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 494 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 495 |
+
):
|
| 496 |
+
raise ValueError(
|
| 497 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 498 |
+
)
|
| 499 |
+
if prompt is not None and prompt_embeds is not None:
|
| 500 |
+
raise ValueError(
|
| 501 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 502 |
+
" only forward one of the two."
|
| 503 |
+
)
|
| 504 |
+
elif prompt is None and prompt_embeds is None:
|
| 505 |
+
raise ValueError(
|
| 506 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 507 |
+
)
|
| 508 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 509 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 510 |
+
|
| 511 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 512 |
+
raise ValueError(
|
| 513 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 514 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 518 |
+
raise ValueError(
|
| 519 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 520 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 524 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 525 |
+
raise ValueError(
|
| 526 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 527 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 528 |
+
f" {negative_prompt_embeds.shape}."
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
|
| 532 |
+
def fuse_qkv_projections(self) -> None:
|
| 533 |
+
r"""Enables fused QKV projections."""
|
| 534 |
+
self.fusing_transformer = True
|
| 535 |
+
self.transformer.fuse_qkv_projections()
|
| 536 |
+
|
| 537 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections
|
| 538 |
+
def unfuse_qkv_projections(self) -> None:
|
| 539 |
+
r"""Disable QKV projection fusion if enabled."""
|
| 540 |
+
if not self.fusing_transformer:
|
| 541 |
+
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
| 542 |
+
else:
|
| 543 |
+
self.transformer.unfuse_qkv_projections()
|
| 544 |
+
self.fusing_transformer = False
|
| 545 |
+
|
| 546 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
|
| 547 |
+
def _prepare_rotary_positional_embeddings(
|
| 548 |
+
self,
|
| 549 |
+
height: int,
|
| 550 |
+
width: int,
|
| 551 |
+
num_frames: int,
|
| 552 |
+
device: torch.device,
|
| 553 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 557 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 558 |
+
|
| 559 |
+
p = self.transformer.config.patch_size
|
| 560 |
+
p_t = self.transformer.config.patch_size_t
|
| 561 |
+
|
| 562 |
+
base_size_width = self.transformer.config.sample_width // p
|
| 563 |
+
base_size_height = self.transformer.config.sample_height // p
|
| 564 |
+
|
| 565 |
+
# RoPE extrapolation factor in NTK
|
| 566 |
+
# token_factor_ratio = (grid_height * grid_width) / (base_size_width * base_size_height)
|
| 567 |
+
# if token_factor_ratio > 1.0:
|
| 568 |
+
# ntk_factor = token_factor_ratio
|
| 569 |
+
# else:
|
| 570 |
+
# ntk_factor = 1.0
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
if p_t is None: # HACK: Go this Branch
|
| 574 |
+
# CogVideoX 1.0
|
| 575 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
| 576 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
| 577 |
+
)
|
| 578 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 579 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 580 |
+
crops_coords=grid_crops_coords, # ((0, 0), (30, 45))
|
| 581 |
+
grid_size=(grid_height, grid_width), # (30, 45)
|
| 582 |
+
# ntk_factor = ntk_factor, # For the extrapolation
|
| 583 |
+
temporal_size=num_frames,
|
| 584 |
+
device=device,
|
| 585 |
+
)
|
| 586 |
+
else:
|
| 587 |
+
# CogVideoX 1.5
|
| 588 |
+
base_num_frames = (num_frames + p_t - 1) // p_t
|
| 589 |
+
|
| 590 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 591 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 592 |
+
crops_coords=None,
|
| 593 |
+
grid_size=(grid_height, grid_width),
|
| 594 |
+
temporal_size=base_num_frames,
|
| 595 |
+
grid_type="slice",
|
| 596 |
+
max_size=(base_size_height, base_size_width),
|
| 597 |
+
device=device,
|
| 598 |
+
)
|
| 599 |
+
|
| 600 |
+
return freqs_cos, freqs_sin
|
| 601 |
+
|
| 602 |
+
@property
|
| 603 |
+
def guidance_scale(self):
|
| 604 |
+
return self._guidance_scale
|
| 605 |
+
|
| 606 |
+
@property
|
| 607 |
+
def num_timesteps(self):
|
| 608 |
+
return self._num_timesteps
|
| 609 |
+
|
| 610 |
+
@property
|
| 611 |
+
def attention_kwargs(self):
|
| 612 |
+
return self._attention_kwargs
|
| 613 |
+
|
| 614 |
+
@property
|
| 615 |
+
def interrupt(self):
|
| 616 |
+
return self._interrupt
|
| 617 |
+
|
| 618 |
+
@torch.no_grad()
|
| 619 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 620 |
+
def __call__(
|
| 621 |
+
self,
|
| 622 |
+
image: PipelineImageInput,
|
| 623 |
+
traj_tensor = None,
|
| 624 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 625 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 626 |
+
height: Optional[int] = None,
|
| 627 |
+
width: Optional[int] = None,
|
| 628 |
+
num_frames: int = 49,
|
| 629 |
+
num_inference_steps: int = 50,
|
| 630 |
+
timesteps: Optional[List[int]] = None,
|
| 631 |
+
guidance_scale: float = 6,
|
| 632 |
+
use_dynamic_cfg: bool = False,
|
| 633 |
+
num_videos_per_prompt: int = 1,
|
| 634 |
+
eta: float = 0.0,
|
| 635 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 636 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 637 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 638 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 639 |
+
output_type: str = "pil",
|
| 640 |
+
return_dict: bool = True,
|
| 641 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 642 |
+
callback_on_step_end: Optional[
|
| 643 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 644 |
+
] = None,
|
| 645 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 646 |
+
max_sequence_length: int = 226,
|
| 647 |
+
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
| 648 |
+
"""
|
| 649 |
+
Function invoked when calling the pipeline for generation.
|
| 650 |
+
|
| 651 |
+
Args:
|
| 652 |
+
image (`PipelineImageInput`):
|
| 653 |
+
The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
|
| 654 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 655 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 656 |
+
instead.
|
| 657 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 658 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 659 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 660 |
+
less than `1`).
|
| 661 |
+
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
| 662 |
+
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
| 663 |
+
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
| 664 |
+
The width in pixels of the generated image. This is set to 720 by default for the best results.
|
| 665 |
+
num_frames (`int`, defaults to `48`):
|
| 666 |
+
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
| 667 |
+
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
|
| 668 |
+
num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
|
| 669 |
+
needs to be satisfied is that of divisibility mentioned above.
|
| 670 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 671 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 672 |
+
expense of slower inference.
|
| 673 |
+
timesteps (`List[int]`, *optional*):
|
| 674 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 675 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 676 |
+
passed will be used. Must be in descending order.
|
| 677 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 678 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 679 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 680 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 681 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 682 |
+
usually at the expense of lower image quality.
|
| 683 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 684 |
+
The number of videos to generate per prompt.
|
| 685 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 686 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 687 |
+
to make generation deterministic.
|
| 688 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 689 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 690 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 691 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 692 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 693 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 694 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 695 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 696 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 697 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 698 |
+
argument.
|
| 699 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 700 |
+
The output format of the generate image. Choose between
|
| 701 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 702 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 703 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 704 |
+
of a plain tuple.
|
| 705 |
+
attention_kwargs (`dict`, *optional*):
|
| 706 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 707 |
+
`self.processor` in
|
| 708 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 709 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 710 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 711 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 712 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 713 |
+
`callback_on_step_end_tensor_inputs`.
|
| 714 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 715 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 716 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 717 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 718 |
+
max_sequence_length (`int`, defaults to `226`):
|
| 719 |
+
Maximum sequence length in encoded prompt. Must be consistent with
|
| 720 |
+
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
| 721 |
+
|
| 722 |
+
Examples:
|
| 723 |
+
|
| 724 |
+
Returns:
|
| 725 |
+
[`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`:
|
| 726 |
+
[`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
|
| 727 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 728 |
+
"""
|
| 729 |
+
|
| 730 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 731 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 732 |
+
|
| 733 |
+
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
|
| 734 |
+
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
|
| 735 |
+
num_frames = num_frames or self.transformer.config.sample_frames
|
| 736 |
+
|
| 737 |
+
num_videos_per_prompt = 1
|
| 738 |
+
|
| 739 |
+
# 1. Check inputs. Raise error if not correct
|
| 740 |
+
self.check_inputs(
|
| 741 |
+
image=image,
|
| 742 |
+
prompt=prompt,
|
| 743 |
+
height=height,
|
| 744 |
+
width=width,
|
| 745 |
+
negative_prompt=negative_prompt,
|
| 746 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 747 |
+
latents=latents,
|
| 748 |
+
prompt_embeds=prompt_embeds,
|
| 749 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 750 |
+
)
|
| 751 |
+
self._guidance_scale = guidance_scale
|
| 752 |
+
self._attention_kwargs = attention_kwargs
|
| 753 |
+
self._interrupt = False
|
| 754 |
+
|
| 755 |
+
# 2. Default call parameters
|
| 756 |
+
if prompt is not None and isinstance(prompt, str):
|
| 757 |
+
batch_size = 1
|
| 758 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 759 |
+
batch_size = len(prompt)
|
| 760 |
+
else:
|
| 761 |
+
batch_size = prompt_embeds.shape[0]
|
| 762 |
+
|
| 763 |
+
device = self._execution_device
|
| 764 |
+
|
| 765 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 766 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 767 |
+
# corresponds to doing no classifier free guidance.
|
| 768 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 769 |
+
|
| 770 |
+
# 3. Encode input prompt
|
| 771 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 772 |
+
prompt=prompt,
|
| 773 |
+
negative_prompt=negative_prompt,
|
| 774 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 775 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 776 |
+
prompt_embeds=prompt_embeds,
|
| 777 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 778 |
+
max_sequence_length=max_sequence_length,
|
| 779 |
+
device=device,
|
| 780 |
+
)
|
| 781 |
+
if do_classifier_free_guidance:
|
| 782 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 783 |
+
|
| 784 |
+
# 4. Prepare timesteps
|
| 785 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 786 |
+
self._num_timesteps = len(timesteps)
|
| 787 |
+
|
| 788 |
+
# 5. Prepare latents
|
| 789 |
+
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 790 |
+
|
| 791 |
+
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
|
| 792 |
+
patch_size_t = self.transformer.config.patch_size_t
|
| 793 |
+
additional_frames = 0
|
| 794 |
+
if patch_size_t is not None and latent_frames % patch_size_t != 0:
|
| 795 |
+
additional_frames = patch_size_t - latent_frames % patch_size_t
|
| 796 |
+
num_frames += additional_frames * self.vae_scale_factor_temporal
|
| 797 |
+
|
| 798 |
+
image = self.video_processor.preprocess(image, height=height, width=width).to(
|
| 799 |
+
device, dtype=prompt_embeds.dtype
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
latent_channels = 16 # self.transformer.config.in_channels // 2
|
| 803 |
+
latents, image_latents = self.prepare_latents(
|
| 804 |
+
image,
|
| 805 |
+
batch_size * num_videos_per_prompt,
|
| 806 |
+
latent_channels,
|
| 807 |
+
num_frames,
|
| 808 |
+
height,
|
| 809 |
+
width,
|
| 810 |
+
prompt_embeds.dtype,
|
| 811 |
+
device,
|
| 812 |
+
generator,
|
| 813 |
+
latents,
|
| 814 |
+
)
|
| 815 |
+
|
| 816 |
+
|
| 817 |
+
# 5.5. Traj Preprocess
|
| 818 |
+
traj_tensor = traj_tensor.to(device, dtype = self.vae.dtype)[None] #.unsqueeze(0)
|
| 819 |
+
traj_tensor = traj_tensor.permute(0, 2, 1, 3, 4)
|
| 820 |
+
traj_latents = self.vae.encode(traj_tensor).latent_dist
|
| 821 |
+
|
| 822 |
+
# Scale, Permute, and other conversion
|
| 823 |
+
traj_latents = traj_latents.sample() * self.vae.config.scaling_factor
|
| 824 |
+
traj_latents = traj_latents.permute(0, 2, 1, 3, 4)
|
| 825 |
+
traj_latents = traj_latents.to(memory_format = torch.contiguous_format).float().to(dtype = prompt_embeds.dtype) # [B, F, C, H, W]
|
| 826 |
+
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 830 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 831 |
+
|
| 832 |
+
# 7. Create rotary embeds if required
|
| 833 |
+
image_rotary_emb = (
|
| 834 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
| 835 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
| 836 |
+
else None
|
| 837 |
+
)
|
| 838 |
+
|
| 839 |
+
# 8. Create ofs embeds if required
|
| 840 |
+
ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0)
|
| 841 |
+
|
| 842 |
+
# 8. Denoising loop
|
| 843 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 844 |
+
|
| 845 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 846 |
+
# for DPM-solver++
|
| 847 |
+
old_pred_original_sample = None
|
| 848 |
+
for i, t in enumerate(timesteps):
|
| 849 |
+
if self.interrupt:
|
| 850 |
+
continue
|
| 851 |
+
|
| 852 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 853 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 854 |
+
|
| 855 |
+
latent_traj = torch.cat([traj_latents] * 2) if do_classifier_free_guidance else traj_latents
|
| 856 |
+
|
| 857 |
+
latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
|
| 858 |
+
latent_model_input = torch.cat([latent_model_input, latent_image_input, latent_traj], dim=2) # The thrid dim grow from 16 to 32
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 862 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 863 |
+
|
| 864 |
+
# predict noise model_output
|
| 865 |
+
noise_pred = self.transformer(
|
| 866 |
+
hidden_states=latent_model_input,
|
| 867 |
+
encoder_hidden_states=prompt_embeds,
|
| 868 |
+
timestep=timestep,
|
| 869 |
+
ofs=ofs_emb,
|
| 870 |
+
image_rotary_emb=image_rotary_emb,
|
| 871 |
+
attention_kwargs=attention_kwargs,
|
| 872 |
+
return_dict=False,
|
| 873 |
+
)[0]
|
| 874 |
+
noise_pred = noise_pred.float()
|
| 875 |
+
|
| 876 |
+
# perform guidance
|
| 877 |
+
if use_dynamic_cfg:
|
| 878 |
+
self._guidance_scale = 1 + guidance_scale * (
|
| 879 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
| 880 |
+
)
|
| 881 |
+
if do_classifier_free_guidance:
|
| 882 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 883 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 884 |
+
|
| 885 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 886 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
| 887 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 888 |
+
else:
|
| 889 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
| 890 |
+
noise_pred,
|
| 891 |
+
old_pred_original_sample,
|
| 892 |
+
t,
|
| 893 |
+
timesteps[i - 1] if i > 0 else None,
|
| 894 |
+
latents,
|
| 895 |
+
**extra_step_kwargs,
|
| 896 |
+
return_dict=False,
|
| 897 |
+
)
|
| 898 |
+
latents = latents.to(prompt_embeds.dtype)
|
| 899 |
+
|
| 900 |
+
# call the callback, if provided
|
| 901 |
+
if callback_on_step_end is not None:
|
| 902 |
+
callback_kwargs = {}
|
| 903 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 904 |
+
callback_kwargs[k] = locals()[k]
|
| 905 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 906 |
+
|
| 907 |
+
latents = callback_outputs.pop("latents", latents)
|
| 908 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 909 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 910 |
+
|
| 911 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 912 |
+
progress_bar.update()
|
| 913 |
+
|
| 914 |
+
if XLA_AVAILABLE:
|
| 915 |
+
xm.mark_step()
|
| 916 |
+
|
| 917 |
+
if not output_type == "latent":
|
| 918 |
+
# Discard any padding frames that were added for CogVideoX 1.5
|
| 919 |
+
latents = latents[:, additional_frames:]
|
| 920 |
+
video = self.decode_latents(latents)
|
| 921 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 922 |
+
else:
|
| 923 |
+
video = latents
|
| 924 |
+
|
| 925 |
+
# Offload all models
|
| 926 |
+
self.maybe_free_model_hooks()
|
| 927 |
+
|
| 928 |
+
if not return_dict:
|
| 929 |
+
return (video,)
|
| 930 |
+
|
| 931 |
+
return CogVideoXPipelineOutput(frames=video)
|
pipelines/pipeline_cogvideox_i2v_motion_FrameINO.py
ADDED
|
@@ -0,0 +1,960 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import inspect
|
| 17 |
+
import math
|
| 18 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import PIL
|
| 21 |
+
import torch
|
| 22 |
+
from transformers import T5EncoderModel, T5Tokenizer
|
| 23 |
+
|
| 24 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 25 |
+
from diffusers.image_processor import PipelineImageInput
|
| 26 |
+
from diffusers.loaders import CogVideoXLoraLoaderMixin
|
| 27 |
+
from diffusers.models import AutoencoderKLCogVideoX, CogVideoXTransformer3DModel
|
| 28 |
+
from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
| 29 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 30 |
+
from diffusers.schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
| 31 |
+
from diffusers.utils import (
|
| 32 |
+
is_torch_xla_available,
|
| 33 |
+
logging,
|
| 34 |
+
replace_example_docstring,
|
| 35 |
+
)
|
| 36 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 37 |
+
from diffusers.video_processor import VideoProcessor
|
| 38 |
+
from diffusers.pipelines.cogvideo.pipeline_output import CogVideoXPipelineOutput
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
if is_torch_xla_available():
|
| 42 |
+
import torch_xla.core.xla_model as xm
|
| 43 |
+
|
| 44 |
+
XLA_AVAILABLE = True
|
| 45 |
+
else:
|
| 46 |
+
XLA_AVAILABLE = False
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
EXAMPLE_DOC_STRING = """
|
| 52 |
+
Examples:
|
| 53 |
+
```py
|
| 54 |
+
>>> import torch
|
| 55 |
+
>>> from diffusers import CogVideoXImageToVideoPipeline
|
| 56 |
+
>>> from diffusers.utils import export_to_video, load_image
|
| 57 |
+
|
| 58 |
+
>>> pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
|
| 59 |
+
>>> pipe.to("cuda")
|
| 60 |
+
|
| 61 |
+
>>> prompt = "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
|
| 62 |
+
>>> image = load_image(
|
| 63 |
+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
|
| 64 |
+
... )
|
| 65 |
+
>>> video = pipe(image, prompt, use_dynamic_cfg=True)
|
| 66 |
+
>>> export_to_video(video.frames[0], "output.mp4", fps=8)
|
| 67 |
+
```
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
|
| 72 |
+
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
|
| 73 |
+
|
| 74 |
+
tw = tgt_width
|
| 75 |
+
th = tgt_height
|
| 76 |
+
h, w = src
|
| 77 |
+
r = h / w
|
| 78 |
+
if r > (th / tw):
|
| 79 |
+
resize_height = th
|
| 80 |
+
resize_width = int(round(th / h * w))
|
| 81 |
+
else:
|
| 82 |
+
resize_width = tw
|
| 83 |
+
resize_height = int(round(tw / w * h))
|
| 84 |
+
|
| 85 |
+
crop_top = int(round((th - resize_height) / 2.0))
|
| 86 |
+
crop_left = int(round((tw - resize_width) / 2.0))
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 93 |
+
def retrieve_timesteps(
|
| 94 |
+
scheduler,
|
| 95 |
+
num_inference_steps: Optional[int] = None,
|
| 96 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 97 |
+
timesteps: Optional[List[int]] = None,
|
| 98 |
+
sigmas: Optional[List[float]] = None,
|
| 99 |
+
**kwargs,
|
| 100 |
+
):
|
| 101 |
+
r"""
|
| 102 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 103 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
scheduler (`SchedulerMixin`):
|
| 107 |
+
The scheduler to get timesteps from.
|
| 108 |
+
num_inference_steps (`int`):
|
| 109 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 110 |
+
must be `None`.
|
| 111 |
+
device (`str` or `torch.device`, *optional*):
|
| 112 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 113 |
+
timesteps (`List[int]`, *optional*):
|
| 114 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 115 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 116 |
+
sigmas (`List[float]`, *optional*):
|
| 117 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 118 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 122 |
+
second element is the number of inference steps.
|
| 123 |
+
"""
|
| 124 |
+
if timesteps is not None and sigmas is not None:
|
| 125 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 126 |
+
if timesteps is not None:
|
| 127 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 128 |
+
if not accepts_timesteps:
|
| 129 |
+
raise ValueError(
|
| 130 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 131 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 132 |
+
)
|
| 133 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 134 |
+
timesteps = scheduler.timesteps
|
| 135 |
+
num_inference_steps = len(timesteps)
|
| 136 |
+
elif sigmas is not None:
|
| 137 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 138 |
+
if not accept_sigmas:
|
| 139 |
+
raise ValueError(
|
| 140 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 141 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 142 |
+
)
|
| 143 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 144 |
+
timesteps = scheduler.timesteps
|
| 145 |
+
num_inference_steps = len(timesteps)
|
| 146 |
+
else:
|
| 147 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 148 |
+
timesteps = scheduler.timesteps
|
| 149 |
+
return timesteps, num_inference_steps
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 153 |
+
def retrieve_latents(
|
| 154 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 155 |
+
):
|
| 156 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 157 |
+
return encoder_output.latent_dist.sample(generator)
|
| 158 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 159 |
+
return encoder_output.latent_dist.mode()
|
| 160 |
+
elif hasattr(encoder_output, "latents"):
|
| 161 |
+
return encoder_output.latents
|
| 162 |
+
else:
|
| 163 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class CogVideoXImageToVideoPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
| 167 |
+
r"""
|
| 168 |
+
Pipeline for image-to-video generation using CogVideoX.
|
| 169 |
+
|
| 170 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
| 171 |
+
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
vae ([`AutoencoderKL`]):
|
| 175 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 176 |
+
text_encoder ([`T5EncoderModel`]):
|
| 177 |
+
Frozen text-encoder. CogVideoX uses
|
| 178 |
+
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the
|
| 179 |
+
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
|
| 180 |
+
tokenizer (`T5Tokenizer`):
|
| 181 |
+
Tokenizer of class
|
| 182 |
+
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 183 |
+
transformer ([`CogVideoXTransformer3DModel`]):
|
| 184 |
+
A text conditioned `CogVideoXTransformer3DModel` to denoise the encoded video latents.
|
| 185 |
+
scheduler ([`SchedulerMixin`]):
|
| 186 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded video latents.
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
_optional_components = []
|
| 190 |
+
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
| 191 |
+
|
| 192 |
+
_callback_tensor_inputs = [
|
| 193 |
+
"latents",
|
| 194 |
+
"prompt_embeds",
|
| 195 |
+
"negative_prompt_embeds",
|
| 196 |
+
]
|
| 197 |
+
|
| 198 |
+
def __init__(
|
| 199 |
+
self,
|
| 200 |
+
tokenizer: T5Tokenizer,
|
| 201 |
+
text_encoder: T5EncoderModel,
|
| 202 |
+
vae: AutoencoderKLCogVideoX,
|
| 203 |
+
transformer: CogVideoXTransformer3DModel,
|
| 204 |
+
scheduler: Union[CogVideoXDDIMScheduler, CogVideoXDPMScheduler],
|
| 205 |
+
):
|
| 206 |
+
super().__init__()
|
| 207 |
+
|
| 208 |
+
self.register_modules(
|
| 209 |
+
tokenizer=tokenizer,
|
| 210 |
+
text_encoder=text_encoder,
|
| 211 |
+
vae=vae,
|
| 212 |
+
transformer=transformer,
|
| 213 |
+
scheduler=scheduler,
|
| 214 |
+
)
|
| 215 |
+
self.vae_scale_factor_spatial = (
|
| 216 |
+
2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 217 |
+
)
|
| 218 |
+
self.vae_scale_factor_temporal = (
|
| 219 |
+
self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 4
|
| 220 |
+
)
|
| 221 |
+
self.vae_scaling_factor_image = self.vae.config.scaling_factor if getattr(self, "vae", None) else 0.7
|
| 222 |
+
|
| 223 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 224 |
+
|
| 225 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
|
| 226 |
+
def _get_t5_prompt_embeds(
|
| 227 |
+
self,
|
| 228 |
+
prompt: Union[str, List[str]] = None,
|
| 229 |
+
num_videos_per_prompt: int = 1,
|
| 230 |
+
max_sequence_length: int = 226,
|
| 231 |
+
device: Optional[torch.device] = None,
|
| 232 |
+
dtype: Optional[torch.dtype] = None,
|
| 233 |
+
):
|
| 234 |
+
device = device or self._execution_device
|
| 235 |
+
dtype = dtype or self.text_encoder.dtype
|
| 236 |
+
|
| 237 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 238 |
+
batch_size = len(prompt)
|
| 239 |
+
|
| 240 |
+
text_inputs = self.tokenizer(
|
| 241 |
+
prompt,
|
| 242 |
+
padding="max_length",
|
| 243 |
+
max_length=max_sequence_length,
|
| 244 |
+
truncation=True,
|
| 245 |
+
add_special_tokens=True,
|
| 246 |
+
return_tensors="pt",
|
| 247 |
+
)
|
| 248 |
+
text_input_ids = text_inputs.input_ids
|
| 249 |
+
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
| 250 |
+
|
| 251 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
| 252 |
+
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
|
| 253 |
+
logger.warning(
|
| 254 |
+
"The following part of your input was truncated because `max_sequence_length` is set to "
|
| 255 |
+
f" {max_sequence_length} tokens: {removed_text}"
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
| 259 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 260 |
+
|
| 261 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 262 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 263 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 264 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 265 |
+
|
| 266 |
+
return prompt_embeds
|
| 267 |
+
|
| 268 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
|
| 269 |
+
def encode_prompt(
|
| 270 |
+
self,
|
| 271 |
+
prompt: Union[str, List[str]],
|
| 272 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 273 |
+
do_classifier_free_guidance: bool = True,
|
| 274 |
+
num_videos_per_prompt: int = 1,
|
| 275 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 276 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 277 |
+
max_sequence_length: int = 226,
|
| 278 |
+
device: Optional[torch.device] = None,
|
| 279 |
+
dtype: Optional[torch.dtype] = None,
|
| 280 |
+
):
|
| 281 |
+
r"""
|
| 282 |
+
Encodes the prompt into text encoder hidden states.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 286 |
+
prompt to be encoded
|
| 287 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 288 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 289 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 290 |
+
less than `1`).
|
| 291 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 292 |
+
Whether to use classifier free guidance or not.
|
| 293 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 294 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 295 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 296 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 297 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 298 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 299 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 300 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 301 |
+
argument.
|
| 302 |
+
device: (`torch.device`, *optional*):
|
| 303 |
+
torch device
|
| 304 |
+
dtype: (`torch.dtype`, *optional*):
|
| 305 |
+
torch dtype
|
| 306 |
+
"""
|
| 307 |
+
device = device or self._execution_device
|
| 308 |
+
|
| 309 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 310 |
+
if prompt is not None:
|
| 311 |
+
batch_size = len(prompt)
|
| 312 |
+
else:
|
| 313 |
+
batch_size = prompt_embeds.shape[0]
|
| 314 |
+
|
| 315 |
+
if prompt_embeds is None:
|
| 316 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 317 |
+
prompt=prompt,
|
| 318 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 319 |
+
max_sequence_length=max_sequence_length,
|
| 320 |
+
device=device,
|
| 321 |
+
dtype=dtype,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 325 |
+
negative_prompt = negative_prompt or ""
|
| 326 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 327 |
+
|
| 328 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 329 |
+
raise TypeError(
|
| 330 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 331 |
+
f" {type(prompt)}."
|
| 332 |
+
)
|
| 333 |
+
elif batch_size != len(negative_prompt):
|
| 334 |
+
raise ValueError(
|
| 335 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 336 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 337 |
+
" the batch size of `prompt`."
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 341 |
+
prompt=negative_prompt,
|
| 342 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 343 |
+
max_sequence_length=max_sequence_length,
|
| 344 |
+
device=device,
|
| 345 |
+
dtype=dtype,
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
return prompt_embeds, negative_prompt_embeds
|
| 349 |
+
|
| 350 |
+
def prepare_latents(
|
| 351 |
+
self,
|
| 352 |
+
image: torch.Tensor,
|
| 353 |
+
batch_size: int = 1,
|
| 354 |
+
num_channels_latents: int = 16,
|
| 355 |
+
num_frames: int = 13,
|
| 356 |
+
height: int = 60,
|
| 357 |
+
width: int = 90,
|
| 358 |
+
dtype: Optional[torch.dtype] = None,
|
| 359 |
+
device: Optional[torch.device] = None,
|
| 360 |
+
generator: Optional[torch.Generator] = None,
|
| 361 |
+
latents: Optional[torch.Tensor] = None,
|
| 362 |
+
):
|
| 363 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 364 |
+
raise ValueError(
|
| 365 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 366 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 370 |
+
shape = (
|
| 371 |
+
batch_size,
|
| 372 |
+
num_frames,
|
| 373 |
+
num_channels_latents,
|
| 374 |
+
height // self.vae_scale_factor_spatial,
|
| 375 |
+
width // self.vae_scale_factor_spatial,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
# For CogVideoX1.5, the latent should add 1 for padding (Not use)
|
| 379 |
+
if self.transformer.config.patch_size_t is not None:
|
| 380 |
+
shape = shape[:1] + (shape[1] + shape[1] % self.transformer.config.patch_size_t,) + shape[2:]
|
| 381 |
+
|
| 382 |
+
image = image.unsqueeze(2) # [B, C, F, H, W]
|
| 383 |
+
|
| 384 |
+
if isinstance(generator, list):
|
| 385 |
+
image_latents = [
|
| 386 |
+
retrieve_latents(self.vae.encode(image[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
|
| 387 |
+
]
|
| 388 |
+
else:
|
| 389 |
+
image_latents = [retrieve_latents(self.vae.encode(img.unsqueeze(0)), generator) for img in image]
|
| 390 |
+
|
| 391 |
+
image_latents = torch.cat(image_latents, dim=0).to(dtype).permute(0, 2, 1, 3, 4) # [B, F, C, H, W]
|
| 392 |
+
|
| 393 |
+
if not self.vae.config.invert_scale_latents:
|
| 394 |
+
image_latents = self.vae_scaling_factor_image * image_latents
|
| 395 |
+
else:
|
| 396 |
+
# This is awkward but required because the CogVideoX team forgot to multiply the
|
| 397 |
+
# scaling factor during training :)
|
| 398 |
+
image_latents = 1 / self.vae_scaling_factor_image * image_latents
|
| 399 |
+
|
| 400 |
+
padding_shape = (
|
| 401 |
+
batch_size,
|
| 402 |
+
num_frames - 1,
|
| 403 |
+
num_channels_latents,
|
| 404 |
+
height // self.vae_scale_factor_spatial,
|
| 405 |
+
width // self.vae_scale_factor_spatial,
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
latent_padding = torch.zeros(padding_shape, device=device, dtype=dtype)
|
| 409 |
+
image_latents = torch.cat([image_latents, latent_padding], dim=1)
|
| 410 |
+
|
| 411 |
+
# Select the first frame along the second dimension
|
| 412 |
+
if self.transformer.config.patch_size_t is not None:
|
| 413 |
+
first_frame = image_latents[:, : image_latents.size(1) % self.transformer.config.patch_size_t, ...]
|
| 414 |
+
image_latents = torch.cat([first_frame, image_latents], dim=1)
|
| 415 |
+
|
| 416 |
+
if latents is None:
|
| 417 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 418 |
+
else:
|
| 419 |
+
latents = latents.to(device)
|
| 420 |
+
|
| 421 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 422 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 423 |
+
return latents, image_latents
|
| 424 |
+
|
| 425 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.decode_latents
|
| 426 |
+
def decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
|
| 427 |
+
latents = latents.permute(0, 2, 1, 3, 4) # [batch_size, num_channels, num_frames, height, width]
|
| 428 |
+
latents = 1 / self.vae_scaling_factor_image * latents
|
| 429 |
+
|
| 430 |
+
frames = self.vae.decode(latents).sample
|
| 431 |
+
return frames
|
| 432 |
+
|
| 433 |
+
# Copied from diffusers.pipelines.animatediff.pipeline_animatediff_video2video.AnimateDiffVideoToVideoPipeline.get_timesteps
|
| 434 |
+
def get_timesteps(self, num_inference_steps, timesteps, strength, device):
|
| 435 |
+
# get the original timestep using init_timestep
|
| 436 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 437 |
+
|
| 438 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
| 439 |
+
timesteps = timesteps[t_start * self.scheduler.order :]
|
| 440 |
+
|
| 441 |
+
return timesteps, num_inference_steps - t_start
|
| 442 |
+
|
| 443 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
| 444 |
+
def prepare_extra_step_kwargs(self, generator, eta):
|
| 445 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 446 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 447 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 448 |
+
# and should be between [0, 1]
|
| 449 |
+
|
| 450 |
+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 451 |
+
extra_step_kwargs = {}
|
| 452 |
+
if accepts_eta:
|
| 453 |
+
extra_step_kwargs["eta"] = eta
|
| 454 |
+
|
| 455 |
+
# check if the scheduler accepts generator
|
| 456 |
+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
| 457 |
+
if accepts_generator:
|
| 458 |
+
extra_step_kwargs["generator"] = generator
|
| 459 |
+
return extra_step_kwargs
|
| 460 |
+
|
| 461 |
+
def check_inputs(
|
| 462 |
+
self,
|
| 463 |
+
image,
|
| 464 |
+
prompt,
|
| 465 |
+
height,
|
| 466 |
+
width,
|
| 467 |
+
negative_prompt,
|
| 468 |
+
callback_on_step_end_tensor_inputs,
|
| 469 |
+
latents=None,
|
| 470 |
+
prompt_embeds=None,
|
| 471 |
+
negative_prompt_embeds=None,
|
| 472 |
+
):
|
| 473 |
+
if (
|
| 474 |
+
not isinstance(image, torch.Tensor)
|
| 475 |
+
and not isinstance(image, PIL.Image.Image)
|
| 476 |
+
and not isinstance(image, list)
|
| 477 |
+
):
|
| 478 |
+
raise ValueError(
|
| 479 |
+
"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
| 480 |
+
f" {type(image)}"
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 484 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 485 |
+
|
| 486 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 487 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 488 |
+
):
|
| 489 |
+
raise ValueError(
|
| 490 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 491 |
+
)
|
| 492 |
+
if prompt is not None and prompt_embeds is not None:
|
| 493 |
+
raise ValueError(
|
| 494 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 495 |
+
" only forward one of the two."
|
| 496 |
+
)
|
| 497 |
+
elif prompt is None and prompt_embeds is None:
|
| 498 |
+
raise ValueError(
|
| 499 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 500 |
+
)
|
| 501 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 502 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 503 |
+
|
| 504 |
+
if prompt is not None and negative_prompt_embeds is not None:
|
| 505 |
+
raise ValueError(
|
| 506 |
+
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
| 507 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 511 |
+
raise ValueError(
|
| 512 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 513 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 517 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 518 |
+
raise ValueError(
|
| 519 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 520 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 521 |
+
f" {negative_prompt_embeds.shape}."
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.fuse_qkv_projections
|
| 525 |
+
def fuse_qkv_projections(self) -> None:
|
| 526 |
+
r"""Enables fused QKV projections."""
|
| 527 |
+
self.fusing_transformer = True
|
| 528 |
+
self.transformer.fuse_qkv_projections()
|
| 529 |
+
|
| 530 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.unfuse_qkv_projections
|
| 531 |
+
def unfuse_qkv_projections(self) -> None:
|
| 532 |
+
r"""Disable QKV projection fusion if enabled."""
|
| 533 |
+
if not self.fusing_transformer:
|
| 534 |
+
logger.warning("The Transformer was not initially fused for QKV projections. Doing nothing.")
|
| 535 |
+
else:
|
| 536 |
+
self.transformer.unfuse_qkv_projections()
|
| 537 |
+
self.fusing_transformer = False
|
| 538 |
+
|
| 539 |
+
# Copied from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._prepare_rotary_positional_embeddings
|
| 540 |
+
def _prepare_rotary_positional_embeddings(
|
| 541 |
+
self,
|
| 542 |
+
height: int,
|
| 543 |
+
width: int,
|
| 544 |
+
num_frames: int,
|
| 545 |
+
device: torch.device,
|
| 546 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 550 |
+
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
| 551 |
+
|
| 552 |
+
p = self.transformer.config.patch_size
|
| 553 |
+
p_t = self.transformer.config.patch_size_t
|
| 554 |
+
|
| 555 |
+
base_size_width = self.transformer.config.sample_width // p
|
| 556 |
+
base_size_height = self.transformer.config.sample_height // p
|
| 557 |
+
|
| 558 |
+
if p_t is None: # HACK: Go this Branch
|
| 559 |
+
# CogVideoX 1.0
|
| 560 |
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
| 561 |
+
(grid_height, grid_width), base_size_width, base_size_height
|
| 562 |
+
)
|
| 563 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 564 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 565 |
+
crops_coords=grid_crops_coords, # ((0, 0), (30, 45))
|
| 566 |
+
grid_size=(grid_height, grid_width), # (30, 45)
|
| 567 |
+
temporal_size=num_frames,
|
| 568 |
+
device=device,
|
| 569 |
+
)
|
| 570 |
+
else:
|
| 571 |
+
# CogVideoX 1.5
|
| 572 |
+
base_num_frames = (num_frames + p_t - 1) // p_t
|
| 573 |
+
|
| 574 |
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
| 575 |
+
embed_dim=self.transformer.config.attention_head_dim,
|
| 576 |
+
crops_coords=None,
|
| 577 |
+
grid_size=(grid_height, grid_width),
|
| 578 |
+
temporal_size=base_num_frames,
|
| 579 |
+
grid_type="slice",
|
| 580 |
+
max_size=(base_size_height, base_size_width),
|
| 581 |
+
device=device,
|
| 582 |
+
)
|
| 583 |
+
|
| 584 |
+
return freqs_cos, freqs_sin
|
| 585 |
+
|
| 586 |
+
@property
|
| 587 |
+
def guidance_scale(self):
|
| 588 |
+
return self._guidance_scale
|
| 589 |
+
|
| 590 |
+
@property
|
| 591 |
+
def num_timesteps(self):
|
| 592 |
+
return self._num_timesteps
|
| 593 |
+
|
| 594 |
+
@property
|
| 595 |
+
def attention_kwargs(self):
|
| 596 |
+
return self._attention_kwargs
|
| 597 |
+
|
| 598 |
+
@property
|
| 599 |
+
def interrupt(self):
|
| 600 |
+
return self._interrupt
|
| 601 |
+
|
| 602 |
+
@torch.no_grad()
|
| 603 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 604 |
+
def __call__(
|
| 605 |
+
self,
|
| 606 |
+
image: PipelineImageInput,
|
| 607 |
+
traj_tensor = None,
|
| 608 |
+
ID_tensor = None,
|
| 609 |
+
prompt: Optional[Union[str, List[str]]] = None,
|
| 610 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 611 |
+
height: Optional[int] = None,
|
| 612 |
+
width: Optional[int] = None,
|
| 613 |
+
num_frames: int = 49,
|
| 614 |
+
num_inference_steps: int = 50,
|
| 615 |
+
timesteps: Optional[List[int]] = None,
|
| 616 |
+
guidance_scale: float = 6,
|
| 617 |
+
use_dynamic_cfg: bool = False,
|
| 618 |
+
add_ID_reference_augment_noise: bool = True,
|
| 619 |
+
num_videos_per_prompt: int = 1,
|
| 620 |
+
eta: float = 0.0,
|
| 621 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 622 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 623 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 624 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 625 |
+
output_type: str = "pil",
|
| 626 |
+
return_dict: bool = True,
|
| 627 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 628 |
+
callback_on_step_end: Optional[
|
| 629 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 630 |
+
] = None,
|
| 631 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 632 |
+
max_sequence_length: int = 226,
|
| 633 |
+
) -> Union[CogVideoXPipelineOutput, Tuple]:
|
| 634 |
+
"""
|
| 635 |
+
Function invoked when calling the pipeline for generation.
|
| 636 |
+
|
| 637 |
+
Args:
|
| 638 |
+
image (`PipelineImageInput`):
|
| 639 |
+
The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
|
| 640 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 641 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 642 |
+
instead.
|
| 643 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 644 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 645 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 646 |
+
less than `1`).
|
| 647 |
+
height (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
| 648 |
+
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
| 649 |
+
width (`int`, *optional*, defaults to self.transformer.config.sample_height * self.vae_scale_factor_spatial):
|
| 650 |
+
The width in pixels of the generated image. This is set to 720 by default for the best results.
|
| 651 |
+
num_frames (`int`, defaults to `48`):
|
| 652 |
+
Number of frames to generate. Must be divisible by self.vae_scale_factor_temporal. Generated video will
|
| 653 |
+
contain 1 extra frame because CogVideoX is conditioned with (num_seconds * fps + 1) frames where
|
| 654 |
+
num_seconds is 6 and fps is 8. However, since videos can be saved at any fps, the only condition that
|
| 655 |
+
needs to be satisfied is that of divisibility mentioned above.
|
| 656 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 657 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 658 |
+
expense of slower inference.
|
| 659 |
+
timesteps (`List[int]`, *optional*):
|
| 660 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 661 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 662 |
+
passed will be used. Must be in descending order.
|
| 663 |
+
guidance_scale (`float`, *optional*, defaults to 7.0):
|
| 664 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 665 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 666 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 667 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 668 |
+
usually at the expense of lower image quality.
|
| 669 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 670 |
+
The number of videos to generate per prompt.
|
| 671 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 672 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 673 |
+
to make generation deterministic.
|
| 674 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 675 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 676 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 677 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 678 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 679 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 680 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 681 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 682 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 683 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 684 |
+
argument.
|
| 685 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 686 |
+
The output format of the generate image. Choose between
|
| 687 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 688 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 689 |
+
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
|
| 690 |
+
of a plain tuple.
|
| 691 |
+
attention_kwargs (`dict`, *optional*):
|
| 692 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 693 |
+
`self.processor` in
|
| 694 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 695 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 696 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 697 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 698 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 699 |
+
`callback_on_step_end_tensor_inputs`.
|
| 700 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 701 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 702 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 703 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 704 |
+
max_sequence_length (`int`, defaults to `226`):
|
| 705 |
+
Maximum sequence length in encoded prompt. Must be consistent with
|
| 706 |
+
`self.transformer.config.max_text_seq_length` otherwise may lead to poor results.
|
| 707 |
+
|
| 708 |
+
Examples:
|
| 709 |
+
|
| 710 |
+
Returns:
|
| 711 |
+
[`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] or `tuple`:
|
| 712 |
+
[`~pipelines.cogvideo.pipeline_output.CogVideoXPipelineOutput`] if `return_dict` is True, otherwise a
|
| 713 |
+
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
| 714 |
+
"""
|
| 715 |
+
|
| 716 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 717 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 718 |
+
|
| 719 |
+
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
|
| 720 |
+
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
|
| 721 |
+
num_frames = num_frames or self.transformer.config.sample_frames
|
| 722 |
+
|
| 723 |
+
num_videos_per_prompt = 1
|
| 724 |
+
|
| 725 |
+
# 1. Check inputs. Raise error if not correct
|
| 726 |
+
self.check_inputs(
|
| 727 |
+
image=image,
|
| 728 |
+
prompt=prompt,
|
| 729 |
+
height=height,
|
| 730 |
+
width=width,
|
| 731 |
+
negative_prompt=negative_prompt,
|
| 732 |
+
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
| 733 |
+
latents=latents,
|
| 734 |
+
prompt_embeds=prompt_embeds,
|
| 735 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 736 |
+
)
|
| 737 |
+
self._guidance_scale = guidance_scale
|
| 738 |
+
self._attention_kwargs = attention_kwargs
|
| 739 |
+
self._interrupt = False
|
| 740 |
+
|
| 741 |
+
# 2. Default call parameters
|
| 742 |
+
if prompt is not None and isinstance(prompt, str):
|
| 743 |
+
batch_size = 1
|
| 744 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 745 |
+
batch_size = len(prompt)
|
| 746 |
+
else:
|
| 747 |
+
batch_size = prompt_embeds.shape[0]
|
| 748 |
+
|
| 749 |
+
device = self._execution_device
|
| 750 |
+
|
| 751 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 752 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 753 |
+
# corresponds to doing no classifier free guidance.
|
| 754 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 755 |
+
|
| 756 |
+
# 3. Encode input prompt
|
| 757 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 758 |
+
prompt=prompt,
|
| 759 |
+
negative_prompt=negative_prompt,
|
| 760 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
| 761 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 762 |
+
prompt_embeds=prompt_embeds,
|
| 763 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 764 |
+
max_sequence_length=max_sequence_length,
|
| 765 |
+
device=device,
|
| 766 |
+
)
|
| 767 |
+
if do_classifier_free_guidance:
|
| 768 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 769 |
+
|
| 770 |
+
# 4. Prepare timesteps
|
| 771 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
| 772 |
+
self._num_timesteps = len(timesteps)
|
| 773 |
+
|
| 774 |
+
# 5. Prepare latents
|
| 775 |
+
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 776 |
+
|
| 777 |
+
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
|
| 778 |
+
patch_size_t = self.transformer.config.patch_size_t
|
| 779 |
+
additional_frames = 0
|
| 780 |
+
if patch_size_t is not None and num_latent_frames % patch_size_t != 0:
|
| 781 |
+
additional_frames = patch_size_t - num_latent_frames % patch_size_t
|
| 782 |
+
num_frames += additional_frames * self.vae_scale_factor_temporal
|
| 783 |
+
|
| 784 |
+
image = self.video_processor.preprocess(image, height=height, width=width).to(
|
| 785 |
+
device, dtype=prompt_embeds.dtype
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
latent_channels = 16 # self.transformer.config.in_channels // 2
|
| 789 |
+
latents, image_latents = self.prepare_latents(
|
| 790 |
+
image,
|
| 791 |
+
batch_size * num_videos_per_prompt,
|
| 792 |
+
latent_channels,
|
| 793 |
+
num_frames,
|
| 794 |
+
height,
|
| 795 |
+
width,
|
| 796 |
+
prompt_embeds.dtype,
|
| 797 |
+
device,
|
| 798 |
+
generator,
|
| 799 |
+
latents,
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
# 5.1. Traj Preprocess
|
| 804 |
+
traj_tensor = traj_tensor.to(device, dtype = self.vae.dtype)[None] #.unsqueeze(0)
|
| 805 |
+
traj_tensor = traj_tensor.permute(0, 2, 1, 3, 4)
|
| 806 |
+
traj_latents = self.vae.encode(traj_tensor).latent_dist
|
| 807 |
+
|
| 808 |
+
# Scale, Permute, and other conversion
|
| 809 |
+
traj_latents = traj_latents.sample() * self.vae.config.scaling_factor
|
| 810 |
+
traj_latents = traj_latents.permute(0, 2, 1, 3, 4)
|
| 811 |
+
traj_latents = traj_latents.to(memory_format = torch.contiguous_format).float().to(dtype = prompt_embeds.dtype) # [B, F, C, H, W]
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
# 5.2. ID Reference Preprocess
|
| 815 |
+
if ID_tensor is not None:
|
| 816 |
+
from train_code.train_cogvideox_motion_FrameINO import img_tensor_to_vae_latent # Put it here to avoid circular import
|
| 817 |
+
|
| 818 |
+
# TODO: test中要不要加Augment Noise再验证一下
|
| 819 |
+
ID_latent = img_tensor_to_vae_latent(ID_tensor.unsqueeze(0), self.vae, traj_latents.device, add_augment_noise = add_ID_reference_augment_noise)
|
| 820 |
+
ID_latent = ID_latent.unsqueeze(1).to(dtype = prompt_embeds.dtype)
|
| 821 |
+
|
| 822 |
+
|
| 823 |
+
|
| 824 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 825 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 826 |
+
|
| 827 |
+
# 7. Create rotary embeds if required
|
| 828 |
+
image_rotary_emb = (
|
| 829 |
+
self._prepare_rotary_positional_embeddings(height, width, latents.size(1), device)
|
| 830 |
+
if self.transformer.config.use_rotary_positional_embeddings
|
| 831 |
+
else None
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
# Copy the 14th frame with the first frame PE information
|
| 835 |
+
freqs_cos, freqs_sin = image_rotary_emb
|
| 836 |
+
first_frame_token_num = freqs_cos.shape[0] // num_latent_frames
|
| 837 |
+
freqs_cos = torch.cat([freqs_cos, freqs_cos[:first_frame_token_num]], dim=0) # Hard Code
|
| 838 |
+
freqs_sin = torch.cat([freqs_sin, freqs_sin[:first_frame_token_num]], dim=0)
|
| 839 |
+
image_rotary_emb = (freqs_cos, freqs_sin)
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
# 8. Create ofs embeds if required
|
| 843 |
+
ofs_emb = None if self.transformer.config.ofs_embed_dim is None else latents.new_full((1,), fill_value=2.0)
|
| 844 |
+
|
| 845 |
+
# 8. Denoising loop
|
| 846 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 847 |
+
|
| 848 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 849 |
+
# for DPM-solver++
|
| 850 |
+
old_pred_original_sample = None
|
| 851 |
+
for i, t in enumerate(timesteps):
|
| 852 |
+
if self.interrupt:
|
| 853 |
+
continue
|
| 854 |
+
|
| 855 |
+
# Noisy latents prepare
|
| 856 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 857 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 858 |
+
|
| 859 |
+
# First Frame latents prepare
|
| 860 |
+
latent_image_input = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
|
| 861 |
+
|
| 862 |
+
# Traj latents prepare
|
| 863 |
+
latent_traj = torch.cat([traj_latents] * 2) if do_classifier_free_guidance else traj_latents
|
| 864 |
+
|
| 865 |
+
# ID Refence prepare
|
| 866 |
+
if ID_tensor is not None:
|
| 867 |
+
|
| 868 |
+
# CFG Double Batch Size
|
| 869 |
+
latent_ID = torch.cat([ID_latent] * 2) if do_classifier_free_guidance else ID_latent
|
| 870 |
+
|
| 871 |
+
# Frame-Wise Token Increase
|
| 872 |
+
latent_model_input = torch.cat([latent_model_input, latent_ID], dim = 1)
|
| 873 |
+
|
| 874 |
+
# Increase the frame dimension of the Traj latents and the first frame latent
|
| 875 |
+
latent_ID_padding = latent_model_input.new_zeros(latent_ID.shape) # Zero latent values
|
| 876 |
+
latent_image_input = torch.cat([latent_image_input, latent_ID_padding], dim=1)
|
| 877 |
+
latent_traj = torch.cat([latent_traj, latent_ID_padding], dim=1)
|
| 878 |
+
|
| 879 |
+
|
| 880 |
+
# Dimension-Wise Concatenation
|
| 881 |
+
latent_model_input = torch.cat([latent_model_input, latent_image_input, latent_traj], dim=2) # The thrid dim grow from 16 to 32
|
| 882 |
+
|
| 883 |
+
|
| 884 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 885 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 886 |
+
|
| 887 |
+
# predict noise model_output
|
| 888 |
+
noise_pred = self.transformer(
|
| 889 |
+
hidden_states=latent_model_input,
|
| 890 |
+
encoder_hidden_states=prompt_embeds,
|
| 891 |
+
timestep=timestep,
|
| 892 |
+
ofs=ofs_emb,
|
| 893 |
+
image_rotary_emb=image_rotary_emb,
|
| 894 |
+
attention_kwargs=attention_kwargs,
|
| 895 |
+
return_dict=False,
|
| 896 |
+
)[0]
|
| 897 |
+
noise_pred = noise_pred.float()
|
| 898 |
+
|
| 899 |
+
|
| 900 |
+
# Discard the Extra ID tokens in the Noise Prediction
|
| 901 |
+
if ID_tensor is not None:
|
| 902 |
+
noise_pred = noise_pred[:, :num_latent_frames, :, :, :]
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
# perform guidance
|
| 906 |
+
if use_dynamic_cfg:
|
| 907 |
+
self._guidance_scale = 1 + guidance_scale * (
|
| 908 |
+
(1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2
|
| 909 |
+
)
|
| 910 |
+
if do_classifier_free_guidance:
|
| 911 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 912 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 913 |
+
|
| 914 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 915 |
+
if not isinstance(self.scheduler, CogVideoXDPMScheduler):
|
| 916 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
| 917 |
+
else:
|
| 918 |
+
latents, old_pred_original_sample = self.scheduler.step(
|
| 919 |
+
noise_pred,
|
| 920 |
+
old_pred_original_sample,
|
| 921 |
+
t,
|
| 922 |
+
timesteps[i - 1] if i > 0 else None,
|
| 923 |
+
latents,
|
| 924 |
+
**extra_step_kwargs,
|
| 925 |
+
return_dict=False,
|
| 926 |
+
)
|
| 927 |
+
latents = latents.to(prompt_embeds.dtype)
|
| 928 |
+
|
| 929 |
+
# call the callback, if provided
|
| 930 |
+
if callback_on_step_end is not None:
|
| 931 |
+
callback_kwargs = {}
|
| 932 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 933 |
+
callback_kwargs[k] = locals()[k]
|
| 934 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 935 |
+
|
| 936 |
+
latents = callback_outputs.pop("latents", latents)
|
| 937 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 938 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 939 |
+
|
| 940 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 941 |
+
progress_bar.update()
|
| 942 |
+
|
| 943 |
+
if XLA_AVAILABLE:
|
| 944 |
+
xm.mark_step()
|
| 945 |
+
|
| 946 |
+
if not output_type == "latent":
|
| 947 |
+
# Discard any padding frames that were added for CogVideoX 1.5
|
| 948 |
+
latents = latents[:, additional_frames:]
|
| 949 |
+
video = self.decode_latents(latents)
|
| 950 |
+
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
| 951 |
+
else:
|
| 952 |
+
video = latents
|
| 953 |
+
|
| 954 |
+
# Offload all models
|
| 955 |
+
self.maybe_free_model_hooks()
|
| 956 |
+
|
| 957 |
+
if not return_dict:
|
| 958 |
+
return (video,)
|
| 959 |
+
|
| 960 |
+
return CogVideoXPipelineOutput(frames=video)
|
pipelines/pipeline_wan_i2v_motion.py
ADDED
|
@@ -0,0 +1,861 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import html
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import os, sys, shutil
|
| 19 |
+
import PIL
|
| 20 |
+
import regex as re
|
| 21 |
+
import torch
|
| 22 |
+
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
|
| 23 |
+
|
| 24 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 25 |
+
from diffusers.image_processor import PipelineImageInput
|
| 26 |
+
from diffusers.loaders import WanLoraLoaderMixin
|
| 27 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 28 |
+
from diffusers.utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
|
| 29 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 30 |
+
from diffusers.video_processor import VideoProcessor
|
| 31 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 32 |
+
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# Import files from the local folder
|
| 36 |
+
root_path = os.path.abspath('.')
|
| 37 |
+
sys.path.append(root_path)
|
| 38 |
+
from architecture.transformer_wan import WanTransformer3DModel
|
| 39 |
+
from architecture.autoencoder_kl_wan import AutoencoderKLWan
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if is_torch_xla_available():
|
| 43 |
+
import torch_xla.core.xla_model as xm
|
| 44 |
+
|
| 45 |
+
XLA_AVAILABLE = True
|
| 46 |
+
else:
|
| 47 |
+
XLA_AVAILABLE = False
|
| 48 |
+
|
| 49 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 50 |
+
|
| 51 |
+
if is_ftfy_available():
|
| 52 |
+
import ftfy
|
| 53 |
+
|
| 54 |
+
EXAMPLE_DOC_STRING = """
|
| 55 |
+
Examples:
|
| 56 |
+
```python
|
| 57 |
+
>>> import torch
|
| 58 |
+
>>> import numpy as np
|
| 59 |
+
>>> from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
|
| 60 |
+
>>> from diffusers.utils import export_to_video, load_image
|
| 61 |
+
>>> from transformers import CLIPVisionModel
|
| 62 |
+
|
| 63 |
+
>>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
|
| 64 |
+
>>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
|
| 65 |
+
>>> image_encoder = CLIPVisionModel.from_pretrained(
|
| 66 |
+
... model_id, subfolder="image_encoder", torch_dtype=torch.float32
|
| 67 |
+
... )
|
| 68 |
+
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
| 69 |
+
>>> pipe = WanImageToVideoPipeline.from_pretrained(
|
| 70 |
+
... model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
|
| 71 |
+
... )
|
| 72 |
+
>>> pipe.to("cuda")
|
| 73 |
+
|
| 74 |
+
>>> image = load_image(
|
| 75 |
+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
|
| 76 |
+
... )
|
| 77 |
+
>>> max_area = 480 * 832
|
| 78 |
+
>>> aspect_ratio = image.height / image.width
|
| 79 |
+
>>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
| 80 |
+
>>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
| 81 |
+
>>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
| 82 |
+
>>> image = image.resize((width, height))
|
| 83 |
+
>>> prompt = (
|
| 84 |
+
... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
|
| 85 |
+
... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
|
| 86 |
+
... )
|
| 87 |
+
>>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
| 88 |
+
|
| 89 |
+
>>> output = pipe(
|
| 90 |
+
... image=image,
|
| 91 |
+
... prompt=prompt,
|
| 92 |
+
... negative_prompt=negative_prompt,
|
| 93 |
+
... height=height,
|
| 94 |
+
... width=width,
|
| 95 |
+
... num_frames=81,
|
| 96 |
+
... guidance_scale=5.0,
|
| 97 |
+
... ).frames[0]
|
| 98 |
+
>>> export_to_video(output, "output.mp4", fps=16)
|
| 99 |
+
```
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def basic_clean(text):
|
| 104 |
+
text = ftfy.fix_text(text)
|
| 105 |
+
text = html.unescape(html.unescape(text))
|
| 106 |
+
return text.strip()
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def whitespace_clean(text):
|
| 110 |
+
text = re.sub(r"\s+", " ", text)
|
| 111 |
+
text = text.strip()
|
| 112 |
+
return text
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def prompt_clean(text):
|
| 116 |
+
text = whitespace_clean(basic_clean(text))
|
| 117 |
+
return text
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 121 |
+
def retrieve_latents(
|
| 122 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 123 |
+
):
|
| 124 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 125 |
+
return encoder_output.latent_dist.sample(generator)
|
| 126 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 127 |
+
return encoder_output.latent_dist.mode()
|
| 128 |
+
elif hasattr(encoder_output, "latents"):
|
| 129 |
+
return encoder_output.latents
|
| 130 |
+
else:
|
| 131 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
| 135 |
+
r"""
|
| 136 |
+
Pipeline for image-to-video generation using Wan.
|
| 137 |
+
|
| 138 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 139 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
tokenizer ([`T5Tokenizer`]):
|
| 143 |
+
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
|
| 144 |
+
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
| 145 |
+
text_encoder ([`T5EncoderModel`]):
|
| 146 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
| 147 |
+
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
| 148 |
+
image_encoder ([`CLIPVisionModel`]):
|
| 149 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically
|
| 150 |
+
the
|
| 151 |
+
[clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large)
|
| 152 |
+
variant.
|
| 153 |
+
transformer ([`WanTransformer3DModel`]):
|
| 154 |
+
Conditional Transformer to denoise the input latents.
|
| 155 |
+
scheduler ([`UniPCMultistepScheduler`]):
|
| 156 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 157 |
+
vae ([`AutoencoderKLWan`]):
|
| 158 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 159 |
+
transformer_2 ([`WanTransformer3DModel`], *optional*):
|
| 160 |
+
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
|
| 161 |
+
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
|
| 162 |
+
`transformer` is used.
|
| 163 |
+
boundary_ratio (`float`, *optional*, defaults to `None`):
|
| 164 |
+
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
|
| 165 |
+
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
|
| 166 |
+
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
|
| 167 |
+
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae"
|
| 171 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
| 172 |
+
_optional_components = ["transformer", "transformer_2", "image_encoder", "image_processor"]
|
| 173 |
+
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
tokenizer: AutoTokenizer,
|
| 177 |
+
text_encoder: UMT5EncoderModel,
|
| 178 |
+
vae: AutoencoderKLWan,
|
| 179 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 180 |
+
image_processor: CLIPImageProcessor = None,
|
| 181 |
+
image_encoder: CLIPVisionModel = None,
|
| 182 |
+
transformer: WanTransformer3DModel = None,
|
| 183 |
+
transformer_2: WanTransformer3DModel = None,
|
| 184 |
+
boundary_ratio: Optional[float] = None,
|
| 185 |
+
expand_timesteps: bool = False,
|
| 186 |
+
):
|
| 187 |
+
super().__init__()
|
| 188 |
+
|
| 189 |
+
self.register_modules(
|
| 190 |
+
vae=vae,
|
| 191 |
+
text_encoder=text_encoder,
|
| 192 |
+
tokenizer=tokenizer,
|
| 193 |
+
image_encoder=image_encoder,
|
| 194 |
+
transformer=transformer,
|
| 195 |
+
scheduler=scheduler,
|
| 196 |
+
image_processor=image_processor,
|
| 197 |
+
transformer_2=transformer_2,
|
| 198 |
+
)
|
| 199 |
+
self.register_to_config(boundary_ratio=boundary_ratio, expand_timesteps=expand_timesteps)
|
| 200 |
+
|
| 201 |
+
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
|
| 202 |
+
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
|
| 203 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 204 |
+
self.image_processor = image_processor
|
| 205 |
+
|
| 206 |
+
def _get_t5_prompt_embeds(
|
| 207 |
+
self,
|
| 208 |
+
prompt: Union[str, List[str]] = None,
|
| 209 |
+
num_videos_per_prompt: int = 1,
|
| 210 |
+
max_sequence_length: int = 512,
|
| 211 |
+
device: Optional[torch.device] = None,
|
| 212 |
+
dtype: Optional[torch.dtype] = None,
|
| 213 |
+
):
|
| 214 |
+
device = device or self._execution_device
|
| 215 |
+
dtype = dtype or self.text_encoder.dtype
|
| 216 |
+
|
| 217 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 218 |
+
prompt = [prompt_clean(u) for u in prompt]
|
| 219 |
+
batch_size = len(prompt)
|
| 220 |
+
|
| 221 |
+
text_inputs = self.tokenizer(
|
| 222 |
+
prompt,
|
| 223 |
+
padding="max_length",
|
| 224 |
+
max_length=max_sequence_length,
|
| 225 |
+
truncation=True,
|
| 226 |
+
add_special_tokens=True,
|
| 227 |
+
return_attention_mask=True,
|
| 228 |
+
return_tensors="pt",
|
| 229 |
+
)
|
| 230 |
+
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
| 231 |
+
seq_lens = mask.gt(0).sum(dim=1).long()
|
| 232 |
+
|
| 233 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
| 234 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 235 |
+
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
| 236 |
+
prompt_embeds = torch.stack(
|
| 237 |
+
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 241 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 242 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 243 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 244 |
+
|
| 245 |
+
return prompt_embeds
|
| 246 |
+
|
| 247 |
+
def encode_image(
|
| 248 |
+
self,
|
| 249 |
+
image: PipelineImageInput,
|
| 250 |
+
device: Optional[torch.device] = None,
|
| 251 |
+
):
|
| 252 |
+
device = device or self._execution_device
|
| 253 |
+
image = self.image_processor(images=image, return_tensors="pt").to(device)
|
| 254 |
+
image_embeds = self.image_encoder(**image, output_hidden_states=True)
|
| 255 |
+
return image_embeds.hidden_states[-2]
|
| 256 |
+
|
| 257 |
+
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
|
| 258 |
+
def encode_prompt(
|
| 259 |
+
self,
|
| 260 |
+
prompt: Union[str, List[str]],
|
| 261 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 262 |
+
do_classifier_free_guidance: bool = True,
|
| 263 |
+
num_videos_per_prompt: int = 1,
|
| 264 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 265 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 266 |
+
max_sequence_length: int = 226,
|
| 267 |
+
device: Optional[torch.device] = None,
|
| 268 |
+
dtype: Optional[torch.dtype] = None,
|
| 269 |
+
):
|
| 270 |
+
r"""
|
| 271 |
+
Encodes the prompt into text encoder hidden states.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 275 |
+
prompt to be encoded
|
| 276 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 277 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 278 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 279 |
+
less than `1`).
|
| 280 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 281 |
+
Whether to use classifier free guidance or not.
|
| 282 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 283 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 284 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 285 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 286 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 287 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 288 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 289 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 290 |
+
argument.
|
| 291 |
+
device: (`torch.device`, *optional*):
|
| 292 |
+
torch device
|
| 293 |
+
dtype: (`torch.dtype`, *optional*):
|
| 294 |
+
torch dtype
|
| 295 |
+
"""
|
| 296 |
+
device = device or self._execution_device
|
| 297 |
+
|
| 298 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 299 |
+
if prompt is not None:
|
| 300 |
+
batch_size = len(prompt)
|
| 301 |
+
else:
|
| 302 |
+
batch_size = prompt_embeds.shape[0]
|
| 303 |
+
|
| 304 |
+
if prompt_embeds is None:
|
| 305 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 306 |
+
prompt=prompt,
|
| 307 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 308 |
+
max_sequence_length=max_sequence_length,
|
| 309 |
+
device=device,
|
| 310 |
+
dtype=dtype,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 314 |
+
negative_prompt = negative_prompt or ""
|
| 315 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 316 |
+
|
| 317 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 318 |
+
raise TypeError(
|
| 319 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 320 |
+
f" {type(prompt)}."
|
| 321 |
+
)
|
| 322 |
+
elif batch_size != len(negative_prompt):
|
| 323 |
+
raise ValueError(
|
| 324 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 325 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 326 |
+
" the batch size of `prompt`."
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 330 |
+
prompt=negative_prompt,
|
| 331 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 332 |
+
max_sequence_length=max_sequence_length,
|
| 333 |
+
device=device,
|
| 334 |
+
dtype=dtype,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
return prompt_embeds, negative_prompt_embeds
|
| 338 |
+
|
| 339 |
+
def check_inputs(
|
| 340 |
+
self,
|
| 341 |
+
prompt,
|
| 342 |
+
negative_prompt,
|
| 343 |
+
image,
|
| 344 |
+
height,
|
| 345 |
+
width,
|
| 346 |
+
prompt_embeds=None,
|
| 347 |
+
negative_prompt_embeds=None,
|
| 348 |
+
image_embeds=None,
|
| 349 |
+
callback_on_step_end_tensor_inputs=None,
|
| 350 |
+
guidance_scale_2=None,
|
| 351 |
+
):
|
| 352 |
+
if image is not None and image_embeds is not None:
|
| 353 |
+
raise ValueError(
|
| 354 |
+
f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
|
| 355 |
+
" only forward one of the two."
|
| 356 |
+
)
|
| 357 |
+
if image is None and image_embeds is None:
|
| 358 |
+
raise ValueError(
|
| 359 |
+
"Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
|
| 360 |
+
)
|
| 361 |
+
if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
|
| 362 |
+
raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
|
| 363 |
+
if height % 16 != 0 or width % 16 != 0:
|
| 364 |
+
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
| 365 |
+
|
| 366 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 367 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 368 |
+
):
|
| 369 |
+
raise ValueError(
|
| 370 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
if prompt is not None and prompt_embeds is not None:
|
| 374 |
+
raise ValueError(
|
| 375 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 376 |
+
" only forward one of the two."
|
| 377 |
+
)
|
| 378 |
+
elif negative_prompt is not None and negative_prompt_embeds is not None:
|
| 379 |
+
raise ValueError(
|
| 380 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
|
| 381 |
+
" only forward one of the two."
|
| 382 |
+
)
|
| 383 |
+
elif prompt is None and prompt_embeds is None:
|
| 384 |
+
raise ValueError(
|
| 385 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 386 |
+
)
|
| 387 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 388 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 389 |
+
elif negative_prompt is not None and (
|
| 390 |
+
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
| 391 |
+
):
|
| 392 |
+
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
| 393 |
+
|
| 394 |
+
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
|
| 395 |
+
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
|
| 396 |
+
|
| 397 |
+
if self.config.boundary_ratio is not None and image_embeds is not None:
|
| 398 |
+
raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.")
|
| 399 |
+
|
| 400 |
+
def prepare_latents(
|
| 401 |
+
self,
|
| 402 |
+
image: PipelineImageInput,
|
| 403 |
+
traj_tensor,
|
| 404 |
+
batch_size: int,
|
| 405 |
+
num_channels_latents: int = 16,
|
| 406 |
+
height: int = 480,
|
| 407 |
+
width: int = 832,
|
| 408 |
+
num_frames: int = 81,
|
| 409 |
+
dtype: Optional[torch.dtype] = None,
|
| 410 |
+
device: Optional[torch.device] = None,
|
| 411 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 412 |
+
latents: Optional[torch.Tensor] = None,
|
| 413 |
+
last_image: Optional[torch.Tensor] = None,
|
| 414 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 415 |
+
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 416 |
+
latent_height = height // self.vae_scale_factor_spatial
|
| 417 |
+
latent_width = width // self.vae_scale_factor_spatial
|
| 418 |
+
|
| 419 |
+
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
|
| 420 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 421 |
+
raise ValueError(
|
| 422 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 423 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
if latents is None:
|
| 427 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 428 |
+
else:
|
| 429 |
+
latents = latents.to(device=device, dtype=dtype)
|
| 430 |
+
|
| 431 |
+
image = image.unsqueeze(2) # [batch_size, channels, 1, height, width]
|
| 432 |
+
|
| 433 |
+
if self.config.expand_timesteps:
|
| 434 |
+
video_condition = image
|
| 435 |
+
|
| 436 |
+
elif last_image is None:
|
| 437 |
+
video_condition = torch.cat(
|
| 438 |
+
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
|
| 439 |
+
)
|
| 440 |
+
else:
|
| 441 |
+
last_image = last_image.unsqueeze(2)
|
| 442 |
+
video_condition = torch.cat(
|
| 443 |
+
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
|
| 444 |
+
dim=2,
|
| 445 |
+
)
|
| 446 |
+
video_condition = video_condition.to(device=device, dtype=self.vae.dtype)
|
| 447 |
+
|
| 448 |
+
latents_mean = (
|
| 449 |
+
torch.tensor(self.vae.config.latents_mean)
|
| 450 |
+
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
| 451 |
+
.to(latents.device, latents.dtype)
|
| 452 |
+
)
|
| 453 |
+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
| 454 |
+
latents.device, latents.dtype
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
if isinstance(generator, list):
|
| 458 |
+
latent_condition = [
|
| 459 |
+
retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
|
| 460 |
+
]
|
| 461 |
+
latent_condition = torch.cat(latent_condition)
|
| 462 |
+
else:
|
| 463 |
+
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
|
| 464 |
+
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
|
| 465 |
+
|
| 466 |
+
latent_condition = latent_condition.to(dtype)
|
| 467 |
+
latent_condition = (latent_condition - latents_mean) * latents_std
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
# Prepare the traj latent
|
| 472 |
+
traj_tensor = traj_tensor.to(device, dtype=self.vae.dtype) #.unsqueeze(0)
|
| 473 |
+
traj_tensor = traj_tensor.unsqueeze(0)
|
| 474 |
+
traj_tensor = traj_tensor.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
| 475 |
+
|
| 476 |
+
# VAE encode
|
| 477 |
+
traj_latents = retrieve_latents(self.vae.encode(traj_tensor), sample_mode="argmax")
|
| 478 |
+
|
| 479 |
+
# Extract Mean and Variance
|
| 480 |
+
traj_latents = (traj_latents - latents_mean) * latents_std
|
| 481 |
+
|
| 482 |
+
# Final Convert
|
| 483 |
+
traj_latents = traj_latents.to(memory_format = torch.contiguous_format).float()
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
if self.config.expand_timesteps:
|
| 488 |
+
first_frame_mask = torch.ones(
|
| 489 |
+
1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device
|
| 490 |
+
)
|
| 491 |
+
first_frame_mask[:, :, 0] = 0
|
| 492 |
+
return latents, latent_condition, traj_latents, first_frame_mask
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
|
| 497 |
+
|
| 498 |
+
if last_image is None:
|
| 499 |
+
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
| 500 |
+
else:
|
| 501 |
+
mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
|
| 502 |
+
first_frame_mask = mask_lat_size[:, :, 0:1]
|
| 503 |
+
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
|
| 504 |
+
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
| 505 |
+
mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width)
|
| 506 |
+
mask_lat_size = mask_lat_size.transpose(1, 2)
|
| 507 |
+
mask_lat_size = mask_lat_size.to(latent_condition.device)
|
| 508 |
+
|
| 509 |
+
return latents, torch.concat([mask_lat_size, latent_condition], dim=1)
|
| 510 |
+
|
| 511 |
+
@property
|
| 512 |
+
def guidance_scale(self):
|
| 513 |
+
return self._guidance_scale
|
| 514 |
+
|
| 515 |
+
@property
|
| 516 |
+
def do_classifier_free_guidance(self):
|
| 517 |
+
return self._guidance_scale > 1
|
| 518 |
+
|
| 519 |
+
@property
|
| 520 |
+
def num_timesteps(self):
|
| 521 |
+
return self._num_timesteps
|
| 522 |
+
|
| 523 |
+
@property
|
| 524 |
+
def current_timestep(self):
|
| 525 |
+
return self._current_timestep
|
| 526 |
+
|
| 527 |
+
@property
|
| 528 |
+
def interrupt(self):
|
| 529 |
+
return self._interrupt
|
| 530 |
+
|
| 531 |
+
@property
|
| 532 |
+
def attention_kwargs(self):
|
| 533 |
+
return self._attention_kwargs
|
| 534 |
+
|
| 535 |
+
@torch.no_grad()
|
| 536 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 537 |
+
def __call__(
|
| 538 |
+
self,
|
| 539 |
+
image: PipelineImageInput,
|
| 540 |
+
prompt: Union[str, List[str]] = None,
|
| 541 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 542 |
+
traj_tensor = None,
|
| 543 |
+
height: int = 480,
|
| 544 |
+
width: int = 832,
|
| 545 |
+
num_frames: int = 81,
|
| 546 |
+
num_inference_steps: int = 50,
|
| 547 |
+
guidance_scale: float = 5.0,
|
| 548 |
+
guidance_scale_2: Optional[float] = None,
|
| 549 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 550 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 551 |
+
latents: Optional[torch.Tensor] = None,
|
| 552 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 553 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 554 |
+
image_embeds: Optional[torch.Tensor] = None,
|
| 555 |
+
last_image: Optional[torch.Tensor] = None,
|
| 556 |
+
output_type: Optional[str] = "np",
|
| 557 |
+
return_dict: bool = True,
|
| 558 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 559 |
+
callback_on_step_end: Optional[
|
| 560 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 561 |
+
] = None,
|
| 562 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 563 |
+
max_sequence_length: int = 512,
|
| 564 |
+
):
|
| 565 |
+
r"""
|
| 566 |
+
The call function to the pipeline for generation.
|
| 567 |
+
|
| 568 |
+
Args:
|
| 569 |
+
image (`PipelineImageInput`):
|
| 570 |
+
The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
|
| 571 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 572 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 573 |
+
instead.
|
| 574 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 575 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 576 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 577 |
+
less than `1`).
|
| 578 |
+
height (`int`, defaults to `480`):
|
| 579 |
+
The height of the generated video.
|
| 580 |
+
width (`int`, defaults to `832`):
|
| 581 |
+
The width of the generated video.
|
| 582 |
+
num_frames (`int`, defaults to `81`):
|
| 583 |
+
The number of frames in the generated video.
|
| 584 |
+
num_inference_steps (`int`, defaults to `50`):
|
| 585 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 586 |
+
expense of slower inference.
|
| 587 |
+
guidance_scale (`float`, defaults to `5.0`):
|
| 588 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 589 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 590 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 591 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 592 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 593 |
+
guidance_scale_2 (`float`, *optional*, defaults to `None`):
|
| 594 |
+
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
|
| 595 |
+
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
|
| 596 |
+
and the pipeline's `boundary_ratio` are not None.
|
| 597 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 598 |
+
The number of images to generate per prompt.
|
| 599 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 600 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 601 |
+
generation deterministic.
|
| 602 |
+
latents (`torch.Tensor`, *optional*):
|
| 603 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 604 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 605 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 606 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 607 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 608 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 609 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 610 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 611 |
+
provided, text embeddings are generated from the `negative_prompt` input argument.
|
| 612 |
+
image_embeds (`torch.Tensor`, *optional*):
|
| 613 |
+
Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
|
| 614 |
+
image embeddings are generated from the `image` input argument.
|
| 615 |
+
output_type (`str`, *optional*, defaults to `"np"`):
|
| 616 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 617 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 618 |
+
Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
|
| 619 |
+
attention_kwargs (`dict`, *optional*):
|
| 620 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 621 |
+
`self.processor` in
|
| 622 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 623 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
| 624 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
| 625 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
| 626 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
| 627 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
| 628 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 629 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 630 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 631 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 632 |
+
max_sequence_length (`int`, defaults to `512`):
|
| 633 |
+
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
|
| 634 |
+
truncated. If the prompt is shorter, it will be padded to this length.
|
| 635 |
+
|
| 636 |
+
Examples:
|
| 637 |
+
|
| 638 |
+
Returns:
|
| 639 |
+
[`~WanPipelineOutput`] or `tuple`:
|
| 640 |
+
If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
| 641 |
+
the first element is a list with the generated images and the second element is a list of `bool`s
|
| 642 |
+
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
| 643 |
+
"""
|
| 644 |
+
|
| 645 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 646 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 647 |
+
|
| 648 |
+
# 1. Check inputs. Raise error if not correct
|
| 649 |
+
self.check_inputs(
|
| 650 |
+
prompt,
|
| 651 |
+
negative_prompt,
|
| 652 |
+
image,
|
| 653 |
+
height,
|
| 654 |
+
width,
|
| 655 |
+
prompt_embeds,
|
| 656 |
+
negative_prompt_embeds,
|
| 657 |
+
image_embeds,
|
| 658 |
+
callback_on_step_end_tensor_inputs,
|
| 659 |
+
guidance_scale_2,
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
if num_frames % self.vae_scale_factor_temporal != 1:
|
| 663 |
+
logger.warning(
|
| 664 |
+
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
|
| 665 |
+
)
|
| 666 |
+
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
| 667 |
+
num_frames = max(num_frames, 1)
|
| 668 |
+
|
| 669 |
+
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
|
| 670 |
+
guidance_scale_2 = guidance_scale
|
| 671 |
+
|
| 672 |
+
self._guidance_scale = guidance_scale
|
| 673 |
+
self._guidance_scale_2 = guidance_scale_2
|
| 674 |
+
self._attention_kwargs = attention_kwargs
|
| 675 |
+
self._current_timestep = None
|
| 676 |
+
self._interrupt = False
|
| 677 |
+
|
| 678 |
+
device = self._execution_device
|
| 679 |
+
|
| 680 |
+
# 2. Define call parameters
|
| 681 |
+
if prompt is not None and isinstance(prompt, str):
|
| 682 |
+
batch_size = 1
|
| 683 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 684 |
+
batch_size = len(prompt)
|
| 685 |
+
else:
|
| 686 |
+
batch_size = prompt_embeds.shape[0]
|
| 687 |
+
|
| 688 |
+
# 3. Encode input prompt
|
| 689 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 690 |
+
prompt=prompt,
|
| 691 |
+
negative_prompt=negative_prompt,
|
| 692 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 693 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 694 |
+
prompt_embeds=prompt_embeds,
|
| 695 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 696 |
+
max_sequence_length=max_sequence_length,
|
| 697 |
+
device=device,
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
# Encode image embedding
|
| 701 |
+
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
|
| 702 |
+
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
| 703 |
+
if negative_prompt_embeds is not None:
|
| 704 |
+
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
| 705 |
+
|
| 706 |
+
# only wan 2.1 i2v transformer accepts image_embeds
|
| 707 |
+
if self.transformer is not None and self.transformer.config.image_dim is not None:
|
| 708 |
+
if image_embeds is None:
|
| 709 |
+
if last_image is None:
|
| 710 |
+
image_embeds = self.encode_image(image, device)
|
| 711 |
+
else:
|
| 712 |
+
image_embeds = self.encode_image([image, last_image], device)
|
| 713 |
+
image_embeds = image_embeds.repeat(batch_size, 1, 1)
|
| 714 |
+
image_embeds = image_embeds.to(transformer_dtype)
|
| 715 |
+
|
| 716 |
+
# 4. Prepare timesteps
|
| 717 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 718 |
+
timesteps = self.scheduler.timesteps
|
| 719 |
+
|
| 720 |
+
# 5. Prepare latent variables
|
| 721 |
+
num_channels_latents = self.vae.config.z_dim
|
| 722 |
+
image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
|
| 723 |
+
if last_image is not None:
|
| 724 |
+
last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
|
| 725 |
+
device, dtype=torch.float32
|
| 726 |
+
)
|
| 727 |
+
|
| 728 |
+
latents_outputs = self.prepare_latents(
|
| 729 |
+
image,
|
| 730 |
+
traj_tensor,
|
| 731 |
+
batch_size * num_videos_per_prompt,
|
| 732 |
+
num_channels_latents,
|
| 733 |
+
height,
|
| 734 |
+
width,
|
| 735 |
+
num_frames,
|
| 736 |
+
torch.float32,
|
| 737 |
+
device,
|
| 738 |
+
generator,
|
| 739 |
+
latents,
|
| 740 |
+
last_image,
|
| 741 |
+
)
|
| 742 |
+
if self.config.expand_timesteps:
|
| 743 |
+
# wan 2.2 5b i2v use firt_frame_mask to mask timesteps
|
| 744 |
+
latents, condition, traj_latents, first_frame_mask = latents_outputs
|
| 745 |
+
else:
|
| 746 |
+
latents, condition = latents_outputs
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
# 6. Denoising loop
|
| 750 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 751 |
+
self._num_timesteps = len(timesteps)
|
| 752 |
+
|
| 753 |
+
if self.config.boundary_ratio is not None:
|
| 754 |
+
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
|
| 755 |
+
else:
|
| 756 |
+
boundary_timestep = None
|
| 757 |
+
|
| 758 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 759 |
+
for i, t in enumerate(timesteps):
|
| 760 |
+
if self.interrupt:
|
| 761 |
+
continue
|
| 762 |
+
|
| 763 |
+
self._current_timestep = t
|
| 764 |
+
|
| 765 |
+
if boundary_timestep is None or t >= boundary_timestep:
|
| 766 |
+
# wan2.1 or high-noise stage in wan2.2
|
| 767 |
+
current_model = self.transformer
|
| 768 |
+
current_guidance_scale = guidance_scale
|
| 769 |
+
else:
|
| 770 |
+
# low-noise stage in wan2.2
|
| 771 |
+
current_model = self.transformer_2
|
| 772 |
+
current_guidance_scale = guidance_scale_2
|
| 773 |
+
|
| 774 |
+
if self.config.expand_timesteps:
|
| 775 |
+
latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents
|
| 776 |
+
latent_model_input = latent_model_input.to(transformer_dtype)
|
| 777 |
+
|
| 778 |
+
# seq_len: num_latent_frames * (latent_height // patch_size) * (latent_width // patch_size)
|
| 779 |
+
temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
|
| 780 |
+
# batch_size, seq_len
|
| 781 |
+
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
|
| 782 |
+
else:
|
| 783 |
+
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
|
| 784 |
+
timestep = t.expand(latents.shape[0])
|
| 785 |
+
|
| 786 |
+
|
| 787 |
+
# Concat the traj latents in channel dimension
|
| 788 |
+
latent_model_input = torch.cat([latent_model_input, traj_latents], dim=1).to(transformer_dtype)
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
# Predict the noise according to the timestep
|
| 792 |
+
with current_model.cache_context("cond"):
|
| 793 |
+
noise_pred = current_model(
|
| 794 |
+
hidden_states=latent_model_input,
|
| 795 |
+
timestep=timestep,
|
| 796 |
+
encoder_hidden_states=prompt_embeds,
|
| 797 |
+
encoder_hidden_states_image=image_embeds,
|
| 798 |
+
attention_kwargs=attention_kwargs,
|
| 799 |
+
return_dict=False,
|
| 800 |
+
)[0]
|
| 801 |
+
|
| 802 |
+
if self.do_classifier_free_guidance:
|
| 803 |
+
with current_model.cache_context("uncond"):
|
| 804 |
+
noise_uncond = current_model(
|
| 805 |
+
hidden_states=latent_model_input,
|
| 806 |
+
timestep=timestep,
|
| 807 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 808 |
+
encoder_hidden_states_image=image_embeds,
|
| 809 |
+
attention_kwargs=attention_kwargs,
|
| 810 |
+
return_dict=False,
|
| 811 |
+
)[0]
|
| 812 |
+
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
|
| 813 |
+
|
| 814 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 815 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 816 |
+
|
| 817 |
+
if callback_on_step_end is not None:
|
| 818 |
+
callback_kwargs = {}
|
| 819 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 820 |
+
callback_kwargs[k] = locals()[k]
|
| 821 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 822 |
+
|
| 823 |
+
latents = callback_outputs.pop("latents", latents)
|
| 824 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 825 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 826 |
+
|
| 827 |
+
# call the callback, if provided
|
| 828 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 829 |
+
progress_bar.update()
|
| 830 |
+
|
| 831 |
+
if XLA_AVAILABLE:
|
| 832 |
+
xm.mark_step()
|
| 833 |
+
|
| 834 |
+
self._current_timestep = None
|
| 835 |
+
|
| 836 |
+
if self.config.expand_timesteps:
|
| 837 |
+
latents = (1 - first_frame_mask) * condition + first_frame_mask * latents
|
| 838 |
+
|
| 839 |
+
if not output_type == "latent":
|
| 840 |
+
latents = latents.to(self.vae.dtype)
|
| 841 |
+
latents_mean = (
|
| 842 |
+
torch.tensor(self.vae.config.latents_mean)
|
| 843 |
+
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
| 844 |
+
.to(latents.device, latents.dtype)
|
| 845 |
+
)
|
| 846 |
+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
| 847 |
+
latents.device, latents.dtype
|
| 848 |
+
)
|
| 849 |
+
latents = latents / latents_std + latents_mean
|
| 850 |
+
video = self.vae.decode(latents, return_dict=False)[0]
|
| 851 |
+
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
| 852 |
+
else:
|
| 853 |
+
video = latents
|
| 854 |
+
|
| 855 |
+
# Offload all models
|
| 856 |
+
self.maybe_free_model_hooks()
|
| 857 |
+
|
| 858 |
+
if not return_dict:
|
| 859 |
+
return (video,)
|
| 860 |
+
|
| 861 |
+
return WanPipelineOutput(frames=video)
|
pipelines/pipeline_wan_i2v_motion_FrameINO.py
ADDED
|
@@ -0,0 +1,937 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import html
|
| 16 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import os, sys, shutil
|
| 19 |
+
import PIL
|
| 20 |
+
import regex as re
|
| 21 |
+
import torch
|
| 22 |
+
from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel
|
| 23 |
+
|
| 24 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 25 |
+
from diffusers.image_processor import PipelineImageInput
|
| 26 |
+
from diffusers.loaders import WanLoraLoaderMixin
|
| 27 |
+
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
| 28 |
+
from diffusers.utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
|
| 29 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 30 |
+
from diffusers.video_processor import VideoProcessor
|
| 31 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 32 |
+
from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# Import files from the local folder
|
| 36 |
+
root_path = os.path.abspath('.')
|
| 37 |
+
sys.path.append(root_path)
|
| 38 |
+
from architecture.transformer_wan import WanTransformer3DModel
|
| 39 |
+
from architecture.autoencoder_kl_wan import AutoencoderKLWan
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
if is_torch_xla_available():
|
| 43 |
+
import torch_xla.core.xla_model as xm
|
| 44 |
+
|
| 45 |
+
XLA_AVAILABLE = True
|
| 46 |
+
else:
|
| 47 |
+
XLA_AVAILABLE = False
|
| 48 |
+
|
| 49 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 50 |
+
|
| 51 |
+
if is_ftfy_available():
|
| 52 |
+
import ftfy
|
| 53 |
+
|
| 54 |
+
EXAMPLE_DOC_STRING = """
|
| 55 |
+
Examples:
|
| 56 |
+
```python
|
| 57 |
+
>>> import torch
|
| 58 |
+
>>> import numpy as np
|
| 59 |
+
>>> from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
|
| 60 |
+
>>> from diffusers.utils import export_to_video, load_image
|
| 61 |
+
>>> from transformers import CLIPVisionModel
|
| 62 |
+
|
| 63 |
+
>>> # Available models: Wan-AI/Wan2.1-I2V-14B-480P-Diffusers, Wan-AI/Wan2.1-I2V-14B-720P-Diffusers
|
| 64 |
+
>>> model_id = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
|
| 65 |
+
>>> image_encoder = CLIPVisionModel.from_pretrained(
|
| 66 |
+
... model_id, subfolder="image_encoder", torch_dtype=torch.float32
|
| 67 |
+
... )
|
| 68 |
+
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
| 69 |
+
>>> pipe = WanImageToVideoPipeline.from_pretrained(
|
| 70 |
+
... model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
|
| 71 |
+
... )
|
| 72 |
+
>>> pipe.to("cuda")
|
| 73 |
+
|
| 74 |
+
>>> image = load_image(
|
| 75 |
+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
|
| 76 |
+
... )
|
| 77 |
+
>>> max_area = 480 * 832
|
| 78 |
+
>>> aspect_ratio = image.height / image.width
|
| 79 |
+
>>> mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
| 80 |
+
>>> height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
| 81 |
+
>>> width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
| 82 |
+
>>> image = image.resize((width, height))
|
| 83 |
+
>>> prompt = (
|
| 84 |
+
... "An astronaut hatching from an egg, on the surface of the moon, the darkness and depth of space realised in "
|
| 85 |
+
... "the background. High quality, ultrarealistic detail and breath-taking movie-like camera shot."
|
| 86 |
+
... )
|
| 87 |
+
>>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
|
| 88 |
+
|
| 89 |
+
>>> output = pipe(
|
| 90 |
+
... image=image,
|
| 91 |
+
... prompt=prompt,
|
| 92 |
+
... negative_prompt=negative_prompt,
|
| 93 |
+
... height=height,
|
| 94 |
+
... width=width,
|
| 95 |
+
... num_frames=81,
|
| 96 |
+
... guidance_scale=5.0,
|
| 97 |
+
... ).frames[0]
|
| 98 |
+
>>> export_to_video(output, "output.mp4", fps=16)
|
| 99 |
+
```
|
| 100 |
+
"""
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def basic_clean(text):
|
| 104 |
+
text = ftfy.fix_text(text)
|
| 105 |
+
text = html.unescape(html.unescape(text))
|
| 106 |
+
return text.strip()
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def whitespace_clean(text):
|
| 110 |
+
text = re.sub(r"\s+", " ", text)
|
| 111 |
+
text = text.strip()
|
| 112 |
+
return text
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def prompt_clean(text):
|
| 116 |
+
text = whitespace_clean(basic_clean(text))
|
| 117 |
+
return text
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
| 121 |
+
def retrieve_latents(
|
| 122 |
+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
| 123 |
+
):
|
| 124 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 125 |
+
return encoder_output.latent_dist.sample(generator)
|
| 126 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 127 |
+
return encoder_output.latent_dist.mode()
|
| 128 |
+
elif hasattr(encoder_output, "latents"):
|
| 129 |
+
return encoder_output.latents
|
| 130 |
+
else:
|
| 131 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
|
| 135 |
+
r"""
|
| 136 |
+
Pipeline for image-to-video generation using Wan.
|
| 137 |
+
|
| 138 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 139 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
tokenizer ([`T5Tokenizer`]):
|
| 143 |
+
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
|
| 144 |
+
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
| 145 |
+
text_encoder ([`T5EncoderModel`]):
|
| 146 |
+
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
| 147 |
+
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
|
| 148 |
+
image_encoder ([`CLIPVisionModel`]):
|
| 149 |
+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPVisionModel), specifically
|
| 150 |
+
the
|
| 151 |
+
[clip-vit-huge-patch14](https://github.com/mlfoundations/open_clip/blob/main/docs/PRETRAINED.md#vit-h14-xlm-roberta-large)
|
| 152 |
+
variant.
|
| 153 |
+
transformer ([`WanTransformer3DModel`]):
|
| 154 |
+
Conditional Transformer to denoise the input latents.
|
| 155 |
+
scheduler ([`UniPCMultistepScheduler`]):
|
| 156 |
+
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
| 157 |
+
vae ([`AutoencoderKLWan`]):
|
| 158 |
+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
| 159 |
+
transformer_2 ([`WanTransformer3DModel`], *optional*):
|
| 160 |
+
Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
|
| 161 |
+
`transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
|
| 162 |
+
`transformer` is used.
|
| 163 |
+
boundary_ratio (`float`, *optional*, defaults to `None`):
|
| 164 |
+
Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
|
| 165 |
+
The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
|
| 166 |
+
`transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
|
| 167 |
+
boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae"
|
| 171 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
| 172 |
+
_optional_components = ["transformer", "transformer_2", "image_encoder", "image_processor"]
|
| 173 |
+
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
tokenizer: AutoTokenizer,
|
| 177 |
+
text_encoder: UMT5EncoderModel,
|
| 178 |
+
vae: AutoencoderKLWan,
|
| 179 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 180 |
+
image_processor: CLIPImageProcessor = None,
|
| 181 |
+
image_encoder: CLIPVisionModel = None,
|
| 182 |
+
transformer: WanTransformer3DModel = None,
|
| 183 |
+
transformer_2: WanTransformer3DModel = None,
|
| 184 |
+
boundary_ratio: Optional[float] = None,
|
| 185 |
+
expand_timesteps: bool = False,
|
| 186 |
+
):
|
| 187 |
+
super().__init__()
|
| 188 |
+
|
| 189 |
+
self.register_modules(
|
| 190 |
+
vae=vae,
|
| 191 |
+
text_encoder=text_encoder,
|
| 192 |
+
tokenizer=tokenizer,
|
| 193 |
+
image_encoder=image_encoder,
|
| 194 |
+
transformer=transformer,
|
| 195 |
+
scheduler=scheduler,
|
| 196 |
+
image_processor=image_processor,
|
| 197 |
+
transformer_2=transformer_2,
|
| 198 |
+
)
|
| 199 |
+
self.register_to_config(boundary_ratio=boundary_ratio, expand_timesteps=expand_timesteps)
|
| 200 |
+
|
| 201 |
+
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
|
| 202 |
+
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
|
| 203 |
+
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
| 204 |
+
self.image_processor = image_processor
|
| 205 |
+
|
| 206 |
+
def _get_t5_prompt_embeds(
|
| 207 |
+
self,
|
| 208 |
+
prompt: Union[str, List[str]] = None,
|
| 209 |
+
num_videos_per_prompt: int = 1,
|
| 210 |
+
max_sequence_length: int = 512,
|
| 211 |
+
device: Optional[torch.device] = None,
|
| 212 |
+
dtype: Optional[torch.dtype] = None,
|
| 213 |
+
):
|
| 214 |
+
device = device or self._execution_device
|
| 215 |
+
dtype = dtype or self.text_encoder.dtype
|
| 216 |
+
|
| 217 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 218 |
+
prompt = [prompt_clean(u) for u in prompt]
|
| 219 |
+
batch_size = len(prompt)
|
| 220 |
+
|
| 221 |
+
text_inputs = self.tokenizer(
|
| 222 |
+
prompt,
|
| 223 |
+
padding="max_length",
|
| 224 |
+
max_length=max_sequence_length,
|
| 225 |
+
truncation=True,
|
| 226 |
+
add_special_tokens=True,
|
| 227 |
+
return_attention_mask=True,
|
| 228 |
+
return_tensors="pt",
|
| 229 |
+
)
|
| 230 |
+
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
|
| 231 |
+
seq_lens = mask.gt(0).sum(dim=1).long()
|
| 232 |
+
|
| 233 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
|
| 234 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
| 235 |
+
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
|
| 236 |
+
prompt_embeds = torch.stack(
|
| 237 |
+
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 241 |
+
_, seq_len, _ = prompt_embeds.shape
|
| 242 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 243 |
+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
|
| 244 |
+
|
| 245 |
+
return prompt_embeds
|
| 246 |
+
|
| 247 |
+
def encode_image(
|
| 248 |
+
self,
|
| 249 |
+
image: PipelineImageInput,
|
| 250 |
+
device: Optional[torch.device] = None,
|
| 251 |
+
):
|
| 252 |
+
device = device or self._execution_device
|
| 253 |
+
image = self.image_processor(images=image, return_tensors="pt").to(device)
|
| 254 |
+
image_embeds = self.image_encoder(**image, output_hidden_states=True)
|
| 255 |
+
return image_embeds.hidden_states[-2]
|
| 256 |
+
|
| 257 |
+
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.encode_prompt
|
| 258 |
+
def encode_prompt(
|
| 259 |
+
self,
|
| 260 |
+
prompt: Union[str, List[str]],
|
| 261 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 262 |
+
do_classifier_free_guidance: bool = True,
|
| 263 |
+
num_videos_per_prompt: int = 1,
|
| 264 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 265 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 266 |
+
max_sequence_length: int = 226,
|
| 267 |
+
device: Optional[torch.device] = None,
|
| 268 |
+
dtype: Optional[torch.dtype] = None,
|
| 269 |
+
):
|
| 270 |
+
r"""
|
| 271 |
+
Encodes the prompt into text encoder hidden states.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 275 |
+
prompt to be encoded
|
| 276 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 277 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 278 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 279 |
+
less than `1`).
|
| 280 |
+
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
| 281 |
+
Whether to use classifier free guidance or not.
|
| 282 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 283 |
+
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
|
| 284 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 285 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 286 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 287 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 288 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 289 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 290 |
+
argument.
|
| 291 |
+
device: (`torch.device`, *optional*):
|
| 292 |
+
torch device
|
| 293 |
+
dtype: (`torch.dtype`, *optional*):
|
| 294 |
+
torch dtype
|
| 295 |
+
"""
|
| 296 |
+
device = device or self._execution_device
|
| 297 |
+
|
| 298 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 299 |
+
if prompt is not None:
|
| 300 |
+
batch_size = len(prompt)
|
| 301 |
+
else:
|
| 302 |
+
batch_size = prompt_embeds.shape[0]
|
| 303 |
+
|
| 304 |
+
if prompt_embeds is None:
|
| 305 |
+
prompt_embeds = self._get_t5_prompt_embeds(
|
| 306 |
+
prompt=prompt,
|
| 307 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 308 |
+
max_sequence_length=max_sequence_length,
|
| 309 |
+
device=device,
|
| 310 |
+
dtype=dtype,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 314 |
+
negative_prompt = negative_prompt or ""
|
| 315 |
+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 316 |
+
|
| 317 |
+
if prompt is not None and type(prompt) is not type(negative_prompt):
|
| 318 |
+
raise TypeError(
|
| 319 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 320 |
+
f" {type(prompt)}."
|
| 321 |
+
)
|
| 322 |
+
elif batch_size != len(negative_prompt):
|
| 323 |
+
raise ValueError(
|
| 324 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 325 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 326 |
+
" the batch size of `prompt`."
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 330 |
+
prompt=negative_prompt,
|
| 331 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 332 |
+
max_sequence_length=max_sequence_length,
|
| 333 |
+
device=device,
|
| 334 |
+
dtype=dtype,
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
return prompt_embeds, negative_prompt_embeds
|
| 338 |
+
|
| 339 |
+
def check_inputs(
|
| 340 |
+
self,
|
| 341 |
+
prompt,
|
| 342 |
+
negative_prompt,
|
| 343 |
+
image,
|
| 344 |
+
height,
|
| 345 |
+
width,
|
| 346 |
+
prompt_embeds=None,
|
| 347 |
+
negative_prompt_embeds=None,
|
| 348 |
+
image_embeds=None,
|
| 349 |
+
callback_on_step_end_tensor_inputs=None,
|
| 350 |
+
guidance_scale_2=None,
|
| 351 |
+
):
|
| 352 |
+
if image is not None and image_embeds is not None:
|
| 353 |
+
raise ValueError(
|
| 354 |
+
f"Cannot forward both `image`: {image} and `image_embeds`: {image_embeds}. Please make sure to"
|
| 355 |
+
" only forward one of the two."
|
| 356 |
+
)
|
| 357 |
+
if image is None and image_embeds is None:
|
| 358 |
+
raise ValueError(
|
| 359 |
+
"Provide either `image` or `prompt_embeds`. Cannot leave both `image` and `image_embeds` undefined."
|
| 360 |
+
)
|
| 361 |
+
if image is not None and not isinstance(image, torch.Tensor) and not isinstance(image, PIL.Image.Image):
|
| 362 |
+
raise ValueError(f"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` but is {type(image)}")
|
| 363 |
+
if height % 16 != 0 or width % 16 != 0:
|
| 364 |
+
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
| 365 |
+
|
| 366 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 367 |
+
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
| 368 |
+
):
|
| 369 |
+
raise ValueError(
|
| 370 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
if prompt is not None and prompt_embeds is not None:
|
| 374 |
+
raise ValueError(
|
| 375 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 376 |
+
" only forward one of the two."
|
| 377 |
+
)
|
| 378 |
+
elif negative_prompt is not None and negative_prompt_embeds is not None:
|
| 379 |
+
raise ValueError(
|
| 380 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
|
| 381 |
+
" only forward one of the two."
|
| 382 |
+
)
|
| 383 |
+
elif prompt is None and prompt_embeds is None:
|
| 384 |
+
raise ValueError(
|
| 385 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 386 |
+
)
|
| 387 |
+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
| 388 |
+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
| 389 |
+
elif negative_prompt is not None and (
|
| 390 |
+
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
|
| 391 |
+
):
|
| 392 |
+
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
|
| 393 |
+
|
| 394 |
+
if self.config.boundary_ratio is None and guidance_scale_2 is not None:
|
| 395 |
+
raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
|
| 396 |
+
|
| 397 |
+
if self.config.boundary_ratio is not None and image_embeds is not None:
|
| 398 |
+
raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.")
|
| 399 |
+
|
| 400 |
+
def prepare_latents(
|
| 401 |
+
self,
|
| 402 |
+
image: PipelineImageInput,
|
| 403 |
+
traj_tensor,
|
| 404 |
+
ID_tensor,
|
| 405 |
+
batch_size: int,
|
| 406 |
+
num_channels_latents: int = 16,
|
| 407 |
+
height: int = 480,
|
| 408 |
+
width: int = 832,
|
| 409 |
+
num_frames: int = 81,
|
| 410 |
+
dtype: Optional[torch.dtype] = None,
|
| 411 |
+
device: Optional[torch.device] = None,
|
| 412 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 413 |
+
latents: Optional[torch.Tensor] = None,
|
| 414 |
+
last_image: Optional[torch.Tensor] = None,
|
| 415 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 416 |
+
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
| 417 |
+
latent_height = height // self.vae_scale_factor_spatial
|
| 418 |
+
latent_width = width // self.vae_scale_factor_spatial
|
| 419 |
+
|
| 420 |
+
shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width)
|
| 421 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 422 |
+
raise ValueError(
|
| 423 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 424 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
if latents is None:
|
| 428 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 429 |
+
else:
|
| 430 |
+
latents = latents.to(device=device, dtype=dtype)
|
| 431 |
+
|
| 432 |
+
image = image.unsqueeze(2) # [batch_size, channels, 1, height, width]
|
| 433 |
+
|
| 434 |
+
if self.config.expand_timesteps:
|
| 435 |
+
video_condition = image
|
| 436 |
+
|
| 437 |
+
elif last_image is None:
|
| 438 |
+
video_condition = torch.cat(
|
| 439 |
+
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
|
| 440 |
+
)
|
| 441 |
+
else:
|
| 442 |
+
last_image = last_image.unsqueeze(2)
|
| 443 |
+
video_condition = torch.cat(
|
| 444 |
+
[image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image],
|
| 445 |
+
dim=2,
|
| 446 |
+
)
|
| 447 |
+
video_condition = video_condition.to(device=device, dtype=self.vae.dtype)
|
| 448 |
+
|
| 449 |
+
latents_mean = (
|
| 450 |
+
torch.tensor(self.vae.config.latents_mean)
|
| 451 |
+
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
| 452 |
+
.to(latents.device, latents.dtype)
|
| 453 |
+
)
|
| 454 |
+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
| 455 |
+
latents.device, latents.dtype
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
if isinstance(generator, list):
|
| 459 |
+
latent_condition = [
|
| 460 |
+
retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
|
| 461 |
+
]
|
| 462 |
+
latent_condition = torch.cat(latent_condition)
|
| 463 |
+
else:
|
| 464 |
+
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
|
| 465 |
+
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
|
| 466 |
+
|
| 467 |
+
latent_condition = latent_condition.to(dtype)
|
| 468 |
+
latent_condition = (latent_condition - latents_mean) * latents_std
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
# Prepare the traj latent
|
| 473 |
+
traj_tensor = traj_tensor.to(device, dtype=self.vae.dtype) #.unsqueeze(0)
|
| 474 |
+
traj_tensor = traj_tensor.unsqueeze(0)
|
| 475 |
+
traj_tensor = traj_tensor.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
| 476 |
+
|
| 477 |
+
# VAE Encode
|
| 478 |
+
traj_latents = retrieve_latents(self.vae.encode(traj_tensor), sample_mode="argmax")
|
| 479 |
+
|
| 480 |
+
# Extract Mean and Variance
|
| 481 |
+
traj_latents = (traj_latents - latents_mean) * latents_std
|
| 482 |
+
|
| 483 |
+
# Final Convert
|
| 484 |
+
traj_latents = traj_latents.to(memory_format = torch.contiguous_format).float()
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
# Prepare the ID latents
|
| 489 |
+
if ID_tensor.shape[2] != 0: # Must have at least one ID frame, could be empty sometime
|
| 490 |
+
|
| 491 |
+
# Tranform
|
| 492 |
+
ID_tensor = ID_tensor.to(device=device, dtype=self.vae.dtype)
|
| 493 |
+
|
| 494 |
+
# VAE encode for each frame One by One
|
| 495 |
+
ID_latents = []
|
| 496 |
+
for frame_idx in range(ID_tensor.shape[2]):
|
| 497 |
+
|
| 498 |
+
# Fetch
|
| 499 |
+
ID_tensor = ID_tensor[:, :, frame_idx].unsqueeze(2)
|
| 500 |
+
|
| 501 |
+
# Single Frame Encode, which will be single frame token
|
| 502 |
+
ID_latent = retrieve_latents(self.vae.encode(ID_tensor), sample_mode="argmax")
|
| 503 |
+
ID_latent = ID_latent.repeat(batch_size, 1, 1, 1, 1)
|
| 504 |
+
|
| 505 |
+
# Convert
|
| 506 |
+
ID_latent = ID_latent.to(dtype)
|
| 507 |
+
ID_latent = (ID_latent - latents_mean) * latents_std
|
| 508 |
+
|
| 509 |
+
# Append
|
| 510 |
+
ID_latents.append(ID_latent)
|
| 511 |
+
|
| 512 |
+
# Final Convert
|
| 513 |
+
ID_latent_condition = torch.cat(ID_latents, dim = 2)
|
| 514 |
+
|
| 515 |
+
# Add padding to the traj latents
|
| 516 |
+
ID_latent_padding = torch.zeros_like(ID_latent_condition)
|
| 517 |
+
traj_latents = torch.cat([traj_latents, ID_latent_padding], dim=2)
|
| 518 |
+
|
| 519 |
+
# Update the number of latents frames for the first frame mask
|
| 520 |
+
# num_latent_frames = num_latent_frames + len(ID_latents)
|
| 521 |
+
|
| 522 |
+
else:
|
| 523 |
+
# Return an empty one
|
| 524 |
+
ID_latent_condition = None
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
if self.config.expand_timesteps: # For Wan2.2
|
| 529 |
+
first_frame_mask = torch.ones(
|
| 530 |
+
1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device
|
| 531 |
+
)
|
| 532 |
+
first_frame_mask[:, :, 0] = 0
|
| 533 |
+
|
| 534 |
+
# Return all condition information needed
|
| 535 |
+
return latents, latent_condition, traj_latents, ID_latent_condition, first_frame_mask
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
# The rest if for Wan2.1
|
| 540 |
+
mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
|
| 541 |
+
|
| 542 |
+
if last_image is None:
|
| 543 |
+
mask_lat_size[:, :, list(range(1, num_frames))] = 0
|
| 544 |
+
else:
|
| 545 |
+
mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0
|
| 546 |
+
first_frame_mask = mask_lat_size[:, :, 0:1]
|
| 547 |
+
first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal)
|
| 548 |
+
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
|
| 549 |
+
mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width)
|
| 550 |
+
mask_lat_size = mask_lat_size.transpose(1, 2)
|
| 551 |
+
mask_lat_size = mask_lat_size.to(latent_condition.device)
|
| 552 |
+
|
| 553 |
+
return latents, torch.concat([mask_lat_size, latent_condition], dim=1)
|
| 554 |
+
|
| 555 |
+
@property
|
| 556 |
+
def guidance_scale(self):
|
| 557 |
+
return self._guidance_scale
|
| 558 |
+
|
| 559 |
+
@property
|
| 560 |
+
def do_classifier_free_guidance(self):
|
| 561 |
+
return self._guidance_scale > 1
|
| 562 |
+
|
| 563 |
+
@property
|
| 564 |
+
def num_timesteps(self):
|
| 565 |
+
return self._num_timesteps
|
| 566 |
+
|
| 567 |
+
@property
|
| 568 |
+
def current_timestep(self):
|
| 569 |
+
return self._current_timestep
|
| 570 |
+
|
| 571 |
+
@property
|
| 572 |
+
def interrupt(self):
|
| 573 |
+
return self._interrupt
|
| 574 |
+
|
| 575 |
+
@property
|
| 576 |
+
def attention_kwargs(self):
|
| 577 |
+
return self._attention_kwargs
|
| 578 |
+
|
| 579 |
+
@torch.no_grad()
|
| 580 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 581 |
+
def __call__(
|
| 582 |
+
self,
|
| 583 |
+
image: PipelineImageInput,
|
| 584 |
+
prompt: Union[str, List[str]] = None,
|
| 585 |
+
negative_prompt: Union[str, List[str]] = None,
|
| 586 |
+
traj_tensor = None,
|
| 587 |
+
ID_tensor = None,
|
| 588 |
+
height: int = 480,
|
| 589 |
+
width: int = 832,
|
| 590 |
+
num_frames: int = 81,
|
| 591 |
+
num_inference_steps: int = 50,
|
| 592 |
+
guidance_scale: float = 5.0,
|
| 593 |
+
guidance_scale_2: Optional[float] = None,
|
| 594 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 595 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 596 |
+
latents: Optional[torch.Tensor] = None,
|
| 597 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 598 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 599 |
+
image_embeds: Optional[torch.Tensor] = None,
|
| 600 |
+
last_image: Optional[torch.Tensor] = None,
|
| 601 |
+
output_type: Optional[str] = "np",
|
| 602 |
+
return_dict: bool = True,
|
| 603 |
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 604 |
+
callback_on_step_end: Optional[
|
| 605 |
+
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
| 606 |
+
] = None,
|
| 607 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 608 |
+
max_sequence_length: int = 512,
|
| 609 |
+
):
|
| 610 |
+
r"""
|
| 611 |
+
The call function to the pipeline for generation.
|
| 612 |
+
|
| 613 |
+
Args:
|
| 614 |
+
image (`PipelineImageInput`):
|
| 615 |
+
The input image to condition the generation on. Must be an image, a list of images or a `torch.Tensor`.
|
| 616 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 617 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 618 |
+
instead.
|
| 619 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 620 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 621 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 622 |
+
less than `1`).
|
| 623 |
+
height (`int`, defaults to `480`):
|
| 624 |
+
The height of the generated video.
|
| 625 |
+
width (`int`, defaults to `832`):
|
| 626 |
+
The width of the generated video.
|
| 627 |
+
num_frames (`int`, defaults to `81`):
|
| 628 |
+
The number of frames in the generated video.
|
| 629 |
+
num_inference_steps (`int`, defaults to `50`):
|
| 630 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 631 |
+
expense of slower inference.
|
| 632 |
+
guidance_scale (`float`, defaults to `5.0`):
|
| 633 |
+
Guidance scale as defined in [Classifier-Free Diffusion
|
| 634 |
+
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
| 635 |
+
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
| 636 |
+
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
| 637 |
+
the text `prompt`, usually at the expense of lower image quality.
|
| 638 |
+
guidance_scale_2 (`float`, *optional*, defaults to `None`):
|
| 639 |
+
Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
|
| 640 |
+
`boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
|
| 641 |
+
and the pipeline's `boundary_ratio` are not None.
|
| 642 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 643 |
+
The number of images to generate per prompt.
|
| 644 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 645 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 646 |
+
generation deterministic.
|
| 647 |
+
latents (`torch.Tensor`, *optional*):
|
| 648 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 649 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 650 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 651 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 652 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 653 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 654 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 655 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 656 |
+
provided, text embeddings are generated from the `negative_prompt` input argument.
|
| 657 |
+
image_embeds (`torch.Tensor`, *optional*):
|
| 658 |
+
Pre-generated image embeddings. Can be used to easily tweak image inputs (weighting). If not provided,
|
| 659 |
+
image embeddings are generated from the `image` input argument.
|
| 660 |
+
output_type (`str`, *optional*, defaults to `"np"`):
|
| 661 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 662 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 663 |
+
Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
|
| 664 |
+
attention_kwargs (`dict`, *optional*):
|
| 665 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 666 |
+
`self.processor` in
|
| 667 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 668 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
| 669 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
| 670 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
| 671 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
| 672 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
| 673 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 674 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 675 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 676 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 677 |
+
max_sequence_length (`int`, defaults to `512`):
|
| 678 |
+
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
|
| 679 |
+
truncated. If the prompt is shorter, it will be padded to this length.
|
| 680 |
+
|
| 681 |
+
Examples:
|
| 682 |
+
|
| 683 |
+
Returns:
|
| 684 |
+
[`~WanPipelineOutput`] or `tuple`:
|
| 685 |
+
If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
|
| 686 |
+
the first element is a list with the generated images and the second element is a list of `bool`s
|
| 687 |
+
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
|
| 688 |
+
"""
|
| 689 |
+
|
| 690 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 691 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 692 |
+
|
| 693 |
+
# 1. Check inputs. Raise error if not correct
|
| 694 |
+
self.check_inputs(
|
| 695 |
+
prompt,
|
| 696 |
+
negative_prompt,
|
| 697 |
+
image,
|
| 698 |
+
height,
|
| 699 |
+
width,
|
| 700 |
+
prompt_embeds,
|
| 701 |
+
negative_prompt_embeds,
|
| 702 |
+
image_embeds,
|
| 703 |
+
callback_on_step_end_tensor_inputs,
|
| 704 |
+
guidance_scale_2,
|
| 705 |
+
)
|
| 706 |
+
|
| 707 |
+
if num_frames % self.vae_scale_factor_temporal != 1:
|
| 708 |
+
logger.warning(
|
| 709 |
+
f"`num_frames - 1` has to be divisible by {self.vae_scale_factor_temporal}. Rounding to the nearest number."
|
| 710 |
+
)
|
| 711 |
+
num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
|
| 712 |
+
num_frames = max(num_frames, 1)
|
| 713 |
+
|
| 714 |
+
if self.config.boundary_ratio is not None and guidance_scale_2 is None:
|
| 715 |
+
guidance_scale_2 = guidance_scale
|
| 716 |
+
|
| 717 |
+
self._guidance_scale = guidance_scale
|
| 718 |
+
self._guidance_scale_2 = guidance_scale_2
|
| 719 |
+
self._attention_kwargs = attention_kwargs
|
| 720 |
+
self._current_timestep = None
|
| 721 |
+
self._interrupt = False
|
| 722 |
+
|
| 723 |
+
device = self._execution_device
|
| 724 |
+
|
| 725 |
+
# 2. Define call parameters
|
| 726 |
+
if prompt is not None and isinstance(prompt, str):
|
| 727 |
+
batch_size = 1
|
| 728 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 729 |
+
batch_size = len(prompt)
|
| 730 |
+
else:
|
| 731 |
+
batch_size = prompt_embeds.shape[0]
|
| 732 |
+
|
| 733 |
+
# 3. Encode input prompt
|
| 734 |
+
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
| 735 |
+
prompt=prompt,
|
| 736 |
+
negative_prompt=negative_prompt,
|
| 737 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 738 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 739 |
+
prompt_embeds=prompt_embeds,
|
| 740 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 741 |
+
max_sequence_length=max_sequence_length,
|
| 742 |
+
device=device,
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
# Encode image embedding
|
| 746 |
+
transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
|
| 747 |
+
prompt_embeds = prompt_embeds.to(transformer_dtype)
|
| 748 |
+
if negative_prompt_embeds is not None:
|
| 749 |
+
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
|
| 750 |
+
|
| 751 |
+
# only wan 2.1 i2v transformer accepts image_embeds
|
| 752 |
+
if self.transformer is not None and self.transformer.config.image_dim is not None:
|
| 753 |
+
if image_embeds is None:
|
| 754 |
+
if last_image is None:
|
| 755 |
+
image_embeds = self.encode_image(image, device)
|
| 756 |
+
else:
|
| 757 |
+
image_embeds = self.encode_image([image, last_image], device)
|
| 758 |
+
image_embeds = image_embeds.repeat(batch_size, 1, 1)
|
| 759 |
+
image_embeds = image_embeds.to(transformer_dtype)
|
| 760 |
+
|
| 761 |
+
# 4. Prepare timesteps
|
| 762 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 763 |
+
timesteps = self.scheduler.timesteps
|
| 764 |
+
|
| 765 |
+
# 5. Prepare latent variables
|
| 766 |
+
num_channels_latents = self.vae.config.z_dim
|
| 767 |
+
image = self.video_processor.preprocess(image, height=height, width=width).to(device, dtype=torch.float32)
|
| 768 |
+
if last_image is not None:
|
| 769 |
+
last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
|
| 770 |
+
device, dtype=torch.float32
|
| 771 |
+
)
|
| 772 |
+
|
| 773 |
+
latents_outputs = self.prepare_latents(
|
| 774 |
+
image,
|
| 775 |
+
traj_tensor,
|
| 776 |
+
ID_tensor,
|
| 777 |
+
batch_size * num_videos_per_prompt,
|
| 778 |
+
num_channels_latents,
|
| 779 |
+
height,
|
| 780 |
+
width,
|
| 781 |
+
num_frames,
|
| 782 |
+
torch.float32,
|
| 783 |
+
device,
|
| 784 |
+
generator,
|
| 785 |
+
latents,
|
| 786 |
+
last_image,
|
| 787 |
+
)
|
| 788 |
+
if self.config.expand_timesteps:
|
| 789 |
+
# wan 2.2 5b i2v use firt_frame_mask to mask timesteps
|
| 790 |
+
latents, condition, traj_latents, ID_latent_condition, first_frame_mask = latents_outputs
|
| 791 |
+
else:
|
| 792 |
+
latents, condition = latents_outputs
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
# 5.5. For ID reference change, we need to add padding for the latents each time
|
| 796 |
+
_, channel_num, num_gen_frames, latent_height, latent_width = latents.shape
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
|
| 800 |
+
# 6. Denoising loop
|
| 801 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 802 |
+
self._num_timesteps = len(timesteps)
|
| 803 |
+
|
| 804 |
+
if self.config.boundary_ratio is not None:
|
| 805 |
+
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
|
| 806 |
+
else:
|
| 807 |
+
boundary_timestep = None
|
| 808 |
+
|
| 809 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 810 |
+
for i, t in enumerate(timesteps):
|
| 811 |
+
if self.interrupt:
|
| 812 |
+
continue
|
| 813 |
+
|
| 814 |
+
self._current_timestep = t
|
| 815 |
+
|
| 816 |
+
if boundary_timestep is None or t >= boundary_timestep:
|
| 817 |
+
# wan2.1 or high-noise stage in wan2.2
|
| 818 |
+
current_model = self.transformer
|
| 819 |
+
current_guidance_scale = guidance_scale
|
| 820 |
+
else:
|
| 821 |
+
# low-noise stage in wan2.2
|
| 822 |
+
current_model = self.transformer_2
|
| 823 |
+
current_guidance_scale = guidance_scale_2
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
if self.config.expand_timesteps:
|
| 827 |
+
|
| 828 |
+
# Multiply with the mask, such that the first frame latent of the model input is the clean latent of the first frame condition (Here, for Frame INO, the first frame should be masked outpainting design)
|
| 829 |
+
latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents # NOTE: 现在first frame应该设定为带masked的first frame(有outpainting的样式的)
|
| 830 |
+
latent_model_input = latent_model_input.to(transformer_dtype)
|
| 831 |
+
|
| 832 |
+
# Add padding for the first_frame_mask here with the length of ID tokens
|
| 833 |
+
if ID_latent_condition is not None:
|
| 834 |
+
mask_padding = torch.ones(
|
| 835 |
+
1, 1, ID_latent_condition.shape[2], latent_height, latent_width, dtype=transformer_dtype, device=device
|
| 836 |
+
)
|
| 837 |
+
first_frame_mask_adjust = torch.cat([first_frame_mask, mask_padding], dim = 2)
|
| 838 |
+
else:
|
| 839 |
+
first_frame_mask_adjust = first_frame_mask
|
| 840 |
+
|
| 841 |
+
# Reshape to num_latent_frames * (latent_height // patch_size) * (latent_width // patch_size)
|
| 842 |
+
temp_ts = (first_frame_mask_adjust[0][0][:, ::2, ::2] * t).flatten()
|
| 843 |
+
timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
|
| 844 |
+
|
| 845 |
+
else:
|
| 846 |
+
|
| 847 |
+
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
|
| 848 |
+
timestep = t.expand(latents.shape[0])
|
| 849 |
+
# TODO: 我现在不是特别确定这里的timestep 跟training的align了吗?
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
# Frame-Wise concatenate ID tokens
|
| 853 |
+
if ID_latent_condition is not None:
|
| 854 |
+
latent_model_input = torch.cat([latent_model_input, ID_latent_condition], dim = 2)
|
| 855 |
+
|
| 856 |
+
|
| 857 |
+
# Concat the trajectory latents in Channel dimension
|
| 858 |
+
latent_model_input = torch.cat([latent_model_input, traj_latents], dim = 1).to(transformer_dtype)
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
# Predict the noise according to the timestep
|
| 862 |
+
with current_model.cache_context("cond"):
|
| 863 |
+
noise_pred = current_model(
|
| 864 |
+
hidden_states = latent_model_input,
|
| 865 |
+
timestep = timestep,
|
| 866 |
+
encoder_hidden_states = prompt_embeds,
|
| 867 |
+
encoder_hidden_states_image = image_embeds,
|
| 868 |
+
attention_kwargs = attention_kwargs,
|
| 869 |
+
return_dict = False,
|
| 870 |
+
)[0]
|
| 871 |
+
|
| 872 |
+
if self.do_classifier_free_guidance:
|
| 873 |
+
with current_model.cache_context("uncond"):
|
| 874 |
+
noise_uncond = current_model(
|
| 875 |
+
hidden_states = latent_model_input,
|
| 876 |
+
timestep = timestep,
|
| 877 |
+
encoder_hidden_states = negative_prompt_embeds,
|
| 878 |
+
encoder_hidden_states_image = image_embeds,
|
| 879 |
+
attention_kwargs = attention_kwargs,
|
| 880 |
+
return_dict = False,
|
| 881 |
+
)[0]
|
| 882 |
+
noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
|
| 883 |
+
|
| 884 |
+
|
| 885 |
+
# Discard the Extra ID tokens in the Noise Prediction
|
| 886 |
+
noise_pred = noise_pred[:, :, :num_gen_frames]
|
| 887 |
+
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 891 |
+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 892 |
+
|
| 893 |
+
if callback_on_step_end is not None:
|
| 894 |
+
callback_kwargs = {}
|
| 895 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 896 |
+
callback_kwargs[k] = locals()[k]
|
| 897 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 898 |
+
|
| 899 |
+
latents = callback_outputs.pop("latents", latents)
|
| 900 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 901 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 902 |
+
|
| 903 |
+
# call the callback, if provided
|
| 904 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 905 |
+
progress_bar.update()
|
| 906 |
+
|
| 907 |
+
if XLA_AVAILABLE:
|
| 908 |
+
xm.mark_step()
|
| 909 |
+
|
| 910 |
+
self._current_timestep = None
|
| 911 |
+
|
| 912 |
+
if self.config.expand_timesteps:
|
| 913 |
+
latents = (1 - first_frame_mask) * condition + first_frame_mask * latents
|
| 914 |
+
|
| 915 |
+
if not output_type == "latent":
|
| 916 |
+
latents = latents.to(self.vae.dtype)
|
| 917 |
+
latents_mean = (
|
| 918 |
+
torch.tensor(self.vae.config.latents_mean)
|
| 919 |
+
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
| 920 |
+
.to(latents.device, latents.dtype)
|
| 921 |
+
)
|
| 922 |
+
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
| 923 |
+
latents.device, latents.dtype
|
| 924 |
+
)
|
| 925 |
+
latents = latents / latents_std + latents_mean
|
| 926 |
+
video = self.vae.decode(latents, return_dict=False)[0]
|
| 927 |
+
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
| 928 |
+
else:
|
| 929 |
+
video = latents
|
| 930 |
+
|
| 931 |
+
# Offload all models
|
| 932 |
+
self.maybe_free_model_hooks()
|
| 933 |
+
|
| 934 |
+
if not return_dict:
|
| 935 |
+
return (video,)
|
| 936 |
+
|
| 937 |
+
return WanPipelineOutput(frames=video)
|
requirements.txt
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pandas
|
| 2 |
+
tqdm
|
| 3 |
+
opencv-python
|
| 4 |
+
pyiqa
|
| 5 |
+
numpy==1.26.0
|
| 6 |
+
ffmpeg-python
|
| 7 |
+
bitsandbytes
|
| 8 |
+
pyarrow
|
| 9 |
+
omegaconf
|
| 10 |
+
peft>=0.15.0
|
| 11 |
+
transformers>=4.56.2 # Install in the newest version
|
| 12 |
+
git+https://github.com/huggingface/diffusers.git
|
| 13 |
+
sentencepiece
|
| 14 |
+
qwen-vl-utils[decord]==0.0.8
|
| 15 |
+
scikit-learn
|
| 16 |
+
matplotlib
|
| 17 |
+
gradio
|
| 18 |
+
imageio-ffmpeg
|
| 19 |
+
bitsandbytes
|
| 20 |
+
git+https://github.com/facebookresearch/segment-anything.git
|
| 21 |
+
git+https://github.com/facebookresearch/sam2.git
|
| 22 |
+
accelerate
|
| 23 |
+
hf-transfer
|
| 24 |
+
|
utils/optical_flow_utils.py
ADDED
|
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def make_colorwheel():
|
| 5 |
+
"""
|
| 6 |
+
Generates a color wheel for optical flow visualization as presented in:
|
| 7 |
+
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
| 8 |
+
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
| 9 |
+
|
| 10 |
+
Code follows the original C++ source code of Daniel Scharstein.
|
| 11 |
+
Code follows the the Matlab source code of Deqing Sun.
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
np.ndarray: Color wheel
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
RY = 15
|
| 18 |
+
YG = 6
|
| 19 |
+
GC = 4
|
| 20 |
+
CB = 11
|
| 21 |
+
BM = 13
|
| 22 |
+
MR = 6
|
| 23 |
+
|
| 24 |
+
ncols = RY + YG + GC + CB + BM + MR
|
| 25 |
+
colorwheel = np.zeros((ncols, 3))
|
| 26 |
+
col = 0
|
| 27 |
+
|
| 28 |
+
# RY
|
| 29 |
+
colorwheel[0:RY, 0] = 255
|
| 30 |
+
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
|
| 31 |
+
col = col+RY
|
| 32 |
+
# YG
|
| 33 |
+
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
|
| 34 |
+
colorwheel[col:col+YG, 1] = 255
|
| 35 |
+
col = col+YG
|
| 36 |
+
# GC
|
| 37 |
+
colorwheel[col:col+GC, 1] = 255
|
| 38 |
+
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
|
| 39 |
+
col = col+GC
|
| 40 |
+
# CB
|
| 41 |
+
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
|
| 42 |
+
colorwheel[col:col+CB, 2] = 255
|
| 43 |
+
col = col+CB
|
| 44 |
+
# BM
|
| 45 |
+
colorwheel[col:col+BM, 2] = 255
|
| 46 |
+
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
|
| 47 |
+
col = col+BM
|
| 48 |
+
# MR
|
| 49 |
+
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
|
| 50 |
+
colorwheel[col:col+MR, 0] = 255
|
| 51 |
+
return colorwheel
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def flow_uv_to_colors(u, v, convert_to_bgr=False):
|
| 55 |
+
"""
|
| 56 |
+
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
| 57 |
+
|
| 58 |
+
According to the C++ source code of Daniel Scharstein
|
| 59 |
+
According to the Matlab source code of Deqing Sun
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
u (np.ndarray): Input horizontal flow of shape [H,W]
|
| 63 |
+
v (np.ndarray): Input vertical flow of shape [H,W]
|
| 64 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
| 65 |
+
|
| 66 |
+
Returns:
|
| 67 |
+
np.ndarray: Flow visualization image of shape [H,W,3] in range [0, 255]
|
| 68 |
+
"""
|
| 69 |
+
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
| 70 |
+
colorwheel = make_colorwheel() # shape [55x3]
|
| 71 |
+
ncols = colorwheel.shape[0]
|
| 72 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
| 73 |
+
a = np.arctan2(-v, -u)/np.pi
|
| 74 |
+
fk = (a+1) / 2*(ncols-1)
|
| 75 |
+
k0 = np.floor(fk).astype(np.int32)
|
| 76 |
+
k1 = k0 + 1
|
| 77 |
+
k1[k1 == ncols] = 0
|
| 78 |
+
f = fk - k0
|
| 79 |
+
for i in range(colorwheel.shape[1]):
|
| 80 |
+
tmp = colorwheel[:,i]
|
| 81 |
+
col0 = tmp[k0] / 255.0
|
| 82 |
+
col1 = tmp[k1] / 255.0
|
| 83 |
+
col = (1-f)*col0 + f*col1
|
| 84 |
+
idx = (rad <= 1)
|
| 85 |
+
col[idx] = 1 - rad[idx] * (1-col[idx])
|
| 86 |
+
col[~idx] = col[~idx] * 0.75 # out of range
|
| 87 |
+
# Note the 2-i => BGR instead of RGB
|
| 88 |
+
ch_idx = 2-i if convert_to_bgr else i
|
| 89 |
+
flow_image[:,:,ch_idx] = np.floor(255 * col)
|
| 90 |
+
return flow_image
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
|
| 94 |
+
"""
|
| 95 |
+
Expects a two dimensional flow image of shape.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
|
| 99 |
+
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
| 100 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
| 104 |
+
"""
|
| 105 |
+
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
| 106 |
+
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
| 107 |
+
|
| 108 |
+
if clip_flow is not None:
|
| 109 |
+
flow_uv = np.clip(flow_uv, 0, clip_flow)
|
| 110 |
+
|
| 111 |
+
u = flow_uv[:,:,0]
|
| 112 |
+
v = flow_uv[:,:,1]
|
| 113 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
| 114 |
+
rad_max = np.max(rad)
|
| 115 |
+
epsilon = 1e-5
|
| 116 |
+
u = u / (rad_max + epsilon)
|
| 117 |
+
v = v / (rad_max + epsilon)
|
| 118 |
+
return flow_uv_to_colors(u, v, convert_to_bgr)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def filter_uv(flow, threshold_factor = 0.1, sample_prob = 1.0):
|
| 123 |
+
'''
|
| 124 |
+
Args:
|
| 125 |
+
flow (numpy): A 2-dim array that stores x and y change in optical flow
|
| 126 |
+
threshold_factor (float): Prob of discarding outliers vector
|
| 127 |
+
sample_prob (float): The selection rate of how much proportion of points we need to store
|
| 128 |
+
'''
|
| 129 |
+
u = flow[:,:,0]
|
| 130 |
+
v = flow[:,:,1]
|
| 131 |
+
|
| 132 |
+
# Filter out those less than the threshold
|
| 133 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
| 134 |
+
rad_max = np.max(rad)
|
| 135 |
+
|
| 136 |
+
threshold = threshold_factor * rad_max
|
| 137 |
+
flow[:,:,0][rad < threshold] = 0
|
| 138 |
+
flow[:,:,1][rad < threshold] = 0
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# Randomly sample based on sample_prob
|
| 142 |
+
zero_prob = 1 - sample_prob
|
| 143 |
+
random_array = np.random.randn(*flow.shape)
|
| 144 |
+
random_array[random_array < zero_prob] = 0
|
| 145 |
+
random_array[random_array >= zero_prob] = 1
|
| 146 |
+
flow = flow * random_array
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
return flow
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
############################################# The following is for dilation method in optical flow ######################################
|
| 154 |
+
def sigma_matrix2(sig_x, sig_y, theta):
|
| 155 |
+
"""Calculate the rotated sigma matrix (two dimensional matrix).
|
| 156 |
+
Args:
|
| 157 |
+
sig_x (float):
|
| 158 |
+
sig_y (float):
|
| 159 |
+
theta (float): Radian measurement.
|
| 160 |
+
Returns:
|
| 161 |
+
ndarray: Rotated sigma matrix.
|
| 162 |
+
"""
|
| 163 |
+
d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]])
|
| 164 |
+
u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
|
| 165 |
+
return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def mesh_grid(kernel_size):
|
| 169 |
+
"""Generate the mesh grid, centering at zero.
|
| 170 |
+
Args:
|
| 171 |
+
kernel_size (int):
|
| 172 |
+
Returns:
|
| 173 |
+
xy (ndarray): with the shape (kernel_size, kernel_size, 2)
|
| 174 |
+
xx (ndarray): with the shape (kernel_size, kernel_size)
|
| 175 |
+
yy (ndarray): with the shape (kernel_size, kernel_size)
|
| 176 |
+
"""
|
| 177 |
+
ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
|
| 178 |
+
xx, yy = np.meshgrid(ax, ax)
|
| 179 |
+
xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
|
| 180 |
+
1))).reshape(kernel_size, kernel_size, 2)
|
| 181 |
+
return xy, xx, yy
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def pdf2(sigma_matrix, grid):
|
| 185 |
+
"""Calculate PDF of the bivariate Gaussian distribution.
|
| 186 |
+
Args:
|
| 187 |
+
sigma_matrix (ndarray): with the shape (2, 2)
|
| 188 |
+
grid (ndarray): generated by :func:`mesh_grid`,
|
| 189 |
+
with the shape (K, K, 2), K is the kernel size.
|
| 190 |
+
Returns:
|
| 191 |
+
kernel (ndarrray): un-normalized kernel.
|
| 192 |
+
"""
|
| 193 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
| 194 |
+
kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2))
|
| 195 |
+
return kernel
|
| 196 |
+
|
| 197 |
+
def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True):
|
| 198 |
+
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
|
| 199 |
+
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
| 200 |
+
Args:
|
| 201 |
+
kernel_size (int):
|
| 202 |
+
sig_x (float):
|
| 203 |
+
sig_y (float):
|
| 204 |
+
theta (float): Radian measurement.
|
| 205 |
+
grid (ndarray, optional): generated by :func:`mesh_grid`,
|
| 206 |
+
with the shape (K, K, 2), K is the kernel size. Default: None
|
| 207 |
+
isotropic (bool):
|
| 208 |
+
Returns:
|
| 209 |
+
kernel (ndarray): normalized kernel.
|
| 210 |
+
"""
|
| 211 |
+
if grid is None:
|
| 212 |
+
grid, _, _ = mesh_grid(kernel_size)
|
| 213 |
+
if isotropic:
|
| 214 |
+
sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]])
|
| 215 |
+
else:
|
| 216 |
+
sigma_matrix = sigma_matrix2(sig_x, sig_y, theta)
|
| 217 |
+
kernel = pdf2(sigma_matrix, grid)
|
| 218 |
+
kernel = kernel / np.sum(kernel)
|
| 219 |
+
return kernel
|