Update app.py
Browse files
app.py
CHANGED
|
@@ -43,13 +43,13 @@ os.system("wget https://github.com/onnx/models/raw/main/vision/classification/zf
|
|
| 43 |
|
| 44 |
ort_session = ort.InferenceSession("zfnet512-12.onnx")
|
| 45 |
|
| 46 |
-
|
| 47 |
def predict(path):
|
| 48 |
img_batch = preprocess(get_image(path))
|
| 49 |
|
| 50 |
outputs = ort_session.run(
|
| 51 |
None,
|
| 52 |
-
{"data_0": img_batch.astype(np.float32)},
|
| 53 |
)
|
| 54 |
|
| 55 |
a = np.argsort(-outputs[0].flatten())
|
|
|
|
| 43 |
|
| 44 |
ort_session = ort.InferenceSession("zfnet512-12.onnx")
|
| 45 |
|
| 46 |
+
|
| 47 |
def predict(path):
|
| 48 |
img_batch = preprocess(get_image(path))
|
| 49 |
|
| 50 |
outputs = ort_session.run(
|
| 51 |
None,
|
| 52 |
+
{"gpu_0/data_0": img_batch.astype(np.float32)},
|
| 53 |
)
|
| 54 |
|
| 55 |
a = np.argsort(-outputs[0].flatten())
|