Spaces:
Running
on
Zero
Running
on
Zero
| import os, sys | |
| import json | |
| import cv2 | |
| import math | |
| import shutil | |
| import numpy as np | |
| import random | |
| import collections | |
| from PIL import Image | |
| import torch | |
| from torch.utils.data import Dataset | |
| # Import files from the local folder | |
| root_path = os.path.abspath('.') | |
| sys.path.append(root_path) | |
| from utils.img_utils import resize_with_antialiasing, numpy_to_pt | |
| def get_video_frames(config, video_frame_path, flip = False): | |
| video_seq_length = config["video_seq_length"] | |
| # Calculate needed parameters | |
| num_frames_input = 0 | |
| for file_name in os.listdir(video_frame_path): | |
| if file_name.startswith("im_"): | |
| num_frames_input += 1 | |
| total_frames_needed = video_seq_length | |
| division_factor = num_frames_input // total_frames_needed | |
| remain_frames = (num_frames_input % total_frames_needed) - 1 # -1 for adaptation | |
| # Define the gap | |
| gaps = [division_factor for _ in range(total_frames_needed-1)] | |
| for idx in range(remain_frames): | |
| if idx % 2 == 0: | |
| gaps[idx//2] += 1 # Start to end order | |
| else: | |
| gaps[-1*(1+(idx//2))] += 1 # End to start order | |
| # Find needed file | |
| needed_img_path = [] | |
| cur_idx = 0 | |
| for gap in gaps: | |
| img_path = os.path.join(video_frame_path, "im_" + str(cur_idx) + ".jpg") | |
| needed_img_path.append(img_path) | |
| # Update the idx | |
| cur_idx += gap | |
| # Append the last one | |
| img_path = os.path.join(video_frame_path, "im_" + str(cur_idx) + ".jpg") | |
| needed_img_path.append(img_path) | |
| # Read all img_path based on the order | |
| video_frames = [] | |
| for img_path in needed_img_path: | |
| if not os.path.exists(img_path): | |
| print("We don't have ", img_path) | |
| frame = cv2.imread(img_path) | |
| try: | |
| frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) | |
| except Exception: | |
| print("The exception places is ", img_path) | |
| # Resize frames | |
| frame = cv2.resize(frame, (config["width"], config["height"]), interpolation = cv2.INTER_CUBIC) | |
| # Flip aug | |
| if flip: | |
| frame = np.fliplr(frame) | |
| # Collect frames | |
| video_frames.append(np.expand_dims(frame, axis=0)) # The frame is already RGB, there is no need to convert here. | |
| # Concatenate | |
| video_frames = np.concatenate(video_frames, axis=0) | |
| assert(len(video_frames) == video_seq_length) | |
| return video_frames | |
| def tokenize_captions(prompt, tokenizer, config, is_train=True): | |
| ''' | |
| Tokenize text prompt be prepared tokenizer from SD2.1 | |
| ''' | |
| captions = [] | |
| if random.random() < config["empty_prompts_proportion"]: | |
| captions.append("") | |
| elif isinstance(prompt, str): | |
| captions.append(prompt) | |
| elif isinstance(prompt, (list, np.ndarray)): | |
| # take a random caption if there are multiple | |
| captions.append(random.choice(prompt) if is_train else prompt[0]) | |
| else: | |
| raise ValueError( | |
| f"Caption column should contain either strings or lists of strings." | |
| ) | |
| inputs = tokenizer( | |
| captions, max_length = tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" | |
| ) | |
| return inputs.input_ids[0] | |
| class Video_Dataset(Dataset): | |
| ''' | |
| Video Dataset to load sequential frames for training with needed pre-processing | |
| ''' | |
| def __init__(self, config, device, normalize=True, tokenizer=None): | |
| # Attribute variables | |
| self.config = config | |
| self.device = device | |
| self.normalize = normalize | |
| self.tokenizer = tokenizer | |
| # Obtain values | |
| self.video_seq_length = config["video_seq_length"] | |
| self.height = config["height"] | |
| self.width = config["width"] | |
| # Process data | |
| self.video_lists = [] | |
| stats_analysis = collections.defaultdict(int) | |
| print("Process all files to check valid datasets....") | |
| for dataset_path in config["dataset_path"]: | |
| for video_name in sorted(os.listdir(dataset_path)): | |
| video_path = os.path.join(dataset_path, video_name) | |
| all_files = os.listdir(video_path) | |
| valid = True | |
| # Valid check 1: the number of files should be in sequential order | |
| num_frames_input = 0 | |
| for file_name in os.listdir(video_path): | |
| if file_name.startswith("im_"): | |
| num_frames_input += 1 | |
| for idx in range(num_frames_input): | |
| img_path = 'im_' + str(idx) + '.jpg' | |
| if img_path not in all_files: # Should be sequential existing | |
| valid = False | |
| stats_analysis["incomplete_img"] += 1 | |
| break | |
| # Valid check 1.5: the number of files must be longer than video_seq_length and less than self.config["acceleration_tolerance"]*self.config["video_seq_length"] | |
| if num_frames_input < self.config["video_seq_length"]: | |
| stats_analysis["too_little_frames"] += 1 | |
| valid = False | |
| if num_frames_input > self.config["acceleration_tolerance"] * self.config["video_seq_length"]: | |
| stats_analysis["too_many_frames"] += 1 | |
| valid = False | |
| if not valid: # SpeedUp so set in the middle here | |
| continue | |
| # Valid check 2: language if needed | |
| if config["use_text"] and not os.path.exists(os.path.join(dataset_path, video_name, "lang.txt")): | |
| stats_analysis["no_lang_txt"] += 1 | |
| valid = False | |
| # Valid check 3: motion if needed | |
| if config["motion_bucket_id"] is None: | |
| flow_path = os.path.join(dataset_path, video_name, "flow.txt") | |
| if "flow.txt" not in all_files: | |
| stats_analysis["no_flow_txt"] += 1 | |
| valid = False | |
| else: | |
| file = open(flow_path, 'r') | |
| info = file.readlines() | |
| if len(info) == 0: | |
| stats_analysis["no_flow_txt"] += 1 | |
| valid = False | |
| if valid: | |
| self.video_lists.append(video_path) | |
| print("stats_analysis is ", stats_analysis) | |
| print("Valid dataset length is ", len(self.video_lists)) | |
| def __len__(self): | |
| return len(self.video_lists) | |
| def _get_motion_value(self, sub_folder_path): | |
| ''' Read the motion value from the flow.txt file prepared; preprocess the flow to accelerate | |
| ''' | |
| # Read the flow.txt | |
| flow_path = os.path.join(sub_folder_path, 'flow.txt') | |
| file = open(flow_path, 'r') | |
| info = file.readlines() | |
| per_video_movement = float(info[0][:-2]) | |
| # Map the raw reflected_motion_bucket_id to target range based on the number of images have | |
| num_frames_input = 0 | |
| for file_name in os.listdir(sub_folder_path): # num_frames_input is the total number of files with name begin with im_ | |
| if file_name.startswith("im_"): | |
| num_frames_input += 1 | |
| # Correct the value based on the number of frames relative to video_seq_length | |
| per_video_movement_correct = per_video_movement * (num_frames_input/self.config["video_seq_length"]) | |
| # Map from one Normal Distribution to another Normal Distribution | |
| z = (per_video_movement_correct - self.config["dataset_motion_mean"]) / (self.config["dataset_motion_std"] + 0.001) | |
| reflected_motion_bucket_id = int((z * self.config["svd_motion_std"]) + self.config["svd_motion_mean"]) | |
| print("We map " + str(per_video_movement) + " to " + str(per_video_movement_correct) + " by length " + str(num_frames_input) + " to bucket_id of " + str(reflected_motion_bucket_id)) | |
| return reflected_motion_bucket_id | |
| def __getitem__(self, idx): | |
| ''' Get item by idx and pre-process by Resize and Normalize to [0, 1] | |
| Args: | |
| idx (int): The index to the file in the directory | |
| Returns: | |
| video_frames (torch.float32): The Pytorch tensor format of obtained frames (max: 1.0; min: 0.0) | |
| reflected_motion_bucket_id (tensor): Motion value is there is optical flow provided, else they are fixed value from config | |
| prompt (tensor): Tokenized text | |
| ''' | |
| # Prepare the text if needed: | |
| if self.config["use_text"]: | |
| # Read the file | |
| file_path = os.path.join(self.video_lists[idx], "lang.txt") | |
| file = open(file_path, 'r') | |
| prompt = file.readlines()[0] # Only read the first line | |
| if self.config["mix_ambiguous"] and os.path.exists(os.path.join(self.video_lists[idx], "processed_text.txt")): | |
| # If we don't have this txt file, we skip | |
| ######################################################## Mix up prompt ######################################################## | |
| # Read the file | |
| file_path = os.path.join(self.video_lists[idx], "processed_text.txt") | |
| file = open(file_path, 'r') | |
| prompts = [line for line in file.readlines()] # Only read the first line | |
| # Get the componenet | |
| action = prompts[0][:-1] | |
| this = prompts[1][:-1] | |
| there = prompts[2][:-1] | |
| random_value = random.random() | |
| # If less than 0.4, we don't care, just use the most concrete one | |
| if random_value >= 0.4 and random_value < 0.6: | |
| # Mask pick object to "This" | |
| prompt = action + " this to " + there | |
| elif random_value >= 0.6 and random_value < 0.8: | |
| # Mask place position to "There" | |
| prompt = action + " " + this + " to there" | |
| elif random_value >= 0.8 and random_value < 1.0: | |
| # Just be like "this to there" | |
| prompt = action + " this to there" | |
| # print("New prompt is ", prompt) | |
| ################################################################################################################################################### | |
| # else: | |
| # print("We don't have llama processed prompt at ", self.video_lists[idx]) | |
| else: | |
| prompt = "" | |
| # Tokenize text prompt | |
| tokenized_prompt = tokenize_captions(prompt, self.tokenizer, self.config) | |
| # Dataset aug by chance (it is needed to check whether there is any object position words [left|right] in the prompt text) | |
| flip = False | |
| if random.random() < self.config["flip_aug_prob"]: | |
| if self.config["use_text"]: | |
| if prompt.find("left") == -1 and prompt.find("right") == -1: # Cannot have position word, like left and right (up and down is ok) | |
| flip = True | |
| else: | |
| flip = True | |
| # Read frames for different datasets; Currently, we have WebVid / Bridge | |
| if self.config["dataset_name"] == "Bridge": | |
| video_frames = get_video_frames(self.config, self.video_lists[idx], flip=flip) | |
| else: | |
| raise NotImplementedError("We don't support this dataset loader") | |
| # Scale [0, 255] -> [-1, 1] | |
| if self.normalize: | |
| video_frames = video_frames.astype(np.float32) / 127.5 - 1 # Be careful to cast to float32 | |
| # Transform to Pytorch Tensor in the range [-1, 1] | |
| video_frames = numpy_to_pt(video_frames) | |
| # print("length of input frames has ", len(video_frames)) | |
| # Get the motion value based on the optical flow | |
| if self.config["motion_bucket_id"] is None: | |
| reflected_motion_bucket_id = self._get_motion_value(self.video_lists[idx]) | |
| else: | |
| reflected_motion_bucket_id = self.config["motion_bucket_id"] | |
| # The tensor we returned is torch float32. We won't cast here for mixed precision training! | |
| return { | |
| "video_frames" : video_frames, | |
| "reflected_motion_bucket_id" : reflected_motion_bucket_id, | |
| "prompt": tokenized_prompt, | |
| } |