Spaces:
Runtime error
Runtime error
fix hard-coded cuda device
Browse files
model.py
CHANGED
|
@@ -10,14 +10,15 @@ checkpoint = "geetu040/DepthPro"
|
|
| 10 |
revision = "project"
|
| 11 |
image_processor = DepthProImageProcessorFast.from_pretrained(checkpoint, revision=revision)
|
| 12 |
model = DepthProForDepthEstimation.from_pretrained(checkpoint, revision=revision)
|
| 13 |
-
|
|
|
|
| 14 |
|
| 15 |
def predict(image):
|
| 16 |
# inference
|
| 17 |
|
| 18 |
# prepare image for the model
|
| 19 |
inputs = image_processor(images=image, return_tensors="pt")
|
| 20 |
-
inputs = {k: v.to(
|
| 21 |
|
| 22 |
with torch.no_grad():
|
| 23 |
outputs = model(**inputs)
|
|
|
|
| 10 |
revision = "project"
|
| 11 |
image_processor = DepthProImageProcessorFast.from_pretrained(checkpoint, revision=revision)
|
| 12 |
model = DepthProForDepthEstimation.from_pretrained(checkpoint, revision=revision)
|
| 13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
+
model = model.to(device)
|
| 15 |
|
| 16 |
def predict(image):
|
| 17 |
# inference
|
| 18 |
|
| 19 |
# prepare image for the model
|
| 20 |
inputs = image_processor(images=image, return_tensors="pt")
|
| 21 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 22 |
|
| 23 |
with torch.no_grad():
|
| 24 |
outputs = model(**inputs)
|