mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): add control adapter default settings while probing
This commit is contained in:
parent
80c2a4b925
commit
13bb3c5e15
@ -14,6 +14,7 @@ from invokeai.backend.util.util import SilenceWarnings
|
||||
from .config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ControlAdapterDefaultSettings,
|
||||
InvalidModelConfigException,
|
||||
ModelConfigFactory,
|
||||
ModelFormat,
|
||||
@ -159,6 +160,12 @@ class ModelProbe(object):
|
||||
fields["format"] = fields.get("format") or probe.get_format()
|
||||
fields["hash"] = fields.get("hash") or ModelHash().hash(model_path)
|
||||
|
||||
fields["default_settings"] = (
|
||||
fields.get("default_settings") or probe.get_default_settings(fields["name"])
|
||||
if isinstance(probe, ControlAdapterProbe)
|
||||
else None
|
||||
)
|
||||
|
||||
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
|
||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||
|
||||
@ -329,6 +336,36 @@ class ModelProbe(object):
|
||||
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||
|
||||
|
||||
class ControlAdapterProbe(ProbeBase):
|
||||
"""Adds `get_default_settings` for ControlNet and T2IAdapter probes"""
|
||||
|
||||
# TODO(psyche): It would be nice to get these from the invocations, but that creates circular dependencies.
|
||||
# "canny": CannyImageProcessorInvocation.get_type()
|
||||
MODEL_NAME_TO_PREPROCESSOR = {
|
||||
"canny": "canny_image_processor",
|
||||
"mlsd": "mlsd_image_processor",
|
||||
"depth": "depth_anything_image_processor",
|
||||
"bae": "normalbae_image_processor",
|
||||
"normal_bae": "normalbae_image_processor",
|
||||
"sketch": "pidi_image_processor",
|
||||
"scribble": "lineart_image_processor",
|
||||
"lineart": "lineart_image_processor",
|
||||
"lineart_anime": "lineart_anime_image_processor",
|
||||
"softedge": "hed_image_processor",
|
||||
"shuffle": "content_shuffle_image_processor",
|
||||
"openpose": "dw_openpose_image_processor",
|
||||
"mediapipe": "mediapipe_face_processor",
|
||||
"pidi": "pidi_image_processor",
|
||||
"zoe": "zoe_depth_image_processor",
|
||||
"color": "color_map_image_processor",
|
||||
}
|
||||
|
||||
def get_default_settings(self, model_name: str) -> Optional[ControlAdapterDefaultSettings]:
|
||||
if model_name in self.MODEL_NAME_TO_PREPROCESSOR:
|
||||
return ControlAdapterDefaultSettings(preprocessor=self.MODEL_NAME_TO_PREPROCESSOR[model_name])
|
||||
return None
|
||||
|
||||
|
||||
# ##################################################3
|
||||
# Checkpoint probing
|
||||
# ##################################################3
|
||||
@ -452,7 +489,7 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
||||
raise InvalidModelConfigException(f"{self.model_path}: Could not determine base type")
|
||||
|
||||
|
||||
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
class ControlNetCheckpointProbe(CheckpointProbeBase, ControlAdapterProbe):
|
||||
"""Class for probing controlnets."""
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
@ -480,7 +517,7 @@ class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
class T2IAdapterCheckpointProbe(CheckpointProbeBase, ControlAdapterProbe):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -618,7 +655,7 @@ class ONNXFolderProbe(PipelineFolderProbe):
|
||||
return ModelVariantType.Normal
|
||||
|
||||
|
||||
class ControlNetFolderProbe(FolderProbeBase):
|
||||
class ControlNetFolderProbe(FolderProbeBase, ControlAdapterProbe):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
config_file = self.model_path / "config.json"
|
||||
if not config_file.exists():
|
||||
@ -692,7 +729,7 @@ class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
class T2IAdapterFolderProbe(FolderProbeBase):
|
||||
class T2IAdapterFolderProbe(FolderProbeBase, ControlAdapterProbe):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
config_file = self.model_path / "config.json"
|
||||
if not config_file.exists():
|
||||
|
Loading…
Reference in New Issue
Block a user