make lineart controlnet processor use MM ram cache; move model locations into processor class files

This commit is contained in:
Lincoln Stein
2024-07-01 19:17:15 -04:00
parent af274bedc1
commit c6dcbce043
3 changed files with 25 additions and 20 deletions

View File

@ -39,8 +39,8 @@ from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNE
from invokeai.backend.image_util.canny import get_canny_edges from invokeai.backend.image_util.canny import get_canny_edges
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
from invokeai.backend.image_util.hed import HEDProcessor from invokeai.backend.image_util.hed import HED_MODEL, HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor from invokeai.backend.image_util.lineart import COARSE_MODEL, LINEART_MODEL, LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from invokeai.backend.image_util.util import np_to_pil, pil_to_np from invokeai.backend.image_util.util import np_to_pil, pil_to_np
from invokeai.backend.model_manager.load import LoadedModelWithoutConfig from invokeai.backend.model_manager.load import LoadedModelWithoutConfig
@ -230,7 +230,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image: Image.Image) -> Image.Image:
hed_weights = self._context.models.load_remote_model("lllyasviel/Annotators::/ControlNetHED.pth") hed_weights = self._context.models.load_remote_model(HED_MODEL)
with hed_weights as weights: with hed_weights as weights:
assert isinstance(weights, dict) assert isinstance(weights, dict)
hed_processor = HEDProcessor(weights) hed_processor = HEDProcessor(weights)
@ -260,10 +260,16 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
coarse: bool = InputField(default=False, description="Whether to use coarse mode") coarse: bool = InputField(default=False, description="Whether to use coarse mode")
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image: Image.Image) -> Image.Image:
lineart_processor = LineartProcessor() model_info = self._context.models.load_remote_model(LINEART_MODEL)
processed_image = lineart_processor.run( model_coarse_info = self._context.models.load_remote_model(COARSE_MODEL)
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse with model_info as model_sd, model_coarse_info as coarse_sd:
) lineart_processor = LineartProcessor(model_sd=model_sd, coarse_sd=coarse_sd)
processed_image = lineart_processor.run(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
coarse=self.coarse,
)
return processed_image return processed_image

View File

@ -1,12 +1,12 @@
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license).""" """Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
from typing import Dict
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
from einops import rearrange from einops import rearrange
from invokeai.backend.model_manager.config import AnyModel
from PIL import Image from PIL import Image
from typing import Dict
from invokeai.backend.image_util.util import ( from invokeai.backend.image_util.util import (
nms, nms,
@ -17,6 +17,8 @@ from invokeai.backend.image_util.util import (
safe_step, safe_step,
) )
HED_MODEL = "lllyasviel/Annotators::/ControlNetHED.pth"
class DoubleConvBlock(torch.nn.Module): class DoubleConvBlock(torch.nn.Module):
def __init__(self, input_channel, output_channel, layer_number): def __init__(self, input_channel, output_channel, layer_number):

View File

@ -1,11 +1,12 @@
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license).""" """Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
from typing import Dict
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from einops import rearrange from einops import rearrange
from huggingface_hub import hf_hub_download
from PIL import Image from PIL import Image
from invokeai.backend.image_util.util import ( from invokeai.backend.image_util.util import (
@ -15,6 +16,9 @@ from invokeai.backend.image_util.util import (
resize_image_to_resolution, resize_image_to_resolution,
) )
LINEART_MODEL = "lllyasviel/Annotators::/sk_model.pth"
COARSE_MODEL = "lllyasviel/Annotators::/sk_model2.pth"
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, in_features): def __init__(self, in_features):
@ -97,22 +101,15 @@ class Generator(nn.Module):
class LineartProcessor: class LineartProcessor:
"""Processor for lineart detection.""" """Processor for lineart detection."""
def __init__(self): def __init__(self, model_sd: Dict[str, torch.Tensor], coarse_sd: Dict[str, torch.Tensor]):
model_path = hf_hub_download("lllyasviel/Annotators", "sk_model.pth")
self.model = Generator(3, 1, 3) self.model = Generator(3, 1, 3)
self.model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) self.model.load_state_dict(model_sd)
self.model.eval() self.model.eval()
coarse_model_path = hf_hub_download("lllyasviel/Annotators", "sk_model2.pth")
self.model_coarse = Generator(3, 1, 3) self.model_coarse = Generator(3, 1, 3)
self.model_coarse.load_state_dict(torch.load(coarse_model_path, map_location=torch.device("cpu"))) self.model_coarse.load_state_dict(coarse_sd)
self.model_coarse.eval() self.model_coarse.eval()
def to(self, device: torch.device):
self.model.to(device)
self.model_coarse.to(device)
return self
def run( def run(
self, input_image: Image.Image, coarse: bool = False, detect_resolution: int = 512, image_resolution: int = 512 self, input_image: Image.Image, coarse: bool = False, detect_resolution: int = 512, image_resolution: int = 512
) -> Image.Image: ) -> Image.Image: