From a40fa8e83b5b72ec3ace9c15a367dcfc0230f2e5 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 20 Jul 2024 18:34:57 -0400 Subject: [PATCH] simplify mapping of detectors to controlnet processors; tweaked download cache name algorithm --- .../invocations/controlnet_image_processors.py | 12 ++++++------ invokeai/backend/util/util.py | 2 +- tests/app/services/model_load/test_load_api.py | 15 ++++++++++++++- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/invokeai/app/invocations/controlnet_image_processors.py b/invokeai/app/invocations/controlnet_image_processors.py index 4a74624818..808fbb636e 100644 --- a/invokeai/app/invocations/controlnet_image_processors.py +++ b/invokeai/app/invocations/controlnet_image_processors.py @@ -139,11 +139,11 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): image: ImageField = InputField(description="The image to process") # Map controlnet_aux detector classes to model files in "lllyasviel/Annotators" - CONTROLNET_PROCESSORS: ClassVar[Dict[type, Path]] = { - MidasDetector: Path("dpt_hybrid-midas-501f0c75.pt"), - MLSDdetector: Path("mlsd_large_512_fp32.pth"), - PidiNetDetector: Path("table5_pidinet.pth"), - ZoeDetector: Path("ZoeD_M12_N.pt"), + CONTROLNET_PROCESSORS: ClassVar[Dict[type, str]] = { + MidasDetector: "lllyasviel/Annotators::/dpt_hybrid-midas-501f0c75.pt", + MLSDdetector: "lllyasviel/Annotators::/mlsd_large_512_fp32.pth", + PidiNetDetector: "lllyasviel/Annotators::/table5_pidinet.pth", + ZoeDetector: "lllyasviel/Annotators::/ZoeD_M12_N.pt", } def run_processor(self, image: Image.Image) -> Image.Image: @@ -155,7 +155,7 @@ class ImageProcessorInvocation(BaseInvocation, WithMetadata, WithBoard): return context.images.get_pil(self.image.image_name, "RGB") def load_processor(self, processor: type) -> LoadedModelWithoutConfig: - remote_source = f"lllyasviel/Annotators::/{self.CONTROLNET_PROCESSORS[processor]}" + remote_source = self.CONTROLNET_PROCESSORS[processor] assert hasattr(processor, "from_pretrained") # no common base class for the controlnet processors! model = self._context.models.load_remote_model( source=remote_source, loader=lambda x: processor.from_pretrained(x.parent, filename=x.name) diff --git a/invokeai/backend/util/util.py b/invokeai/backend/util/util.py index b3466ddba9..c1582bb3b7 100644 --- a/invokeai/backend/util/util.py +++ b/invokeai/backend/util/util.py @@ -26,7 +26,7 @@ def slugify(value: str, allow_unicode: bool = False) -> str: value = unicodedata.normalize("NFKC", value) else: value = unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii") - value = re.sub(r"[/]", "_", value.lower()) + value = re.sub(r"[/:]+", "_", value.lower()) value = re.sub(r"[^.\w\s-]", "", value.lower()) return re.sub(r"[-\s]+", "-", value).strip("-_") diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py index 6f2c7bd931..17628cfc53 100644 --- a/tests/app/services/model_load/test_load_api.py +++ b/tests/app/services/model_load/test_load_api.py @@ -81,8 +81,21 @@ def test_download_diffusers(mock_context: InvocationContext) -> None: def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None: - model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::vae") + model_path = mock_context.models.download_and_cache_model("stabilityai/sdxl-turbo::/vae") assert model_path.is_dir() + assert model_path.name != "vae" # will not create the vae subfolder with preserve_subfolders False assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or ( model_path / "diffusion_pytorch_model.safetensors" ).exists() + +def test_download_diffusers_preserve_subfolders(mock_context: InvocationContext) -> None: + model_path = mock_context.models.download_and_cache_model( + "stabilityai/sdxl-turbo::/vae", + preserve_subfolders=True, + ) + assert model_path.is_dir() + assert model_path.name == "vae" # will create the vae subfolder with preserve_subfolders True + assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or ( + model_path / "diffusion_pytorch_model.safetensors" + ).exists() +