Spaces:
Paused
Paused
Update mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py
Browse files
mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py
CHANGED
|
@@ -792,8 +792,8 @@ class RAFTDepthNormalDPT5(nn.Module):
|
|
| 792 |
self.relu = nn.ReLU(inplace=True)
|
| 793 |
|
| 794 |
def get_bins(self, bins_num):
|
| 795 |
-
|
| 796 |
-
depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device="cpu")
|
| 797 |
depth_bins_vec = torch.exp(depth_bins_vec)
|
| 798 |
return depth_bins_vec
|
| 799 |
|
|
@@ -848,7 +848,8 @@ class RAFTDepthNormalDPT5(nn.Module):
|
|
| 848 |
return norm_normalize(torch.cat([normal_out, confidence], dim=1))
|
| 849 |
#return norm_normalize(torch.cat([normal_out, confidence], dim=1).float())
|
| 850 |
|
| 851 |
-
def create_mesh_grid(self, height, width, batch, device="cpu", set_buffer=True):
|
|
|
|
| 852 |
y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device),
|
| 853 |
torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij')
|
| 854 |
meshgrid = torch.stack((x, y))
|
|
|
|
| 792 |
self.relu = nn.ReLU(inplace=True)
|
| 793 |
|
| 794 |
def get_bins(self, bins_num):
|
| 795 |
+
depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device="cuda")
|
| 796 |
+
#depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device="cpu")
|
| 797 |
depth_bins_vec = torch.exp(depth_bins_vec)
|
| 798 |
return depth_bins_vec
|
| 799 |
|
|
|
|
| 848 |
return norm_normalize(torch.cat([normal_out, confidence], dim=1))
|
| 849 |
#return norm_normalize(torch.cat([normal_out, confidence], dim=1).float())
|
| 850 |
|
| 851 |
+
#def create_mesh_grid(self, height, width, batch, device="cpu", set_buffer=True):
|
| 852 |
+
def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True):
|
| 853 |
y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device),
|
| 854 |
torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij')
|
| 855 |
meshgrid = torch.stack((x, y))
|