Spaces:
Runtime error
Runtime error
| # code ref: https://github.com/yjn870/FSRCNN-pytorch/blob/master/models.py | |
| from . import common | |
| import math | |
| from argparse import Namespace | |
| import torch.nn as nn | |
| from models import register | |
| import torch.nn.functional as F | |
| def make_model(args, parent=False): | |
| return FSRCNN(args) | |
| def FSRCNN(scale_ratio, rgb_range=1): | |
| args = Namespace() | |
| args.scale = [scale_ratio] | |
| args.n_colors = 3 | |
| args.rgb_range = rgb_range | |
| return FSRCNN(args) | |
| class FSRCNN(nn.Module): | |
| def __init__(self, args, conv=common.default_conv, d=56, s=12 * 3, m=8): | |
| super(FSRCNN, self).__init__() | |
| scale = args.scale[0] | |
| act = nn.PReLU() | |
| m_first_part = [] | |
| m_first_part.append(conv(args.n_colors, d, kernel_size=5)) | |
| m_first_part.append(act) | |
| self.first_part = nn.Sequential(*m_first_part) | |
| m_mid_part = [] | |
| m_mid_part.append(conv(d, s, kernel_size=1)) | |
| m_mid_part.append(act) | |
| for _ in range(m): | |
| m_mid_part.append(conv(s, s, kernel_size=3)) | |
| m_mid_part.append(act) | |
| m_mid_part.append(conv(s, d, kernel_size=1)) | |
| m_mid_part.append(act) | |
| self.mid_part = nn.Sequential(*m_mid_part) | |
| self.last_part = nn.ConvTranspose2d(d, args.n_colors, kernel_size=9, stride=scale, padding=9//2, | |
| output_padding=scale-1) | |
| # self._initialize_weights() | |
| def _initialize_weights(self): | |
| for m in self.first_part: | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel()))) | |
| nn.init.zeros_(m.bias.data) | |
| for m in self.mid_part: | |
| if isinstance(m, nn.Conv2d): | |
| nn.init.normal_(m.weight.data, mean=0.0, std=math.sqrt(2/(m.out_channels*m.weight.data[0][0].numel()))) | |
| nn.init.zeros_(m.bias.data) | |
| nn.init.normal_(self.last_part.weight.data, mean=0.0, std=0.001) | |
| nn.init.zeros_(self.last_part.bias.data) | |
| def forward(self, x, out_size=None): | |
| x = self.first_part(x) | |
| x = self.mid_part(x) | |
| x = self.last_part(x) | |
| return x | |
| def load_state_dict(self, state_dict, strict=False): | |
| own_state = self.state_dict() | |
| for name, param in state_dict.items(): | |
| if name in own_state: | |
| if isinstance(param, nn.Parameter): | |
| param = param.data | |
| try: | |
| own_state[name].copy_(param) | |
| except Exception: | |
| if name.find('tail') >= 0: | |
| print('Replace pre-trained upsampler to new one...') | |
| else: | |
| raise RuntimeError('While copying the parameter named {}, ' | |
| 'whose dimensions in the model are {} and ' | |
| 'whose dimensions in the checkpoint are {}.' | |
| .format(name, own_state[name].size(), param.size())) | |
| elif strict: | |
| if name.find('tail') == -1: | |
| raise KeyError('unexpected key "{}" in state_dict' | |
| .format(name)) | |
| if strict: | |
| missing = set(own_state.keys()) - set(state_dict.keys()) | |
| if len(missing) > 0: | |
| raise KeyError('missing keys in state_dict: "{}"'.format(missing)) |