From 13bb3c5e15407639b5872d0deec3fc4155c09b4f Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Fri, 8 Mar 2024 16:01:43 +1100 Subject: [PATCH] feat(mm): add control adapter default settings while probing --- invokeai/backend/model_manager/probe.py | 45 ++++++++++++++++++++++--- 1 file changed, 41 insertions(+), 4 deletions(-) diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 75925dcf0b..49c25fa843 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -14,6 +14,7 @@ from invokeai.backend.util.util import SilenceWarnings from .config import ( AnyModelConfig, BaseModelType, + ControlAdapterDefaultSettings, InvalidModelConfigException, ModelConfigFactory, ModelFormat, @@ -159,6 +160,12 @@ class ModelProbe(object): fields["format"] = fields.get("format") or probe.get_format() fields["hash"] = fields.get("hash") or ModelHash().hash(model_path) + fields["default_settings"] = ( + fields.get("default_settings") or probe.get_default_settings(fields["name"]) + if isinstance(probe, ControlAdapterProbe) + else None + ) + if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase): fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant() @@ -329,6 +336,36 @@ class ModelProbe(object): raise Exception("The model {model_name} is potentially infected by malware. Aborting import.") +class ControlAdapterProbe(ProbeBase): + """Adds `get_default_settings` for ControlNet and T2IAdapter probes""" + + # TODO(psyche): It would be nice to get these from the invocations, but that creates circular dependencies. + # "canny": CannyImageProcessorInvocation.get_type() + MODEL_NAME_TO_PREPROCESSOR = { + "canny": "canny_image_processor", + "mlsd": "mlsd_image_processor", + "depth": "depth_anything_image_processor", + "bae": "normalbae_image_processor", + "normal_bae": "normalbae_image_processor", + "sketch": "pidi_image_processor", + "scribble": "lineart_image_processor", + "lineart": "lineart_image_processor", + "lineart_anime": "lineart_anime_image_processor", + "softedge": "hed_image_processor", + "shuffle": "content_shuffle_image_processor", + "openpose": "dw_openpose_image_processor", + "mediapipe": "mediapipe_face_processor", + "pidi": "pidi_image_processor", + "zoe": "zoe_depth_image_processor", + "color": "color_map_image_processor", + } + + def get_default_settings(self, model_name: str) -> Optional[ControlAdapterDefaultSettings]: + if model_name in self.MODEL_NAME_TO_PREPROCESSOR: + return ControlAdapterDefaultSettings(preprocessor=self.MODEL_NAME_TO_PREPROCESSOR[model_name]) + return None + + # ##################################################3 # Checkpoint probing # ##################################################3 @@ -452,7 +489,7 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase): raise InvalidModelConfigException(f"{self.model_path}: Could not determine base type") -class ControlNetCheckpointProbe(CheckpointProbeBase): +class ControlNetCheckpointProbe(CheckpointProbeBase, ControlAdapterProbe): """Class for probing controlnets.""" def get_base_type(self) -> BaseModelType: @@ -480,7 +517,7 @@ class CLIPVisionCheckpointProbe(CheckpointProbeBase): raise NotImplementedError() -class T2IAdapterCheckpointProbe(CheckpointProbeBase): +class T2IAdapterCheckpointProbe(CheckpointProbeBase, ControlAdapterProbe): def get_base_type(self) -> BaseModelType: raise NotImplementedError() @@ -618,7 +655,7 @@ class ONNXFolderProbe(PipelineFolderProbe): return ModelVariantType.Normal -class ControlNetFolderProbe(FolderProbeBase): +class ControlNetFolderProbe(FolderProbeBase, ControlAdapterProbe): def get_base_type(self) -> BaseModelType: config_file = self.model_path / "config.json" if not config_file.exists(): @@ -692,7 +729,7 @@ class CLIPVisionFolderProbe(FolderProbeBase): return BaseModelType.Any -class T2IAdapterFolderProbe(FolderProbeBase): +class T2IAdapterFolderProbe(FolderProbeBase, ControlAdapterProbe): def get_base_type(self) -> BaseModelType: config_file = self.model_path / "config.json" if not config_file.exists():