make hed controlnet processor use MM ram cache

This commit is contained in:
Lincoln Stein 2024-07-01 19:00:13 -04:00
parent b000bc2f58
commit af274bedc1
2 changed files with 16 additions and 17 deletions

View File

@ -230,15 +230,18 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image: Image.Image) -> Image.Image:
hed_processor = HEDProcessor()
processed_image = hed_processor.run(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
# safe not supported in controlnet_aux v0.0.3
# safe=self.safe,
scribble=self.scribble,
)
hed_weights = self._context.models.load_remote_model("lllyasviel/Annotators::/ControlNetHED.pth")
with hed_weights as weights:
assert isinstance(weights, dict)
hed_processor = HEDProcessor(weights)
processed_image = hed_processor.run(
image,
detect_resolution=self.detect_resolution,
image_resolution=self.image_resolution,
# safe not supported in controlnet_aux v0.0.3
# safe=self.safe,
scribble=self.scribble,
)
return processed_image

View File

@ -4,8 +4,9 @@ import cv2
import numpy as np
import torch
from einops import rearrange
from huggingface_hub import hf_hub_download
from invokeai.backend.model_manager.config import AnyModel
from PIL import Image
from typing import Dict
from invokeai.backend.image_util.util import (
nms,
@ -76,16 +77,11 @@ class HEDProcessor:
On instantiation, loads the HED model from the HuggingFace Hub.
"""
def __init__(self):
model_path = hf_hub_download("lllyasviel/Annotators", "ControlNetHED.pth")
def __init__(self, state_dict: Dict[str, torch.Tensor]):
self.network = ControlNetHED_Apache2()
self.network.load_state_dict(torch.load(model_path, map_location="cpu"))
self.network.load_state_dict(state_dict)
self.network.float().eval()
def to(self, device: torch.device):
self.network.to(device)
return self
def run(
self,
input_image: Image.Image,