Spaces:
Runtime error
Runtime error
| """ | |
| Implementation of YOLOv3 architecture | |
| """ | |
| import pytorch_lightning as pl | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.optim.lr_scheduler import OneCycleLR | |
| from . import config | |
| from .loss import YoloLoss | |
| model_config = [ | |
| (32, 3, 1), | |
| (64, 3, 2), | |
| ["B", 1], | |
| (128, 3, 2), | |
| ["B", 2], | |
| (256, 3, 2), | |
| ["B", 8], | |
| (512, 3, 2), | |
| ["B", 8], | |
| (1024, 3, 2), | |
| ["B", 4], # darknet 53 ends here | |
| (512, 1, 1), | |
| (1024, 3, 1), | |
| "S", | |
| (256, 1, 1), | |
| "U", | |
| (256, 1, 1), | |
| (512, 3, 1), | |
| "S", | |
| (128, 1, 1), | |
| "U", | |
| (128, 1, 1), | |
| (256, 3, 1), | |
| "S" | |
| ] | |
| class CNNBlock(pl.LightningModule): | |
| def __init__(self, in_channels, out_channels, bn_act=True, **kwargs): | |
| super().__init__() | |
| self.conv = nn.Conv2d(in_channels, out_channels, bias=not bn_act, **kwargs) | |
| self.bn = nn.BatchNorm2d(out_channels) | |
| self.leaky = nn.LeakyReLU(0.1) | |
| self.use_bn_act = bn_act | |
| def forward(self, x): | |
| if self.use_bn_act: | |
| return self.leaky(self.bn((self.conv(x)))) | |
| else: | |
| return self.conv(x) | |
| class ResidualBlock(pl.LightningModule): | |
| def __init__(self, channels, use_residual=True, num_repeats=1): | |
| super().__init__() | |
| self.layers = nn.ModuleList() | |
| for repeat in range(num_repeats): | |
| self.layers += [ | |
| nn.Sequential( | |
| CNNBlock(channels, channels//2, kernel_size=1), | |
| CNNBlock(channels//2, channels, kernel_size=3, padding=1) | |
| ) | |
| ] | |
| self.use_residual = use_residual | |
| self.num_repeats = num_repeats | |
| def forward(self, x): | |
| for layer in self.layers: | |
| if self.use_residual: | |
| x = x + layer(x) | |
| else: | |
| x = layer(x) | |
| return x | |
| class ScalePrediction(pl.LightningModule): | |
| def __init__(self, in_channels, num_classes): | |
| super().__init__() | |
| self.pred = nn.Sequential( | |
| CNNBlock(in_channels, 2 * in_channels, kernel_size=3, padding=1), | |
| CNNBlock(2 * in_channels, (num_classes + 5) * 3, kernel_size=1, bn_act=False) | |
| ) | |
| self.num_classes = num_classes | |
| def forward(self, x): | |
| return ( | |
| self.pred(x). | |
| reshape(x.shape[0], 3, self.num_classes + 5, x.shape[2], x.shape[3]). | |
| permute(0, 1, 3, 4, 2) | |
| ) | |
| class YOLOv3(pl.LightningModule): | |
| def __init__(self, in_channels=3, num_classes=20): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.in_channels = in_channels | |
| self.layers = self._create_conv_layers() | |
| self.scaled_anchors = ( | |
| torch.tensor(config.ANCHORS) * torch.tensor(config.S).unsqueeze(1).unsqueeze(1).repeat(1, 3, 2) # ? | |
| ).to(config.DEVICE) | |
| self.learning_rate = config.LEARNING_RATE | |
| self.weight_decay = config.WEIGHT_DECAY | |
| self.best_lr = 1e-3 ## ? | |
| def forward(self, x): # ? | |
| outputs = [] # for each scale | |
| route_connections = [] | |
| for layer in self.layers: | |
| if isinstance(layer, ScalePrediction): | |
| outputs.append(layer(x)) | |
| continue | |
| x = layer(x) | |
| if isinstance(layer, ResidualBlock) and layer.num_repeats == 8: | |
| route_connections.append(x) | |
| elif isinstance(layer, nn.Upsample): | |
| x = torch.cat([x, route_connections[-1]], dim=1) | |
| route_connections.pop() | |
| return outputs | |
| def _create_conv_layers(self): | |
| layers = nn.ModuleList() | |
| in_channels = self.in_channels | |
| for module in model_config: | |
| if isinstance(module, tuple): | |
| out_channels, kernel_size, stride = module | |
| layers.append( | |
| CNNBlock(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=1 if kernel_size==3 else 0) | |
| ) | |
| in_channels = out_channels | |
| elif isinstance(module, list): | |
| num_repeats = module[1] | |
| layers.append( | |
| ResidualBlock(in_channels, num_repeats=num_repeats) | |
| ) | |
| elif isinstance(module, str): | |
| if module == "S": | |
| layers += [ | |
| ResidualBlock(in_channels, use_residual=False, num_repeats=1), | |
| CNNBlock(in_channels, in_channels//2, kernel_size=1), | |
| ScalePrediction(in_channels//2, num_classes=self.num_classes) | |
| ] | |
| in_channels = in_channels // 2 | |
| elif module == "U": | |
| layers.append(nn.Upsample(scale_factor=2)) | |
| in_channels = in_channels * 3 | |
| return layers | |
| def yololoss(self): | |
| return YoloLoss() | |
| def training_step(self, batch, batch_idx): | |
| x, y = batch | |
| y0, y1, y2 = y[0], y[1], y[2] | |
| out = self.forward(x) | |
| # print(out[0].shape, y0.shape) | |
| loss = ( # ? | |
| self.yololoss()(out[0], y0, self.scaled_anchors[0]) | |
| + self.yololoss()(out[1], y1, self.scaled_anchors[1]) | |
| + self.yololoss()(out[2], y2, self.scaled_anchors[2]) | |
| ) | |
| self.log( | |
| "train_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True | |
| ) | |
| return loss | |
| def test_step(self, batch, batch_idx): | |
| x, y = batch | |
| y0, y1, y2 = y[0], y[1], y[2] | |
| out = self.forward(x) | |
| loss = ( | |
| self.yololoss()(out[0], y0, self.scaled_anchors[0]) | |
| + self.yololoss()(out[1], y1, self.scaled_anchors[1]) | |
| + self.yololoss()(out[2], y2, self.scaled_anchors[2]) | |
| ) | |
| self.log( | |
| "test_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True | |
| ) | |
| return loss | |
| def on_train_epoch_end(self) -> None: | |
| print( | |
| f"Epoch: {self.current_epoch}, Loss: {self.trainer.callback_metrics['train_loss_epoch']}" | |
| ) | |
| def on_test_epoch_end(self) -> None: | |
| print( | |
| f"Epoch: {self.current_epoch}, Loss: {self.trainer.callback_metrics['test_loss_epoch']}" | |
| ) | |
| def configure_optimizers(self): | |
| optimizer = optim.Adam( | |
| self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay | |
| ) | |
| scheduler = OneCycleLR( | |
| optimizer, | |
| max_lr=self.best_lr, | |
| steps_per_epoch=len(self.trainer.datamodule.train_dataloader()), | |
| epochs=config.NUM_EPOCHS, | |
| pct_start=8 / config.NUM_EPOCHS, | |
| div_factor=100, | |
| three_phase=False, | |
| final_div_factor=100, | |
| anneal_strategy="linear" | |
| ) | |
| return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}] | |
| def on_train_end(self) -> None: | |
| torch.save(self.state_dict(), config.MODEL_STATE_DICT_PATH) | |
| if __name__ == "main": | |
| num_classes = 20 | |
| IMAGE_SIZE = 416 | |
| model = YOLOv3(num_classes=num_classes) | |
| x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE)) | |
| out = model(x) | |
| assert model(x)[0].shape == ( | |
| 2, | |
| 3, | |
| IMAGE_SIZE // 32, | |
| IMAGE_SIZE // 32, | |
| num_classes + 5 | |
| ) | |
| assert model(x)[1].shape == ( | |
| 2, | |
| 3, | |
| IMAGE_SIZE // 16, | |
| IMAGE_SIZE // 16, | |
| num_classes + 5 | |
| ) | |
| assert model(x)[2].shape == ( | |
| 2, | |
| 3, | |
| IMAGE_SIZE // 8, | |
| IMAGE_SIZE // 8, | |
| num_classes + 5 | |
| ) | |
| print("Image size compatibility check passed!") | |