File size: 3,020 Bytes
1b6c34a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
This script makes the ONNX model compatible with Triton inference server.
"""

import sys
import numpy as np
import onnx
import onnxruntime as ort
import onnx_graphsurgeon as gs


def add_squeeze(graph, speed_input, speed_unsqueezed):
    """
    Add squeeze operation to the speed input to change shape from [batch_size, 1] to [batch_size]
    """
    # Create a squeeze node
    squeeze_node = gs.Node(
        op="Squeeze",
        name="speed_squeeze",
        inputs=[speed_unsqueezed],
        outputs=[gs.Variable(name="speed_squeezed", dtype=speed_unsqueezed.dtype)]
    )
    
    ## Find first node that has speed_unsqueezed as input
    insert_idx = 0
    for idx, node in enumerate(graph.nodes):
        for i, input_name in enumerate(node.inputs):
            if input_name.name == speed_unsqueezed.name:
                insert_idx = idx
                break
        if insert_idx != 0:
            break
    
    ## Add squeeze node to the graph
    insert_idx = min(0, insert_idx - 1)
    graph.nodes.insert(insert_idx, squeeze_node)
    
    # Update the speed input to point to the squeezed output
    for node in graph.nodes:
        for i, input_name in enumerate(node.inputs):
            if input_name.name == speed_input.name and not node.name == "speed_squeeze":
                node.inputs[i] = squeeze_node.outputs[0]
                
    return graph


def main():
    if len(sys.argv) != 2:
        print("Usage: python make_triton_compatible.py <onnx_model_path>")
        sys.exit(1)

    onnx_model_path = sys.argv[1]
    onnx_model = onnx.load(onnx_model_path)
    onnx.checker.check_model(onnx_model)
    print("Model is valid")
    
    graph = gs.import_onnx(onnx_model)

    ## get input_id for speed
    speed_idx, speed = None, None
    for idx, input_ in enumerate(graph.inputs):
        if input_.name=="speed":
            speed_idx = idx
            speed = input_

    # Update the speed input to have shape [batch_size, 1]
    speed_unsqueezed = gs.Variable(name="speed", dtype=speed.dtype, shape=[speed.shape[0], 1])
    graph.inputs[speed_idx] = speed_unsqueezed
        
    ## Add squeeze to change speed shape from [batch_size, 1] to [batch_size]
    if speed is not None:
        print(f"Found speed input: {speed.name}")
        print(f"Found speed input shape: {speed.shape}")
        print(f"Found speed input dtype: {speed.dtype}")
        print(f"Found speed input: {speed}")
        print(f"Found speed input: {type(speed)}")
        graph = add_squeeze(graph, speed, speed_unsqueezed)
        
        # Export the modified graph back to ONNX
        modified_model = gs.export_onnx(graph)
        onnx.checker.check_model(modified_model)
        
        # Save the modified model
        output_path = onnx_model_path.replace('.onnx', '_triton.onnx')
        onnx.save(modified_model, output_path)
        print(f"Modified model saved to: {output_path}")
    else:
        print("Speed input not found in the model")


if __name__ == "__main__":
    main()