fix: make DA Pipeline a subclass of RawModel

This commit is contained in:
blessedcoolant 2024-07-31 21:14:49 +05:30
parent 18f89ed5ed
commit b4cf78a95d
2 changed files with 9 additions and 6 deletions

View File

@ -1,10 +1,13 @@
from typing import cast from typing import Optional, cast
import torch
from PIL import Image from PIL import Image
from transformers.pipelines import DepthEstimationPipeline from transformers.pipelines import DepthEstimationPipeline
from invokeai.backend.raw_model import RawModel
class DepthAnythingPipeline:
class DepthAnythingPipeline(RawModel):
"""Custom wrapper for the Depth Estimation pipeline from transformers adding compatibility """Custom wrapper for the Depth Estimation pipeline from transformers adding compatibility
for Invoke's Model Management System""" for Invoke's Model Management System"""
@ -20,6 +23,9 @@ class DepthAnythingPipeline:
depth_map = depth_map.resize((resolution, new_height)) depth_map = depth_map.resize((resolution, new_height))
return depth_map return depth_map
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
pass
def calc_size(self) -> int: def calc_size(self) -> int:
from invokeai.backend.model_manager.load.model_util import calc_module_size from invokeai.backend.model_manager.load.model_util import calc_module_size

View File

@ -31,16 +31,13 @@ from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapt
from typing_extensions import Annotated, Any, Dict from typing_extensions import Annotated, Any, Dict
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.model_hash.hash_validator import validate_hash from invokeai.backend.model_hash.hash_validator import validate_hash
from invokeai.backend.raw_model import RawModel from invokeai.backend.raw_model import RawModel
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
# ModelMixin is the base class for all diffusers and transformers models # ModelMixin is the base class for all diffusers and transformers models
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime # RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
AnyModel = Union[ AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline]
ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline, DepthAnythingPipeline
]
class InvalidModelConfigException(Exception): class InvalidModelConfigException(Exception):