mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
controlnet processors use MM cache system
This commit is contained in:
parent
5d1f6db414
commit
b000bc2f58
@ -1607,7 +1607,7 @@ model configuration to `load_model_by_config()`. It may raise a
|
|||||||
Within invocations, the following methods are available from the
|
Within invocations, the following methods are available from the
|
||||||
`InvocationContext` object:
|
`InvocationContext` object:
|
||||||
|
|
||||||
### context.download_and_cache_model(source) -> Path
|
### context.download_and_cache_model(source, [preserve_subfolders=False]) -> Path
|
||||||
|
|
||||||
This method accepts a `source` of a remote model, downloads and caches
|
This method accepts a `source` of a remote model, downloads and caches
|
||||||
it locally, and then returns a Path to the local model. The source can
|
it locally, and then returns a Path to the local model. The source can
|
||||||
@ -1626,6 +1626,16 @@ directory using this syntax:
|
|||||||
|
|
||||||
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
* stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||||
|
|
||||||
|
When requesting a huggingface repo, if the requested file(s) live in a
|
||||||
|
nested subfolder, the nesting information will be discarded and the
|
||||||
|
file(s) will be placed in the top level of the returned
|
||||||
|
directory. Thus, when requesting
|
||||||
|
`stabilityai/stable-diffusion-v4::vae`, the contents of `vae` will be
|
||||||
|
found at the top level of the returned path and not in a subdirectory.
|
||||||
|
This behavior can be changed by passing `preserve_subfolders=True`,
|
||||||
|
which will preserve the subfolder structure and return the path to the
|
||||||
|
subdirectory.
|
||||||
|
|
||||||
### context.load_local_model(model_path, [loader]) -> LoadedModel
|
### context.load_local_model(model_path, [loader]) -> LoadedModel
|
||||||
|
|
||||||
This method loads a local model from the indicated path, returning a
|
This method loads a local model from the indicated path, returning a
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||||
from builtins import bool, float
|
from builtins import bool, float
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Literal, Union
|
from typing import ClassVar, Dict, List, Literal, Union
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -43,6 +43,7 @@ from invokeai.backend.image_util.hed import HEDProcessor
|
|||||||
from invokeai.backend.image_util.lineart import LineartProcessor
|
from invokeai.backend.image_util.lineart import LineartProcessor
|
||||||
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
from invokeai.backend.image_util.lineart_anime import LineartAnimeProcessor
|
||||||
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
from invokeai.backend.image_util.util import np_to_pil, pil_to_np
|
||||||
|
from invokeai.backend.model_manager.load import LoadedModelWithoutConfig
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, Classification, invocation, invocation_output
|
||||||
@ -132,6 +133,14 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
|
|
||||||
image: ImageField = InputField(description="The image to process")
|
image: ImageField = InputField(description="The image to process")
|
||||||
|
|
||||||
|
# Map controlnet_aux detector classes to model files in "lllyasviel/Annotators"
|
||||||
|
CONTROLNET_PROCESSORS: ClassVar[Dict[type, Path]] = {
|
||||||
|
MidasDetector: Path("dpt_hybrid-midas-501f0c75.pt"),
|
||||||
|
MLSDdetector: Path("mlsd_large_512_fp32.pth"),
|
||||||
|
PidiNetDetector: Path("table5_pidinet.pth"),
|
||||||
|
ZoeDetector: Path("ZoeD_M12_N.pt"),
|
||||||
|
}
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
# superclass just passes through image without processing
|
# superclass just passes through image without processing
|
||||||
return image
|
return image
|
||||||
@ -140,6 +149,14 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard):
|
|||||||
# allows override for any special formatting specific to the preprocessor
|
# allows override for any special formatting specific to the preprocessor
|
||||||
return context.images.get_pil(self.image.image_name, "RGB")
|
return context.images.get_pil(self.image.image_name, "RGB")
|
||||||
|
|
||||||
|
def load_processor(self, processor: type) -> LoadedModelWithoutConfig:
|
||||||
|
remote_source = f"lllyasviel/Annotators::/{self.CONTROLNET_PROCESSORS[processor]}"
|
||||||
|
assert hasattr(processor, "from_pretrained") # no common base class for the controlnet processors!
|
||||||
|
model = self._context.models.load_remote_model(
|
||||||
|
source=remote_source, loader=lambda x: processor.from_pretrained(x.parent, filename=x.name)
|
||||||
|
)
|
||||||
|
return model
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
self._context = context
|
self._context = context
|
||||||
raw_image = self.load_image(context)
|
raw_image = self.load_image(context)
|
||||||
@ -288,17 +305,17 @@ class MidasDepthImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
# depth_and_normal: bool = InputField(default=False, description="whether to use depth and normal mode")
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
# TODO: replace from_pretrained() calls with context.models.download_and_cache() (or similar)
|
with self.load_processor(MidasDetector) as midas_processor:
|
||||||
midas_processor = MidasDetector.from_pretrained("lllyasviel/Annotators")
|
assert isinstance(midas_processor, MidasDetector)
|
||||||
processed_image = midas_processor(
|
processed_image = midas_processor(
|
||||||
image,
|
image,
|
||||||
a=np.pi * self.a_mult,
|
a=np.pi * self.a_mult,
|
||||||
bg_th=self.bg_th,
|
bg_th=self.bg_th,
|
||||||
image_resolution=self.image_resolution,
|
image_resolution=self.image_resolution,
|
||||||
detect_resolution=self.detect_resolution,
|
detect_resolution=self.detect_resolution,
|
||||||
# dept_and_normal not supported in controlnet_aux v0.0.3
|
# dept_and_normal not supported in controlnet_aux v0.0.3
|
||||||
# depth_and_normal=self.depth_and_normal,
|
# depth_and_normal=self.depth_and_normal,
|
||||||
)
|
)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@ -316,10 +333,11 @@ class NormalbaeImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
normalbae_processor = NormalBaeDetector.from_pretrained("lllyasviel/Annotators")
|
with self.load_processor(NormalBaeDetector) as normalbae_processor:
|
||||||
processed_image = normalbae_processor(
|
assert isinstance(normalbae_processor, NormalBaeDetector)
|
||||||
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
processed_image = normalbae_processor(
|
||||||
)
|
image, detect_resolution=self.detect_resolution, image_resolution=self.image_resolution
|
||||||
|
)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@ -335,14 +353,15 @@ class MlsdImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
thr_d: float = InputField(default=0.1, ge=0, description="MLSD parameter `thr_d`")
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
mlsd_processor = MLSDdetector.from_pretrained("lllyasviel/Annotators")
|
with self.load_processor(MLSDdetector) as mlsd_processor:
|
||||||
processed_image = mlsd_processor(
|
assert isinstance(mlsd_processor, MLSDdetector)
|
||||||
image,
|
processed_image = mlsd_processor(
|
||||||
detect_resolution=self.detect_resolution,
|
image,
|
||||||
image_resolution=self.image_resolution,
|
detect_resolution=self.detect_resolution,
|
||||||
thr_v=self.thr_v,
|
image_resolution=self.image_resolution,
|
||||||
thr_d=self.thr_d,
|
thr_v=self.thr_v,
|
||||||
)
|
thr_d=self.thr_d,
|
||||||
|
)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@ -358,14 +377,15 @@ class PidiImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
scribble: bool = InputField(default=False, description=FieldDescriptions.scribble_mode)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
pidi_processor = PidiNetDetector.from_pretrained("lllyasviel/Annotators")
|
with self.load_processor(PidiNetDetector) as pidi_processor:
|
||||||
processed_image = pidi_processor(
|
assert isinstance(pidi_processor, PidiNetDetector)
|
||||||
image,
|
processed_image = pidi_processor(
|
||||||
detect_resolution=self.detect_resolution,
|
image,
|
||||||
image_resolution=self.image_resolution,
|
detect_resolution=self.detect_resolution,
|
||||||
safe=self.safe,
|
image_resolution=self.image_resolution,
|
||||||
scribble=self.scribble,
|
safe=self.safe,
|
||||||
)
|
scribble=self.scribble,
|
||||||
|
)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@ -410,8 +430,9 @@ class ZoeDepthImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
"""Applies Zoe depth processing to image"""
|
"""Applies Zoe depth processing to image"""
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
zoe_depth_processor = ZoeDetector.from_pretrained("lllyasviel/Annotators")
|
with self.load_processor(ZoeDetector) as zoe_depth_processor:
|
||||||
processed_image = zoe_depth_processor(image)
|
assert isinstance(zoe_depth_processor, ZoeDetector)
|
||||||
|
processed_image: Image.Image = zoe_depth_processor(image)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
@ -459,6 +480,8 @@ class LeresImageProcessorInvocation(ImageProcessorInvocation):
|
|||||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
|
# LeresDetector requires two hard-coded models, which breaks the load_processor() pattern.
|
||||||
|
# TODO (LS): Modify download_and_cache() to accept multiple downloaded checkpoint files.
|
||||||
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
leres_processor = LeresDetector.from_pretrained("lllyasviel/Annotators")
|
||||||
processed_image = leres_processor(
|
processed_image = leres_processor(
|
||||||
image,
|
image,
|
||||||
@ -525,14 +548,17 @@ class SegmentAnythingProcessorInvocation(ImageProcessorInvocation):
|
|||||||
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
image_resolution: int = InputField(default=512, ge=1, description=FieldDescriptions.image_res)
|
||||||
|
|
||||||
def run_processor(self, image: Image.Image) -> Image.Image:
|
def run_processor(self, image: Image.Image) -> Image.Image:
|
||||||
# segment_anything_processor = SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")
|
model_path = self._context.models.download_and_cache_model(
|
||||||
segment_anything_processor = SamDetectorReproducibleColors.from_pretrained(
|
source="ybelkada/segment-anything::/checkpoints/sam_vit_h_4b8939.pth", preserve_subfolders=True
|
||||||
"ybelkada/segment-anything", subfolder="checkpoints"
|
|
||||||
)
|
|
||||||
np_img = np.array(image, dtype=np.uint8)
|
|
||||||
processed_image = segment_anything_processor(
|
|
||||||
np_img, image_resolution=self.image_resolution, detect_resolution=self.detect_resolution
|
|
||||||
)
|
)
|
||||||
|
with self._context.models.load_local_model(
|
||||||
|
model_path, loader=lambda x: SamDetectorReproducibleColors.from_pretrained(x)
|
||||||
|
) as segment_anything_processor:
|
||||||
|
assert isinstance(segment_anything_processor, SamDetectorReproducibleColors)
|
||||||
|
np_img = np.array(image, dtype=np.uint8)
|
||||||
|
processed_image: Image.Image = segment_anything_processor(
|
||||||
|
np_img, image_resolution=self.image_resolution, detect_resolution=self.detect_resolution
|
||||||
|
)
|
||||||
return processed_image
|
return processed_image
|
||||||
|
|
||||||
|
|
||||||
|
@ -243,11 +243,17 @@ class ModelInstallServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def download_and_cache_model(self, source: str | AnyHttpUrl) -> Path:
|
def download_and_cache_model(
|
||||||
|
self,
|
||||||
|
source: str | AnyHttpUrl,
|
||||||
|
preserve_subfolders: bool = False,
|
||||||
|
) -> Path:
|
||||||
"""
|
"""
|
||||||
Download the model file located at source to the models cache and return its Path.
|
Download the model file located at source to the models cache and return its Path.
|
||||||
|
|
||||||
:param source: A string representing a URL or repo_id.
|
:param source: A string representing a URL or repo_id.
|
||||||
|
:param preserve_subfolders: (optional) If True, the subfolder hierarchy will be preserved;
|
||||||
|
otherwise flattened.
|
||||||
|
|
||||||
The model file will be downloaded into the system-wide model cache
|
The model file will be downloaded into the system-wide model cache
|
||||||
(`models/.cache`) if it isn't already there. Note that the model cache
|
(`models/.cache`) if it isn't already there. Note that the model cache
|
||||||
|
@ -373,8 +373,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def download_and_cache_model(
|
def download_and_cache_model(
|
||||||
self,
|
self,
|
||||||
source: str | AnyHttpUrl,
|
source: str | AnyHttpUrl,
|
||||||
|
preserve_subfolders: bool = False,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""Download the model file located at source to the models cache and return its Path."""
|
"""Download the model file located at source to the models cache and return its Path."""
|
||||||
|
model_source = self._guess_source(str(source))
|
||||||
model_path = self._download_cache_path(str(source), self._app_config)
|
model_path = self._download_cache_path(str(source), self._app_config)
|
||||||
|
|
||||||
# We expect the cache directory to contain one and only one downloaded file or directory.
|
# We expect the cache directory to contain one and only one downloaded file or directory.
|
||||||
@ -386,12 +388,12 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
return contents[0]
|
return contents[0]
|
||||||
|
|
||||||
model_path.mkdir(parents=True, exist_ok=True)
|
model_path.mkdir(parents=True, exist_ok=True)
|
||||||
model_source = self._guess_source(str(source))
|
|
||||||
remote_files, _ = self._remote_files_from_source(model_source)
|
remote_files, _ = self._remote_files_from_source(model_source)
|
||||||
job = self._multifile_download(
|
job = self._multifile_download(
|
||||||
dest=model_path,
|
dest=model_path,
|
||||||
remote_files=remote_files,
|
remote_files=remote_files,
|
||||||
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
|
subfolder=model_source.subfolder if isinstance(model_source, HFModelSource) else None,
|
||||||
|
preserve_subfolders=preserve_subfolders,
|
||||||
)
|
)
|
||||||
files_string = "file" if len(remote_files) == 1 else "files"
|
files_string = "file" if len(remote_files) == 1 else "files"
|
||||||
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
|
self._logger.info(f"Queuing model download: {source} ({len(remote_files)} {files_string})")
|
||||||
@ -773,15 +775,22 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
subfolder: Optional[Path] = None,
|
subfolder: Optional[Path] = None,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
submit_job: bool = True,
|
submit_job: bool = True,
|
||||||
|
preserve_subfolders: bool = False,
|
||||||
) -> MultiFileDownloadJob:
|
) -> MultiFileDownloadJob:
|
||||||
# HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
|
# HuggingFace repo subfolders are a little tricky. If the name of the model is "sdxl-turbo", and
|
||||||
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
|
# we are installing the "vae" subfolder, we do not want to create an additional folder level, such
|
||||||
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
|
# as "sdxl-turbo/vae", nor do we want to put the contents of the vae folder directly into "sdxl-turbo".
|
||||||
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
|
# So what we do is to synthesize a folder named "sdxl-turbo_vae" here.
|
||||||
|
# The exception is when preserve_subfolders is true, in which case we keep the hierarchy
|
||||||
|
# of subfolders and return the path to the last enclosing subfolder.
|
||||||
if subfolder:
|
if subfolder:
|
||||||
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
|
if preserve_subfolders:
|
||||||
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
|
path_to_remove = remote_files[0].path.parts[0]
|
||||||
path_to_add = Path(f"{top}_{subfolder}")
|
path_to_add = Path(".")
|
||||||
|
else:
|
||||||
|
top = Path(remote_files[0].path.parts[0]) # e.g. "sdxl-turbo/"
|
||||||
|
path_to_remove = top / subfolder.parts[-1] # sdxl-turbo/vae/
|
||||||
|
path_to_add = Path(f"{top}_{subfolder}")
|
||||||
else:
|
else:
|
||||||
path_to_remove = Path(".")
|
path_to_remove = Path(".")
|
||||||
path_to_add = Path(".")
|
path_to_add = Path(".")
|
||||||
|
@ -447,6 +447,7 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
def download_and_cache_model(
|
def download_and_cache_model(
|
||||||
self,
|
self,
|
||||||
source: str | AnyHttpUrl,
|
source: str | AnyHttpUrl,
|
||||||
|
preserve_subfolders: bool = False,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
"""
|
"""
|
||||||
Download the model file located at source to the models cache and return its Path.
|
Download the model file located at source to the models cache and return its Path.
|
||||||
@ -457,11 +458,14 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
source: A URL that points to the model, or a huggingface repo_id.
|
source: A URL that points to the model, or a huggingface repo_id.
|
||||||
|
preserve_subfolders: (optional, False) If True, then preserve subfolder structure.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Path to the downloaded model
|
Path to the downloaded model (file or directory)
|
||||||
"""
|
"""
|
||||||
return self._services.model_manager.install.download_and_cache_model(source=source)
|
return self._services.model_manager.install.download_and_cache_model(
|
||||||
|
source=source, preserve_subfolders=preserve_subfolders
|
||||||
|
)
|
||||||
|
|
||||||
def load_local_model(
|
def load_local_model(
|
||||||
self,
|
self,
|
||||||
|
@ -289,7 +289,13 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)
|
target_device, copy=True, non_blocking=TorchDevice.get_non_blocking(target_device)
|
||||||
)
|
)
|
||||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||||
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
|
try:
|
||||||
|
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
|
||||||
|
except TypeError as e:
|
||||||
|
if "got an unexpected keyword argument 'non_blocking' in str(e)":
|
||||||
|
cache_entry.model.to(target_device)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
cache_entry.device = target_device
|
cache_entry.device = target_device
|
||||||
except Exception as e: # blow away cache entry
|
except Exception as e: # blow away cache entry
|
||||||
self._delete_cache_entry(cache_entry)
|
self._delete_cache_entry(cache_entry)
|
||||||
|
Loading…
Reference in New Issue
Block a user