make lineart controlnet processor use MM ram cache; move model locations into processor class files

This commit is contained in:
Lincoln Stein 2024-07-01 19:17:15 -04:00
parent af274bedc1
commit c6dcbce043
3 changed files with 25 additions and 20 deletions

View File

@ -39,8 +39,8 @@ from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNE
from invokeai.backend.image_util.canny import get_canny_edges
from invokeai.backend.image_util.depth_anything import DEPTH_ANYTHING_MODELS, DepthAnythingDetector
from invokeai.backend.image_util.dw_openpose import DWPOSE_MODELS, DWOpenposeDetector
from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.hed import HED_MODEL, HEDProcessor
from invokeai.backend.image_util.lineart import COARSE_MODEL, LINEART_MODEL, LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
from invokeai.backend.model_manager.load import LoadedModelWithoutConfig
@ -230,7 +230,7 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image: Image.Image) -> Image.Image:
hed_weights = self._context.models.load_remote_model("lllyasviel/Annotators::/ControlNetHED.pth")
hed_weights = self._context.models.load_remote_model(HED_MODEL)
with hed_weights as weights:
assert isinstance(weights, dict)
hed_processor = HEDProcessor(weights)
@ -260,10 +260,16 @@ class LineartImageProcessorInvocation(ImageProcessorInvocation):
coarse: bool = InputField(default=False, description="Whether to use coarse mode")
def run_processor(self, image: Image.Image) -> Image.Image:
lineart_processor = LineartProcessor()
processed_image = lineart_processor.run(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution, coarse=self.coarse
)
model_info = self._context.models.load_remote_model(LINEART_MODEL)
model_coarse_info = self._context.models.load_remote_model(COARSE_MODEL)
with model_info as model_sd, model_coarse_info as coarse_sd:
lineart_processor = LineartProcessor(model_sd=model_sd, coarse_sd=coarse_sd)
processed_image = lineart_processor.run(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
coarse=self.coarse,
)
return processed_image

View File

@ -1,12 +1,12 @@
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
from typing import Dict
import cv2
import numpy as np
import torch
from einops import rearrange
from invokeai.backend.model_manager.config import AnyModel
from PIL import Image
from typing import Dict
from invokeai.backend.image_util.util import (
nms,
@ -17,6 +17,8 @@ from invokeai.backend.image_util.util import (
safe_step,
)
HED_MODEL = "lllyasviel/Annotators::/ControlNetHED.pth"
class DoubleConvBlock(torch.nn.Module):
def __init__(self, input_channel, output_channel, layer_number):

View File

@ -1,11 +1,12 @@
"""Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license)."""
from typing import Dict
import cv2
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from huggingface_hub import hf_hub_download
from PIL import Image
from invokeai.backend.image_util.util import (
@ -15,6 +16,9 @@ from invokeai.backend.image_util.util import (
resize_image_to_resolution,
)
LINEART_MODEL = "lllyasviel/Annotators::/sk_model.pth"
COARSE_MODEL = "lllyasviel/Annotators::/sk_model2.pth"
class ResidualBlock(nn.Module):
def __init__(self, in_features):
@ -97,22 +101,15 @@ class Generator(nn.Module):
class LineartProcessor:
"""Processor for lineart detection."""
def __init__(self):
model_path = hf_hub_download("lllyasviel/Annotators", "sk_model.pth")
def __init__(self, model_sd: Dict[str, torch.Tensor], coarse_sd: Dict[str, torch.Tensor]):
self.model = Generator(3, 1, 3)
self.model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
self.model.load_state_dict(model_sd)
self.model.eval()
coarse_model_path = hf_hub_download("lllyasviel/Annotators", "sk_model2.pth")
self.model_coarse = Generator(3, 1, 3)
self.model_coarse.load_state_dict(torch.load(coarse_model_path, map_location=torch.device("cpu")))
self.model_coarse.load_state_dict(coarse_sd)
self.model_coarse.eval()
def to(self, device: torch.device):
self.model.to(device)
self.model_coarse.to(device)
return self
def run(
self, input_image: Image.Image, coarse: bool = False, detect_resolution: int = 512, image_resolution: int = 512
) -> Image.Image: