mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(mm): use canonical capitalization for all model-related enums, classes
For example, "Lora" -> "LoRA", "Vae" -> "VAE".
This commit is contained in:
@ -133,7 +133,7 @@ class MainModelLoaderInvocation(BaseInvocation):
|
|||||||
vae=VaeField(
|
vae=VaeField(
|
||||||
vae=ModelInfo(
|
vae=ModelInfo(
|
||||||
key=key,
|
key=key,
|
||||||
submodel_type=SubModelType.Vae,
|
submodel_type=SubModelType.VAE,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -85,7 +85,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
vae=VaeField(
|
vae=VaeField(
|
||||||
vae=ModelInfo(
|
vae=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel_type=SubModelType.Vae,
|
submodel_type=SubModelType.VAE,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@ -142,7 +142,7 @@ class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
|||||||
vae=VaeField(
|
vae=VaeField(
|
||||||
vae=ModelInfo(
|
vae=ModelInfo(
|
||||||
key=model_key,
|
key=model_key,
|
||||||
submodel_type=SubModelType.Vae,
|
submodel_type=SubModelType.VAE,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -59,8 +59,8 @@ class ModelType(str, Enum):
|
|||||||
|
|
||||||
ONNX = "onnx"
|
ONNX = "onnx"
|
||||||
Main = "main"
|
Main = "main"
|
||||||
Vae = "vae"
|
VAE = "vae"
|
||||||
Lora = "lora"
|
LoRA = "lora"
|
||||||
ControlNet = "controlnet" # used by model_probe
|
ControlNet = "controlnet" # used by model_probe
|
||||||
TextualInversion = "embedding"
|
TextualInversion = "embedding"
|
||||||
IPAdapter = "ip_adapter"
|
IPAdapter = "ip_adapter"
|
||||||
@ -76,9 +76,9 @@ class SubModelType(str, Enum):
|
|||||||
TextEncoder2 = "text_encoder_2"
|
TextEncoder2 = "text_encoder_2"
|
||||||
Tokenizer = "tokenizer"
|
Tokenizer = "tokenizer"
|
||||||
Tokenizer2 = "tokenizer_2"
|
Tokenizer2 = "tokenizer_2"
|
||||||
Vae = "vae"
|
VAE = "vae"
|
||||||
VaeDecoder = "vae_decoder"
|
VAEDecoder = "vae_decoder"
|
||||||
VaeEncoder = "vae_encoder"
|
VAEEncoder = "vae_encoder"
|
||||||
Scheduler = "scheduler"
|
Scheduler = "scheduler"
|
||||||
SafetyChecker = "safety_checker"
|
SafetyChecker = "safety_checker"
|
||||||
|
|
||||||
@ -96,8 +96,8 @@ class ModelFormat(str, Enum):
|
|||||||
|
|
||||||
Diffusers = "diffusers"
|
Diffusers = "diffusers"
|
||||||
Checkpoint = "checkpoint"
|
Checkpoint = "checkpoint"
|
||||||
Lycoris = "lycoris"
|
LyCORIS = "lycoris"
|
||||||
Onnx = "onnx"
|
ONNX = "onnx"
|
||||||
Olive = "olive"
|
Olive = "olive"
|
||||||
EmbeddingFile = "embedding_file"
|
EmbeddingFile = "embedding_file"
|
||||||
EmbeddingFolder = "embedding_folder"
|
EmbeddingFolder = "embedding_folder"
|
||||||
@ -115,12 +115,12 @@ class SchedulerPredictionType(str, Enum):
|
|||||||
class ModelRepoVariant(str, Enum):
|
class ModelRepoVariant(str, Enum):
|
||||||
"""Various hugging face variants on the diffusers format."""
|
"""Various hugging face variants on the diffusers format."""
|
||||||
|
|
||||||
DEFAULT = "" # model files without "fp16" or other qualifier - empty str
|
Default = "" # model files without "fp16" or other qualifier - empty str
|
||||||
FP16 = "fp16"
|
FP16 = "fp16"
|
||||||
FP32 = "fp32"
|
FP32 = "fp32"
|
||||||
ONNX = "onnx"
|
ONNX = "onnx"
|
||||||
OPENVINO = "openvino"
|
OpenVINO = "openvino"
|
||||||
FLAX = "flax"
|
Flax = "flax"
|
||||||
|
|
||||||
|
|
||||||
class ModelSourceType(str, Enum):
|
class ModelSourceType(str, Enum):
|
||||||
@ -183,51 +183,51 @@ class DiffusersConfigBase(ModelConfigBase):
|
|||||||
"""Model config for diffusers-style models."""
|
"""Model config for diffusers-style models."""
|
||||||
|
|
||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.DEFAULT
|
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default
|
||||||
|
|
||||||
|
|
||||||
class LoRALycorisConfig(ModelConfigBase):
|
class LoRALyCORISConfig(ModelConfigBase):
|
||||||
"""Model config for LoRA/Lycoris models."""
|
"""Model config for LoRA/Lycoris models."""
|
||||||
|
|
||||||
type: Literal[ModelType.Lora] = ModelType.Lora
|
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||||
format: Literal[ModelFormat.Lycoris] = ModelFormat.Lycoris
|
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_tag() -> Tag:
|
def get_tag() -> Tag:
|
||||||
return Tag(f"{ModelType.Lora.value}.{ModelFormat.Lycoris.value}")
|
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")
|
||||||
|
|
||||||
|
|
||||||
class LoRADiffusersConfig(ModelConfigBase):
|
class LoRADiffusersConfig(ModelConfigBase):
|
||||||
"""Model config for LoRA/Diffusers models."""
|
"""Model config for LoRA/Diffusers models."""
|
||||||
|
|
||||||
type: Literal[ModelType.Lora] = ModelType.Lora
|
type: Literal[ModelType.LoRA] = ModelType.LoRA
|
||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_tag() -> Tag:
|
def get_tag() -> Tag:
|
||||||
return Tag(f"{ModelType.Lora.value}.{ModelFormat.Diffusers.value}")
|
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.Diffusers.value}")
|
||||||
|
|
||||||
|
|
||||||
class VaeCheckpointConfig(CheckpointConfigBase):
|
class VAECheckpointConfig(CheckpointConfigBase):
|
||||||
"""Model config for standalone VAE models."""
|
"""Model config for standalone VAE models."""
|
||||||
|
|
||||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||||
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
format: Literal[ModelFormat.Checkpoint] = ModelFormat.Checkpoint
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_tag() -> Tag:
|
def get_tag() -> Tag:
|
||||||
return Tag(f"{ModelType.Vae.value}.{ModelFormat.Checkpoint.value}")
|
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Checkpoint.value}")
|
||||||
|
|
||||||
|
|
||||||
class VaeDiffusersConfig(ModelConfigBase):
|
class VAEDiffusersConfig(ModelConfigBase):
|
||||||
"""Model config for standalone VAE models (diffusers version)."""
|
"""Model config for standalone VAE models (diffusers version)."""
|
||||||
|
|
||||||
type: Literal[ModelType.Vae] = ModelType.Vae
|
type: Literal[ModelType.VAE] = ModelType.VAE
|
||||||
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_tag() -> Tag:
|
def get_tag() -> Tag:
|
||||||
return Tag(f"{ModelType.Vae.value}.{ModelFormat.Diffusers.value}")
|
return Tag(f"{ModelType.VAE.value}.{ModelFormat.Diffusers.value}")
|
||||||
|
|
||||||
|
|
||||||
class ControlNetDiffusersConfig(DiffusersConfigBase):
|
class ControlNetDiffusersConfig(DiffusersConfigBase):
|
||||||
@ -356,11 +356,11 @@ AnyModelConfig = Annotated[
|
|||||||
Union[
|
Union[
|
||||||
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
Annotated[MainDiffusersConfig, MainDiffusersConfig.get_tag()],
|
||||||
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
Annotated[MainCheckpointConfig, MainCheckpointConfig.get_tag()],
|
||||||
Annotated[VaeDiffusersConfig, VaeDiffusersConfig.get_tag()],
|
Annotated[VAEDiffusersConfig, VAEDiffusersConfig.get_tag()],
|
||||||
Annotated[VaeCheckpointConfig, VaeCheckpointConfig.get_tag()],
|
Annotated[VAECheckpointConfig, VAECheckpointConfig.get_tag()],
|
||||||
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
Annotated[ControlNetDiffusersConfig, ControlNetDiffusersConfig.get_tag()],
|
||||||
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
Annotated[ControlNetCheckpointConfig, ControlNetCheckpointConfig.get_tag()],
|
||||||
Annotated[LoRALycorisConfig, LoRALycorisConfig.get_tag()],
|
Annotated[LoRALyCORISConfig, LoRALyCORISConfig.get_tag()],
|
||||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||||
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||||
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||||
|
@ -22,8 +22,8 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
|
|||||||
from .. import ModelLoader, ModelLoaderRegistry
|
from .. import ModelLoader, ModelLoaderRegistry
|
||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.Diffusers)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.LoRA, format=ModelFormat.LyCORIS)
|
||||||
class LoraLoader(ModelLoader):
|
class LoraLoader(ModelLoader):
|
||||||
"""Class to load LoRA models."""
|
"""Class to load LoRA models."""
|
||||||
|
|
||||||
|
@ -18,7 +18,7 @@ from .. import ModelLoaderRegistry
|
|||||||
from .generic_diffusers import GenericDiffusersLoader
|
from .generic_diffusers import GenericDiffusersLoader
|
||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Onnx)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.ONNX)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.ONNX, format=ModelFormat.Olive)
|
||||||
class OnnyxDiffusersModel(GenericDiffusersLoader):
|
class OnnyxDiffusersModel(GenericDiffusersLoader):
|
||||||
"""Class to load onnx models."""
|
"""Class to load onnx models."""
|
||||||
|
@ -20,9 +20,9 @@ from .. import ModelLoaderRegistry
|
|||||||
from .generic_diffusers import GenericDiffusersLoader
|
from .generic_diffusers import GenericDiffusersLoader
|
||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.VAE, format=ModelFormat.Diffusers)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion1, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
|
@ModelLoaderRegistry.register(base=BaseModelType.StableDiffusion2, type=ModelType.VAE, format=ModelFormat.Checkpoint)
|
||||||
class VaeLoader(GenericDiffusersLoader):
|
class VaeLoader(GenericDiffusersLoader):
|
||||||
"""Class to load VAE models."""
|
"""Class to load VAE models."""
|
||||||
|
|
||||||
|
@ -97,8 +97,8 @@ class ModelProbe(object):
|
|||||||
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
"StableDiffusionXLImg2ImgPipeline": ModelType.Main,
|
||||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||||
"LatentConsistencyModelPipeline": ModelType.Main,
|
"LatentConsistencyModelPipeline": ModelType.Main,
|
||||||
"AutoencoderKL": ModelType.Vae,
|
"AutoencoderKL": ModelType.VAE,
|
||||||
"AutoencoderTiny": ModelType.Vae,
|
"AutoencoderTiny": ModelType.VAE,
|
||||||
"ControlNetModel": ModelType.ControlNet,
|
"ControlNetModel": ModelType.ControlNet,
|
||||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||||
"T2IAdapter": ModelType.T2IAdapter,
|
"T2IAdapter": ModelType.T2IAdapter,
|
||||||
@ -143,7 +143,7 @@ class ModelProbe(object):
|
|||||||
model_type = cls.get_model_type_from_folder(model_path)
|
model_type = cls.get_model_type_from_folder(model_path)
|
||||||
else:
|
else:
|
||||||
model_type = cls.get_model_type_from_checkpoint(model_path)
|
model_type = cls.get_model_type_from_checkpoint(model_path)
|
||||||
format_type = ModelFormat.Onnx if model_type == ModelType.ONNX else format_type
|
format_type = ModelFormat.ONNX if model_type == ModelType.ONNX else format_type
|
||||||
|
|
||||||
probe_class = cls.PROBES[format_type].get(model_type)
|
probe_class = cls.PROBES[format_type].get(model_type)
|
||||||
if not probe_class:
|
if not probe_class:
|
||||||
@ -172,7 +172,7 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
# additional fields needed for main and controlnet models
|
# additional fields needed for main and controlnet models
|
||||||
if (
|
if (
|
||||||
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.Vae]
|
fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.VAE]
|
||||||
and fields["format"] is ModelFormat.Checkpoint
|
and fields["format"] is ModelFormat.Checkpoint
|
||||||
):
|
):
|
||||||
fields["config_path"] = cls._get_checkpoint_config_path(
|
fields["config_path"] = cls._get_checkpoint_config_path(
|
||||||
@ -185,7 +185,7 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
# additional fields needed for main non-checkpoint models
|
# additional fields needed for main non-checkpoint models
|
||||||
elif fields["type"] == ModelType.Main and fields["format"] in [
|
elif fields["type"] == ModelType.Main and fields["format"] in [
|
||||||
ModelFormat.Onnx,
|
ModelFormat.ONNX,
|
||||||
ModelFormat.Olive,
|
ModelFormat.Olive,
|
||||||
ModelFormat.Diffusers,
|
ModelFormat.Diffusers,
|
||||||
]:
|
]:
|
||||||
@ -219,11 +219,11 @@ class ModelProbe(object):
|
|||||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||||
return ModelType.Main
|
return ModelType.Main
|
||||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||||
return ModelType.Vae
|
return ModelType.VAE
|
||||||
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
elif any(key.startswith(v) for v in {"lora_te_", "lora_unet_"}):
|
||||||
return ModelType.Lora
|
return ModelType.LoRA
|
||||||
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
elif any(key.endswith(v) for v in {"to_k_lora.up.weight", "to_q_lora.down.weight"}):
|
||||||
return ModelType.Lora
|
return ModelType.LoRA
|
||||||
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
elif any(key.startswith(v) for v in {"control_model", "input_blocks"}):
|
||||||
return ModelType.ControlNet
|
return ModelType.ControlNet
|
||||||
elif key in {"emb_params", "string_to_param"}:
|
elif key in {"emb_params", "string_to_param"}:
|
||||||
@ -245,7 +245,7 @@ class ModelProbe(object):
|
|||||||
if (folder_path / f"learned_embeds.{suffix}").exists():
|
if (folder_path / f"learned_embeds.{suffix}").exists():
|
||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
|
if (folder_path / f"pytorch_lora_weights.{suffix}").exists():
|
||||||
return ModelType.Lora
|
return ModelType.LoRA
|
||||||
if (folder_path / "unet/model.onnx").exists():
|
if (folder_path / "unet/model.onnx").exists():
|
||||||
return ModelType.ONNX
|
return ModelType.ONNX
|
||||||
if (folder_path / "image_encoder.txt").exists():
|
if (folder_path / "image_encoder.txt").exists():
|
||||||
@ -301,7 +301,7 @@ class ModelProbe(object):
|
|||||||
if base_type is BaseModelType.StableDiffusion1
|
if base_type is BaseModelType.StableDiffusion1
|
||||||
else "../controlnet/cldm_v21.yaml"
|
else "../controlnet/cldm_v21.yaml"
|
||||||
)
|
)
|
||||||
elif model_type is ModelType.Vae:
|
elif model_type is ModelType.VAE:
|
||||||
config_file = (
|
config_file = (
|
||||||
"../stable-diffusion/v1-inference.yaml"
|
"../stable-diffusion/v1-inference.yaml"
|
||||||
if base_type is BaseModelType.StableDiffusion1
|
if base_type is BaseModelType.StableDiffusion1
|
||||||
@ -511,12 +511,12 @@ class FolderProbeBase(ProbeBase):
|
|||||||
if ".fp16" in x.suffixes:
|
if ".fp16" in x.suffixes:
|
||||||
return ModelRepoVariant.FP16
|
return ModelRepoVariant.FP16
|
||||||
if "openvino_model" in x.name:
|
if "openvino_model" in x.name:
|
||||||
return ModelRepoVariant.OPENVINO
|
return ModelRepoVariant.OpenVINO
|
||||||
if "flax_model" in x.name:
|
if "flax_model" in x.name:
|
||||||
return ModelRepoVariant.FLAX
|
return ModelRepoVariant.Flax
|
||||||
if x.suffix == ".onnx":
|
if x.suffix == ".onnx":
|
||||||
return ModelRepoVariant.ONNX
|
return ModelRepoVariant.ONNX
|
||||||
return ModelRepoVariant.DEFAULT
|
return ModelRepoVariant.Default
|
||||||
|
|
||||||
|
|
||||||
class PipelineFolderProbe(FolderProbeBase):
|
class PipelineFolderProbe(FolderProbeBase):
|
||||||
@ -722,8 +722,8 @@ class T2IAdapterFolderProbe(FolderProbeBase):
|
|||||||
|
|
||||||
############## register probe classes ######
|
############## register probe classes ######
|
||||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||||
@ -731,8 +731,8 @@ ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderPro
|
|||||||
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
|
||||||
|
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.VAE, VaeCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.LoRA, LoRACheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||||
|
@ -35,7 +35,7 @@ def filter_files(
|
|||||||
The file list can be obtained from the `files` field of HuggingFaceMetadata,
|
The file list can be obtained from the `files` field of HuggingFaceMetadata,
|
||||||
as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
|
as defined in `invokeai.backend.model_manager.metadata.metadata_base`.
|
||||||
"""
|
"""
|
||||||
variant = variant or ModelRepoVariant.DEFAULT
|
variant = variant or ModelRepoVariant.Default
|
||||||
paths: List[Path] = []
|
paths: List[Path] = []
|
||||||
root = files[0].parts[0]
|
root = files[0].parts[0]
|
||||||
|
|
||||||
@ -90,11 +90,11 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
|||||||
result.add(path)
|
result.add(path)
|
||||||
|
|
||||||
elif "openvino_model" in path.name:
|
elif "openvino_model" in path.name:
|
||||||
if variant == ModelRepoVariant.OPENVINO:
|
if variant == ModelRepoVariant.OpenVINO:
|
||||||
result.add(path)
|
result.add(path)
|
||||||
|
|
||||||
elif "flax_model" in path.name:
|
elif "flax_model" in path.name:
|
||||||
if variant == ModelRepoVariant.FLAX:
|
if variant == ModelRepoVariant.Flax:
|
||||||
result.add(path)
|
result.add(path)
|
||||||
|
|
||||||
elif path.suffix in [".json", ".txt"]:
|
elif path.suffix in [".json", ".txt"]:
|
||||||
@ -103,7 +103,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
|||||||
elif variant in [
|
elif variant in [
|
||||||
ModelRepoVariant.FP16,
|
ModelRepoVariant.FP16,
|
||||||
ModelRepoVariant.FP32,
|
ModelRepoVariant.FP32,
|
||||||
ModelRepoVariant.DEFAULT,
|
ModelRepoVariant.Default,
|
||||||
] and path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"]:
|
] and path.suffix in [".bin", ".safetensors", ".pt", ".ckpt"]:
|
||||||
# For weights files, we want to select the best one for each subfolder. For example, we may have multiple
|
# For weights files, we want to select the best one for each subfolder. For example, we may have multiple
|
||||||
# text encoders:
|
# text encoders:
|
||||||
@ -127,7 +127,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
|||||||
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
|
# Some special handling is needed here if there is not an exact match and if we cannot infer the variant
|
||||||
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
|
# from the file name. In this case, we only give this file a point if the requested variant is FP32 or DEFAULT.
|
||||||
if candidate_variant_label == f".{variant}" or (
|
if candidate_variant_label == f".{variant}" or (
|
||||||
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.DEFAULT]
|
not candidate_variant_label and variant in [ModelRepoVariant.FP32, ModelRepoVariant.Default]
|
||||||
):
|
):
|
||||||
score += 1
|
score += 1
|
||||||
|
|
||||||
@ -148,7 +148,7 @@ def _filter_by_variant(files: List[Path], variant: ModelRepoVariant) -> Set[Path
|
|||||||
# config and text files then we return an empty list
|
# config and text files then we return an empty list
|
||||||
if (
|
if (
|
||||||
variant
|
variant
|
||||||
and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OPENVINO, ModelRepoVariant.FLAX]
|
and variant in [ModelRepoVariant.ONNX, ModelRepoVariant.OpenVINO, ModelRepoVariant.Flax]
|
||||||
and not any(variant.value in x.name for x in result)
|
and not any(variant.value in x.name for x in result)
|
||||||
):
|
):
|
||||||
return set()
|
return set()
|
||||||
|
@ -144,7 +144,7 @@ class addModelsForm(CyclingForm, npyscreen.FormMultiPage):
|
|||||||
|
|
||||||
self.nextrely = top_of_table
|
self.nextrely = top_of_table
|
||||||
self.lora_models = self.add_model_widgets(
|
self.lora_models = self.add_model_widgets(
|
||||||
model_type=ModelType.Lora,
|
model_type=ModelType.LoRA,
|
||||||
window_width=window_width,
|
window_width=window_width,
|
||||||
)
|
)
|
||||||
bottom_of_table = max(bottom_of_table, self.nextrely)
|
bottom_of_table = max(bottom_of_table, self.nextrely)
|
||||||
|
@ -23,7 +23,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
ModelSourceType,
|
ModelSourceType,
|
||||||
ModelType,
|
ModelType,
|
||||||
TextualInversionFileConfig,
|
TextualInversionFileConfig,
|
||||||
VaeDiffusersConfig,
|
VAEDiffusersConfig,
|
||||||
)
|
)
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
@ -141,12 +141,12 @@ def test_filter(store: ModelRecordServiceBase):
|
|||||||
source="test/source",
|
source="test/source",
|
||||||
source_type=ModelSourceType.Path,
|
source_type=ModelSourceType.Path,
|
||||||
)
|
)
|
||||||
config3 = VaeDiffusersConfig(
|
config3 = VAEDiffusersConfig(
|
||||||
key="config3",
|
key="config3",
|
||||||
path="/tmp/config3",
|
path="/tmp/config3",
|
||||||
name="config3",
|
name="config3",
|
||||||
base=BaseModelType("sd-2"),
|
base=BaseModelType("sd-2"),
|
||||||
type=ModelType.Vae,
|
type=ModelType.VAE,
|
||||||
hash="CONFIG3HASH",
|
hash="CONFIG3HASH",
|
||||||
source="test/source",
|
source="test/source",
|
||||||
source_type=ModelSourceType.Path,
|
source_type=ModelSourceType.Path,
|
||||||
@ -157,7 +157,7 @@ def test_filter(store: ModelRecordServiceBase):
|
|||||||
assert len(matches) == 2
|
assert len(matches) == 2
|
||||||
assert matches[0].name in {"config1", "config2"}
|
assert matches[0].name in {"config1", "config2"}
|
||||||
|
|
||||||
matches = store.search_by_attr(model_type=ModelType.Vae)
|
matches = store.search_by_attr(model_type=ModelType.VAE)
|
||||||
assert len(matches) == 1
|
assert len(matches) == 1
|
||||||
assert matches[0].name == "config3"
|
assert matches[0].name == "config3"
|
||||||
assert matches[0].key == "config3"
|
assert matches[0].key == "config3"
|
||||||
@ -190,10 +190,10 @@ def test_unique(store: ModelRecordServiceBase):
|
|||||||
source="test/source/",
|
source="test/source/",
|
||||||
source_type=ModelSourceType.Path,
|
source_type=ModelSourceType.Path,
|
||||||
)
|
)
|
||||||
config3 = VaeDiffusersConfig(
|
config3 = VAEDiffusersConfig(
|
||||||
path="/tmp/config3",
|
path="/tmp/config3",
|
||||||
base=BaseModelType("sd-2"),
|
base=BaseModelType("sd-2"),
|
||||||
type=ModelType.Vae,
|
type=ModelType.VAE,
|
||||||
name="nonuniquename",
|
name="nonuniquename",
|
||||||
hash="CONFIG1HASH",
|
hash="CONFIG1HASH",
|
||||||
source="test/source/",
|
source="test/source/",
|
||||||
@ -257,11 +257,11 @@ def test_filter_2(store: ModelRecordServiceBase):
|
|||||||
source="test/source/",
|
source="test/source/",
|
||||||
source_type=ModelSourceType.Path,
|
source_type=ModelSourceType.Path,
|
||||||
)
|
)
|
||||||
config5 = VaeDiffusersConfig(
|
config5 = VAEDiffusersConfig(
|
||||||
path="/tmp/config5",
|
path="/tmp/config5",
|
||||||
name="dup_name1",
|
name="dup_name1",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType.Vae,
|
type=ModelType.VAE,
|
||||||
hash="CONFIG3HASH",
|
hash="CONFIG3HASH",
|
||||||
source="test/source/",
|
source="test/source/",
|
||||||
source_type=ModelSourceType.Path,
|
source_type=ModelSourceType.Path,
|
||||||
@ -283,7 +283,7 @@ def test_filter_2(store: ModelRecordServiceBase):
|
|||||||
|
|
||||||
matches = store.search_by_attr(
|
matches = store.search_by_attr(
|
||||||
base_model=BaseModelType.StableDiffusion1,
|
base_model=BaseModelType.StableDiffusion1,
|
||||||
model_type=ModelType.Vae,
|
model_type=ModelType.VAE,
|
||||||
model_name="dup_name1",
|
model_name="dup_name1",
|
||||||
)
|
)
|
||||||
assert len(matches) == 1
|
assert len(matches) == 1
|
||||||
|
@ -28,7 +28,7 @@ from invokeai.backend.model_manager.config import (
|
|||||||
ModelSourceType,
|
ModelSourceType,
|
||||||
ModelType,
|
ModelType,
|
||||||
ModelVariantType,
|
ModelVariantType,
|
||||||
VaeDiffusersConfig,
|
VAEDiffusersConfig,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
from invokeai.backend.model_manager.load import ModelCache, ModelConvertCache
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
@ -162,13 +162,13 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBas
|
|||||||
db = create_mock_sqlite_database(mm2_app_config, logger)
|
db = create_mock_sqlite_database(mm2_app_config, logger)
|
||||||
store = ModelRecordServiceSQL(db)
|
store = ModelRecordServiceSQL(db)
|
||||||
# add five simple config records to the database
|
# add five simple config records to the database
|
||||||
config1 = VaeDiffusersConfig(
|
config1 = VAEDiffusersConfig(
|
||||||
key="test_config_1",
|
key="test_config_1",
|
||||||
path="/tmp/foo1",
|
path="/tmp/foo1",
|
||||||
format=ModelFormat.Diffusers,
|
format=ModelFormat.Diffusers,
|
||||||
name="test2",
|
name="test2",
|
||||||
base=BaseModelType.StableDiffusion2,
|
base=BaseModelType.StableDiffusion2,
|
||||||
type=ModelType.Vae,
|
type=ModelType.VAE,
|
||||||
hash="111222333444",
|
hash="111222333444",
|
||||||
source="stabilityai/sdxl-vae",
|
source="stabilityai/sdxl-vae",
|
||||||
source_type=ModelSourceType.HFRepoID,
|
source_type=ModelSourceType.HFRepoID,
|
||||||
@ -204,7 +204,7 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBas
|
|||||||
format=ModelFormat.Diffusers,
|
format=ModelFormat.Diffusers,
|
||||||
name="test4",
|
name="test4",
|
||||||
base=BaseModelType.StableDiffusionXL,
|
base=BaseModelType.StableDiffusionXL,
|
||||||
type=ModelType.Lora,
|
type=ModelType.LoRA,
|
||||||
hash="111222333444",
|
hash="111222333444",
|
||||||
source="author4/model4",
|
source="author4/model4",
|
||||||
source_type=ModelSourceType.HFRepoID,
|
source_type=ModelSourceType.HFRepoID,
|
||||||
@ -215,7 +215,7 @@ def mm2_record_store(mm2_app_config: InvokeAIAppConfig) -> ModelRecordServiceBas
|
|||||||
format=ModelFormat.Diffusers,
|
format=ModelFormat.Diffusers,
|
||||||
name="test5",
|
name="test5",
|
||||||
base=BaseModelType.StableDiffusion1,
|
base=BaseModelType.StableDiffusion1,
|
||||||
type=ModelType.Lora,
|
type=ModelType.LoRA,
|
||||||
hash="111222333444",
|
hash="111222333444",
|
||||||
source="author4/model5",
|
source="author4/model5",
|
||||||
source_type=ModelSourceType.HFRepoID,
|
source_type=ModelSourceType.HFRepoID,
|
||||||
|
@ -104,7 +104,7 @@ def sdxl_base_files() -> List[Path]:
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
ModelRepoVariant.DEFAULT,
|
ModelRepoVariant.Default,
|
||||||
[
|
[
|
||||||
"model_index.json",
|
"model_index.json",
|
||||||
"scheduler/scheduler_config.json",
|
"scheduler/scheduler_config.json",
|
||||||
@ -129,7 +129,7 @@ def sdxl_base_files() -> List[Path]:
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
ModelRepoVariant.OPENVINO,
|
ModelRepoVariant.OpenVINO,
|
||||||
[
|
[
|
||||||
"model_index.json",
|
"model_index.json",
|
||||||
"scheduler/scheduler_config.json",
|
"scheduler/scheduler_config.json",
|
||||||
@ -211,7 +211,7 @@ def sdxl_base_files() -> List[Path]:
|
|||||||
],
|
],
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
ModelRepoVariant.FLAX,
|
ModelRepoVariant.Flax,
|
||||||
[
|
[
|
||||||
"model_index.json",
|
"model_index.json",
|
||||||
"scheduler/scheduler_config.json",
|
"scheduler/scheduler_config.json",
|
||||||
|
@ -21,7 +21,7 @@ def test_get_base_type(vae_path: str, expected_type: BaseModelType, datadir: Pat
|
|||||||
base_type = probe.get_base_type()
|
base_type = probe.get_base_type()
|
||||||
assert base_type == expected_type
|
assert base_type == expected_type
|
||||||
repo_variant = probe.get_repo_variant()
|
repo_variant = probe.get_repo_variant()
|
||||||
assert repo_variant == ModelRepoVariant.DEFAULT
|
assert repo_variant == ModelRepoVariant.Default
|
||||||
|
|
||||||
|
|
||||||
def test_repo_variant(datadir: Path):
|
def test_repo_variant(datadir: Path):
|
||||||
|
Reference in New Issue
Block a user