migrate lineart_anime controlnet processor to MM caching system

This commit is contained in:
Lincoln Stein
2024-07-01 20:42:41 -04:00
parent c6dcbce043
commit 08d7bd2a0b
2 changed files with 16 additions and 16 deletions

View File

@ -41,7 +41,7 @@ from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, De
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
from invokeai.backend.image_util.hed import HED_MODEL, HEDProcessor from invokeai.backend.image_util.hed import HED_MODEL, HEDProcessor
from invokeai.backend.image_util.lineart import COARSE_MODEL, LINEART_MODEL, LineartProcessor from invokeai.backend.image_util.lineart import COARSE_MODEL, LINEART_MODEL, LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor from invokeai.backend.image_util.lineart_anime import LINEART_ANIME_MODEL, LineartAnimeProcessor
from invokeai.backend.image_util.util import np_to_pil, pil_to_np from invokeai.backend.image_util.util import np_to_pil, pil_to_np
from invokeai.backend.model_manager.load import LoadedModelWithoutConfig from invokeai.backend.model_manager.load import LoadedModelWithoutConfig
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
@ -287,12 +287,13 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image: Image.Image) -> Image.Image:
processor = LineartAnimeProcessor() with self._context.models.load_remote_model(LINEART_ANIME_MODEL) as model_sd:
processed_image = processor.run( processor = LineartAnimeProcessor(model_sd)
image, processed_image = processor.run(
detect_resolution=self.detect_resolution, image,
image_resolution=self.image_resolution, detect_resolution=self.detect_resolution,
) image_resolution=self.image_resolution,
)
return processed_image return processed_image

View File

@ -1,14 +1,13 @@
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license).""" """Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
import functools import functools
from typing import Optional from typing import Dict, Optional
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from einops import rearrange from einops import rearrange
from huggingface_hub import hf_hub_download
from PIL import Image from PIL import Image
from invokeai.backend.image_util.util import ( from invokeai.backend.image_util.util import (
@ -18,6 +17,8 @@ from invokeai.backend.image_util.util import (
resize_image_to_resolution, resize_image_to_resolution,
) )
LINEART_ANIME_MODEL = "lllyasviel/Annotators::/netG.pth"
class UnetGenerator(nn.Module): class UnetGenerator(nn.Module):
"""Create a Unet-based generator""" """Create a Unet-based generator"""
@ -142,16 +143,14 @@ class UnetSkipConnectionBlock(nn.Module):
class LineartAnimeProcessor: class LineartAnimeProcessor:
"""Processes an image to detect lineart.""" """Processes an image to detect lineart."""
def __init__(self): def __init__(self, model_sd: Dict[str, torch.Tensor]):
model_path = hf_hub_download("lllyasviel/Annotators", "netG.pth")
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
self.model = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False) self.model = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
ckpt = torch.load(model_path) for key in list(model_sd.keys()):
for key in list(ckpt.keys()):
if "module." in key: if "module." in key:
ckpt[key.replace("module.", "")] = ckpt[key] model_sd[key.replace("module.", "")] = model_sd[key]
del ckpt[key] del model_sd[key]
self.model.load_state_dict(ckpt) self.model.load_state_dict(model_sd)
self.model.eval() self.model.eval()
def to(self, device: torch.device): def to(self, device: torch.device):