import torch from transformers import AutoModelForDepthEstimation class FeatureExtractor(torch.nn.Module): def __init__(self): super().__init__() self.fe = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg') self.patch_size = self.fe.patch_size self.embed_dim = self.fe.embed_dim def forward(self, x): return self.fe.forward_features(x)['x_norm_patchtokens'] class FeatureExtractorDepth(torch.nn.Module): def __init__(self): super().__init__() self.fe = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Base-hf") self.patch_size = 14 self.embed_dim = 768 def forward(self, x): x = self.fe(x, output_hidden_states=True).hidden_states return x[-1][:,1:,:]