mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: lint & other minor issues
This commit is contained in:
parent
8f5e2cbcc7
commit
c859eb865e
@ -65,13 +65,11 @@ class DepthAnythingDetector:
|
||||
self.model_size = model_size
|
||||
|
||||
if self.model_size == "small":
|
||||
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384], localhub=True)
|
||||
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
||||
if self.model_size == "base":
|
||||
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768], localhub=True)
|
||||
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||
if self.model_size == "large":
|
||||
self.model = DPT_DINOv2(
|
||||
encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024], localhub=True
|
||||
)
|
||||
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||
|
||||
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
||||
self.model.eval()
|
||||
|
@ -22,9 +22,7 @@ def _make_fusion_block(features, use_bn, size=None):
|
||||
|
||||
|
||||
class DPTHead(nn.Module):
|
||||
def __init__(
|
||||
self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False
|
||||
):
|
||||
def __init__(self, nclass, in_channels, features, out_channels, use_bn=False, use_clstoken=False):
|
||||
super(DPTHead, self).__init__()
|
||||
|
||||
self.nclass = nclass
|
||||
@ -138,19 +136,18 @@ class DPTHead(nn.Module):
|
||||
class DPT_DINOv2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
features,
|
||||
out_channels,
|
||||
encoder="vitl",
|
||||
features=256,
|
||||
out_channels=[256, 512, 1024, 1024],
|
||||
use_bn=False,
|
||||
use_clstoken=False,
|
||||
localhub=True,
|
||||
):
|
||||
super(DPT_DINOv2, self).__init__()
|
||||
|
||||
assert encoder in ["vits", "vitb", "vitl"]
|
||||
|
||||
# # in case the Internet connection is not stable, please load the DINOv2 locally
|
||||
# if localhub:
|
||||
# if use_local:
|
||||
# self.pretrained = torch.hub.load(
|
||||
# torchhub_path / "facebookresearch_dinov2_main",
|
||||
# "dinov2_{:}14".format(encoder),
|
||||
@ -170,7 +167,7 @@ class DPT_DINOv2(nn.Module):
|
||||
|
||||
dim = self.pretrained.blocks[0].attn.qkv.in_features
|
||||
|
||||
self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
|
||||
self.depth_head = DPTHead(1, dim, features, out_channels=out_channels, use_bn=use_bn, use_clstoken=use_clstoken)
|
||||
|
||||
def forward(self, x):
|
||||
h, w = x.shape[-2:]
|
||||
|
Loading…
Reference in New Issue
Block a user