diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 1392f43767..ba7d1c8265 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -230,15 +230,18 @@ class HedImageProcessorInvocation(ImageProcessorInvocation): scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) def run_processor(self, image: Image.Image) -> Image.Image: - hed_processor = HEDProcessor() - processed_image = hed_processor.run( - image, - detect_resolution=self.detect_resolution, - image_resolution=self.image_resolution, - # safe not supported in controlnet_aux v0.0.3 - # safe=self.safe, - scribble=self.scribble, - ) + hed_weights = self._context.models.load_remote_model("lllyasviel/Annotators::/ControlNetHED.pth") + with hed_weights as weights: + assert isinstance(weights, dict) + hed_processor = HEDProcessor(weights) + processed_image = hed_processor.run( + image, + detect_resolution=self.detect_resolution, + image_resolution=self.image_resolution, + # safe not supported in controlnet_aux v0.0.3 + # safe=self.safe, + scribble=self.scribble, + ) return processed_image diff --git a/invokeai/backend/image_util/hed.py b/invokeai/backend/image_util/hed.py index 97706df8b9..8a7aea8403 100644 --- a/invokeai/backend/image_util/hed.py +++ b/invokeai/backend/image_util/hed.py @@ -4,8 +4,9 @@ import cv2 import numpy as np import torch from einops import rearrange -from huggingface_hub import hf_hub_download +from invokeai.backend.model_manager.config import AnyModel from PIL import Image +from typing import Dict from invokeai.backend.image_util.util import ( nms, @@ -76,16 +77,11 @@ class HEDProcessor: On instantiation, loads the HED model from the HuggingFace Hub. """ - def __init__(self): - model_path = hf_hub_download("lllyasviel/Annotators", "ControlNetHED.pth") + def __init__(self, state_dict: Dict[str, torch.Tensor]): self.network = ControlNetHED_Apache2() - self.network.load_state_dict(torch.load(model_path, map_location="cpu")) + self.network.load_state_dict(state_dict) self.network.float().eval() - def to(self, device: torch.device): - self.network.to(device) - return self - def run( self, input_image: Image.Image,