mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
migrate lineart_anime controlnet processor to MM caching system
This commit is contained in:
parent
c6dcbce043
commit
08d7bd2a0b
@ -41,7 +41,7 @@ from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, De
|
||||
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
|
||||
from invokeai.backend.image_util.hed import HED_MODEL, HEDProcessor
|
||||
from invokeai.backend.image_util.lineart import COARSE_MODEL, LINEART_MODEL, LineartProcessor
|
||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||
from invokeai.backend.image_util.lineart_anime import LINEART_ANIME_MODEL, LineartAnimeProcessor
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
from invokeai.backend.model_manager.load import LoadedModelWithoutConfig
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
@ -287,12 +287,13 @@ class LineartAnimeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
processor = LineartAnimeProcessor()
|
||||
processed_image = processor.run(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
with self._context.models.load_remote_model(LINEART_ANIME_MODEL) as model_sd:
|
||||
processor = LineartAnimeProcessor(model_sd)
|
||||
processed_image = processor.run(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
|
@ -1,14 +1,13 @@
|
||||
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
|
||||
|
||||
import functools
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.util import (
|
||||
@ -18,6 +17,8 @@ from invokeai.backend.image_util.util import (
|
||||
resize_image_to_resolution,
|
||||
)
|
||||
|
||||
LINEART_ANIME_MODEL = "lllyasviel/Annotators::/netG.pth"
|
||||
|
||||
|
||||
class UnetGenerator(nn.Module):
|
||||
"""Create a Unet-based generator"""
|
||||
@ -142,16 +143,14 @@ class UnetSkipConnectionBlock(nn.Module):
|
||||
class LineartAnimeProcessor:
|
||||
"""Processes an image to detect lineart."""
|
||||
|
||||
def __init__(self):
|
||||
model_path = hf_hub_download("lllyasviel/Annotators", "netG.pth")
|
||||
def __init__(self, model_sd: Dict[str, torch.Tensor]):
|
||||
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
||||
self.model = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False)
|
||||
ckpt = torch.load(model_path)
|
||||
for key in list(ckpt.keys()):
|
||||
for key in list(model_sd.keys()):
|
||||
if "module." in key:
|
||||
ckpt[key.replace("module.", "")] = ckpt[key]
|
||||
del ckpt[key]
|
||||
self.model.load_state_dict(ckpt)
|
||||
model_sd[key.replace("module.", "")] = model_sd[key]
|
||||
del model_sd[key]
|
||||
self.model.load_state_dict(model_sd)
|
||||
self.model.eval()
|
||||
|
||||
def to(self, device: torch.device):
|
||||
|
Loading…
Reference in New Issue
Block a user