mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make hed controlnet processor use MM ram cache
This commit is contained in:
parent
b000bc2f58
commit
af274bedc1
@ -230,15 +230,18 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
hed_processor = HEDProcessor()
|
||||
processed_image = hed_processor.run(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
hed_weights = self._context.models.load_remote_model("lllyasviel/Annotators::/ControlNetHED.pth")
|
||||
with hed_weights as weights:
|
||||
assert isinstance(weights, dict)
|
||||
hed_processor = HEDProcessor(weights)
|
||||
processed_image = hed_processor.run(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
# safe not supported in controlnet_aux v0.0.3
|
||||
# safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
|
@ -4,8 +4,9 @@ import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from huggingface_hub import hf_hub_download
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from PIL import Image
|
||||
from typing import Dict
|
||||
|
||||
from invokeai.backend.image_util.util import (
|
||||
nms,
|
||||
@ -76,16 +77,11 @@ class HEDProcessor:
|
||||
On instantiation, loads the HED model from the HuggingFace Hub.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
model_path = hf_hub_download("lllyasviel/Annotators", "ControlNetHED.pth")
|
||||
def __init__(self, state_dict: Dict[str, torch.Tensor]):
|
||||
self.network = ControlNetHED_Apache2()
|
||||
self.network.load_state_dict(torch.load(model_path, map_location="cpu"))
|
||||
self.network.load_state_dict(state_dict)
|
||||
self.network.float().eval()
|
||||
|
||||
def to(self, device: torch.device):
|
||||
self.network.to(device)
|
||||
return self
|
||||
|
||||
def run(
|
||||
self,
|
||||
input_image: Image.Image,
|
||||
|
Loading…
Reference in New Issue
Block a user