| | from cgitb import text |
| | import os |
| |
|
| | import clip |
| | import torch.onnx |
| | import torch |
| | from torch import nn |
| | from multiprocessing import Pool |
| |
|
| | class TextTransformer(nn.Module): |
| | def __init__(self, clip_model): |
| | super().__init__() |
| | self.clip_model = clip_model |
| |
|
| | def forward(self, x: torch.Tensor): |
| | return self.clip_model.encode_text(x) |
| |
|
| | def export(model, input, path): |
| | print(f"Exporting to {path}") |
| | torch.onnx.export( |
| | model, |
| | input, |
| | path, |
| | export_params=True, |
| | opset_version=16, |
| | do_constant_folding=True, |
| | input_names = ['input'], |
| | output_names = ['output'], |
| | dynamic_axes={ |
| | 'input' : {0 : 'batch_size'}, |
| | 'output' : {0 : 'batch_size'} |
| | } |
| | ) |
| |
|
| | def convert(model_name, dashed_name): |
| | visual_path = f"{output_dir}/clip-{dashed_name}-visual.onnx" |
| | textual_path = f"{output_dir}/clip-{dashed_name}-textual.onnx" |
| | visual_exists = os.path.exists(visual_path) |
| | textual_exists = os.path.exists(textual_path) |
| | if visual_exists and textual_exists: |
| | print(f"{visual_path} exists, skipping") |
| | print(f"{textual_path} exists, skipping") |
| | return |
| |
|
| | print(f"Model: {model_name}") |
| | print(f"Loading CLIP") |
| | model, _ = clip.load(model_name, device=device) |
| | model = model.to(device=device) |
| |
|
| |
|
| | if not visual_exists: |
| | input_res = model.visual.input_resolution |
| | export( |
| | model.visual, |
| | torch.rand(1, 3, input_res, input_res), |
| | visual_path, |
| | ) |
| | else: |
| | print(f"{visual_path} exists, skipping") |
| |
|
| | if not textual_exists: |
| | text_transformer = TextTransformer(model) |
| | export( |
| | text_transformer, |
| | clip.tokenize(["hello onnx"]).to(device), |
| | textual_path, |
| | ) |
| | else: |
| | print(f"{textual_path} exists, skipping") |
| |
|
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | device = "cpu" |
| | output_dir = "converted" |
| | if __name__ == "__main__": |
| | print(f"Torch device: {device}") |
| | |
| | available_models = clip.available_models() |
| | print(f"Available models: {available_models}") |
| |
|
| | models = [ |
| | ("RN50", "resnet-50"), |
| | ("RN101", "resnet-101"), |
| | ("RN50x4", "resnet-50x4"), |
| | ("RN50x16", "resnet-50x16"), |
| | ("RN50x64", "resnet-50x64"), |
| | ("RN50", "resnet-50"), |
| | ("RN50", "resnet-50"), |
| | ("RN50", "resnet-50"), |
| | ("ViT-B/16", "vit-base-patch16"), |
| | ("ViT-B/32", "vit-base-patch32"), |
| | ("ViT-L/14", "vit-large-patch14"), |
| | ("ViT-L/14@336px", "vit-large-patch14-336"), |
| | ] |
| |
|
| | print(f"Converting models: {models}") |
| |
|
| | for model in models: |
| | convert(*model) |
| |
|
| | |
| | |
| | |
| |
|
| | print("done") |
| |
|