mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
chore: cleanup DepthAnything code
This commit is contained in:
parent
7e4b462fca
commit
af660163ca
@ -574,7 +574,7 @@ DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"]
|
||||
title="Depth Anything Processor",
|
||||
tags=["controlnet", "depth", "depth anything"],
|
||||
category="controlnet",
|
||||
version="1.0.0",
|
||||
version="1.0.1",
|
||||
)
|
||||
class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Generates a depth map based on the Depth Anything algorithm"""
|
||||
@ -583,13 +583,12 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
default="small", description="The size of the depth model to use"
|
||||
)
|
||||
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
|
||||
offload: bool = InputField(default=False)
|
||||
|
||||
def run_processor(self, image: Image.Image):
|
||||
depth_anything_detector = DepthAnythingDetector()
|
||||
depth_anything_detector.load_model(model_size=self.model_size)
|
||||
|
||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution, offload=self.offload)
|
||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||
return processed_image
|
||||
|
||||
|
||||
|
@ -13,9 +13,11 @@ from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.backend.util.util import download_with_progress_bar
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
logger = InvokeAILogger.get_logger(config=config)
|
||||
|
||||
DEPTH_ANYTHING_MODELS = {
|
||||
"large": {
|
||||
@ -54,8 +56,9 @@ class DepthAnythingDetector:
|
||||
def __init__(self) -> None:
|
||||
self.model = None
|
||||
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
||||
self.device = choose_torch_device()
|
||||
|
||||
def load_model(self, model_size=Literal["large", "base", "small"]):
|
||||
def load_model(self, model_size: Literal["large", "base", "small"] = "small"):
|
||||
DEPTH_ANYTHING_MODEL_PATH = pathlib.Path(config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"])
|
||||
if not DEPTH_ANYTHING_MODEL_PATH.exists():
|
||||
download_with_progress_bar(DEPTH_ANYTHING_MODELS[model_size]["url"], DEPTH_ANYTHING_MODEL_PATH)
|
||||
@ -71,8 +74,6 @@ class DepthAnythingDetector:
|
||||
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||
case "large":
|
||||
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||
case _:
|
||||
raise TypeError("Not a supported model")
|
||||
|
||||
self.model.load_state_dict(torch.load(DEPTH_ANYTHING_MODEL_PATH.as_posix(), map_location="cpu"))
|
||||
self.model.eval()
|
||||
@ -80,20 +81,20 @@ class DepthAnythingDetector:
|
||||
self.model.to(choose_torch_device())
|
||||
return self.model
|
||||
|
||||
def to(self, device):
|
||||
self.model.to(device)
|
||||
return self
|
||||
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
||||
if not self.model:
|
||||
logger.warn("DepthAnything model was not loaded. Returning original image")
|
||||
return image
|
||||
|
||||
def __call__(self, image, resolution=512, offload=False):
|
||||
image = np.array(image, dtype=np.uint8)
|
||||
image = image[:, :, ::-1] / 255.0
|
||||
np_image = np.array(image, dtype=np.uint8)
|
||||
np_image = np_image[:, :, ::-1] / 255.0
|
||||
|
||||
image_height, image_width = image.shape[:2]
|
||||
image = transform({"image": image})["image"]
|
||||
image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device())
|
||||
image_height, image_width = np_image.shape[:2]
|
||||
np_image = transform({"image": image})["image"]
|
||||
tensor_image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device())
|
||||
|
||||
with torch.no_grad():
|
||||
depth = self.model(image)
|
||||
depth = self.model(tensor_image)
|
||||
depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0]
|
||||
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
||||
|
||||
@ -103,7 +104,4 @@ class DepthAnythingDetector:
|
||||
new_height = int(image_height * (resolution / image_width))
|
||||
depth_map = depth_map.resize((resolution, new_height))
|
||||
|
||||
if offload:
|
||||
del self.model
|
||||
|
||||
return depth_map
|
||||
|
Loading…
Reference in New Issue
Block a user