From 18f89ed5edd36fcb9ecf4109eb5975b189fc1ff2 Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 31 Jul 2024 03:57:54 +0530 Subject: [PATCH] fix: Make DepthAnything work with Invoke's Model Management --- .../controlnet_image_processors.py | 16 +++++++++--- .../depth_anything/depth_anything_pipeline.py | 26 +++++++++++++++++++ invokeai/backend/model_manager/config.py | 5 +++- .../backend/model_manager/load/model_util.py | 5 +++- 4 files changed, 47 insertions(+), 5 deletions(-) create mode 100644 invokeai/backend/image_util/depth_anything/depth_anything_pipeline.py diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 1b022a071d..cd56681c4a 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -2,6 +2,7 @@ # initial implementation by Gregg Helt, 2023 # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux from builtins import bool, float +from pathlib import Path from typing import Dict, List, Literal, Union import cv2 @@ -21,6 +22,7 @@ from controlnet_aux.util import HWC3, ade_palette from PIL import Image from pydantic import BaseModel, Field, field_validator, model_validator from transformers import pipeline +from transformers.pipelines import DepthEstimationPipeline from invokeai.app.invocations.baseinvocation import ( BaseInvocation, @@ -44,6 +46,7 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize from invokeai.backend.image_util.canny import get_canny_edges +from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector from invokeai.backend.image_util.hed import HEDProcessor from invokeai.backend.image_util.lineart import LineartProcessor @@ -614,9 +617,16 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation): resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) def run_processor(self, image: Image.Image) -> Image.Image: - depth_anything_pipeline = pipeline(task="depth-estimation", model=DEPTH_ANYTHING_MODELS[self.model_size]) - depth_map = depth_anything_pipeline(image)["depth"] - return depth_map + def load_depth_anything(model_path: Path): + depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True) + assert isinstance(depth_anything_pipeline, DepthEstimationPipeline) + return DepthAnythingPipeline(depth_anything_pipeline) + + with self._context.models.load_remote_model( + source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything + ) as depth_anything_detector: + assert isinstance(depth_anything_detector, DepthAnythingPipeline) + return depth_anything_detector.generate_depth(image, self.resolution) @invocation( diff --git a/invokeai/backend/image_util/depth_anything/depth_anything_pipeline.py b/invokeai/backend/image_util/depth_anything/depth_anything_pipeline.py new file mode 100644 index 0000000000..35f555c27b --- /dev/null +++ b/invokeai/backend/image_util/depth_anything/depth_anything_pipeline.py @@ -0,0 +1,26 @@ +from typing import cast + +from PIL import Image +from transformers.pipelines import DepthEstimationPipeline + + +class DepthAnythingPipeline: + """Custom wrapper for the Depth Estimation pipeline from transformers adding compatibility + for Invoke's Model Management System""" + + def __init__(self, pipeline: DepthEstimationPipeline) -> None: + self.pipeline = pipeline + + def generate_depth(self, image: Image.Image, resolution: int = 512): + image_width, image_height = image.size + depth_map = self.pipeline(image)["depth"] + depth_map = cast(Image.Image, depth_map) + + new_height = int(image_height * (resolution / image_width)) + depth_map = depth_map.resize((resolution, new_height)) + return depth_map + + def calc_size(self) -> int: + from invokeai.backend.model_manager.load.model_util import calc_module_size + + return calc_module_size(self.pipeline.model) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 332ac6c8fa..0ec57fc538 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -31,13 +31,16 @@ from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapt from typing_extensions import Annotated, Any, Dict from invokeai.app.util.misc import uuid_string +from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline from invokeai.backend.model_hash.hash_validator import validate_hash from invokeai.backend.raw_model import RawModel from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES # ModelMixin is the base class for all diffusers and transformers models # RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime -AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline] +AnyModel = Union[ + ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline, DepthAnythingPipeline +] class InvalidModelConfigException(Exception): diff --git a/invokeai/backend/model_manager/load/model_util.py b/invokeai/backend/model_manager/load/model_util.py index f070a42965..32790a2465 100644 --- a/invokeai/backend/model_manager/load/model_util.py +++ b/invokeai/backend/model_manager/load/model_util.py @@ -11,6 +11,7 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.schedulers.scheduling_utils import SchedulerMixin from transformers import CLIPTokenizer +from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_manager.config import AnyModel @@ -34,7 +35,9 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int: elif isinstance(model, CLIPTokenizer): # TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now. return 0 - elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)): + elif isinstance( + model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel, DepthAnythingPipeline) + ): return model.calc_size() else: # TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the