controlnet processors use MM cache system

This commit is contained in:
Lincoln Stein 2024-06-30 21:17:04 -04:00
parent 5d1f6db414
commit b000bc2f58
6 changed files with 111 additions and 50 deletions

View File

@ -1607,7 +1607,7 @@ model configuration to `load_model_by_config()`. It may raise a
Within invocations, the following methods are available from the Within invocations, the following methods are available from the
`InvocationContext` object: `InvocationContext` object:
### context.download_and_cache_model(source) -> Path ### context.download_and_cache_model(source, [preserve_subfolders=False]) -> Path
This method accepts a `source` of a remote model, downloads and caches This method accepts a `source` of a remote model, downloads and caches
it locally, and then returns a Path to the local model. The source can it locally, and then returns a Path to the local model. The source can
@ -1626,6 +1626,16 @@ directory using this syntax:
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors * stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
When requesting a huggingface repo, if the requested file(s) live in a
nested subfolder, the nesting information will be discarded and the
file(s) will be placed in the top level of the returned
directory. Thus, when requesting
`stabilityai/stable-diffusion-v4::vae`, the contents of `vae` will be
found at the top level of the returned path and not in a subdirectory.
This behavior can be changed by passing `preserve_subfolders=True`,
which will preserve the subfolder structure and return the path to the
subdirectory.
### context.load_local_model(model_path, [loader]) -> LoadedModel ### context.load_local_model(model_path, [loader]) -> LoadedModel
This method loads a local model from the indicated path, returning a This method loads a local model from the indicated path, returning a

View File

