mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
enhance model autodetection during import
- Imported V2 legacy models will now autoconvert into diffusers at load time regardless of setting of --ckpt_convert. - model manager `heuristic_import()` function now looks for side-by-side yaml and vae files for custom configuration and VAE respectively. Example of this: illuminati-v1.1.safetensors illuminati-v1.1.vae.safetensors illuminati-v1.1.yaml When the user tries to import `illuminati-v1.1.safetensors`, the yaml file will be used for its configuration, and the VAE will be used for its VAE. Conversion to diffusers will happen if needed, and the yaml file will be used to determine which V2 format (if any) to apply.
This commit is contained in:
parent
1cb88960fe
commit
dcb21c0f46
@ -211,6 +211,26 @@ description for the model, whether to make this the default model that
|
||||
is loaded at InvokeAI startup time, and whether to replace its
|
||||
VAE. Generally the answer to the latter question is "no".
|
||||
|
||||
### Specifying a configuration file for legacy checkpoints
|
||||
|
||||
Some checkpoint files come with instructions to use a specific .yaml
|
||||
configuration file. For InvokeAI load this file correctly, please put
|
||||
the config file in the same directory as the corresponding `.ckpt` or
|
||||
`.safetensors` file and make sure the file has the same basename as
|
||||
the weights file. Here is an example:
|
||||
|
||||
```bash
|
||||
wonderful-model-v2.ckpt
|
||||
wonderful-model-v2.yaml
|
||||
```
|
||||
|
||||
Similarly, to use a custom VAE, name the VAE like this:
|
||||
|
||||
```bash
|
||||
wonderful-model-v2.vae.pt
|
||||
```
|
||||
|
||||
|
||||
### Converting legacy models into `diffusers`
|
||||
|
||||
The CLI `!convert_model` will convert a `.safetensors` or `.ckpt`
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -19,7 +19,7 @@ import warnings
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from shutil import move, rmtree
|
||||
from typing import Any, Optional, Union, Callable
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import safetensors.torch
|
||||
@ -35,12 +35,7 @@ from picklescan.scanner import scan_file_path
|
||||
from ldm.invoke.devices import CPU_DEVICE
|
||||
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ldm.invoke.globals import Globals, global_cache_dir
|
||||
from ldm.util import (
|
||||
ask_user,
|
||||
download_with_resume,
|
||||
instantiate_from_config,
|
||||
url_attachment_name,
|
||||
)
|
||||
from ldm.util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name
|
||||
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
@ -384,15 +379,16 @@ class ModelManager(object):
|
||||
if not os.path.isabs(weights):
|
||||
weights = os.path.normpath(os.path.join(Globals.root, weights))
|
||||
|
||||
# check whether this is a v2 file and force conversion
|
||||
convert = Globals.ckpt_convert or self.is_v2_config(config)
|
||||
|
||||
# if converting automatically to diffusers, then we do the conversion and return
|
||||
# a diffusers pipeline
|
||||
if Globals.ckpt_convert:
|
||||
if convert:
|
||||
print(
|
||||
f">> Converting legacy checkpoint {model_name} into a diffusers model..."
|
||||
)
|
||||
from ldm.invoke.ckpt_to_diffuser import (
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
from ldm.invoke.ckpt_to_diffuser import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
|
||||
self.offload_model(self.current_model)
|
||||
if vae_config := self._choose_diffusers_vae(model_name):
|
||||
@ -433,13 +429,13 @@ class ModelManager(object):
|
||||
weight_bytes = f.read()
|
||||
model_hash = self._cached_sha256(weights, weight_bytes)
|
||||
sd = None
|
||||
|
||||
|
||||
if weights.endswith(".ckpt"):
|
||||
self.scan_model(model_name, weights)
|
||||
sd = torch.load(io.BytesIO(weight_bytes), map_location="cpu")
|
||||
else:
|
||||
sd = safetensors.torch.load(weight_bytes)
|
||||
|
||||
|
||||
del weight_bytes
|
||||
# merged models from auto11 merge board are flat for some reason
|
||||
if "state_dict" in sd:
|
||||
@ -462,8 +458,8 @@ class ModelManager(object):
|
||||
vae = os.path.normpath(os.path.join(Globals.root, vae))
|
||||
if os.path.exists(vae):
|
||||
print(f" | Loading VAE weights from: {vae}")
|
||||
if vae.endswith((".ckpt",".pt")):
|
||||
self.scan_model(vae,vae)
|
||||
if vae.endswith((".ckpt", ".pt")):
|
||||
self.scan_model(vae, vae)
|
||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||
else:
|
||||
vae_ckpt = safetensors.torch.load_file(vae)
|
||||
@ -547,6 +543,15 @@ class ModelManager(object):
|
||||
|
||||
return pipeline, width, height, model_hash
|
||||
|
||||
def is_v2_config(self, config: Path) -> bool:
|
||||
try:
|
||||
mconfig = OmegaConf.load(config)
|
||||
return (
|
||||
mconfig["model"]["params"]["unet_config"]["params"]["context_dim"] > 768
|
||||
)
|
||||
except:
|
||||
return False
|
||||
|
||||
def model_name_or_path(self, model_name: Union[str, DictConfig]) -> str | Path:
|
||||
if isinstance(model_name, DictConfig) or isinstance(model_name, dict):
|
||||
mconfig = model_name
|
||||
@ -724,7 +729,7 @@ class ModelManager(object):
|
||||
SDLegacyType.V2_v (V2 using 'v_prediction' prediction type)
|
||||
SDLegacyType.UNKNOWN
|
||||
"""
|
||||
global_step = checkpoint.get('global_step')
|
||||
global_step = checkpoint.get("global_step")
|
||||
state_dict = checkpoint.get("state_dict") or checkpoint
|
||||
|
||||
try:
|
||||
@ -751,14 +756,14 @@ class ModelManager(object):
|
||||
return SDLegacyType.UNKNOWN
|
||||
|
||||
def heuristic_import(
|
||||
self,
|
||||
path_url_or_repo: str,
|
||||
convert: bool = False,
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
model_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
config_file_callback: Callable[[Path],Path] = None,
|
||||
self,
|
||||
path_url_or_repo: str,
|
||||
convert: bool = False,
|
||||
model_name: str = None,
|
||||
description: str = None,
|
||||
model_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
config_file_callback: Callable[[Path], Path] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Accept a string which could be:
|
||||
@ -833,10 +838,10 @@ class ModelManager(object):
|
||||
Path(thing).rglob("*.safetensors")
|
||||
):
|
||||
if model_name := self.heuristic_import(
|
||||
str(m),
|
||||
convert,
|
||||
commit_to_conf=commit_to_conf,
|
||||
config_file_callback=config_file_callback,
|
||||
str(m),
|
||||
convert,
|
||||
commit_to_conf=commit_to_conf,
|
||||
config_file_callback=config_file_callback,
|
||||
):
|
||||
print(f" >> {model_name} successfully imported")
|
||||
return model_name
|
||||
@ -864,57 +869,66 @@ class ModelManager(object):
|
||||
|
||||
# another round of heuristics to guess the correct config file.
|
||||
checkpoint = None
|
||||
if model_path.suffix.endswith((".ckpt",".pt")):
|
||||
self.scan_model(model_path,model_path)
|
||||
if model_path.suffix.endswith((".ckpt", ".pt")):
|
||||
self.scan_model(model_path, model_path)
|
||||
checkpoint = torch.load(model_path)
|
||||
else:
|
||||
checkpoint = safetensors.torch.load_file(model_path)
|
||||
# additional probing needed if no config file provided
|
||||
if model_config_file is None:
|
||||
model_type = self.probe_model_type(checkpoint)
|
||||
if model_type == SDLegacyType.V1:
|
||||
print(" | SD-v1 model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V1_INPAINT:
|
||||
print(" | SD-v1 inpainting model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v1-inpainting-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V2_v:
|
||||
print(
|
||||
" | SD-v2-v model detected"
|
||||
)
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V2_e:
|
||||
print(
|
||||
" | SD-v2-e model detected"
|
||||
)
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V2:
|
||||
print(
|
||||
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
||||
)
|
||||
# Is there a like-named .yaml file in the same directory as the
|
||||
# weights file? If so, we treat this as our model
|
||||
if model_path.with_suffix(".yaml").exists():
|
||||
model_config_file = model_path.with_suffix(".yaml")
|
||||
print(f" | Using config file {model_config_file.name}")
|
||||
else:
|
||||
print(
|
||||
f"** {thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
|
||||
)
|
||||
model_type = self.probe_model_type(checkpoint)
|
||||
if model_type == SDLegacyType.V1:
|
||||
print(" | SD-v1 model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v1-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V1_INPAINT:
|
||||
print(" | SD-v1 inpainting model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root,
|
||||
"configs/stable-diffusion/v1-inpainting-inference.yaml",
|
||||
)
|
||||
elif model_type == SDLegacyType.V2_v:
|
||||
print(" | SD-v2-v model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference-v.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V2_e:
|
||||
print(" | SD-v2-e model detected")
|
||||
model_config_file = Path(
|
||||
Globals.root, "configs/stable-diffusion/v2-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V2:
|
||||
print(
|
||||
f"** {thing} is a V2 checkpoint file, but its parameterization cannot be determined. Please provide configuration file path."
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"** {thing} is a legacy checkpoint file but not a known Stable Diffusion model. Please provide configuration file path."
|
||||
)
|
||||
|
||||
if not model_config_file and config_file_callback:
|
||||
model_config_file = config_file_callback(model_path)
|
||||
if not model_config_file:
|
||||
return
|
||||
|
||||
if model_config_file.name.startswith('v2'):
|
||||
if self.is_v2_config(model_config_file):
|
||||
convert = True
|
||||
print(
|
||||
" | This SD-v2 model will be converted to diffusers format for use"
|
||||
)
|
||||
print(" | This SD-v2 model will be converted to diffusers format for use")
|
||||
|
||||
# look for a custom vae
|
||||
vae_path = None
|
||||
for suffix in ["pt", "ckpt", "safetensors"]:
|
||||
if (model_path.with_suffix(f".vae.{suffix}")).exists():
|
||||
vae_path = model_path.with_suffix(f".vae.{suffix}")
|
||||
print(f" | Using VAE file {vae_path.name}")
|
||||
vae = None if vae_path else dict(repo_id="stabilityai/sd-vae-ft-mse")
|
||||
|
||||
if convert:
|
||||
diffuser_path = Path(
|
||||
@ -923,7 +937,8 @@ class ModelManager(object):
|
||||
model_name = self.convert_and_import(
|
||||
model_path,
|
||||
diffusers_path=diffuser_path,
|
||||
vae=dict(repo_id="stabilityai/sd-vae-ft-mse"),
|
||||
vae=vae,
|
||||
vae_path=vae_path,
|
||||
model_name=model_name,
|
||||
model_description=description,
|
||||
original_config_file=model_config_file,
|
||||
@ -941,7 +956,8 @@ class ModelManager(object):
|
||||
model_name=model_name,
|
||||
model_description=description,
|
||||
vae=str(
|
||||
Path(
|
||||
vae_path
|
||||
or Path(
|
||||
Globals.root,
|
||||
"models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt",
|
||||
)
|
||||
@ -953,15 +969,16 @@ class ModelManager(object):
|
||||
return model_name
|
||||
|
||||
def convert_and_import(
|
||||
self,
|
||||
ckpt_path: Path,
|
||||
diffusers_path: Path,
|
||||
model_name=None,
|
||||
model_description=None,
|
||||
vae=None,
|
||||
original_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
scan_needed: bool=True,
|
||||
self,
|
||||
ckpt_path: Path,
|
||||
diffusers_path: Path,
|
||||
model_name=None,
|
||||
model_description=None,
|
||||
vae: dict = None,
|
||||
vae_path: Path = None,
|
||||
original_config_file: Path = None,
|
||||
commit_to_conf: Path = None,
|
||||
scan_needed: bool = True,
|
||||
) -> str:
|
||||
"""
|
||||
Convert a legacy ckpt weights file to diffuser model and import
|
||||
@ -975,7 +992,7 @@ class ModelManager(object):
|
||||
|
||||
new_config = None
|
||||
|
||||
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffuser
|
||||
from ldm.invoke.ckpt_to_diffuser import convert_ckpt_to_diffusers
|
||||
|
||||
if diffusers_path.exists():
|
||||
print(
|
||||
@ -990,12 +1007,13 @@ class ModelManager(object):
|
||||
# By passing the specified VAE to the conversion function, the autoencoder
|
||||
# will be built into the model rather than tacked on afterward via the config file
|
||||
vae_model = self._load_vae(vae) if vae else None
|
||||
convert_ckpt_to_diffuser(
|
||||
convert_ckpt_to_diffusers(
|
||||
ckpt_path,
|
||||
diffusers_path,
|
||||
extract_ema=True,
|
||||
original_config_file=original_config_file,
|
||||
vae=vae_model,
|
||||
vae_path=str(vae_path) if vae_path else None,
|
||||
scan_needed=scan_needed,
|
||||
)
|
||||
print(
|
||||
@ -1048,7 +1066,7 @@ class ModelManager(object):
|
||||
# In the event that the original entry is using a custom ckpt VAE, we try to
|
||||
# map that VAE onto a diffuser VAE using a hard-coded dictionary.
|
||||
# I would prefer to do this differently: We load the ckpt model into memory, swap the
|
||||
# VAE in memory, and then pass that to convert_ckpt_to_diffuser() so that the swapped
|
||||
# VAE in memory, and then pass that to convert_ckpt_to_diffusers() so that the swapped
|
||||
# VAE is built into the model. However, when I tried this I got obscure key errors.
|
||||
if vae:
|
||||
return vae
|
||||
@ -1134,14 +1152,14 @@ class ModelManager(object):
|
||||
legacy_locations = [
|
||||
Path(
|
||||
models_dir,
|
||||
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker"
|
||||
"CompVis/stable-diffusion-safety-checker/models--CompVis--stable-diffusion-safety-checker",
|
||||
),
|
||||
Path("bert-base-uncased/models--bert-base-uncased"),
|
||||
Path(
|
||||
"openai/clip-vit-large-patch14/models--openai--clip-vit-large-patch14"
|
||||
),
|
||||
]
|
||||
legacy_locations.extend(list(global_cache_dir("diffusers").glob('*')))
|
||||
legacy_locations.extend(list(global_cache_dir("diffusers").glob("*")))
|
||||
legacy_layout = False
|
||||
for model in legacy_locations:
|
||||
legacy_layout = legacy_layout or model.exists()
|
||||
@ -1185,7 +1203,7 @@ class ModelManager(object):
|
||||
source.unlink()
|
||||
else:
|
||||
move(source, dest)
|
||||
|
||||
|
||||
# now clean up by removing any empty directories
|
||||
empty = [
|
||||
root
|
||||
|
Loading…
Reference in New Issue
Block a user