mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat: Add Resolution to DepthAnything
This commit is contained in:
committed by
Kent Keirsey
parent
39fedb090b
commit
7cb49e65bd
@ -64,12 +64,15 @@ class DepthAnythingDetector:
|
||||
del self.model
|
||||
self.model_size = model_size
|
||||
|
||||
if self.model_size == "small":
|
||||
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
||||
if self.model_size == "base":
|
||||
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||
if self.model_size == "large":
|
||||
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||
match self.model_size:
|
||||
case "small":
|
||||
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
||||
case "base":
|
||||
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||
case "large":
|
||||
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||
case _:
|
||||
raise TypeError("Not a supported model")
|
||||
|
||||
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
||||
self.model.eval()
|
||||
@ -81,12 +84,11 @@ class DepthAnythingDetector:
|
||||
self.model.to(device)
|
||||
return self
|
||||
|
||||
def __call__(self, image, offload=False):
|
||||
def __call__(self, image, resolution=512, offload=False):
|
||||
image = np.array(image, dtype=np.uint8)
|
||||
original_width, original_height = image.shape[:2]
|
||||
image = image[:, :, ::-1] / 255.0
|
||||
|
||||
image_width, image_height = image.shape[:2]
|
||||
image_height, image_width = image.shape[:2]
|
||||
image = transform({"image": image})["image"]
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device())
|
||||
|
||||
@ -97,7 +99,9 @@ class DepthAnythingDetector:
|
||||
|
||||
depth_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8)
|
||||
depth_map = Image.fromarray(depth_map)
|
||||
depth_map = depth_map.resize((original_height, original_width))
|
||||
|
||||
new_height = int(image_height * (resolution / image_width))
|
||||
depth_map = depth_map.resize((resolution, new_height))
|
||||
|
||||
if offload:
|
||||
del self.model
|
||||
|
Reference in New Issue
Block a user