| import torch | |
| import torch.nn as nn | |
| class Quant(torch.autograd.Function): | |
| def forward(ctx, input): | |
| input = torch.clamp(input, 0, 1) | |
| output = (input * 255.).round() / 255. | |
| return output | |
| def backward(ctx, grad_output): | |
| return grad_output | |
| class Quantization(nn.Module): | |
| def __init__(self): | |
| super(Quantization, self).__init__() | |
| def forward(self, input): | |
| return Quant.apply(input) | |