mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
controlnet processors use MM cache system
This commit is contained in:
parent
5d1f6db414
commit
b000bc2f58
@ -1607,7 +1607,7 @@ model configuration to `load_model_by_config()`. It may raise a
|
||||
Within invocations, the following methods are available from the
|
||||
`InvocationContext` object:
|
||||
|
||||
### context.download_and_cache_model(source) -> Path
|
||||
### context.download_and_cache_model(source, [preserve_subfolders=False]) -> Path
|
||||
|
||||
This method accepts a `source` of a remote model, downloads and caches
|
||||
it locally, and then returns a Path to the local model. The source can
|
||||
@ -1626,6 +1626,16 @@ directory using this syntax:
|
||||
|
||||
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||
|
||||
When requesting a huggingface repo, if the requested file(s) live in a
|
||||
nested subfolder, the nesting information will be discarded and the
|
||||
file(s) will be placed in the top level of the returned
|
||||
directory. Thus, when requesting
|
||||
`stabilityai/stable-diffusion-v4::vae`, the contents of `vae` will be
|
||||
found at the top level of the returned path and not in a subdirectory.
|
||||
This behavior can be changed by passing `preserve_subfolders=True`,
|
||||
which will preserve the subfolder structure and return the path to the
|
||||
subdirectory.
|
||||
|
||||
### context.load_local_model(model_path, [loader]) -> LoadedModel
|
||||
|
||||
This method loads a local model from the indicated path, returning a
|
||||
|
@ -3,7 +3,7 @@
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
from builtins import bool, float
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Union
|
||||
from typing import ClassVar, Dict, List, Literal, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@ -43,6 +43,7 @@ from invokeai.backend.image_util.hed import HEDProcessor
|
||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||
from invokeai.backend.model_manager.load import LoadedModelWithoutConfig
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
|
||||
@ -132,6 +133,14 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
|
||||
image: ImageField = InputField(description="The image to process")
|
||||
|
||||
# Map controlnet_aux detector classes to model files in "lllyasviel/Annotators"
|
||||
CONTROLNET_PROCESSORS: ClassVar[Dict[type, Path]] = {
|
||||
MidasDetector: Path("dpt_hybrid-midas-501f0c75.pt"),
|
||||
MLSDdetector: Path("mlsd_large_512_fp32.pth"),
|
||||
PidiNetDetector: Path("table5_pidinet.pth"),
|
||||
ZoeDetector: Path("ZoeD_M12_N.pt"),
|
||||
}
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# superclass just passes through image without processing
|
||||
return image
|
||||
@ -140,6 +149,14 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
||||
# allows override for any special formatting specific to the preprocessor
|
||||
return context.images.get_pil(self.image.image_name, "RGB")
|
||||
|
||||
def load_processor(self, processor: type) -> LoadedModelWithoutConfig:
|
||||
remote_source = f"lllyasviel/Annotators::/{self.CONTROLNET_PROCESSORS[processor]}"
|
||||
assert hasattr(processor, "from_pretrained") # no common base class for the controlnet processors!
|
||||
model = self._context.models.load_remote_model(
|
||||
source=remote_source, loader=lambda x: processor.from_pretrained(x.parent, filename=x.name)
|
||||
)
|
||||
return model
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||
self._context = context
|
||||
raw_image = self.load_image(context)
|
||||
@ -288,17 +305,17 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
|
||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = midas_processor(
|
||||
image,
|
||||
a=np.pi * self.a_mult,
|
||||
bg_th=self.bg_th,
|
||||
image_resolution=self.image_resolution,
|
||||
detect_resolution=self.detect_resolution,
|
||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal=self.depth_and_normal,
|
||||
)
|
||||
with self.load_processor(MidasDetector) as midas_processor:
|
||||
assert isinstance(midas_processor, MidasDetector)
|
||||
processed_image = midas_processor(
|
||||
image,
|
||||
a=np.pi * self.a_mult,
|
||||
bg_th=self.bg_th,
|
||||
image_resolution=self.image_resolution,
|
||||
detect_resolution=self.detect_resolution,
|
||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||
# depth_and_normal=self.depth_and_normal,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@ -316,10 +333,11 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = normalbae_processor(
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
||||
)
|
||||
with self.load_processor(NormalBaeDetector) as normalbae_processor:
|
||||
assert isinstance(normalbae_processor, NormalBaeDetector)
|
||||
processed_image = normalbae_processor(
|
||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@ -335,14 +353,15 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
||||
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = mlsd_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
thr_v=self.thr_v,
|
||||
thr_d=self.thr_d,
|
||||
)
|
||||
with self.load_processor(MLSDdetector) as mlsd_processor:
|
||||
assert isinstance(mlsd_processor, MLSDdetector)
|
||||
processed_image = mlsd_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
thr_v=self.thr_v,
|
||||
thr_d=self.thr_d,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@ -358,14 +377,15 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = pidi_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
with self.load_processor(PidiNetDetector) as pidi_processor:
|
||||
assert isinstance(pidi_processor, PidiNetDetector)
|
||||
processed_image = pidi_processor(
|
||||
image,
|
||||
detect_resolution=self.detect_resolution,
|
||||
image_resolution=self.image_resolution,
|
||||
safe=self.safe,
|
||||
scribble=self.scribble,
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
@ -410,8 +430,9 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
||||
"""Applies Zoe depth processing to image"""
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = zoe_depth_processor(image)
|
||||
with self.load_processor(ZoeDetector) as zoe_depth_processor:
|
||||
assert isinstance(zoe_depth_processor, ZoeDetector)
|
||||
processed_image: Image.Image = zoe_depth_processor(image)
|
||||
return processed_image
|
||||
|
||||
|
||||
@ -459,6 +480,8 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# LeresDetector requires two hard-coded models, which breaks the load_processor() pattern.
|
||||
# TODO (LS): Modify download_and_cache() to accept multiple downloaded checkpoint files.
|
||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||
processed_image = leres_processor(
|
||||
image,
|
||||
@ -525,14 +548,17 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||
|
||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
||||
"ybelkada/segment-anything", subfolder="checkpoints"
|
||||
)
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
processed_image = segment_anything_processor(
|
||||
np_img, image_resolution=self.image_resolution, detect_resolution=self.detect_resolution
|
||||
model_path = self._context.models.download_and_cache_model(
|
||||
source="ybelkada/segment-anything::/checkpoints/sam_vit_h_4b8939.pth", preserve_subfolders=True
|
||||
)
|
||||
with self._context.models.load_local_model(
|
||||
model_path, loader=lambda x: SamDetectorReproducibleColors.from_pretrained(x)
|
||||
) as segment_anything_processor:
|
||||
assert isinstance(segment_anything_processor, SamDetectorReproducibleColors)
|
||||
np_img = np.array(image, dtype=np.uint8)
|
||||
processed_image: Image.Image = segment_anything_processor(
|
||||
np_img, image_resolution=self.image_resolution, detect_resolution=self.detect_resolution
|
||||
)
|
||||
return processed_image
|
||||
|
||||
|
||||
|
@ -243,11 +243,17 @@ class ModelInstallServiceBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
|
||||
def download_and_cache_model(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
preserve_subfolders: bool = False,
|
||||
) -> Path:
|
||||
"""
|
||||
Download the model file located at source to the models cache and return its Path.
|
||||
|
||||
:param source: A string representing a URL or repo_id.
|
||||
:param preserve_subfolders: (optional) If True, the subfolder hierarchy will be preserved;
|
||||
otherwise flattened.
|
||||
|
||||
The model file will be downloaded into the system-wide model cache
|
||||
(`models/.cache`) if it isn't already there. Note that the model cache
|
||||
|
@ -373,8 +373,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def download_and_cache_model(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
preserve_subfolders: bool = False,
|
||||
) -> Path:
|
||||
"""Download the model file located at source to the models cache and return its Path."""
|
||||
model_source = self._guess_source(str(source))
|
||||
model_path = self._download_cache_path(str(source), self._app_config)
|
||||
|
||||
# We expect the cache directory to contain one and only one downloaded file or directory.
|
||||
@ -386,12 +388,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
return contents[0]
|
||||
|
||||
model_path.mkdir(parents=True, exist_ok=True)
|
||||
model_source = self._guess_source(str(source))
|
||||
remote_files, _ = self._remote_files_from_source(model_source)
|
||||
job = self._multifile_download(
|
||||
dest=model_path,
|
||||
remote_files=remote_files,
|
||||
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
|
||||
preserve_subfolders=preserve_subfolders,
|
||||
)
|
||||
files_string = "file" if len(remote_files) == 1 else "files"
|
||||
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
|
||||
@ -773,15 +775,22 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
subfolder: Optional[Path] = None,
|
||||
access_token: Optional[str] = None,
|
||||
submit_job: bool = True,
|
||||
preserve_subfolders: bool = False,
|
||||
) -> MultiFileDownloadJob:
|
||||
# HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
|
||||
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
|
||||
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
|
||||
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
|
||||
# The exception is when preserve_subfolders is true, in which case we keep the hierarchy
|
||||
# of subfolders and return the path to the last enclosing subfolder.
|
||||
if subfolder:
|
||||
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
|
||||
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
|
||||
path_to_add = Path(f"{top}_{subfolder}")
|
||||
if preserve_subfolders:
|
||||
path_to_remove = remote_files[0].path.parts[0]
|
||||
path_to_add = Path(".")
|
||||
else:
|
||||
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
|
||||
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
|
||||
path_to_add = Path(f"{top}_{subfolder}")
|
||||
else:
|
||||
path_to_remove = Path(".")
|
||||
path_to_add = Path(".")
|
||||
|
@ -447,6 +447,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
def download_and_cache_model(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
preserve_subfolders: bool = False,
|
||||
) -> Path:
|
||||
"""
|
||||
Download the model file located at source to the models cache and return its Path.
|
||||
@ -457,11 +458,14 @@ class ModelsInterface(InvocationContextInterface):
|
||||
|
||||
Args:
|
||||
source: A URL that points to the model, or a huggingface repo_id.
|
||||
preserve_subfolders: (optional, False) If True, then preserve subfolder structure.
|
||||
|
||||
Returns:
|
||||
Path to the downloaded model
|
||||
Path to the downloaded model (file or directory)
|
||||
"""
|
||||
return self._services.model_manager.install.download_and_cache_model(source=source)
|
||||
return self._services.model_manager.install.download_and_cache_model(
|
||||
source=source, preserve_subfolders=preserve_subfolders
|
||||
)
|
||||
|
||||
def load_local_model(
|
||||
self,
|
||||
|
@ -289,7 +289,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)
|
||||
)
|
||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
|
||||
try:
|
||||
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
|
||||
except TypeError as e:
|
||||
if "got an unexpected keyword argument 'non_blocking' in str(e)":
|
||||
cache_entry.model.to(target_device)
|
||||
else:
|
||||
raise e
|
||||
cache_entry.device = target_device
|
||||
except Exception as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
|
Loading…
Reference in New Issue
Block a user