diff --git a/invokeai/backend/image_util/depth_anything/__init__.py b/invokeai/backend/image_util/depth_anything/__init__.py index ddb0aaa0c4..78b828d816 100644 --- a/invokeai/backend/image_util/depth_anything/__init__.py +++ b/invokeai/backend/image_util/depth_anything/__init__.py @@ -17,6 +17,8 @@ from invokeai.backend.util.util import download_with_progress_bar config = InvokeAIAppConfig.get_config() +DEPTH_ANYTHING_MODEL_SIZES = Literal["large", "base", "small"] + DEPTH_ANYTHING_MODELS = { "large": { "url": "https://huggingface.co/spaces/LiheYoung/Depth-Anything/resolve/main/checkpoints/depth_anything_vitl14.pth?download=true", @@ -53,9 +55,9 @@ transform = Compose( class DepthAnythingDetector: def __init__(self) -> None: self.model = None - self.model_size: Union[Literal["large", "base", "small"], None] = None + self.model_size: Union[DEPTH_ANYTHING_MODEL_SIZES, None] = None - def load_model(self, model_size=Literal["large", "base", "small"]): + def load_model(self, model_size: DEPTH_ANYTHING_MODEL_SIZES = "small"): DEPTH_ANYTHING_MODEL_PATH = pathlib.Path(config.models_path / DEPTH_ANYTHING_MODELS[model_size]["local"]) if not DEPTH_ANYTHING_MODEL_PATH.exists(): download_with_progress_bar(DEPTH_ANYTHING_MODELS[model_size]["url"], DEPTH_ANYTHING_MODEL_PATH) @@ -84,16 +86,19 @@ class DepthAnythingDetector: self.model.to(device) return self - def __call__(self, image, resolution=512): - image = np.array(image, dtype=np.uint8) - image = image[:, :, ::-1] / 255.0 + def __call__(self, image: Image.Image, resolution: int = 512): + if self.model is None: + raise Exception("Depth Anything Model not loaded") - image_height, image_width = image.shape[:2] - image = transform({"image": image})["image"] - image = torch.from_numpy(image).unsqueeze(0).to(choose_torch_device()) + np_image = np.array(image, dtype=np.uint8) + np_image = np_image[:, :, ::-1] / 255.0 + + image_height, image_width = np_image.shape[:2] + np_image = transform({"image": image})["image"] + tensor_image = torch.from_numpy(np_image).unsqueeze(0).to(choose_torch_device()) with torch.no_grad(): - depth = self.model(image) + depth = self.model(tensor_image) depth = F.interpolate(depth[None], (image_height, image_width), mode="bilinear", align_corners=False)[0, 0] depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0