feat(backend): lift managed model loading out of depthanything class

This commit is contained in:
psychedelicious 2024-04-29 08:56:00 +10:00
parent fcb071f30c
commit 1fe90c357c
2 changed files with 31 additions and 32 deletions

View File

@ -2,6 +2,7 @@
# initial implementation by Gregg Helt, 2023 # initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import bool, float from builtins import bool, float
from pathlib import Path
from typing import Dict, List, Literal, Union from typing import Dict, List, Literal, Union
import cv2 import cv2
@ -37,11 +38,12 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES
from invokeai.backend.image_util.canny import get_canny_edges from invokeai.backend.image_util.canny import get_canny_edges
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
from invokeai.backend.image_util.hed import HEDProcessor from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
@ -603,11 +605,15 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res) resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image: def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
depth_anything_detector = DepthAnythingDetector(context) def loader(model_path: Path):
depth_anything_detector.load_model(model_size=self.model_size) return DepthAnythingDetector.load_model(
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
)
processed_image = depth_anything_detector(image=image, resolution=self.resolution) with context.models.load_ckpt_from_url(source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader) as model:
return processed_image depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
return processed_image
@invocation( @invocation(

View File

@ -1,4 +1,5 @@
from typing import Literal, Optional, Union from pathlib import Path
from typing import Literal
import cv2 import cv2
import numpy as np import numpy as np
@ -9,10 +10,8 @@ from PIL import Image
from torchvision.transforms import Compose from torchvision.transforms import Compose
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config.config_default import get_config
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2 from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
config = get_config() config = get_config()
@ -43,33 +42,27 @@ transform = Compose(
class DepthAnythingDetector: class DepthAnythingDetector:
def __init__(self, context: InvocationContext) -> None: def __init__(self, model: DPT_DINOv2, device: torch.device) -> None:
self.context = context self.model = model
self.model: Optional[DPT_DINOv2] = None self.device = device
self.model_size: Union[Literal["large", "base", "small"], None] = None
self.device = TorchDevice.choose_torch_device()
def load_model(self, model_size: Literal["large", "base", "small"] = "small") -> DPT_DINOv2: @staticmethod
depth_anything_model_path = self.context.models.download_and_cache_ckpt(DEPTH_ANYTHING_MODELS[model_size]) def load_model(
model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small"
) -> DPT_DINOv2:
match model_size:
case "small":
model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
case "base":
model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
case "large":
model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
if not self.model or model_size != self.model_size: model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu"))
del self.model model.eval()
self.model_size = model_size
match self.model_size: model.to(device)
case "small": return model
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
case "base":
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
case "large":
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
assert self.model is not None
self.model.load_state_dict(torch.load(depth_anything_model_path.as_posix(), map_location="cpu"))
self.model.eval()
self.model.to(self.device)
return self.model
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image: def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
if not self.model: if not self.model: