File size: 3,020 Bytes
1b6c34a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
"""
This script makes the ONNX model compatible with Triton inference server.
"""
import sys
import numpy as np
import onnx
import onnxruntime as ort
import onnx_graphsurgeon as gs
def add_squeeze(graph, speed_input, speed_unsqueezed):
"""
Add squeeze operation to the speed input to change shape from [batch_size, 1] to [batch_size]
"""
# Create a squeeze node
squeeze_node = gs.Node(
op="Squeeze",
name="speed_squeeze",
inputs=[speed_unsqueezed],
outputs=[gs.Variable(name="speed_squeezed", dtype=speed_unsqueezed.dtype)]
)
## Find first node that has speed_unsqueezed as input
insert_idx = 0
for idx, node in enumerate(graph.nodes):
for i, input_name in enumerate(node.inputs):
if input_name.name == speed_unsqueezed.name:
insert_idx = idx
break
if insert_idx != 0:
break
## Add squeeze node to the graph
insert_idx = min(0, insert_idx - 1)
graph.nodes.insert(insert_idx, squeeze_node)
# Update the speed input to point to the squeezed output
for node in graph.nodes:
for i, input_name in enumerate(node.inputs):
if input_name.name == speed_input.name and not node.name == "speed_squeeze":
node.inputs[i] = squeeze_node.outputs[0]
return graph
def main():
if len(sys.argv) != 2:
print("Usage: python make_triton_compatible.py <onnx_model_path>")
sys.exit(1)
onnx_model_path = sys.argv[1]
onnx_model = onnx.load(onnx_model_path)
onnx.checker.check_model(onnx_model)
print("Model is valid")
graph = gs.import_onnx(onnx_model)
## get input_id for speed
speed_idx, speed = None, None
for idx, input_ in enumerate(graph.inputs):
if input_.name=="speed":
speed_idx = idx
speed = input_
# Update the speed input to have shape [batch_size, 1]
speed_unsqueezed = gs.Variable(name="speed", dtype=speed.dtype, shape=[speed.shape[0], 1])
graph.inputs[speed_idx] = speed_unsqueezed
## Add squeeze to change speed shape from [batch_size, 1] to [batch_size]
if speed is not None:
print(f"Found speed input: {speed.name}")
print(f"Found speed input shape: {speed.shape}")
print(f"Found speed input dtype: {speed.dtype}")
print(f"Found speed input: {speed}")
print(f"Found speed input: {type(speed)}")
graph = add_squeeze(graph, speed, speed_unsqueezed)
# Export the modified graph back to ONNX
modified_model = gs.export_onnx(graph)
onnx.checker.check_model(modified_model)
# Save the modified model
output_path = onnx_model_path.replace('.onnx', '_triton.onnx')
onnx.save(modified_model, output_path)
print(f"Modified model saved to: {output_path}")
else:
print("Speed input not found in the model")
if __name__ == "__main__":
main()
|