From b4cf78a95d2257f396fc4d05ba99736cdb11130d Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Wed, 31 Jul 2024 21:14:49 +0530 Subject: [PATCH] fix: make DA Pipeline a subclass of RawModel --- .../depth_anything/depth_anything_pipeline.py | 10 ++++++++-- invokeai/backend/model_manager/config.py | 5 +---- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/invokeai/backend/image_util/depth_anything/depth_anything_pipeline.py b/invokeai/backend/image_util/depth_anything/depth_anything_pipeline.py index 35f555c27b..2217e0b285 100644 --- a/invokeai/backend/image_util/depth_anything/depth_anything_pipeline.py +++ b/invokeai/backend/image_util/depth_anything/depth_anything_pipeline.py @@ -1,10 +1,13 @@ -from typing import cast +from typing import Optional, cast +import torch from PIL import Image from transformers.pipelines import DepthEstimationPipeline +from invokeai.backend.raw_model import RawModel -class DepthAnythingPipeline: + +class DepthAnythingPipeline(RawModel): """Custom wrapper for the Depth Estimation pipeline from transformers adding compatibility for Invoke's Model Management System""" @@ -20,6 +23,9 @@ class DepthAnythingPipeline: depth_map = depth_map.resize((resolution, new_height)) return depth_map + def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None): + pass + def calc_size(self) -> int: from invokeai.backend.model_manager.load.model_util import calc_module_size diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 0ec57fc538..332ac6c8fa 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -31,16 +31,13 @@ from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapt from typing_extensions import Annotated, Any, Dict from invokeai.app.util.misc import uuid_string -from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline from invokeai.backend.model_hash.hash_validator import validate_hash from invokeai.backend.raw_model import RawModel from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES # ModelMixin is the base class for all diffusers and transformers models # RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime -AnyModel = Union[ - ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline, DepthAnythingPipeline -] +AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline] class InvalidModelConfigException(Exception):