mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(backend): lift managed model loading out of depthanything class
This commit is contained in:
parent
fcb071f30c
commit
1fe90c357c
@ -2,6 +2,7 @@
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
from builtins import bool, float
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Union
|
||||
|
||||
import cv2
|
||||
@ -37,11 +38,12 @@ from invokeai.app.invocations.util import validate_begin_end_step, validate_weig
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES
|
||||
from invokeai.backend.image_util.canny import get_canny_edges
|
||||
from invokeai.backend.image_util.depth_anything import DepthAnythingDetector
|
||||
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
|
||||
from invokeai.backend.image_util.dw_openpose import DWOpenposeDetector
|
||||
from invokeai.backend.image_util.hed import HEDProcessor
|
||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||
|
||||
@ -603,11 +605,15 @@ class DepthAnythingImageProcessorInvocation(ImageProcessorInvocation):
|
||||
resolution: int = InputField(default=512, ge=64, multiple_of=64, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image, context: InvocationContext) -> Image.Image:
|
||||
depth_anything_detector = DepthAnythingDetector(context)
|
||||
depth_anything_detector.load_model(model_size=self.model_size)
|
||||
def loader(model_path: Path):
|
||||
return DepthAnythingDetector.load_model(
|
||||
model_path, model_size=self.model_size, device=TorchDevice.choose_torch_device()
|
||||
)
|
||||
|
||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||
return processed_image
|
||||
with context.models.load_ckpt_from_url(source=DEPTH_ANYTHING_MODELS[self.model_size], loader=loader) as model:
|
||||
depth_anything_detector = DepthAnythingDetector(model, TorchDevice.choose_torch_device())
|
||||
processed_image = depth_anything_detector(image=image, resolution=self.resolution)
|
||||
return processed_image
|
||||
|
||||
|
||||
@invocation(
|
||||
|
@ -1,4 +1,5 @@
|
||||
from typing import Literal, Optional, Union
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@ -9,10 +10,8 @@ from PIL import Image
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
config = get_config()
|
||||
@ -43,33 +42,27 @@ transform = Compose(
|
||||
|
||||
|
||||
class DepthAnythingDetector:
|
||||
def __init__(self, context: InvocationContext) -> None:
|
||||
self.context = context
|
||||
self.model: Optional[DPT_DINOv2] = None
|
||||
self.model_size: Union[Literal["large", "base", "small"], None] = None
|
||||
self.device = TorchDevice.choose_torch_device()
|
||||
def __init__(self, model: DPT_DINOv2, device: torch.device) -> None:
|
||||
self.model = model
|
||||
self.device = device
|
||||
|
||||
def load_model(self, model_size: Literal["large", "base", "small"] = "small") -> DPT_DINOv2:
|
||||
depth_anything_model_path = self.context.models.download_and_cache_ckpt(DEPTH_ANYTHING_MODELS[model_size])
|
||||
@staticmethod
|
||||
def load_model(
|
||||
model_path: Path, device: torch.device, model_size: Literal["large", "base", "small"] = "small"
|
||||
) -> DPT_DINOv2:
|
||||
match model_size:
|
||||
case "small":
|
||||
model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
||||
case "base":
|
||||
model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||
case "large":
|
||||
model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||
|
||||
if not self.model or model_size != self.model_size:
|
||||
del self.model
|
||||
self.model_size = model_size
|
||||
model.load_state_dict(torch.load(model_path.as_posix(), map_location="cpu"))
|
||||
model.eval()
|
||||
|
||||
match self.model_size:
|
||||
case "small":
|
||||
self.model = DPT_DINOv2(encoder="vits", features=64, out_channels=[48, 96, 192, 384])
|
||||
case "base":
|
||||
self.model = DPT_DINOv2(encoder="vitb", features=128, out_channels=[96, 192, 384, 768])
|
||||
case "large":
|
||||
self.model = DPT_DINOv2(encoder="vitl", features=256, out_channels=[256, 512, 1024, 1024])
|
||||
|
||||
assert self.model is not None
|
||||
self.model.load_state_dict(torch.load(depth_anything_model_path.as_posix(), map_location="cpu"))
|
||||
self.model.eval()
|
||||
|
||||
self.model.to(self.device)
|
||||
return self.model
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
def __call__(self, image: Image.Image, resolution: int = 512) -> Image.Image:
|
||||
if not self.model:
|
||||
|
Loading…
Reference in New Issue
Block a user