mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(mm): add config
field to ckpt vaes
This commit is contained in:
parent
76cbc745e1
commit
16a5d718bf
@ -13,6 +13,7 @@ from invokeai.backend.model_manager import (
|
|||||||
ModelFormat,
|
ModelFormat,
|
||||||
ModelType,
|
ModelType,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.model_manager.config import CheckpointConfigBase
|
||||||
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||||
|
|
||||||
from .. import ModelLoaderRegistry
|
from .. import ModelLoaderRegistry
|
||||||
@ -26,7 +27,7 @@ class VaeLoader(GenericDiffusersLoader):
|
|||||||
"""Class to load VAE models."""
|
"""Class to load VAE models."""
|
||||||
|
|
||||||
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
|
||||||
if config.format != ModelFormat.Checkpoint:
|
if not isinstance(config, CheckpointConfigBase):
|
||||||
return False
|
return False
|
||||||
elif (
|
elif (
|
||||||
dest_path.exists()
|
dest_path.exists()
|
||||||
@ -38,13 +39,12 @@ class VaeLoader(GenericDiffusersLoader):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
def _convert_model(self, config: AnyModelConfig, model_path: Path, output_path: Path) -> Path:
|
||||||
# TO DO: check whether sdxl VAE models convert.
|
# TODO(MM2): check whether sdxl VAE models convert.
|
||||||
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||||
raise Exception(f"Vae conversion not supported for model type: {config.base}")
|
raise Exception(f"VAE conversion not supported for model type: {config.base}")
|
||||||
else:
|
else:
|
||||||
config_file = (
|
assert isinstance(config, CheckpointConfigBase)
|
||||||
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
|
config_file = config.config
|
||||||
)
|
|
||||||
|
|
||||||
if model_path.suffix == ".safetensors":
|
if model_path.suffix == ".safetensors":
|
||||||
checkpoint = safetensors_load_file(model_path, device="cpu")
|
checkpoint = safetensors_load_file(model_path, device="cpu")
|
||||||
@ -55,7 +55,7 @@ class VaeLoader(GenericDiffusersLoader):
|
|||||||
if "state_dict" in checkpoint:
|
if "state_dict" in checkpoint:
|
||||||
checkpoint = checkpoint["state_dict"]
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file)
|
ckpt_config = OmegaConf.load(self._app_config.root_path / config_file)
|
||||||
assert isinstance(ckpt_config, DictConfig)
|
assert isinstance(ckpt_config, DictConfig)
|
||||||
|
|
||||||
vae_model = convert_ldm_vae_to_diffusers(
|
vae_model = convert_ldm_vae_to_diffusers(
|
||||||
|
@ -137,7 +137,7 @@ class ModelProbe(object):
|
|||||||
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||||
model_info = None
|
model_info = None
|
||||||
model_type = None
|
model_type = None
|
||||||
if format_type == "diffusers":
|
if format_type is ModelFormat.Diffusers:
|
||||||
model_type = cls.get_model_type_from_folder(model_path)
|
model_type = cls.get_model_type_from_folder(model_path)
|
||||||
else:
|
else:
|
||||||
model_type = cls.get_model_type_from_checkpoint(model_path)
|
model_type = cls.get_model_type_from_checkpoint(model_path)
|
||||||
@ -168,7 +168,7 @@ class ModelProbe(object):
|
|||||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||||
|
|
||||||
# additional fields needed for main and controlnet models
|
# additional fields needed for main and controlnet models
|
||||||
if fields["type"] in [ModelType.Main, ModelType.ControlNet] and fields["format"] == ModelFormat.Checkpoint:
|
if fields["type"] in [ModelType.Main, ModelType.ControlNet, ModelType.Vae] and fields["format"] == ModelFormat.Checkpoint:
|
||||||
fields["config"] = cls._get_checkpoint_config_path(
|
fields["config"] = cls._get_checkpoint_config_path(
|
||||||
model_path,
|
model_path,
|
||||||
model_type=fields["type"],
|
model_type=fields["type"],
|
||||||
@ -285,13 +285,21 @@ class ModelProbe(object):
|
|||||||
if possible_conf.exists():
|
if possible_conf.exists():
|
||||||
return possible_conf.absolute()
|
return possible_conf.absolute()
|
||||||
|
|
||||||
if model_type == ModelType.Main:
|
if model_type is ModelType.Main:
|
||||||
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
config_file = LEGACY_CONFIGS[base_type][variant_type]
|
||||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||||
config_file = config_file[prediction_type]
|
config_file = config_file[prediction_type]
|
||||||
elif model_type == ModelType.ControlNet:
|
elif model_type is ModelType.ControlNet:
|
||||||
config_file = (
|
config_file = (
|
||||||
"../controlnet/cldm_v15.yaml" if base_type == BaseModelType("sd-1") else "../controlnet/cldm_v21.yaml"
|
"../controlnet/cldm_v15.yaml"
|
||||||
|
if base_type is BaseModelType.StableDiffusion1
|
||||||
|
else "../controlnet/cldm_v21.yaml"
|
||||||
|
)
|
||||||
|
elif model_type is ModelType.Vae:
|
||||||
|
config_file = (
|
||||||
|
"../stable-diffusion/v1-inference.yaml"
|
||||||
|
if base_type is BaseModelType.StableDiffusion1
|
||||||
|
else "../stable-diffusion/v2-inference.yaml"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise InvalidModelConfigException(
|
raise InvalidModelConfigException(
|
||||||
|
Loading…
Reference in New Issue
Block a user