Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| # -*- coding:utf-8 -*- | |
| import glob | |
| import os | |
| import os.path as osp | |
| import random | |
| import json | |
| import time | |
| import hashlib | |
| from multiprocessing.pool import Pool | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import ExifTags, Image, ImageOps | |
| from torch.utils.data import Dataset | |
| from tqdm import tqdm | |
| from .data_augment import ( | |
| augment_hsv, | |
| letterbox, | |
| mixup, | |
| random_affine, | |
| mosaic_augmentation, | |
| ) | |
| from yolov6.utils.events import LOGGER | |
| # Parameters | |
| IMG_FORMATS = ["bmp", "jpg", "jpeg", "png", "tif", "tiff", "dng", "webp", "mpo"] | |
| # Get orientation exif tag | |
| for k, v in ExifTags.TAGS.items(): | |
| if v == "Orientation": | |
| ORIENTATION = k | |
| break | |
| class TrainValDataset(Dataset): | |
| # YOLOv6 train_loader/val_loader, loads images and labels for training and validation | |
| def __init__( | |
| self, | |
| img_dir, | |
| img_size=640, | |
| batch_size=16, | |
| augment=False, | |
| hyp=None, | |
| rect=False, | |
| check_images=False, | |
| check_labels=False, | |
| stride=32, | |
| pad=0.0, | |
| rank=-1, | |
| data_dict=None, | |
| task="train", | |
| ): | |
| assert task.lower() in ("train", "val", "speed"), f"Not supported task: {task}" | |
| t1 = time.time() | |
| self.__dict__.update(locals()) | |
| self.main_process = self.rank in (-1, 0) | |
| self.task = self.task.capitalize() | |
| self.class_names = data_dict["names"] | |
| self.img_paths, self.labels = self.get_imgs_labels(self.img_dir) | |
| if self.rect: | |
| shapes = [self.img_info[p]["shape"] for p in self.img_paths] | |
| self.shapes = np.array(shapes, dtype=np.float64) | |
| self.batch_indices = np.floor( | |
| np.arange(len(shapes)) / self.batch_size | |
| ).astype( | |
| np.int | |
| ) # batch indices of each image | |
| self.sort_files_shapes() | |
| t2 = time.time() | |
| if self.main_process: | |
| LOGGER.info(f"%.1fs for dataset initialization." % (t2 - t1)) | |
| def __len__(self): | |
| """Get the length of dataset""" | |
| return len(self.img_paths) | |
| def __getitem__(self, index): | |
| """Fetching a data sample for a given key. | |
| This function applies mosaic and mixup augments during training. | |
| During validation, letterbox augment is applied. | |
| """ | |
| # Mosaic Augmentation | |
| if self.augment and random.random() < self.hyp["mosaic"]: | |
| img, labels = self.get_mosaic(index) | |
| shapes = None | |
| # MixUp augmentation | |
| if random.random() < self.hyp["mixup"]: | |
| img_other, labels_other = self.get_mosaic( | |
| random.randint(0, len(self.img_paths) - 1) | |
| ) | |
| img, labels = mixup(img, labels, img_other, labels_other) | |
| else: | |
| # Load image | |
| img, (h0, w0), (h, w) = self.load_image(index) | |
| # Letterbox | |
| shape = ( | |
| self.batch_shapes[self.batch_indices[index]] | |
| if self.rect | |
| else self.img_size | |
| ) # final letterboxed shape | |
| img, ratio, pad = letterbox(img, shape, auto=False, scaleup=self.augment) | |
| shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling | |
| labels = self.labels[index].copy() | |
| if labels.size: | |
| w *= ratio | |
| h *= ratio | |
| # new boxes | |
| boxes = np.copy(labels[:, 1:]) | |
| boxes[:, 0] = ( | |
| w * (labels[:, 1] - labels[:, 3] / 2) + pad[0] | |
| ) # top left x | |
| boxes[:, 1] = ( | |
| h * (labels[:, 2] - labels[:, 4] / 2) + pad[1] | |
| ) # top left y | |
| boxes[:, 2] = ( | |
| w * (labels[:, 1] + labels[:, 3] / 2) + pad[0] | |
| ) # bottom right x | |
| boxes[:, 3] = ( | |
| h * (labels[:, 2] + labels[:, 4] / 2) + pad[1] | |
| ) # bottom right y | |
| labels[:, 1:] = boxes | |
| if self.augment: | |
| img, labels = random_affine( | |
| img, | |
| labels, | |
| degrees=self.hyp["degrees"], | |
| translate=self.hyp["translate"], | |
| scale=self.hyp["scale"], | |
| shear=self.hyp["shear"], | |
| new_shape=(self.img_size, self.img_size), | |
| ) | |
| if len(labels): | |
| h, w = img.shape[:2] | |
| labels[:, [1, 3]] = labels[:, [1, 3]].clip(0, w - 1e-3) # x1, x2 | |
| labels[:, [2, 4]] = labels[:, [2, 4]].clip(0, h - 1e-3) # y1, y2 | |
| boxes = np.copy(labels[:, 1:]) | |
| boxes[:, 0] = ((labels[:, 1] + labels[:, 3]) / 2) / w # x center | |
| boxes[:, 1] = ((labels[:, 2] + labels[:, 4]) / 2) / h # y center | |
| boxes[:, 2] = (labels[:, 3] - labels[:, 1]) / w # width | |
| boxes[:, 3] = (labels[:, 4] - labels[:, 2]) / h # height | |
| labels[:, 1:] = boxes | |
| if self.augment: | |
| img, labels = self.general_augment(img, labels) | |
| labels_out = torch.zeros((len(labels), 6)) | |
| if len(labels): | |
| labels_out[:, 1:] = torch.from_numpy(labels) | |
| # Convert | |
| img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB | |
| img = np.ascontiguousarray(img) | |
| return torch.from_numpy(img), labels_out, self.img_paths[index], shapes | |
| def load_image(self, index): | |
| """Load image. | |
| This function loads image by cv2, resize original image to target shape(img_size) with keeping ratio. | |
| Returns: | |
| Image, original shape of image, resized image shape | |
| """ | |
| path = self.img_paths[index] | |
| im = cv2.imread(path) | |
| assert im is not None, f"Image Not Found {path}, workdir: {os.getcwd()}" | |
| h0, w0 = im.shape[:2] # origin shape | |
| r = self.img_size / max(h0, w0) | |
| if r != 1: | |
| im = cv2.resize( | |
| im, | |
| (int(w0 * r), int(h0 * r)), | |
| interpolation=cv2.INTER_AREA | |
| if r < 1 and not self.augment | |
| else cv2.INTER_LINEAR, | |
| ) | |
| return im, (h0, w0), im.shape[:2] | |
| def collate_fn(batch): | |
| """Merges a list of samples to form a mini-batch of Tensor(s)""" | |
| img, label, path, shapes = zip(*batch) | |
| for i, l in enumerate(label): | |
| l[:, 0] = i # add target image index for build_targets() | |
| return torch.stack(img, 0), torch.cat(label, 0), path, shapes | |
| def get_imgs_labels(self, img_dir): | |
| assert osp.exists(img_dir), f"{img_dir} is an invalid directory path!" | |
| valid_img_record = osp.join( | |
| osp.dirname(img_dir), "." + osp.basename(img_dir) + ".json" | |
| ) | |
| NUM_THREADS = min(8, os.cpu_count()) | |
| img_paths = glob.glob(osp.join(img_dir, "*"), recursive=True) | |
| img_paths = sorted( | |
| p for p in img_paths if p.split(".")[-1].lower() in IMG_FORMATS | |
| ) | |
| assert img_paths, f"No images found in {img_dir}." | |
| img_hash = self.get_hash(img_paths) | |
| if osp.exists(valid_img_record): | |
| with open(valid_img_record, "r") as f: | |
| cache_info = json.load(f) | |
| if "image_hash" in cache_info and cache_info["image_hash"] == img_hash: | |
| img_info = cache_info["information"] | |
| else: | |
| self.check_images = True | |
| else: | |
| self.check_images = True | |
| # check images | |
| if self.check_images and self.main_process: | |
| img_info = {} | |
| nc, msgs = 0, [] # number corrupt, messages | |
| LOGGER.info( | |
| f"{self.task}: Checking formats of images with {NUM_THREADS} process(es): " | |
| ) | |
| with Pool(NUM_THREADS) as pool: | |
| pbar = tqdm( | |
| pool.imap(TrainValDataset.check_image, img_paths), | |
| total=len(img_paths), | |
| ) | |
| for img_path, shape_per_img, nc_per_img, msg in pbar: | |
| if nc_per_img == 0: # not corrupted | |
| img_info[img_path] = {"shape": shape_per_img} | |
| nc += nc_per_img | |
| if msg: | |
| msgs.append(msg) | |
| pbar.desc = f"{nc} image(s) corrupted" | |
| pbar.close() | |
| if msgs: | |
| LOGGER.info("\n".join(msgs)) | |
| cache_info = {"information": img_info, "image_hash": img_hash} | |
| # save valid image paths. | |
| with open(valid_img_record, "w") as f: | |
| json.dump(cache_info, f) | |
| # check and load anns | |
| label_dir = osp.join( | |
| osp.dirname(osp.dirname(img_dir)), "labels", osp.basename(img_dir) | |
| ) | |
| assert osp.exists(label_dir), f"{label_dir} is an invalid directory path!" | |
| img_paths = list(img_info.keys()) | |
| label_paths = sorted( | |
| osp.join(label_dir, osp.splitext(osp.basename(p))[0] + ".txt") | |
| for p in img_paths | |
| ) | |
| label_hash = self.get_hash(label_paths) | |
| if "label_hash" not in cache_info or cache_info["label_hash"] != label_hash: | |
| self.check_labels = True | |
| if self.check_labels: | |
| cache_info["label_hash"] = label_hash | |
| nm, nf, ne, nc, msgs = 0, 0, 0, 0, [] # number corrupt, messages | |
| LOGGER.info( | |
| f"{self.task}: Checking formats of labels with {NUM_THREADS} process(es): " | |
| ) | |
| with Pool(NUM_THREADS) as pool: | |
| pbar = pool.imap( | |
| TrainValDataset.check_label_files, zip(img_paths, label_paths) | |
| ) | |
| pbar = tqdm(pbar, total=len(label_paths)) if self.main_process else pbar | |
| for ( | |
| img_path, | |
| labels_per_file, | |
| nc_per_file, | |
| nm_per_file, | |
| nf_per_file, | |
| ne_per_file, | |
| msg, | |
| ) in pbar: | |
| if nc_per_file == 0: | |
| img_info[img_path]["labels"] = labels_per_file | |
| else: | |
| img_info.pop(img_path) | |
| nc += nc_per_file | |
| nm += nm_per_file | |
| nf += nf_per_file | |
| ne += ne_per_file | |
| if msg: | |
| msgs.append(msg) | |
| if self.main_process: | |
| pbar.desc = f"{nf} label(s) found, {nm} label(s) missing, {ne} label(s) empty, {nc} invalid label files" | |
| if self.main_process: | |
| pbar.close() | |
| with open(valid_img_record, "w") as f: | |
| json.dump(cache_info, f) | |
| if msgs: | |
| LOGGER.info("\n".join(msgs)) | |
| if nf == 0: | |
| LOGGER.warning( | |
| f"WARNING: No labels found in {osp.dirname(self.img_paths[0])}. " | |
| ) | |
| if self.task.lower() == "val": | |
| if self.data_dict.get("is_coco", False): # use original json file when evaluating on coco dataset. | |
| assert osp.exists(self.data_dict["anno_path"]), "Eval on coco dataset must provide valid path of the annotation file in config file: data/coco.yaml" | |
| else: | |
| assert ( | |
| self.class_names | |
| ), "Class names is required when converting labels to coco format for evaluating." | |
| save_dir = osp.join(osp.dirname(osp.dirname(img_dir)), "annotations") | |
| if not osp.exists(save_dir): | |
| os.mkdir(save_dir) | |
| save_path = osp.join( | |
| save_dir, "instances_" + osp.basename(img_dir) + ".json" | |
| ) | |
| TrainValDataset.generate_coco_format_labels( | |
| img_info, self.class_names, save_path | |
| ) | |
| img_paths, labels = list( | |
| zip( | |
| *[ | |
| ( | |
| img_path, | |
| np.array(info["labels"], dtype=np.float32) | |
| if info["labels"] | |
| else np.zeros((0, 5), dtype=np.float32), | |
| ) | |
| for img_path, info in img_info.items() | |
| ] | |
| ) | |
| ) | |
| self.img_info = img_info | |
| LOGGER.info( | |
| f"{self.task}: Final numbers of valid images: {len(img_paths)}/ labels: {len(labels)}. " | |
| ) | |
| return img_paths, labels | |
| def get_mosaic(self, index): | |
| """Gets images and labels after mosaic augments""" | |
| indices = [index] + random.choices( | |
| range(0, len(self.img_paths)), k=3 | |
| ) # 3 additional image indices | |
| random.shuffle(indices) | |
| imgs, hs, ws, labels = [], [], [], [] | |
| for index in indices: | |
| img, _, (h, w) = self.load_image(index) | |
| labels_per_img = self.labels[index] | |
| imgs.append(img) | |
| hs.append(h) | |
| ws.append(w) | |
| labels.append(labels_per_img) | |
| img, labels = mosaic_augmentation(self.img_size, imgs, hs, ws, labels, self.hyp) | |
| return img, labels | |
| def general_augment(self, img, labels): | |
| """Gets images and labels after general augment | |
| This function applies hsv, random ud-flip and random lr-flips augments. | |
| """ | |
| nl = len(labels) | |
| # HSV color-space | |
| augment_hsv( | |
| img, | |
| hgain=self.hyp["hsv_h"], | |
| sgain=self.hyp["hsv_s"], | |
| vgain=self.hyp["hsv_v"], | |
| ) | |
| # Flip up-down | |
| if random.random() < self.hyp["flipud"]: | |
| img = np.flipud(img) | |
| if nl: | |
| labels[:, 2] = 1 - labels[:, 2] | |
| # Flip left-right | |
| if random.random() < self.hyp["fliplr"]: | |
| img = np.fliplr(img) | |
| if nl: | |
| labels[:, 1] = 1 - labels[:, 1] | |
| return img, labels | |
| def sort_files_shapes(self): | |
| # Sort by aspect ratio | |
| batch_num = self.batch_indices[-1] + 1 | |
| s = self.shapes # wh | |
| ar = s[:, 1] / s[:, 0] # aspect ratio | |
| irect = ar.argsort() | |
| self.img_paths = [self.img_paths[i] for i in irect] | |
| self.labels = [self.labels[i] for i in irect] | |
| self.shapes = s[irect] # wh | |
| ar = ar[irect] | |
| # Set training image shapes | |
| shapes = [[1, 1]] * batch_num | |
| for i in range(batch_num): | |
| ari = ar[self.batch_indices == i] | |
| mini, maxi = ari.min(), ari.max() | |
| if maxi < 1: | |
| shapes[i] = [maxi, 1] | |
| elif mini > 1: | |
| shapes[i] = [1, 1 / mini] | |
| self.batch_shapes = ( | |
| np.ceil(np.array(shapes) * self.img_size / self.stride + self.pad).astype( | |
| np.int | |
| ) | |
| * self.stride | |
| ) | |
| def check_image(im_file): | |
| # verify an image. | |
| nc, msg = 0, "" | |
| try: | |
| im = Image.open(im_file) | |
| im.verify() # PIL verify | |
| shape = im.size # (width, height) | |
| im_exif = im._getexif() | |
| if im_exif and ORIENTATION in im_exif: | |
| rotation = im_exif[ORIENTATION] | |
| if rotation in (6, 8): | |
| shape = (shape[1], shape[0]) | |
| assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" | |
| assert im.format.lower() in IMG_FORMATS, f"invalid image format {im.format}" | |
| if im.format.lower() in ("jpg", "jpeg"): | |
| with open(im_file, "rb") as f: | |
| f.seek(-2, 2) | |
| if f.read() != b"\xff\xd9": # corrupt JPEG | |
| ImageOps.exif_transpose(Image.open(im_file)).save( | |
| im_file, "JPEG", subsampling=0, quality=100 | |
| ) | |
| msg += f"WARNING: {im_file}: corrupt JPEG restored and saved" | |
| return im_file, shape, nc, msg | |
| except Exception as e: | |
| nc = 1 | |
| msg = f"WARNING: {im_file}: ignoring corrupt image: {e}" | |
| return im_file, None, nc, msg | |
| def check_label_files(args): | |
| img_path, lb_path = args | |
| nm, nf, ne, nc, msg = 0, 0, 0, 0, "" # number (missing, found, empty, message | |
| try: | |
| if osp.exists(lb_path): | |
| nf = 1 # label found | |
| with open(lb_path, "r") as f: | |
| labels = [ | |
| x.split() for x in f.read().strip().splitlines() if len(x) | |
| ] | |
| labels = np.array(labels, dtype=np.float32) | |
| if len(labels): | |
| assert all( | |
| len(l) == 5 for l in labels | |
| ), f"{lb_path}: wrong label format." | |
| assert ( | |
| labels >= 0 | |
| ).all(), f"{lb_path}: Label values error: all values in label file must > 0" | |
| assert ( | |
| labels[:, 1:] <= 1 | |
| ).all(), f"{lb_path}: Label values error: all coordinates must be normalized" | |
| _, indices = np.unique(labels, axis=0, return_index=True) | |
| if len(indices) < len(labels): # duplicate row check | |
| labels = labels[indices] # remove duplicates | |
| msg += f"WARNING: {lb_path}: {len(labels) - len(indices)} duplicate labels removed" | |
| labels = labels.tolist() | |
| else: | |
| ne = 1 # label empty | |
| labels = [] | |
| else: | |
| nm = 1 # label missing | |
| labels = [] | |
| return img_path, labels, nc, nm, nf, ne, msg | |
| except Exception as e: | |
| nc = 1 | |
| msg = f"WARNING: {lb_path}: ignoring invalid labels: {e}" | |
| return img_path, None, nc, nm, nf, ne, msg | |
| def generate_coco_format_labels(img_info, class_names, save_path): | |
| # for evaluation with pycocotools | |
| dataset = {"categories": [], "annotations": [], "images": []} | |
| for i, class_name in enumerate(class_names): | |
| dataset["categories"].append( | |
| {"id": i, "name": class_name, "supercategory": ""} | |
| ) | |
| ann_id = 0 | |
| LOGGER.info(f"Convert to COCO format") | |
| for i, (img_path, info) in enumerate(tqdm(img_info.items())): | |
| labels = info["labels"] if info["labels"] else [] | |
| img_id = osp.splitext(osp.basename(img_path))[0] | |
| img_id = int(img_id) if img_id.isnumeric() else img_id | |
| img_w, img_h = info["shape"] | |
| dataset["images"].append( | |
| { | |
| "file_name": os.path.basename(img_path), | |
| "id": img_id, | |
| "width": img_w, | |
| "height": img_h, | |
| } | |
| ) | |
| if labels: | |
| for label in labels: | |
| c, x, y, w, h = label[:5] | |
| # convert x,y,w,h to x1,y1,x2,y2 | |
| x1 = (x - w / 2) * img_w | |
| y1 = (y - h / 2) * img_h | |
| x2 = (x + w / 2) * img_w | |
| y2 = (y + h / 2) * img_h | |
| # cls_id starts from 0 | |
| cls_id = int(c) | |
| w = max(0, x2 - x1) | |
| h = max(0, y2 - y1) | |
| dataset["annotations"].append( | |
| { | |
| "area": h * w, | |
| "bbox": [x1, y1, w, h], | |
| "category_id": cls_id, | |
| "id": ann_id, | |
| "image_id": img_id, | |
| "iscrowd": 0, | |
| # mask | |
| "segmentation": [], | |
| } | |
| ) | |
| ann_id += 1 | |
| with open(save_path, "w") as f: | |
| json.dump(dataset, f) | |
| LOGGER.info( | |
| f"Convert to COCO format finished. Resutls saved in {save_path}" | |
| ) | |
| def get_hash(paths): | |
| """Get the hash value of paths""" | |
| assert isinstance(paths, list), "Only support list currently." | |
| h = hashlib.md5("".join(paths).encode()) | |
| return h.hexdigest() | |