RADAR-demo / models /feature_extractor.py
arcanoXIII's picture
Upload 13 files
7e08bf1 verified
import torch
from transformers import AutoModelForDepthEstimation
class FeatureExtractor(torch.nn.Module):
def __init__(self):
super().__init__()
self.fe = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
self.patch_size = self.fe.patch_size
self.embed_dim = self.fe.embed_dim
def forward(self, x):
return self.fe.forward_features(x)['x_norm_patchtokens']
class FeatureExtractorDepth(torch.nn.Module):
def __init__(self):
super().__init__()
self.fe = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Base-hf")
self.patch_size = 14
self.embed_dim = 768
def forward(self, x):
x = self.fe(x, output_hidden_states=True).hidden_states
return x[-1][:,1:,:]