| from typing import List | |
| import onnx | |
| import torch | |
| import torch.nn as nn | |
| from onnxsim import simplify | |
| class Preprocess(nn.Module): | |
| def __init__(self, input_shape: List[int]): | |
| super(Preprocess, self).__init__() | |
| self.input_shape = tuple(input_shape) | |
| self.mean = torch.tensor([0.4815, 0.4578, 0.4082]).view(1, 3, 1, 1) | |
| self.std = torch.tensor([0.2686, 0.2613, 0.2758]).view(1, 3, 1, 1) | |
| def forward(self, x: torch.Tensor): | |
| x = torch.nn.functional.interpolate( | |
| input=x, | |
| size=self.input_shape[2:], | |
| ) | |
| x = x / 255.0 | |
| x = (x - self.mean) / self.std | |
| return x | |
| if __name__ == "__main__": | |
| input_shape = [1, 3, 448, 448] | |
| output_onnx_file = "preprocessing.onnx" | |
| model = Preprocess(input_shape=input_shape) | |
| torch.onnx.export( | |
| model, | |
| torch.randn(input_shape), | |
| output_onnx_file, | |
| opset_version=20, | |
| input_names=["input_rgb"], | |
| output_names=["output_preprocessing"], | |
| dynamic_axes={ | |
| "input_rgb": { | |
| 0: "batch_size", | |
| 2: "height", | |
| 3: "width", | |
| }, | |
| }, | |
| ) | |
| model_onnx = onnx.load(output_onnx_file) | |
| model_simplified, _ = simplify(model_onnx) | |
| onnx.save(model_simplified, output_onnx_file) | |