controlnet processors use MM cache system

This commit is contained in:
Lincoln Stein 2024-06-30 21:17:04 -04:00
parent 5d1f6db414
commit b000bc2f58
6 changed files with 111 additions and 50 deletions

View File

@ -1607,7 +1607,7 @@ model configuration to `load_model_by_config()`. It may raise a
Within invocations, the following methods are available from the
`InvocationContext` object:
### context.download_and_cache_model(source) -> Path
### context.download_and_cache_model(source, [preserve_subfolders=False]) -> Path
This method accepts a `source` of a remote model, downloads and caches
it locally, and then returns a Path to the local model. The source can
@ -1626,6 +1626,16 @@ directory using this syntax:
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
When requesting a huggingface repo, if the requested file(s) live in a
nested subfolder, the nesting information will be discarded and the
file(s) will be placed in the top level of the returned
directory. Thus, when requesting
`stabilityai/stable-diffusion-v4::vae`, the contents of `vae` will be
found at the top level of the returned path and not in a subdirectory.
This behavior can be changed by passing `preserve_subfolders=True`,
which will preserve the subfolder structure and return the path to the
subdirectory.
### context.load_local_model(model_path, [loader]) -> LoadedModel
This method loads a local model from the indicated path, returning a

View File

@ -3,7 +3,7 @@
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import bool, float
from pathlib import Path
from typing import Dict, List, Literal, Union
from typing import ClassVar, Dict, List, Literal, Union
import cv2
import numpy as np
@ -43,6 +43,7 @@ from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
from invokeai.backend.model_manager.load import LoadedModelWithoutConfig
from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
@ -132,6 +133,14 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
image: ImageField = InputField(description="The image to process")
# Map controlnet_aux detector classes to model files in "lllyasviel/Annotators"
CONTROLNET_PROCESSORS: ClassVar[Dict[type, Path]] = {
MidasDetector: Path("dpt_hybrid-midas-501f0c75.pt"),
MLSDdetector: Path("mlsd_large_512_fp32.pth"),
PidiNetDetector: Path("table5_pidinet.pth"),
ZoeDetector: Path("ZoeD_M12_N.pt"),
}
def run_processor(self, image: Image.Image) -> Image.Image:
# superclass just passes through image without processing
return image
@ -140,6 +149,14 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
# allows override for any special formatting specific to the preprocessor
return context.images.get_pil(self.image.image_name, "RGB")
def load_processor(self, processor: type) -> LoadedModelWithoutConfig:
remote_source = f"lllyasviel/Annotators::/{self.CONTROLNET_PROCESSORS[processor]}"
assert hasattr(processor, "from_pretrained") # no common base class for the controlnet processors!
model = self._context.models.load_remote_model(
source=remote_source, loader=lambda x: processor.from_pretrained(x.parent, filename=x.name)
)
return model
def invoke(self, context: InvocationContext) -> ImageOutput:
self._context = context
raw_image = self.load_image(context)
@ -288,8 +305,8 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
def run_processor(self, image: Image.Image) -> Image.Image:
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
with self.load_processor(MidasDetector) as midas_processor:
assert isinstance(midas_processor, MidasDetector)
processed_image = midas_processor(
image,
a=np.pi * self.a_mult,
@ -316,7 +333,8 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
with self.load_processor(NormalBaeDetector) as normalbae_processor:
assert isinstance(normalbae_processor, NormalBaeDetector)
processed_image = normalbae_processor(
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
)
@ -335,7 +353,8 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
def run_processor(self, image: Image.Image) -> Image.Image:
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
with self.load_processor(MLSDdetector) as mlsd_processor:
assert isinstance(mlsd_processor, MLSDdetector)
processed_image = mlsd_processor(
image,
detect_resolution=self.detect_resolution,
@ -358,7 +377,8 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image: Image.Image) -> Image.Image:
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
with self.load_processor(PidiNetDetector) as pidi_processor:
assert isinstance(pidi_processor, PidiNetDetector)
processed_image = pidi_processor(
image,
detect_resolution=self.detect_resolution,
@ -410,8 +430,9 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image"""
def run_processor(self, image: Image.Image) -> Image.Image:
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
processed_image = zoe_depth_processor(image)
with self.load_processor(ZoeDetector) as zoe_depth_processor:
assert isinstance(zoe_depth_processor, ZoeDetector)
processed_image: Image.Image = zoe_depth_processor(image)
return processed_image
@ -459,6 +480,8 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
# LeresDetector requires two hard-coded models, which breaks the load_processor() pattern.
# TODO (LS): Modify download_and_cache() to accept multiple downloaded checkpoint files.
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
processed_image = leres_processor(
image,
@ -525,12 +548,15 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image:
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
"ybelkada/segment-anything", subfolder="checkpoints"
model_path = self._context.models.download_and_cache_model(
source="ybelkada/segment-anything::/checkpoints/sam_vit_h_4b8939.pth", preserve_subfolders=True
)
with self._context.models.load_local_model(
model_path, loader=lambda x: SamDetectorReproducibleColors.from_pretrained(x)
) as segment_anything_processor:
assert isinstance(segment_anything_processor, SamDetectorReproducibleColors)
np_img = np.array(image, dtype=np.uint8)
processed_image = segment_anything_processor(
processed_image: Image.Image = segment_anything_processor(
np_img, image_resolution=self.image_resolution, detect_resolution=self.detect_resolution
)
return processed_image

View File

@ -243,11 +243,17 @@ class ModelInstallServiceBase(ABC):
"""
@abstractmethod
def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
def download_and_cache_model(
self,
source: str | AnyHttpUrl,
preserve_subfolders: bool = False,
) -> Path:
"""
Download the model file located at source to the models cache and return its Path.
:param source: A string representing a URL or repo_id.
:param preserve_subfolders: (optional) If True, the subfolder hierarchy will be preserved;
otherwise flattened.
The model file will be downloaded into the system-wide model cache
(`models/.cache`) if it isn't already there. Note that the model cache

View File

@ -373,8 +373,10 @@ class ModelInstallService(ModelInstallServiceBase):
def download_and_cache_model(
self,
source: str | AnyHttpUrl,
preserve_subfolders: bool = False,
) -> Path:
"""Download the model file located at source to the models cache and return its Path."""
model_source = self._guess_source(str(source))
model_path = self._download_cache_path(str(source), self._app_config)
# We expect the cache directory to contain one and only one downloaded file or directory.
@ -386,12 +388,12 @@ class ModelInstallService(ModelInstallServiceBase):
return contents[0]
model_path.mkdir(parents=True, exist_ok=True)
model_source = self._guess_source(str(source))
remote_files, _ = self._remote_files_from_source(model_source)
job = self._multifile_download(
dest=model_path,
remote_files=remote_files,
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
preserve_subfolders=preserve_subfolders,
)
files_string = "file" if len(remote_files) == 1 else "files"
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
@ -773,12 +775,19 @@ class ModelInstallService(ModelInstallServiceBase):
subfolder: Optional[Path] = None,
access_token: Optional[str] = None,
submit_job: bool = True,
preserve_subfolders: bool = False,
) -> MultiFileDownloadJob:
# HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
# The exception is when preserve_subfolders is true, in which case we keep the hierarchy
# of subfolders and return the path to the last enclosing subfolder.
if subfolder:
if preserve_subfolders:
path_to_remove = remote_files[0].path.parts[0]
path_to_add = Path(".")
else:
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
path_to_add = Path(f"{top}_{subfolder}")

View File

@ -447,6 +447,7 @@ class ModelsInterface(InvocationContextInterface):
def download_and_cache_model(
self,
source: str | AnyHttpUrl,
preserve_subfolders: bool = False,
) -> Path:
"""
Download the model file located at source to the models cache and return its Path.
@ -457,11 +458,14 @@ class ModelsInterface(InvocationContextInterface):
Args:
source: A URL that points to the model, or a huggingface repo_id.
preserve_subfolders: (optional, False) If True, then preserve subfolder structure.
Returns:
Path to the downloaded model
Path to the downloaded model (file or directory)
"""
return self._services.model_manager.install.download_and_cache_model(source=source)
return self._services.model_manager.install.download_and_cache_model(
source=source, preserve_subfolders=preserve_subfolders
)
def load_local_model(
self,

View File

@ -289,7 +289,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)
)
cache_entry.model.load_state_dict(new_dict, assign=True)
try:
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
except TypeError as e:
if "got an unexpected keyword argument 'non_blocking' in str(e)":
cache_entry.model.to(target_device)
else:
raise e
cache_entry.device = target_device
except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry)