mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: Make DepthAnything work with Invoke's Model Management
This commit is contained in:
parent
f170697ebe
commit
18f89ed5ed
@ -2,6 +2,7 @@
|
|||||||
# initial implementation by Gregg Helt, 2023
|
# initial implementation by Gregg Helt, 2023
|
||||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||||
from builtins import bool, float
|
from builtins import bool, float
|
||||||
|
from pathlib import Path
|
||||||
from typing import Dict, List, Literal, Union
|
from typing import Dict, List, Literal, Union
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@ -21,6 +22,7 @@ from controlnet_aux.util import HWC3, ade_palette
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
from transformers import pipeline
|
from transformers import pipeline
|
||||||
|
from transformers.pipelines import DepthEstimationPipeline
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@ -44,6 +46,7 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
|
|||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, heuristic_resize
|
||||||
from invokeai.backend.image_util.canny import get_canny_edges
|
from invokeai.backend.image_util.canny import get_canny_edges
|
||||||
|
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||||
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
||||||
from invokeai.backend.image_util.hed import HEDProcessor
|
from invokeai.backend.image_util.hed import HEDProcessor
|
||||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||||
@ -614,9 +617,16 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
depth_anything_pipeline = pipeline(task="depth-estimation", model=DEPTH_ANYTHING_MODELS[self.model_size])
|
def load_depth_anything(model_path: Path):
|
||||||
depth_map = depth_anything_pipeline(image)["depth"]
|
depth_anything_pipeline = pipeline(model=str(model_path), task="depth-estimation", local_files_only=True)
|
||||||
return depth_map
|
assert isinstance(depth_anything_pipeline, DepthEstimationPipeline)
|
||||||
|
return DepthAnythingPipeline(depth_anything_pipeline)
|
||||||
|
|
||||||
|
with self._context.models.load_remote_model(
|
||||||
|
source=DEPTH_ANYTHING_MODELS[self.model_size], loader=load_depth_anything
|
||||||
|
) as depth_anything_detector:
|
||||||
|
assert isinstance(depth_anything_detector, DepthAnythingPipeline)
|
||||||
|
return depth_anything_detector.generate_depth(image, self.resolution)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
|
@ -0,0 +1,26 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
from transformers.pipelines import DepthEstimationPipeline
|
||||||
|
|
||||||
|
|
||||||
|
class DepthAnythingPipeline:
|
||||||
|
"""Custom wrapper for the Depth Estimation pipeline from transformers adding compatibility
|
||||||
|
for Invoke's Model Management System"""
|
||||||
|
|
||||||
|
def __init__(self, pipeline: DepthEstimationPipeline) -> None:
|
||||||
|
self.pipeline = pipeline
|
||||||
|
|
||||||
|
def generate_depth(self, image: Image.Image, resolution: int = 512):
|
||||||
|
image_width, image_height = image.size
|
||||||
|
depth_map = self.pipeline(image)["depth"]
|
||||||
|
depth_map = cast(Image.Image, depth_map)
|
||||||
|
|
||||||
|
new_height = int(image_height * (resolution / image_width))
|
||||||
|
depth_map = depth_map.resize((resolution, new_height))
|
||||||
|
return depth_map
|
||||||
|
|
||||||
|
def calc_size(self) -> int:
|
||||||
|
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||||
|
|
||||||
|
return calc_module_size(self.pipeline.model)
|
@ -31,13 +31,16 @@ from pydantic import BaseModel, ConfigDict, Discriminator, Field, Tag, TypeAdapt
|
|||||||
from typing_extensions import Annotated, Any, Dict
|
from typing_extensions import Annotated, Any, Dict
|
||||||
|
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||||
from invokeai.backend.model_hash.hash_validator import validate_hash
|
from invokeai.backend.model_hash.hash_validator import validate_hash
|
||||||
from invokeai.backend.raw_model import RawModel
|
from invokeai.backend.raw_model import RawModel
|
||||||
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
|
||||||
|
|
||||||
# ModelMixin is the base class for all diffusers and transformers models
|
# ModelMixin is the base class for all diffusers and transformers models
|
||||||
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
|
||||||
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline]
|
AnyModel = Union[
|
||||||
|
ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline, DepthAnythingPipeline
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class InvalidModelConfigException(Exception):
|
class InvalidModelConfigException(Exception):
|
||||||
|
@ -11,6 +11,7 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
|||||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
|
|
||||||
|
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||||
from invokeai.backend.lora import LoRAModelRaw
|
from invokeai.backend.lora import LoRAModelRaw
|
||||||
from invokeai.backend.model_manager.config import AnyModel
|
from invokeai.backend.model_manager.config import AnyModel
|
||||||
@ -34,7 +35,9 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
|||||||
elif isinstance(model, CLIPTokenizer):
|
elif isinstance(model, CLIPTokenizer):
|
||||||
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
|
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
|
||||||
return 0
|
return 0
|
||||||
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)):
|
elif isinstance(
|
||||||
|
model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel, DepthAnythingPipeline)
|
||||||
|
):
|
||||||
return model.calc_size()
|
return model.calc_size()
|
||||||
else:
|
else:
|
||||||
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
||||||
|
Loading…
Reference in New Issue
Block a user