@ -3,7 +3,7 @@
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux # heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import bool, float from builtins import bool, float
from pathlib import Path from pathlib import Path
from typing import Dict, List, Literal, Union from typing import ClassVar, Dict, List, Literal, Union
import cv2 import cv2
import numpy as np import numpy as np
@ -43,6 +43,7 @@ from invokeai.backend.image_util.hed import HEDProcessor
from invokeai.backend.image_util.lineart import LineartProcessor from invokeai.backend.image_util.lineart import LineartProcessor
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
from invokeai.backend.image_util.util import np_to_pil, pil_to_np from invokeai.backend.image_util.util import np_to_pil, pil_to_np
from invokeai.backend.model_manager.load import LoadedModelWithoutConfig
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
@ -132,6 +133,14 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
image: ImageField = InputField(description="The image to process") image: ImageField = InputField(description="The image to process")
# Map controlnet_aux detector classes to model files in "lllyasviel/Annotators"
CONTROLNET_PROCESSORS: ClassVar[Dict[type, Path]] = {
MidasDetector: Path("dpt_hybrid-midas-501f0c75.pt"),
MLSDdetector: Path("mlsd_large_512_fp32.pth"),
PidiNetDetector: Path("table5_pidinet.pth"),
ZoeDetector: Path("ZoeD_M12_N.pt"),
}
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image: Image.Image) -> Image.Image:
# superclass just passes through image without processing # superclass just passes through image without processing
return image return image
@ -140,6 +149,14 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
# allows override for any special formatting specific to the preprocessor # allows override for any special formatting specific to the preprocessor
return context.images.get_pil(self.image.image_name, "RGB") return context.images.get_pil(self.image.image_name, "RGB")
def load_processor(self, processor: type) -> LoadedModelWithoutConfig:
remote_source = f"lllyasviel/Annotators::/{self.CONTROLNET_PROCESSORS[processor]}"
assert hasattr(processor, "from_pretrained") # no common base class for the controlnet processors!
model = self._context.models.load_remote_model(
source=remote_source, loader=lambda x: processor.from_pretrained(x.parent, filename=x.name)
)
return model
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
self._context = context self._context = context
raw_image = self.load_image(context) raw_image = self.load_image(context)
@ -288,17 +305,17 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode") # depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image: Image.Image) -> Image.Image:
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar) with self.load_processor(MidasDetector) as midas_processor:
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators") assert isinstance(midas_processor, MidasDetector)
processed_image = midas_processor( processed_image = midas_processor(
image, image,
a=np.pi * self.a_mult, a=np.pi * self.a_mult,
bg_th=self.bg_th, bg_th=self.bg_th,
image_resolution=self.image_resolution, image_resolution=self.image_resolution,
detect_resolution=self.detect_resolution, detect_resolution=self.detect_resolution,
# dept_and_normal not supported in controlnet_aux v0.0.3 # dept_and_normal not supported in controlnet_aux v0.0.3
# depth_and_normal=self.depth_and_normal, # depth_and_normal=self.depth_and_normal,
) )
return processed_image return processed_image
@ -316,10 +333,11 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image: Image.Image) -> Image.Image:
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators") with self.load_processor(NormalBaeDetector) as normalbae_processor:
processed_image = normalbae_processor( assert isinstance(normalbae_processor, NormalBaeDetector)
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution processed_image = normalbae_processor(
) image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
)
return processed_image return processed_image
@ -335,14 +353,15 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`") thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image: Image.Image) -> Image.Image:
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators") with self.load_processor(MLSDdetector) as mlsd_processor:
processed_image = mlsd_processor( assert isinstance(mlsd_processor, MLSDdetector)
image, processed_image = mlsd_processor(
detect_resolution=self.detect_resolution, image,
image_resolution=self.image_resolution, detect_resolution=self.detect_resolution,
thr_v=self.thr_v, image_resolution=self.image_resolution,
thr_d=self.thr_d, thr_v=self.thr_v,
) thr_d=self.thr_d,
)
return processed_image return processed_image
@ -358,14 +377,15 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode) scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image: Image.Image) -> Image.Image:
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators") with self.load_processor(PidiNetDetector) as pidi_processor:
processed_image = pidi_processor( assert isinstance(pidi_processor, PidiNetDetector)
image, processed_image = pidi_processor(
detect_resolution=self.detect_resolution, image,
image_resolution=self.image_resolution, detect_resolution=self.detect_resolution,
safe=self.safe, image_resolution=self.image_resolution,
scribble=self.scribble, safe=self.safe,
) scribble=self.scribble,
)
return processed_image return processed_image
@ -410,8 +430,9 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
"""Applies Zoe depth processing to image""" """Applies Zoe depth processing to image"""
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image: Image.Image) -> Image.Image:
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators") with self.load_processor(ZoeDetector) as zoe_depth_processor:
processed_image = zoe_depth_processor(image) assert isinstance(zoe_depth_processor, ZoeDetector)
processed_image: Image.Image = zoe_depth_processor(image)
return processed_image return processed_image
@ -459,6 +480,8 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image: Image.Image) -> Image.Image:
# LeresDetector requires two hard-coded models, which breaks the load_processor() pattern.
# TODO (LS): Modify download_and_cache() to accept multiple downloaded checkpoint files.
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators") leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
processed_image = leres_processor( processed_image = leres_processor(
image, image,
@ -525,14 +548,17 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res) image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
def run_processor(self, image: Image.Image) -> Image.Image: def run_processor(self, image: Image.Image) -> Image.Image:
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints") model_path = self._context.models.download_and_cache_model(
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained( source="ybelkada/segment-anything::/checkpoints/sam_vit_h_4b8939.pth", preserve_subfolders=True
"ybelkada/segment-anything", subfolder="checkpoints"
)
np_img = np.array(image, dtype=np.uint8)
processed_image = segment_anything_processor(
np_img, image_resolution=self.image_resolution, detect_resolution=self.detect_resolution
) )
with self._context.models.load_local_model(
model_path, loader=lambda x: SamDetectorReproducibleColors.from_pretrained(x)
) as segment_anything_processor:
assert isinstance(segment_anything_processor, SamDetectorReproducibleColors)
np_img = np.array(image, dtype=np.uint8)
processed_image: Image.Image = segment_anything_processor(
np_img, image_resolution=self.image_resolution, detect_resolution=self.detect_resolution
)
return processed_image return processed_image

View File

@ -243,11 +243,17 @@ class ModelInstallServiceBase(ABC):
""" """
@abstractmethod @abstractmethod
def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path: def download_and_cache_model(
self,
source: str | AnyHttpUrl,
preserve_subfolders: bool = False,
) -> Path:
""" """
Download the model file located at source to the models cache and return its Path. Download the model file located at source to the models cache and return its Path.
:param source: A string representing a URL or repo_id. :param source: A string representing a URL or repo_id.
:param preserve_subfolders: (optional) If True, the subfolder hierarchy will be preserved;
otherwise flattened.
The model file will be downloaded into the system-wide model cache The model file will be downloaded into the system-wide model cache
(`models/.cache`) if it isn't already there. Note that the model cache (`models/.cache`) if it isn't already there. Note that the model cache

View File

@ -373,8 +373,10 @@ class ModelInstallService(ModelInstallServiceBase):
def download_and_cache_model( def download_and_cache_model(
self, self,
source: str | AnyHttpUrl, source: str | AnyHttpUrl,
preserve_subfolders: bool = False,
) -> Path: ) -> Path:
"""Download the model file located at source to the models cache and return its Path.""" """Download the model file located at source to the models cache and return its Path."""
model_source = self._guess_source(str(source))
model_path = self._download_cache_path(str(source), self._app_config) model_path = self._download_cache_path(str(source), self._app_config)
# We expect the cache directory to contain one and only one downloaded file or directory. # We expect the cache directory to contain one and only one downloaded file or directory.
@ -386,12 +388,12 @@ class ModelInstallService(ModelInstallServiceBase):
return contents[0] return contents[0]
model_path.mkdir(parents=True, exist_ok=True) model_path.mkdir(parents=True, exist_ok=True)
model_source = self._guess_source(str(source))
remote_files, _ = self._remote_files_from_source(model_source) remote_files, _ = self._remote_files_from_source(model_source)
job = self._multifile_download( job = self._multifile_download(
dest=model_path, dest=model_path,
remote_files=remote_files, remote_files=remote_files,
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None, subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
preserve_subfolders=preserve_subfolders,
) )
files_string = "file" if len(remote_files) == 1 else "files" files_string = "file" if len(remote_files) == 1 else "files"
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})") self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
@ -773,15 +775,22 @@ class ModelInstallService(ModelInstallServiceBase):
subfolder: Optional[Path] = None, subfolder: Optional[Path] = None,
access_token: Optional[str] = None, access_token: Optional[str] = None,
submit_job: bool = True, submit_job: bool = True,
preserve_subfolders: bool = False,
) -> MultiFileDownloadJob: ) -> MultiFileDownloadJob:
# HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and # HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such # we are installing the "vae" subfolder, we do not want to create an additional folder level, such
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo". # as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here. # So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
# The exception is when preserve_subfolders is true, in which case we keep the hierarchy
# of subfolders and return the path to the last enclosing subfolder.
if subfolder: if subfolder:
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/" if preserve_subfolders:
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/ path_to_remove = remote_files[0].path.parts[0]
path_to_add = Path(f"{top}_{subfolder}") path_to_add = Path(".")
else:
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
path_to_add = Path(f"{top}_{subfolder}")
else: else:
path_to_remove = Path(".") path_to_remove = Path(".")
path_to_add = Path(".") path_to_add = Path(".")

View File

@ -447,6 +447,7 @@ class ModelsInterface(InvocationContextInterface):
def download_and_cache_model( def download_and_cache_model(
self, self,
source: str | AnyHttpUrl, source: str | AnyHttpUrl,
preserve_subfolders: bool = False,
) -> Path: ) -> Path:
""" """
Download the model file located at source to the models cache and return its Path. Download the model file located at source to the models cache and return its Path.
@ -457,11 +458,14 @@ class ModelsInterface(InvocationContextInterface):
Args: Args:
source: A URL that points to the model, or a huggingface repo_id. source: A URL that points to the model, or a huggingface repo_id.
preserve_subfolders: (optional, False) If True, then preserve subfolder structure.
Returns: Returns:
Path to the downloaded model Path to the downloaded model (file or directory)
""" """
return self._services.model_manager.install.download_and_cache_model(source=source) return self._services.model_manager.install.download_and_cache_model(
source=source, preserve_subfolders=preserve_subfolders
)
def load_local_model( def load_local_model(
self, self,

View File

@ -289,7 +289,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device) target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)
) )
cache_entry.model.load_state_dict(new_dict, assign=True) cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device)) try:
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
except TypeError as e:
if "got an unexpected keyword argument 'non_blocking' in str(e)":
cache_entry.model.to(target_device)
else:
raise e
cache_entry.device = target_device cache_entry.device = target_device
except Exception as e: # blow away cache entry except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry) self._delete_cache_entry(cache_entry)