mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: lint & other minor issues
This commit is contained in:
parent
8f5e2cbcc7
commit
c859eb865e
@ -65,13 +65,11 @@ class DepthAnythingDetector:
|
|||||||
self.model_size = model_size
|
self.model_size = model_size
|
||||||
|
|
||||||
if self.model_size == "small":
|
if self.model_size == "small":
|
||||||
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384], localhub=True)
|
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
||||||
if self.model_size == "base":
|
if self.model_size == "base":
|
||||||
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768], localhub=True)
|
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||||
if self.model_size == "large":
|
if self.model_size == "large":
|
||||||
self.model = DPT_DINOv2(
|
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||||
encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024], localhub=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
@ -22,9 +22,7 @@ def _make_fusion_block(features, use_bn, size=None):
|
|||||||
|
|
||||||
|
|
||||||
class DPTHead(nn.Module):
|
class DPTHead(nn.Module):
|
||||||
def __init__(
|
def __init__(self, nclass, in_channels, features, out_channels, use_bn=False, use_clstoken=False):
|
||||||
self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False
|
|
||||||
):
|
|
||||||
super(DPTHead, self).__init__()
|
super(DPTHead, self).__init__()
|
||||||
|
|
||||||
self.nclass = nclass
|
self.nclass = nclass
|
||||||
@ -138,19 +136,18 @@ class DPTHead(nn.Module):
|
|||||||
class DPT_DINOv2(nn.Module):
|
class DPT_DINOv2(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
features,
|
||||||
|
out_channels,
|
||||||
encoder="vitl",
|
encoder="vitl",
|
||||||
features=256,
|
|
||||||
out_channels=[256, 512, 1024, 1024],
|
|
||||||
use_bn=False,
|
use_bn=False,
|
||||||
use_clstoken=False,
|
use_clstoken=False,
|
||||||
localhub=True,
|
|
||||||
):
|
):
|
||||||
super(DPT_DINOv2, self).__init__()
|
super(DPT_DINOv2, self).__init__()
|
||||||
|
|
||||||
assert encoder in ["vits", "vitb", "vitl"]
|
assert encoder in ["vits", "vitb", "vitl"]
|
||||||
|
|
||||||
# # in case the Internet connection is not stable, please load the DINOv2 locally
|
# # in case the Internet connection is not stable, please load the DINOv2 locally
|
||||||
# if localhub:
|
# if use_local:
|
||||||
# self.pretrained = torch.hub.load(
|
# self.pretrained = torch.hub.load(
|
||||||
# torchhub_path / "facebookresearch_dinov2_main",
|
# torchhub_path / "facebookresearch_dinov2_main",
|
||||||
# "dinov2_{:}14".format(encoder),
|
# "dinov2_{:}14".format(encoder),
|
||||||
@ -170,7 +167,7 @@ class DPT_DINOv2(nn.Module):
|
|||||||
|
|
||||||
dim = self.pretrained.blocks[0].attn.qkv.in_features
|
dim = self.pretrained.blocks[0].attn.qkv.in_features
|
||||||
|
|
||||||
self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
|
self.depth_head = DPTHead(1, dim, features, out_channels=out_channels, use_bn=use_bn, use_clstoken=use_clstoken)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h, w = x.shape[-2:]
|
h, w = x.shape[-2:]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user