diff --git a/invokeai/backend/image_util/depth_anything/__init__.py b/invokeai/backend/image_util/depth_anything/__init__.py index 626e48a87a..37adaa3004 100644 --- a/invokeai/backend/image_util/depth_anything/__init__.py +++ b/invokeai/backend/image_util/depth_anything/__init__.py @@ -65,13 +65,11 @@ class DepthAnythingDetector: self.model_size = model_size if self.model_size == "small": - self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384], localhub=True) + self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384]) if self.model_size == "base": - self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768], localhub=True) + self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768]) if self.model_size == "large": - self.model = DPT_DINOv2( - encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024], localhub=True - ) + self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024]) self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu")) self.model.eval() diff --git a/invokeai/backend/image_util/depth_anything/model/dpt.py b/invokeai/backend/image_util/depth_anything/model/dpt.py index 3be6f97fc4..e1101b3c39 100644 --- a/invokeai/backend/image_util/depth_anything/model/dpt.py +++ b/invokeai/backend/image_util/depth_anything/model/dpt.py @@ -22,9 +22,7 @@ def _make_fusion_block(features, use_bn, size=None): class DPTHead(nn.Module): - def __init__( - self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False - ): + def __init__(self, nclass, in_channels, features, out_channels, use_bn=False, use_clstoken=False): super(DPTHead, self).__init__() self.nclass = nclass @@ -138,19 +136,18 @@ class DPTHead(nn.Module): class DPT_DINOv2(nn.Module): def __init__( self, + features, + out_channels, encoder="vitl", - features=256, - out_channels=[256, 512, 1024, 1024], use_bn=False, use_clstoken=False, - localhub=True, ): super(DPT_DINOv2, self).__init__() assert encoder in ["vits", "vitb", "vitl"] # # in case the Internet connection is not stable, please load the DINOv2 locally - # if localhub: + # if use_local: # self.pretrained = torch.hub.load( # torchhub_path / "facebookresearch_dinov2_main", # "dinov2_{:}14".format(encoder), @@ -170,7 +167,7 @@ class DPT_DINOv2(nn.Module): dim = self.pretrained.blocks[0].attn.qkv.in_features - self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) + self.depth_head = DPTHead(1, dim, features, out_channels=out_channels, use_bn=use_bn, use_clstoken=use_clstoken) def forward(self, x): h, w = x.shape[-2:]