mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make hed controlnet processor use MM ram cache
This commit is contained in:
parent
b000bc2f58
commit
af274bedc1
@ -230,7 +230,10 @@ class HedImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
hed_processor = HEDProcessor()
|
hed_weights = self._context.models.load_remote_model("lllyasviel/Annotators::/ControlNetHED.pth")
|
||||||
|
with hed_weights as weights:
|
||||||
|
assert isinstance(weights, dict)
|
||||||
|
hed_processor = HEDProcessor(weights)
|
||||||
processed_image = hed_processor.run(
|
processed_image = hed_processor.run(
|
||||||
image,
|
image,
|
||||||
detect_resolution=self.detect_resolution,
|
detect_resolution=self.detect_resolution,
|
||||||
|
@ -4,8 +4,9 @@ import cv2
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from huggingface_hub import hf_hub_download
|
from invokeai.backend.model_manager.config import AnyModel
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
from invokeai.backend.image_util.util import (
|
from invokeai.backend.image_util.util import (
|
||||||
nms,
|
nms,
|
||||||
@ -76,16 +77,11 @@ class HEDProcessor:
|
|||||||
On instantiation, loads the HED model from the HuggingFace Hub.
|
On instantiation, loads the HED model from the HuggingFace Hub.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, state_dict: Dict[str, torch.Tensor]):
|
||||||
model_path = hf_hub_download("lllyasviel/Annotators", "ControlNetHED.pth")
|
|
||||||
self.network = ControlNetHED_Apache2()
|
self.network = ControlNetHED_Apache2()
|
||||||
self.network.load_state_dict(torch.load(model_path, map_location="cpu"))
|
self.network.load_state_dict(state_dict)
|
||||||
self.network.float().eval()
|
self.network.float().eval()
|
||||||
|
|
||||||
def to(self, device: torch.device):
|
|
||||||
self.network.to(device)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
input_image: Image.Image,
|
input_image: Image.Image,
|
||||||
|
Loading…
Reference in New Issue
Block a user