fix: lint & other minor issues

This commit is contained in:
blessedcoolant 2024-01-23 02:50:20 +05:30 committed by Kent Keirsey
parent 8f5e2cbcc7
commit c859eb865e
2 changed files with 8 additions and 13 deletions

View File

@ -65,13 +65,11 @@ class DepthAnythingDetector:
self.model_size = model_size
if self.model_size == "small":
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384], localhub=True)
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
if self.model_size == "base":
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768], localhub=True)
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
if self.model_size == "large":
self.model = DPT_DINOv2(
encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024], localhub=True
)
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
self.model.eval()

View File

@ -22,9 +22,7 @@ def _make_fusion_block(features, use_bn, size=None):
class DPTHead(nn.Module):
def __init__(
self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False
):
def __init__(self, nclass, in_channels, features, out_channels, use_bn=False, use_clstoken=False):
super(DPTHead, self).__init__()
self.nclass = nclass
@ -138,19 +136,18 @@ class DPTHead(nn.Module):
class DPT_DINOv2(nn.Module):
def __init__(
self,
features,
out_channels,
encoder="vitl",
features=256,
out_channels=[256, 512, 1024, 1024],
use_bn=False,
use_clstoken=False,
localhub=True,
):
super(DPT_DINOv2, self).__init__()
assert encoder in ["vits", "vitb", "vitl"]
# # in case the Internet connection is not stable, please load the DINOv2 locally
# if localhub:
# if use_local:
# self.pretrained = torch.hub.load(
# torchhub_path / "facebookresearch_dinov2_main",
# "dinov2_{:}14".format(encoder),
@ -170,7 +167,7 @@ class DPT_DINOv2(nn.Module):
dim = self.pretrained.blocks[0].attn.qkv.in_features
self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
self.depth_head = DPTHead(1, dim, features, out_channels=out_channels, use_bn=use_bn, use_clstoken=use_clstoken)
def forward(self, x):
h, w = x.shape[-2:]