import io import os import glob from pathlib import Path import pickle import random import time import cv2 import torch import numpy as np import pandas as pd import torchvision.transforms as transforms from PIL import Image, ImageOps, ImageCms from decord import VideoReader from torch.utils.data.dataset import Dataset from controlnet_aux import CannyDetector, HEDdetector import torch.nn.functional as F from helpers import generate_1x_sequence, generate_2x_sequence, generate_large_blur_sequence, generate_test_case def unpack_mm_params(p): if isinstance(p, (tuple, list)): return p[0], p[1] elif isinstance(p, (int, float)): return p, p raise Exception(f'Unknown input parameter type.\nParameter: {p}.\nType: {type(p)}') def resize_for_crop(image, min_h, min_w): img_h, img_w = image.shape[-2:] if img_h >= min_h and img_w >= min_w: coef = min(min_h / img_h, min_w / img_w) elif img_h <= min_h and img_w <=min_w: coef = max(min_h / img_h, min_w / img_w) else: coef = min_h / img_h if min_h > img_h else min_w / img_w out_h, out_w = int(img_h * coef), int(img_w * coef) resized_image = transforms.functional.resize(image, (out_h, out_w), antialias=True) return resized_image class BaseClass(Dataset): def __init__( self, data_dir, output_dir, image_size=(320, 512), hflip_p=0.5, controlnet_type='canny', split='train', *args, **kwargs ): self.split = split self.height, self.width = unpack_mm_params(image_size) self.data_dir = data_dir self.output_dir = output_dir self.hflip_p = hflip_p self.image_size = image_size self.length = 0 def __len__(self): return self.length def load_frames(self, frames): # frames: numpy array of shape (N, H, W, C), 0–255 # → tensor of shape (N, C, H, W) as float pixel_values = torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous().float() # normalize to [-1, 1] pixel_values = pixel_values / 127.5 - 1.0 # resize to (self.height, self.width) pixel_values = F.interpolate( pixel_values, size=(self.height, self.width), mode="bilinear", align_corners=False ) return pixel_values def get_batch(self, idx): raise Exception('Get batch method is not realized.') def __getitem__(self, idx): while True: try: video, caption, motion_blur = self.get_batch(idx) break except Exception as e: print(e) idx = random.randint(0, self.length - 1) video, = [ resize_for_crop(x, self.height, self.width) for x in [video] ] video, = [ transforms.functional.center_crop(x, (self.height, self.width)) for x in [video] ] data = { 'video': video, 'caption': caption, } return data def load_as_srgb(path): img = Image.open(path) img = ImageOps.exif_transpose(img) if 'icc_profile' in img.info: icc = img.info['icc_profile'] src_profile = ImageCms.ImageCmsProfile(io.BytesIO(icc)) dst_profile = ImageCms.createProfile("sRGB") img = ImageCms.profileToProfile(img, src_profile, dst_profile, outputMode='RGB') else: img = img.convert("RGB") # Assume sRGB return img class GoProMotionBlurDataset(BaseClass): #7 frame go pro dataset def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Set blur and sharp directories based on split if self.split == 'train': self.blur_root = os.path.join(self.data_dir, 'train', 'blur') self.sharp_root = os.path.join(self.data_dir, 'train', 'sharp') elif self.split in ['val', 'test']: self.blur_root = os.path.join(self.data_dir, 'test', 'blur') self.sharp_root = os.path.join(self.data_dir, 'test', 'sharp') else: raise ValueError(f"Unsupported split: {self.split}") # Collect all blurred image paths pattern = os.path.join(self.blur_root, '*', '*.png') self.blur_paths = sorted(glob.glob(pattern)) if self.split == 'val': # Optional: limit validation subset self.blur_paths = self.blur_paths[:5] filtered_blur_paths = [] for path in self.blur_paths: output_deblurred_dir = os.path.join(self.output_dir, "deblurred") full_output_path = Path(output_deblurred_dir, *path.split('/')[-2:]).with_suffix(".mp4") if not os.path.exists(full_output_path): filtered_blur_paths.append(path) # Window and padding parameters self.window_size = 7 # original number of sharp frames self.pad = 2 # number of times to repeat last frame self.output_length = self.window_size + self.pad self.half_window = self.window_size // 2 self.length = len(self.blur_paths) # Normalized input interval: always [-0.5, 0.5] self.input_interval = torch.tensor([[-0.5, 0.5]], dtype=torch.float) # Precompute normalized output intervals: first for window_size frames, then pad duplicates step = 1.0 / (self.window_size - 1) # intervals for the original 7 frames window_intervals = [] for i in range(self.window_size): start = -0.5 + i * step if i < self.window_size - 1: end = -0.5 + (i + 1) * step else: end = 0.5 window_intervals.append([start, end]) # append the last interval pad times intervals = window_intervals + [window_intervals[-1]] * self.pad self.output_interval = torch.tensor(intervals, dtype=torch.float) def __len__(self): return self.length def __getitem__(self, idx): # Path to the blurred (center) frame blur_path = self.blur_paths[idx] seq_name = os.path.basename(os.path.dirname(blur_path)) frame_name = os.path.basename(blur_path) center_idx = int(os.path.splitext(frame_name)[0]) # Compute sharp frame range [center-half, center+half] start_idx = center_idx - self.half_window end_idx = center_idx + self.half_window # Load sharp frames sharp_dir = os.path.join(self.sharp_root, seq_name) frames = [] for i in range(start_idx, end_idx + 1): sharp_filename = f"{i:06d}.png" sharp_path = os.path.join(sharp_dir, sharp_filename) img = Image.open(sharp_path).convert('RGB') frames.append(img) # Repeat last sharp frame so total frames == output_length while len(frames) < self.output_length: frames.append(frames[-1]) # Load blurred image blur_img = Image.open(blur_path).convert('RGB') # Convert to pixel values via BaseClass loader video = self.load_frames(np.array(frames)) # shape: (output_length, H, W, C) blur_input = self.load_frames(np.expand_dims(np.array(blur_img), 0)) # shape: (1, H, W, C) end_time = time.time() data = { 'file_name': os.path.join(seq_name, frame_name), 'blur_img': blur_input, 'video': video, "caption": "", 'motion_blur_amount': torch.tensor(self.half_window, dtype=torch.long), 'input_interval': self.input_interval, 'output_interval': self.output_interval, "num_frames": self.window_size, "mode": "1x", } return data class OutsidePhotosDataset(BaseClass): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.image_paths = sorted(glob.glob(os.path.join(self.data_dir, '**', '*.*'), recursive=True)) INTERVALS = [ {"in_start": 0, "in_end": 16, "out_start": 0, "out_end": 16, "center": 8, "window_size": 16, "mode": "1x", "fps": 240}, {"in_start": 4, "in_end": 12, "out_start": 0, "out_end": 16, "center": 8, "window_size": 16, "mode": "2x", "fps": 240},] #other modes commented out for faster processing #{"in_start": 0, "in_end": 4, "out_start": 0, "out_end": 4, "center": 2, "window_size": 4, "mode": "1x", "fps": 240}, #{"in_start": 0, "in_end": 8, "out_start": 0, "out_end": 8, "center": 4, "window_size": 8, "mode": "1x", "fps": 240}, #{"in_start": 0, "in_end": 12, "out_start": 0, "out_end": 12, "center": 6, "window_size": 12, "mode": "1x", "fps": 240}, #{"in_start": 0, "in_end": 32, "out_start": 0, "out_end": 32, "center": 12, "window_size": 32, "mode": "lb", "fps": 120} #{"in_start": 0, "in_end": 48, "out_start": 0, "out_end": 48, "center": 24, "window_size": 48, "mode": "lb", "fps": 80} self.cleaned_intervals = [] for image_path in self.image_paths: for interval in INTERVALS: #create a copy of the interval dictionary i = interval.copy() #add the image path to the interval dictionary i['video_name'] = image_path video_name = i['video_name'] mode = i['mode'] vid_name_w_extension = os.path.relpath(video_name, self.data_dir).split('.')[0] # "frame_00000" output_name = ( f"{vid_name_w_extension}_{mode}.mp4" ) full_output_path = os.path.join("/datasets/sai/gencam/cogvideox/training/cogvideox-outsidephotos/deblurred", output_name) #THIS IS A HACK - YOU NEED TO UPDATE THIS TO YOUR OUTPUT DIRECTORY # Keep only if output doesn't exist if not os.path.exists(full_output_path): self.cleaned_intervals.append(i) self.length = len(self.cleaned_intervals) def __len__(self): return self.length def __getitem__(self, idx): interval = self.cleaned_intervals[idx] in_start = interval['in_start'] in_end = interval['in_end'] out_start = interval['out_start'] out_end = interval['out_end'] center = interval['center'] window = interval['window_size'] mode = interval['mode'] fps = interval['fps'] image_path = interval['video_name'] blur_img_original = load_as_srgb(image_path) H,W = blur_img_original.size frame_paths = [] frame_paths = ["../assets/dummy_image.png" for _ in range(window)] #any random path replicated # generate test case _, seq_frames, inp_int, out_int, high_fps_video, num_frames = generate_test_case( frame_paths=frame_paths, window_max=window, in_start=in_start, in_end=in_end, out_start=out_start,out_end=out_end, center=center, mode=mode, fps=fps ) file_name = image_path # Get base directory and frame prefix relative_file_name = os.path.relpath(file_name, self.data_dir) base_dir = os.path.dirname(relative_file_name) frame_stem = os.path.splitext(os.path.basename(file_name))[0] # "frame_00000" # Build new filename new_filename = ( f"{frame_stem}_{mode}.png" ) blur_img =blur_img_original.resize((self.image_size[1], self.image_size[0])) #cause pil is width, height # Final path relative_file_name = os.path.join(base_dir, new_filename) blur_input = self.load_frames(np.expand_dims(blur_img, 0).copy()) # seq_frames is list of frames; stack along time dim video = self.load_frames(np.stack(seq_frames, axis=0)) data = { 'file_name': relative_file_name, "original_size": (H, W), 'blur_img': blur_input, 'video': video, 'caption': "", 'input_interval': inp_int, 'output_interval': out_int, "num_frames": num_frames, } return data class FullMotionBlurDataset(BaseClass): """ A dataset that randomly selects among 1×, 2×, or large-blur modes per sample. Uses category-specific _list.txt files under each subfolder of FullDataset to assemble sequences. In 'test' split, it instead loads precomputed intervals from intervals_test.pkl and uses generate_test_case. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.seq_dirs = [] # TEST split: load fixed intervals early if self.split == 'test': pkl_path = os.path.join(self.data_dir, 'intervals_test.pkl') with open(pkl_path, 'rb') as f: self.test_intervals = pickle.load(f) assert self.test_intervals, f"No test intervals found in {pkl_path}" cleaned_intervals = [] for interval in self.test_intervals: # Extract interval values in_start = interval['in_start'] in_end = interval['in_end'] out_start = interval['out_start'] out_end = interval['out_end'] center = interval['center'] window = interval['window_size'] mode = interval['mode'] fps = interval['fps'] # e.g. "lower_fps_frames/720p_240fps_1/frame_00247.png" category, seq = interval['video_name'].split('/')#.split('/') seq_dir = os.path.join(self.data_dir, category, 'lower_fps_frames', seq) frame_paths = sorted(glob.glob(os.path.join(seq_dir, '*.png'))) rel_path = os.path.relpath(frame_paths[center], self.data_dir) rel_path = os.path.splitext(rel_path)[0] # remove the file extension output_name = ( f"{rel_path}_" f"in{in_start:04d}_ie{in_end:04d}_" f"os{out_start:04d}_oe{out_end:04d}_" f"ctr{center:04d}_win{window:04d}_" f"fps{fps:04d}_{mode}.mp4" ) output_deblurred_dir = os.path.join(self.output_dir, "deblurred") full_output_path = os.path.join(output_deblurred_dir, output_name) # Keep only if output doesn't exist if not os.path.exists(full_output_path): cleaned_intervals.append(interval) print("Len of test intervals after cleaning: ", len(cleaned_intervals)) print("Len of test intervals before cleaning: ", len(self.test_intervals)) self.test_intervals = cleaned_intervals # TRAIN/VAL: build seq_dirs from each category's list or fallback list_file = 'train_list.txt' if self.split == 'train' else 'test_list.txt' for category in sorted(os.listdir(self.data_dir)): cat_dir = os.path.join(self.data_dir, category) if not os.path.isdir(cat_dir): continue list_path = os.path.join(cat_dir, list_file) if os.path.isfile(list_path): with open(list_path, 'r') as f: for line in f: rel = line.strip() if not rel: continue seq_dir = os.path.join(self.data_dir, rel) if os.path.isdir(seq_dir): self.seq_dirs.append(seq_dir) else: fps_root = os.path.join(cat_dir, 'lower_fps_frames') if os.path.isdir(fps_root): for seq in sorted(os.listdir(fps_root)): seq_path = os.path.join(fps_root, seq) if os.path.isdir(seq_path): self.seq_dirs.append(seq_path) if self.split == 'val': self.seq_dirs = self.seq_dirs[:5] if self.split == 'train': self.seq_dirs *= 10 assert self.seq_dirs, \ f"No sequences found for split '{self.split}' in {self.data_dir}" def __len__(self): return len(self.test_intervals) if self.split == 'test' else len(self.seq_dirs) def __getitem__(self, idx): # Prepare base items if self.split == 'test': interval = self.test_intervals[idx] category, seq = interval['video_name'].split('/') seq_dir = os.path.join(self.data_dir, category, 'lower_fps_frames', seq) frame_paths = sorted(glob.glob(os.path.join(seq_dir, '*.png'))) in_start = interval['in_start'] in_end = interval['in_end'] out_start = interval['out_start'] out_end = interval['out_end'] center = interval['center'] window = interval['window_size'] mode = interval['mode'] fps = interval['fps'] # generate test case blur_img, seq_frames, inp_int, out_int, high_fps_video, num_frames = generate_test_case( frame_paths=frame_paths, window_max=window, in_start=in_start, in_end=in_end, out_start=out_start,out_end=out_end, center=center, mode=mode, fps=fps ) file_name = frame_paths[center] else: seq_dir = self.seq_dirs[idx] frame_paths = sorted(glob.glob(os.path.join(seq_dir, '*.png'))) mode = random.choice(['1x', '2x', 'large_blur']) if mode == '1x' or len(frame_paths) < 50: base_rate = random.choice([1, 2]) blur_img, seq_frames, inp_int, out_int, _ = generate_1x_sequence( frame_paths, window_max=16, output_len=17, base_rate=base_rate ) elif mode == '2x': base_rate = random.choice([1, 2]) blur_img, seq_frames, inp_int, out_int, _ = generate_2x_sequence( frame_paths, window_max=16, output_len=17, base_rate=base_rate ) else: max_base = min((len(frame_paths) - 1) // 17, 3) base_rate = random.randint(1, max_base) blur_img, seq_frames, inp_int, out_int, _ = generate_large_blur_sequence( frame_paths, window_max=16, output_len=17, base_rate=base_rate ) file_name = frame_paths[0] num_frames = 16 # blur_img is a single frame; wrap in batch dim blur_input = self.load_frames(np.expand_dims(blur_img, 0)) # seq_frames is list of frames; stack along time dim video = self.load_frames(np.stack(seq_frames, axis=0)) relative_file_name = os.path.relpath(file_name, self.data_dir) if self.split == 'test': # Get base directory and frame prefix base_dir = os.path.dirname(relative_file_name) frame_stem = os.path.splitext(os.path.basename(relative_file_name))[0] # "frame_00000" # Build new filename new_filename = ( f"{frame_stem}_" f"in{in_start:04d}_ie{in_end:04d}_" f"os{out_start:04d}_oe{out_end:04d}_" f"ctr{center:04d}_win{window:04d}_" f"fps{fps:04d}_{mode}.png" ) # Final path relative_file_name = os.path.join(base_dir, new_filename) data = { 'file_name': relative_file_name, 'blur_img': blur_input, 'num_frames': num_frames, 'video': video, 'caption': "", 'mode': mode, 'input_interval': inp_int, 'output_interval': out_int, } if self.split == 'test': high_fps_video = self.load_frames(np.stack(high_fps_video, axis=0)) data['high_fps_video'] = high_fps_video return data class GoPro2xMotionBlurDataset(BaseClass): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Set blur and sharp directories based on split if self.split == 'train': self.blur_root = os.path.join(self.data_dir, 'train', 'blur') self.sharp_root = os.path.join(self.data_dir, 'train', 'sharp') elif self.split in ['val', 'test']: self.blur_root = os.path.join(self.data_dir, 'test', 'blur') self.sharp_root = os.path.join(self.data_dir, 'test', 'sharp') else: raise ValueError(f"Unsupported split: {self.split}") # Collect all blurred image paths pattern = os.path.join(self.blur_root, '*', '*.png') def get_sharp_paths(blur_paths): sharp_paths = [] for blur_path in blur_paths: base_dir = blur_path.replace('/blur/', '/sharp/') frame_num = int(os.path.basename(blur_path).split('.')[0]) dir_path = os.path.dirname(base_dir) sequence = [ os.path.join(dir_path, f"{frame_num + offset:06d}.png") for offset in range(-6, 7) ] if all(os.path.exists(path) for path in sequence): sharp_paths.append(sequence) return sharp_paths self.blur_paths = sorted(glob.glob(pattern)) filtered_blur_paths = [] for path in self.blur_paths: output_deblurred_dir = os.path.join(self.output_dir, "deblurred") full_output_path = Path(output_deblurred_dir, *path.split('/')[-2:]).with_suffix(".mp4") if not os.path.exists(full_output_path): filtered_blur_paths.append(path) self.blur_paths = filtered_blur_paths self.sharp_paths = get_sharp_paths(self.blur_paths) if self.split == 'val': # Optional: limit validation subset self.sharp_paths = self.sharp_paths[:5] self.length = len(self.sharp_paths) def __len__(self): return self.length def __getitem__(self, idx): # Path to the blurred (center) frame sharp_path = self.sharp_paths[idx] # Load sharp frames blur_img, seq_frames, inp_int, out_int, high_fps_video, num_frames = generate_test_case( frame_paths=sharp_path, window_max=13, in_start=3, in_end=10, out_start=0,out_end=13, center=6, mode="2x", fps=240 ) # Convert to pixel values via BaseClass loader video = self.load_frames(np.array(seq_frames)) # shape: (output_length, H, W, C) blur_input = self.load_frames(np.expand_dims(np.array(blur_img), 0)) # shape: (1, H, W, C) last_two_parts_of_path = os.path.join(*sharp_path[6].split(os.sep)[-2:]) #print(f"Time taken to load and process data: {end_time - start_time:.2f} seconds") data = { 'file_name': last_two_parts_of_path, 'blur_img': blur_input, 'video': video, "caption": "", 'input_interval': inp_int, 'output_interval': out_int, "num_frames": num_frames, "mode": "2x", } return data class BAISTDataset(BaseClass): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) test_folders = { "gWA_sBM_c01_d26_mWA0_ch06_cropped_32X": None, "gBR_sBM_c01_d05_mBR0_ch01_cropped_32X": None, "gMH_sBM_c01_d22_mMH0_ch04_cropped_32X": None, "gHO_sBM_c01_d20_mHO0_ch05_cropped_32X": None, "gMH_sBM_c01_d22_mMH0_ch08_cropped_32X": None, "gWA_sBM_c01_d26_mWA0_ch02_cropped_32X": None, "gJS_sBM_c01_d02_mJS0_ch08_cropped_32X": None, "gHO_sBM_c01_d20_mHO0_ch07_cropped_32X": None, "gHO_sBM_c01_d20_mHO0_ch06_cropped_32X": None, "gBR_sBM_c01_d05_mBR0_ch03_cropped_32X": None, "gBR_sBM_c01_d05_mBR0_ch05_cropped_32X": None, "gHO_sBM_c01_d20_mHO0_ch02_cropped_32X": None, "gHO_sBM_c01_d20_mHO0_ch03_cropped_32X": None, "gHO_sBM_c01_d20_mHO0_ch09_cropped_32X": None, "gMH_sBM_c01_d22_mMH0_ch10_cropped_32X": None, "gWA_sBM_c01_d26_mWA0_ch10_cropped_32X": None, "gBR_sBM_c01_d05_mBR0_ch06_cropped_32X": None, "gHO_sBM_c01_d20_mHO0_ch08_cropped_32X": None, "gMH_sBM_c01_d22_mMH0_ch06_cropped_32X": None, "gHO_sBM_c01_d20_mHO0_ch10_cropped_32X": None, "gMH_sBM_c01_d22_mMH0_ch09_cropped_32X": None, "gMH_sBM_c01_d22_mMH0_ch02_cropped_32X": None, "gBR_sBM_c01_d05_mBR0_ch04_cropped_32X": None, "gPO_sBM_c01_d10_mPO0_ch09_cropped_32X": None, "gMH_sBM_c01_d22_mMH0_ch01_cropped_32X": None, "gMH_sBM_c01_d22_mMH0_ch07_cropped_32X": None, "gMH_sBM_c01_d22_mMH0_ch03_cropped_32X": None, "gHO_sBM_c01_d20_mHO0_ch04_cropped_32X": None, "gBR_sBM_c01_d05_mBR0_ch02_cropped_32X": None, "gHO_sBM_c01_d20_mHO0_ch01_cropped_32X": None, "gMH_sBM_c01_d22_mMH0_ch05_cropped_32X": None, "gPO_sBM_c01_d10_mPO0_ch10_cropped_32X": None, } def collect_blur_images(root_dir, allowed_folders, skip_start=40, skip_end=40): blur_image_paths = [] for dirpath, dirnames, filenames in os.walk(root_dir): if os.path.basename(dirpath) == "blur": parent_folder = os.path.basename(os.path.dirname(dirpath)) if (self.split in ["test", "val"] and parent_folder in test_folders) or (self.split in "train" and parent_folder not in test_folders): # Filter and sort valid image filenames valid_files = [ f for f in filenames if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')) and os.path.splitext(f)[0].isdigit() ] valid_files.sort(key=lambda x: int(os.path.splitext(x)[0])) # Skip first and last N files middle_files = valid_files[skip_start:len(valid_files) - skip_end] for f in middle_files: from pathlib import Path full_path = Path(os.path.join(dirpath, f)) output_deblurred_dir = os.path.join(self.output_dir, "deblurred") full_output_path = Path(output_deblurred_dir, *full_path.parts[-3:]).with_suffix(".mp4") if not os.path.exists(full_output_path) or self.split in ["train", "val"]: blur_image_paths.append(os.path.join(dirpath, f)) return blur_image_paths self.image_paths = collect_blur_images(self.data_dir, test_folders) #if bbx path does not exist, remove the image path self.image_paths = [path for path in self.image_paths if os.path.exists(path.replace("blur", "blur_anno").replace(".png", ".pkl"))] filtered_image_paths = [] for blur_path in self.image_paths: base_dir = blur_path.replace('/blur/', '/sharp/').replace('.png', '') sharp_paths = [f"{base_dir}_{i:03d}.png" for i in range(7)] if all(os.path.exists(p) for p in sharp_paths): filtered_image_paths.append(blur_path) self.image_paths = filtered_image_paths if self.split == 'val': # Optional: limit validation subset self.image_paths = self.image_paths[:4] self.length = len(self.image_paths) def __len__(self): return self.length def __getitem__(self, idx): image_path = self.image_paths[idx] blur_img_original = load_as_srgb(image_path) bbx_path = image_path.replace("blur", "blur_anno").replace(".png", ".pkl") #load the bbx path bbx = np.load(bbx_path, allow_pickle=True)['bbox'][0:4] # Final crop box #turn crop_box into tupel W,H = blur_img_original.size blur_img = blur_img_original.resize((self.image_size[1], self.image_size[0]), resample=Image.BILINEAR) #cause pil is width, height blur_np = np.array([blur_img]) base_dir = os.path.dirname(os.path.dirname(image_path)) # strip /blur filename = os.path.splitext(os.path.basename(image_path))[0] # '00000000' sharp_dir = os.path.join(base_dir, "sharp") frame_paths = [ os.path.join(sharp_dir, f"{filename}_{i:03d}.png") for i in range(7) ] _, seq_frames, inp_int, out_int, high_fps_video, num_frames = generate_test_case( frame_paths=frame_paths, window_max=7, in_start=0, in_end=7, out_start=0,out_end=7, center=3, mode="1x", fps=240 ) pixel_values = self.load_frames(np.stack(seq_frames, axis=0)) blur_pixel_values = self.load_frames(blur_np) relative_file_name = os.path.relpath(image_path, self.data_dir) out_bbx = bbx.copy() scale_x = blur_pixel_values.shape[3]/W scale_y = blur_pixel_values.shape[2]/H #scale the bbx out_bbx[0] = int(out_bbx[0] * scale_x) out_bbx[1] = int(out_bbx[1] * scale_y) out_bbx[2] = int(out_bbx[2] * scale_x) out_bbx[3] = int(out_bbx[3] * scale_y) out_bbx = torch.tensor(out_bbx, dtype=torch.uint32) #crop image using the bbx blur_img_npy = np.array(blur_img) out_bbx_npy = out_bbx.numpy().astype(np.uint32) blur_img_npy = blur_img_npy[out_bbx_npy[1]:out_bbx_npy[3], out_bbx_npy[0]:out_bbx_npy[2], :] data = { 'file_name': relative_file_name, 'blur_img': blur_pixel_values, 'video': pixel_values, 'bbx': out_bbx, 'caption': "", 'input_interval': inp_int, 'output_interval': out_int, "num_frames": num_frames, 'mode': "1x", } return data