mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): probe for main model default settings
Currently, this is just the width and height, derived from the model base.
This commit is contained in:
parent
2584a950aa
commit
3fd824306c
@ -17,6 +17,7 @@ from .config import (
|
|||||||
BaseModelType,
|
BaseModelType,
|
||||||
ControlAdapterDefaultSettings,
|
ControlAdapterDefaultSettings,
|
||||||
InvalidModelConfigException,
|
InvalidModelConfigException,
|
||||||
|
MainModelDefaultSettings,
|
||||||
ModelConfigFactory,
|
ModelConfigFactory,
|
||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelRepoVariant,
|
ModelRepoVariant,
|
||||||
@ -160,11 +161,13 @@ class ModelProbe(object):
|
|||||||
fields["format"] = fields.get("format") or probe.get_format()
|
fields["format"] = fields.get("format") or probe.get_format()
|
||||||
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
fields["hash"] = fields.get("hash") or ModelHash(algorithm=hash_algo).hash(model_path)
|
||||||
|
|
||||||
fields["default_settings"] = (
|
fields["default_settings"] = fields.get("default_settings")
|
||||||
fields.get("default_settings") or probe.get_default_settings(fields["name"])
|
|
||||||
if isinstance(probe, ControlAdapterProbe)
|
if not fields["default_settings"]:
|
||||||
else None
|
if fields["type"] in {ModelType.ControlNet, ModelType.T2IAdapter}:
|
||||||
)
|
fields["default_settings"] = get_default_settings_controlnet_t2i_adapter(fields["name"])
|
||||||
|
elif fields["type"] is ModelType.Main:
|
||||||
|
fields["default_settings"] = get_default_settings_main(fields["base"])
|
||||||
|
|
||||||
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
|
if format_type == ModelFormat.Diffusers and isinstance(probe, FolderProbeBase):
|
||||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||||
@ -336,36 +339,41 @@ class ModelProbe(object):
|
|||||||
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
raise Exception("The model {model_name} is potentially infected by malware. Aborting import.")
|
||||||
|
|
||||||
|
|
||||||
class ControlAdapterProbe(ProbeBase):
|
# Probing utilities
|
||||||
"""Adds `get_default_settings` for ControlNet and T2IAdapter probes"""
|
MODEL_NAME_TO_PREPROCESSOR = {
|
||||||
|
"canny": "canny_image_processor",
|
||||||
|
"mlsd": "mlsd_image_processor",
|
||||||
|
"depth": "depth_anything_image_processor",
|
||||||
|
"bae": "normalbae_image_processor",
|
||||||
|
"normal": "normalbae_image_processor",
|
||||||
|
"sketch": "pidi_image_processor",
|
||||||
|
"scribble": "lineart_image_processor",
|
||||||
|
"lineart": "lineart_image_processor",
|
||||||
|
"lineart_anime": "lineart_anime_image_processor",
|
||||||
|
"softedge": "hed_image_processor",
|
||||||
|
"shuffle": "content_shuffle_image_processor",
|
||||||
|
"pose": "dw_openpose_image_processor",
|
||||||
|
"mediapipe": "mediapipe_face_processor",
|
||||||
|
"pidi": "pidi_image_processor",
|
||||||
|
"zoe": "zoe_depth_image_processor",
|
||||||
|
"color": "color_map_image_processor",
|
||||||
|
}
|
||||||
|
|
||||||
# TODO(psyche): It would be nice to get these from the invocations, but that creates circular dependencies.
|
|
||||||
# "canny": CannyImageProcessorInvocation.get_type()
|
|
||||||
MODEL_NAME_TO_PREPROCESSOR = {
|
|
||||||
"canny": "canny_image_processor",
|
|
||||||
"mlsd": "mlsd_image_processor",
|
|
||||||
"depth": "depth_anything_image_processor",
|
|
||||||
"bae": "normalbae_image_processor",
|
|
||||||
"normal": "normalbae_image_processor",
|
|
||||||
"sketch": "pidi_image_processor",
|
|
||||||
"scribble": "lineart_image_processor",
|
|
||||||
"lineart": "lineart_image_processor",
|
|
||||||
"lineart_anime": "lineart_anime_image_processor",
|
|
||||||
"softedge": "hed_image_processor",
|
|
||||||
"shuffle": "content_shuffle_image_processor",
|
|
||||||
"pose": "dw_openpose_image_processor",
|
|
||||||
"mediapipe": "mediapipe_face_processor",
|
|
||||||
"pidi": "pidi_image_processor",
|
|
||||||
"zoe": "zoe_depth_image_processor",
|
|
||||||
"color": "color_map_image_processor",
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
def get_default_settings_controlnet_t2i_adapter(model_name: str) -> Optional[ControlAdapterDefaultSettings]:
|
||||||
def get_default_settings(cls, model_name: str) -> Optional[ControlAdapterDefaultSettings]:
|
for k, v in MODEL_NAME_TO_PREPROCESSOR.items():
|
||||||
for k, v in cls.MODEL_NAME_TO_PREPROCESSOR.items():
|
if k in model_name:
|
||||||
if k in model_name:
|
return ControlAdapterDefaultSettings(preprocessor=v)
|
||||||
return ControlAdapterDefaultSettings(preprocessor=v)
|
return None
|
||||||
return None
|
|
||||||
|
|
||||||
|
def get_default_settings_main(model_base: BaseModelType) -> Optional[MainModelDefaultSettings]:
|
||||||
|
if model_base is BaseModelType.StableDiffusion1 or model_base is BaseModelType.StableDiffusion2:
|
||||||
|
return MainModelDefaultSettings(width=512, height=512)
|
||||||
|
elif model_base is BaseModelType.StableDiffusionXL:
|
||||||
|
return MainModelDefaultSettings(width=1024, height=1024)
|
||||||
|
# We don't provide defaults for BaseModelType.StableDiffusionXLRefiner, as they are not standalone models.
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# ##################################################3
|
# ##################################################3
|
||||||
@ -491,7 +499,7 @@ class TextualInversionCheckpointProbe(CheckpointProbeBase):
|
|||||||
raise InvalidModelConfigException(f"{self.model_path}: Could not determine base type")
|
raise InvalidModelConfigException(f"{self.model_path}: Could not determine base type")
|
||||||
|
|
||||||
|
|
||||||
class ControlNetCheckpointProbe(CheckpointProbeBase, ControlAdapterProbe):
|
class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||||
"""Class for probing controlnets."""
|
"""Class for probing controlnets."""
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
@ -519,7 +527,7 @@ class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapterCheckpointProbe(CheckpointProbeBase, ControlAdapterProbe):
|
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@ -657,7 +665,7 @@ class ONNXFolderProbe(PipelineFolderProbe):
|
|||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
|
|
||||||
class ControlNetFolderProbe(FolderProbeBase, ControlAdapterProbe):
|
class ControlNetFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
config_file = self.model_path / "config.json"
|
config_file = self.model_path / "config.json"
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
@ -731,7 +739,7 @@ class CLIPVisionFolderProbe(FolderProbeBase):
|
|||||||
return BaseModelType.Any
|
return BaseModelType.Any
|
||||||
|
|
||||||
|
|
||||||
class T2IAdapterFolderProbe(FolderProbeBase, ControlAdapterProbe):
|
class T2IAdapterFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
config_file = self.model_path / "config.json"
|
config_file = self.model_path / "config.json"
|
||||||
if not config_file.exists():
|
if not config_file.exists():
|
||||||
|
Loading…
Reference in New Issue
Block a user