migrate lineart_anime controlnet processor to MM caching system

This commit is contained in:
Lincoln Stein 2024-07-01 20:42:41 -04:00
parent c6dcbce043
commit 08d7bd2a0b
2 changed files with 16 additions and 16 deletions

View File

@ -41,7 +41,7 @@ from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, De
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
from invokeai.backend.image_util.hed import HED_MODEL, HEDProcessor
from invokeai.backend.image_util.lineart import COARSE_MODEL, LINEART_MODEL, LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from invokeai.backend.image_util.lineart_anime import LINEART_ANIME_MODEL, LineartAnimeProcessor
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
from invokeai.backend.model_manager.load import LoadedModelWithoutConfig
from invokeai.backend.util.devices import TorchDevice
@ -287,12 +287,13 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
processor = LineartAnimeProcessor()
processed_image = processor.run(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
)
with self._context.models.load_remote_model(LINEART_ANIME_MODEL) as model_sd:
processor = LineartAnimeProcessor(model_sd)
processed_image = processor.run(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
)
return processed_image

View File

@ -1,14 +1,13 @@
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
import functools
from typing import Optional
from typing import Dict, Optional
import cv2
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from huggingface_hub import hf_hub_download
from PIL import Image
from invokeai.backend.image_util.util import (
@ -18,6 +17,8 @@ from invokeai.backend.image_util.util import (
resize_image_to_resolution,
)
LINEART_ANIME_MODEL = "lllyasviel/Annotators::/netG.pth"
class UnetGenerator(nn.Module):
"""Create a Unet-based generator"""
@ -142,16 +143,14 @@ class UnetSkipConnectionBlock(nn.Module):
class LineartAnimeProcessor:
"""Processes an image to detect lineart."""
def __init__(self):
model_path = hf_hub_download("lllyasviel/Annotators", "netG.pth")
def __init__(self, model_sd: Dict[str, torch.Tensor]):
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
self.model = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
ckpt = torch.load(model_path)
for key in list(ckpt.keys()):
for key in list(model_sd.keys()):
if "module." in key:
ckpt[key.replace("module.", "")] = ckpt[key]
del ckpt[key]
self.model.load_state_dict(ckpt)
model_sd[key.replace("module.", "")] = model_sd[key]
del model_sd[key]
self.model.load_state_dict(model_sd)
self.model.eval()
def to(self, device: torch.device):