mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: make DA Pipeline a subclass of RawModel
This commit is contained in:
parent
18f89ed5ed
commit
b4cf78a95d
@ -1,10 +1,13 @@
|
|||||||
from typing import cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers.pipelines import DepthEstimationPipeline
|
from transformers.pipelines import DepthEstimationPipeline
|
||||||
|
|
||||||
|
from invokeai.backend.raw_model import RawModel
|
||||||
|
|
||||||
class DepthAnythingPipeline:
|
|
||||||
|
class DepthAnythingPipeline(RawModel):
|
||||||
"""Custom wrapper for the Depth Estimation pipeline from transformers adding compatibility
|
"""Custom wrapper for the Depth Estimation pipeline from transformers adding compatibility
|
||||||
for Invoke's Model Management System"""
|
for Invoke's Model Management System"""
|
||||||
|
|
||||||
@ -20,6 +23,9 @@ class DepthAnythingPipeline:
|
|||||||
depth_map = depth_map.resize((resolution, new_height))
|
depth_map = depth_map.resize((resolution, new_height))
|
||||||
return depth_map
|
return depth_map
|
||||||
|
|
||||||
|
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
|
||||||
|
pass
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||||
|
|
||||||
|
@ -31,16 +31,13 @@ from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapt
|
|||||||
from typing_extensions import Annotated, Any, Dict
|
from typing_extensions import Annotated, Any, Dict
|
||||||
|
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
|
||||||
from invokeai.backend.model_hash.hash_validator import validate_hash
|
from invokeai.backend.model_hash.hash_validator import validate_hash
|
||||||
from invokeai.backend.raw_model import RawModel
|
from invokeai.backend.raw_model import RawModel
|
||||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||||
|
|
||||||
# ModelMixin is the base class for all diffusers and transformers models
|
# ModelMixin is the base class for all diffusers and transformers models
|
||||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||||
AnyModel = Union[
|
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline]
|
||||||
ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline, DepthAnythingPipeline
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class InvalidModelConfigException(Exception):
|
class InvalidModelConfigException(Exception):
|
||||||
|
Loading…
Reference in New Issue
Block a user