fix: lint & other minor issues

This commit is contained in:
blessedcoolant 2024-01-23 02:50:20 +05:30 committed by Kent Keirsey
parent 8f5e2cbcc7
commit c859eb865e
2 changed files with 8 additions and 13 deletions

View File

@ -65,13 +65,11 @@ class DepthAnythingDetector:
self.model_size = model_size self.model_size = model_size
if self.model_size == "small": if self.model_size == "small":
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384], localhub=True) self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
if self.model_size == "base": if self.model_size == "base":
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768], localhub=True) self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
if self.model_size == "large": if self.model_size == "large":
self.model = DPT_DINOv2( self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024], localhub=True
)
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu")) self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
self.model.eval() self.model.eval()

View File

@ -22,9 +22,7 @@ def _make_fusion_block(features, use_bn, size=None):
class DPTHead(nn.Module): class DPTHead(nn.Module):
def __init__( def __init__(self, nclass, in_channels, features, out_channels, use_bn=False, use_clstoken=False):
self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False
):
super(DPTHead, self).__init__() super(DPTHead, self).__init__()
self.nclass = nclass self.nclass = nclass
@ -138,19 +136,18 @@ class DPTHead(nn.Module):
class DPT_DINOv2(nn.Module): class DPT_DINOv2(nn.Module):
def __init__( def __init__(
self, self,
features,
out_channels,
encoder="vitl", encoder="vitl",
features=256,
out_channels=[256, 512, 1024, 1024],
use_bn=False, use_bn=False,
use_clstoken=False, use_clstoken=False,
localhub=True,
): ):
super(DPT_DINOv2, self).__init__() super(DPT_DINOv2, self).__init__()
assert encoder in ["vits", "vitb", "vitl"] assert encoder in ["vits", "vitb", "vitl"]
# # in case the Internet connection is not stable, please load the DINOv2 locally # # in case the Internet connection is not stable, please load the DINOv2 locally
# if localhub: # if use_local:
# self.pretrained = torch.hub.load( # self.pretrained = torch.hub.load(
# torchhub_path / "facebookresearch_dinov2_main", # torchhub_path / "facebookresearch_dinov2_main",
# "dinov2_{:}14".format(encoder), # "dinov2_{:}14".format(encoder),
@ -170,7 +167,7 @@ class DPT_DINOv2(nn.Module):
dim = self.pretrained.blocks[0].attn.qkv.in_features dim = self.pretrained.blocks[0].attn.qkv.in_features
self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) self.depth_head = DPTHead(1, dim, features, out_channels=out_channels, use_bn=use_bn, use_clstoken=use_clstoken)
def forward(self, x): def forward(self, x):
h, w = x.shape[-2:] h, w = x.shape[-2:]