|
|
""" |
|
|
This script makes the ONNX model compatible with Triton inference server. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import numpy as np |
|
|
import onnx |
|
|
import onnxruntime as ort |
|
|
import onnx_graphsurgeon as gs |
|
|
|
|
|
|
|
|
def add_squeeze(graph, speed_input, speed_unsqueezed): |
|
|
""" |
|
|
Add squeeze operation to the speed input to change shape from [batch_size, 1] to [batch_size] |
|
|
""" |
|
|
|
|
|
squeeze_node = gs.Node( |
|
|
op="Squeeze", |
|
|
name="speed_squeeze", |
|
|
inputs=[speed_unsqueezed], |
|
|
outputs=[gs.Variable(name="speed_squeezed", dtype=speed_unsqueezed.dtype)] |
|
|
) |
|
|
|
|
|
|
|
|
insert_idx = 0 |
|
|
for idx, node in enumerate(graph.nodes): |
|
|
for i, input_name in enumerate(node.inputs): |
|
|
if input_name.name == speed_unsqueezed.name: |
|
|
insert_idx = idx |
|
|
break |
|
|
if insert_idx != 0: |
|
|
break |
|
|
|
|
|
|
|
|
insert_idx = min(0, insert_idx - 1) |
|
|
graph.nodes.insert(insert_idx, squeeze_node) |
|
|
|
|
|
|
|
|
for node in graph.nodes: |
|
|
for i, input_name in enumerate(node.inputs): |
|
|
if input_name.name == speed_input.name and not node.name == "speed_squeeze": |
|
|
node.inputs[i] = squeeze_node.outputs[0] |
|
|
|
|
|
return graph |
|
|
|
|
|
|
|
|
def main(): |
|
|
if len(sys.argv) != 2: |
|
|
print("Usage: python make_triton_compatible.py <onnx_model_path>") |
|
|
sys.exit(1) |
|
|
|
|
|
onnx_model_path = sys.argv[1] |
|
|
onnx_model = onnx.load(onnx_model_path) |
|
|
onnx.checker.check_model(onnx_model) |
|
|
print("Model is valid") |
|
|
|
|
|
graph = gs.import_onnx(onnx_model) |
|
|
|
|
|
|
|
|
speed_idx, speed = None, None |
|
|
for idx, input_ in enumerate(graph.inputs): |
|
|
if input_.name=="speed": |
|
|
speed_idx = idx |
|
|
speed = input_ |
|
|
|
|
|
|
|
|
speed_unsqueezed = gs.Variable(name="speed", dtype=speed.dtype, shape=[speed.shape[0], 1]) |
|
|
graph.inputs[speed_idx] = speed_unsqueezed |
|
|
|
|
|
|
|
|
if speed is not None: |
|
|
print(f"Found speed input: {speed.name}") |
|
|
print(f"Found speed input shape: {speed.shape}") |
|
|
print(f"Found speed input dtype: {speed.dtype}") |
|
|
print(f"Found speed input: {speed}") |
|
|
print(f"Found speed input: {type(speed)}") |
|
|
graph = add_squeeze(graph, speed, speed_unsqueezed) |
|
|
|
|
|
|
|
|
modified_model = gs.export_onnx(graph) |
|
|
onnx.checker.check_model(modified_model) |
|
|
|
|
|
|
|
|
output_path = onnx_model_path.replace('.onnx', '_triton.onnx') |
|
|
onnx.save(modified_model, output_path) |
|
|
print(f"Modified model saved to: {output_path}") |
|
|
else: |
|
|
print("Speed input not found in the model") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|