Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -28,11 +28,17 @@ class Config:
|
|
| 28 |
|
| 29 |
class ModelManager:
|
| 30 |
@staticmethod
|
| 31 |
-
def load_model(checkpoint_name: str):
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
|
|
|
| 35 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
model = torch.jit.load(model_path)
|
| 37 |
model.eval()
|
| 38 |
model.to("cuda")
|
|
@@ -60,7 +66,7 @@ class ImageProcessor:
|
|
| 60 |
depth_map = depth_output.squeeze().cpu().numpy()
|
| 61 |
|
| 62 |
if seg_model_name != "no-bg-removal":
|
| 63 |
-
seg_model = ModelManager.load_model(Config.SEG_CHECKPOINTS[seg_model_name])
|
| 64 |
seg_output = ModelManager.run_model(seg_model, input_tensor, image.height, image.width)
|
| 65 |
seg_mask = (seg_output.argmax(dim=1) > 0).float().cpu().numpy()[0]
|
| 66 |
depth_map[seg_mask == 0] = np.nan
|
|
|
|
| 28 |
|
| 29 |
class ModelManager:
|
| 30 |
@staticmethod
|
| 31 |
+
def load_model(checkpoint_name: str, type='depth'):
|
| 32 |
+
if checkpoint_name == 'seg':
|
| 33 |
+
model_path = hf_hub_download(
|
| 34 |
+
repo_id="shimu0215/seg", # 你的模型仓库
|
| 35 |
+
filename="sapiens_1b_seg_foreground_epoch_8_torchscript.pt2", # 你的模型文件
|
| 36 |
)
|
| 37 |
+
else:
|
| 38 |
+
model_path = hf_hub_download(
|
| 39 |
+
repo_id="shimu0215/seg",
|
| 40 |
+
filename="sapiens_2b_render_people_epoch_25_torchscript.pt2",
|
| 41 |
+
)
|
| 42 |
model = torch.jit.load(model_path)
|
| 43 |
model.eval()
|
| 44 |
model.to("cuda")
|
|
|
|
| 66 |
depth_map = depth_output.squeeze().cpu().numpy()
|
| 67 |
|
| 68 |
if seg_model_name != "no-bg-removal":
|
| 69 |
+
seg_model = ModelManager.load_model(Config.SEG_CHECKPOINTS[seg_model_name],type='seg')
|
| 70 |
seg_output = ModelManager.run_model(seg_model, input_tensor, image.height, image.width)
|
| 71 |
seg_mask = (seg_output.argmax(dim=1) > 0).float().cpu().numpy()[0]
|
| 72 |
depth_map[seg_mask == 0] = np.nan
|