make hed controlnet processor use MM ram cache

This commit is contained in:
Lincoln Stein 2024-07-01 19:00:13 -04:00
parent b000bc2f58
commit af274bedc1
2 changed files with 16 additions and 17 deletions

View File

@ -230,7 +230,10 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image: Image.Image) -> Image.Image:
hed_processor = HEDProcessor() hed_weights = self._context.models.load_remote_model("lllyasviel/Annotators::/ControlNetHED.pth")
with hed_weights as weights:
assert isinstance(weights, dict)
hed_processor = HEDProcessor(weights)
processed_image = hed_processor.run( processed_image = hed_processor.run(
image, image,
detect_resolution=self.detect_resolution, detect_resolution=self.detect_resolution,

View File

@ -4,8 +4,9 @@ import cv2
import numpy as np import numpy as np
import torch import torch
from einops import rearrange from einops import rearrange
from huggingface_hub import hf_hub_download from invokeai.backend.model_manager.config import AnyModel
from PIL import Image from PIL import Image
from typing import Dict
from invokeai.backend.image_util.util import ( from invokeai.backend.image_util.util import (
nms, nms,
@ -76,16 +77,11 @@ class HEDProcessor:
On instantiation, loads the HED model from the HuggingFace Hub. On instantiation, loads the HED model from the HuggingFace Hub.
""" """
def __init__(self): def __init__(self, state_dict: Dict[str, torch.Tensor]):
model_path = hf_hub_download("lllyasviel/Annotators", "ControlNetHED.pth")
self.network = ControlNetHED_Apache2() self.network = ControlNetHED_Apache2()
self.network.load_state_dict(torch.load(model_path, map_location="cpu")) self.network.load_state_dict(state_dict)
self.network.float().eval() self.network.float().eval()
def to(self, device: torch.device):
self.network.to(device)
return self
def run( def run(
self, self,
input_image: Image.Image, input_image: Image.Image,