mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
added method to convert vaes
This commit is contained in:
parent
fd82763412
commit
b1a99d772c
@ -33,6 +33,7 @@ from .model_cache import ModelCache
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
from omegaconf.dictconfig import DictConfig
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
|
"OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
|
||||||
@ -614,16 +615,29 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
|||||||
|
|
||||||
return new_checkpoint
|
return new_checkpoint
|
||||||
|
|
||||||
|
|
||||||
def convert_ldm_vae_checkpoint(checkpoint, config):
|
def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||||
# extract state dict for VAE
|
# Extract state dict for VAE. Works both with burnt-in
|
||||||
|
# VAEs, and with standalone VAEs.
|
||||||
|
|
||||||
|
# checkpoint can either be a all-in-one stable diffusion
|
||||||
|
# model, or an isolated vae .ckpt. This tests for
|
||||||
|
# a key that will be present in the all-in-one model
|
||||||
|
# that isn't present in the isolated ckpt.
|
||||||
|
probe_key = "first_stage_model.encoder.conv_in.weight"
|
||||||
|
if probe_key in checkpoint:
|
||||||
vae_state_dict = {}
|
vae_state_dict = {}
|
||||||
vae_key = "first_stage_model."
|
vae_key = "first_stage_model."
|
||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if key.startswith(vae_key):
|
if key.startswith(vae_key):
|
||||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||||
|
else:
|
||||||
|
vae_state_dict = checkpoint
|
||||||
|
|
||||||
|
new_checkpoint = convert_ldm_vae_state_dict(vae_state_dict,config)
|
||||||
|
return new_checkpoint
|
||||||
|
|
||||||
|
def convert_ldm_vae_state_dict(vae_state_dict, config):
|
||||||
new_checkpoint = {}
|
new_checkpoint = {}
|
||||||
|
|
||||||
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
||||||
@ -1049,6 +1063,19 @@ def replace_checkpoint_vae(checkpoint, vae_path:str):
|
|||||||
new_key = f'first_stage_model.{vae_key}'
|
new_key = f'first_stage_model.{vae_key}'
|
||||||
checkpoint[new_key] = state_dict[vae_key]
|
checkpoint[new_key] = state_dict[vae_key]
|
||||||
|
|
||||||
|
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int)->AutoencoderKL:
|
||||||
|
vae_config = create_vae_diffusers_config(
|
||||||
|
vae_config, image_size=image_size
|
||||||
|
)
|
||||||
|
|
||||||
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
||||||
|
checkpoint, vae_config
|
||||||
|
)
|
||||||
|
|
||||||
|
vae = AutoencoderKL(**vae_config)
|
||||||
|
vae.load_state_dict(converted_vae_checkpoint)
|
||||||
|
return vae
|
||||||
|
|
||||||
def load_pipeline_from_original_stable_diffusion_ckpt(
|
def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||||
checkpoint_path: str,
|
checkpoint_path: str,
|
||||||
original_config_file: str = None,
|
original_config_file: str = None,
|
||||||
@ -1244,15 +1271,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
|||||||
if vae:
|
if vae:
|
||||||
logger.debug("Using replacement diffusers VAE")
|
logger.debug("Using replacement diffusers VAE")
|
||||||
else: # convert the original or replacement VAE
|
else: # convert the original or replacement VAE
|
||||||
vae_config = create_vae_diffusers_config(
|
vae = convert_ldm_vae_to_diffusers(
|
||||||
original_config, image_size=image_size
|
checkpoint,
|
||||||
)
|
original_config,
|
||||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
image_size)
|
||||||
checkpoint, vae_config
|
|
||||||
)
|
|
||||||
|
|
||||||
vae = AutoencoderKL(**vae_config)
|
|
||||||
vae.load_state_dict(converted_vae_checkpoint)
|
|
||||||
|
|
||||||
# Convert the text model.
|
# Convert the text model.
|
||||||
model_type = pipeline_type
|
model_type = pipeline_type
|
||||||
|
@ -32,7 +32,7 @@ import safetensors.torch
|
|||||||
|
|
||||||
from diffusers import DiffusionPipeline, SchedulerMixin, ConfigMixin
|
from diffusers import DiffusionPipeline, SchedulerMixin, ConfigMixin
|
||||||
from diffusers import logging as diffusers_logging
|
from diffusers import logging as diffusers_logging
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi, scan_cache_dir
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from transformers import logging as transformers_logging
|
from transformers import logging as transformers_logging
|
||||||
|
@ -131,9 +131,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import sys
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from contextlib import suppress
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from packaging import version
|
from packaging import version
|
||||||
@ -145,6 +143,7 @@ import safetensors
|
|||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import torch
|
import torch
|
||||||
from diffusers import AutoencoderKL
|
from diffusers import AutoencoderKL
|
||||||
|
from diffusers.utils import is_safetensors_available
|
||||||
from huggingface_hub import scan_cache_dir
|
from huggingface_hub import scan_cache_dir
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
@ -157,7 +156,6 @@ from invokeai.backend.util import download_with_resume
|
|||||||
from ..util import CUDA_DEVICE
|
from ..util import CUDA_DEVICE
|
||||||
from .model_cache import (ModelCache, ModelLocker, SDModelType,
|
from .model_cache import (ModelCache, ModelLocker, SDModelType,
|
||||||
SilenceWarnings)
|
SilenceWarnings)
|
||||||
|
|
||||||
# We are only starting to number the config file with release 3.
|
# We are only starting to number the config file with release 3.
|
||||||
# The config file version doesn't have to start at release version, but it will help
|
# The config file version doesn't have to start at release version, but it will help
|
||||||
# reduce confusion.
|
# reduce confusion.
|
||||||
@ -307,17 +305,12 @@ class ModelManager(object):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# This is a temporary workaround for callers that use "type/name" as the model name
|
# Commented-out workaround for callers that use "type/name" as the model name
|
||||||
# because they haven't adjusted to the new return format of `list_models()`
|
# because they haven't adjusted to the new return format of `list_models()`
|
||||||
if "/" in model_name:
|
# if "/" in model_name:
|
||||||
model_key = model_name
|
# model_key = model_name
|
||||||
else:
|
# else:
|
||||||
model_key = self.create_key(model_name, model_type)
|
model_key = self.create_key(model_name, model_type)
|
||||||
|
|
||||||
# TODO: delete default model or add check that this stable diffusion model
|
|
||||||
# if not model_name:
|
|
||||||
# model_name = self.default_model()
|
|
||||||
|
|
||||||
if model_key not in self.config:
|
if model_key not in self.config:
|
||||||
raise InvalidModelError(
|
raise InvalidModelError(
|
||||||
f'"{model_key}" is not a known model name. Please check your models.yaml file'
|
f'"{model_key}" is not a known model name. Please check your models.yaml file'
|
||||||
@ -358,7 +351,6 @@ class ModelManager(object):
|
|||||||
if model_type == SDModelType.Vae and "vae" in mconfig:
|
if model_type == SDModelType.Vae and "vae" in mconfig:
|
||||||
print("NOT_IMPLEMENTED - RETURN CUSTOM VAE")
|
print("NOT_IMPLEMENTED - RETURN CUSTOM VAE")
|
||||||
|
|
||||||
|
|
||||||
model_context = self.cache.get_model(
|
model_context = self.cache.get_model(
|
||||||
location,
|
location,
|
||||||
model_type = model_type,
|
model_type = model_type,
|
||||||
@ -414,7 +406,7 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
Given a model name returns the OmegaConf (dict-like) object describing it.
|
Given a model name returns the OmegaConf (dict-like) object describing it.
|
||||||
"""
|
"""
|
||||||
if not self.exists(model_name, model_type):
|
if not self.model_exists(model_name, model_type):
|
||||||
return None
|
return None
|
||||||
return self.config[self.create_key(model_name, model_type)]
|
return self.config[self.create_key(model_name, model_type)]
|
||||||
|
|
||||||
@ -962,6 +954,44 @@ class ModelManager(object):
|
|||||||
)
|
)
|
||||||
return diffusers_path
|
return diffusers_path
|
||||||
|
|
||||||
|
def convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
|
||||||
|
"""
|
||||||
|
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
|
||||||
|
object, cache it to disk, and return Path to converted
|
||||||
|
file. If already on disk then just returns Path.
|
||||||
|
"""
|
||||||
|
weights_file = global_resolve_path(mconfig.weights)
|
||||||
|
config_file = global_resolve_path(mconfig.config)
|
||||||
|
diffusers_path = global_resolve_path(Path('models',Globals.converted_ckpts_dir)) / weights_file.stem
|
||||||
|
image_size = mconfig.get('width') or mconfig.get('height') or 512
|
||||||
|
|
||||||
|
# return cached version if it exists
|
||||||
|
if diffusers_path.exists():
|
||||||
|
return diffusers_path
|
||||||
|
|
||||||
|
# this avoids circular import error
|
||||||
|
from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||||
|
checkpoint = torch.load(weights_file, map_location="cpu")\
|
||||||
|
if weights_file.suffix in ['.ckpt','.pt'] \
|
||||||
|
else safetensors.torch.load_file(weights_file)
|
||||||
|
|
||||||
|
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||||
|
if "state_dict" in checkpoint:
|
||||||
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
|
config = OmegaConf.load(config_file)
|
||||||
|
|
||||||
|
vae_model = convert_ldm_vae_to_diffusers(
|
||||||
|
checkpoint = checkpoint,
|
||||||
|
vae_config = config,
|
||||||
|
image_size = image_size
|
||||||
|
)
|
||||||
|
vae_model.save_pretrained(
|
||||||
|
diffusers_path,
|
||||||
|
safe_serialization=is_safetensors_available()
|
||||||
|
)
|
||||||
|
return diffusers_path
|
||||||
|
|
||||||
def _get_vae_for_conversion(
|
def _get_vae_for_conversion(
|
||||||
self,
|
self,
|
||||||
weights: Path,
|
weights: Path,
|
||||||
@ -1178,7 +1208,7 @@ class ModelManager(object):
|
|||||||
current_version = self.config.get("_version","1.0.0")
|
current_version = self.config.get("_version","1.0.0")
|
||||||
if version.parse(current_version) < version.parse(CONFIG_FILE_VERSION):
|
if version.parse(current_version) < version.parse(CONFIG_FILE_VERSION):
|
||||||
self.logger.warning(f'models.yaml version {current_version} detected. Updating to {CONFIG_FILE_VERSION}')
|
self.logger.warning(f'models.yaml version {current_version} detected. Updating to {CONFIG_FILE_VERSION}')
|
||||||
self.logger.warning(f'The original file will be renamed models.yaml.orig')
|
self.logger.warning('The original file will be renamed models.yaml.orig')
|
||||||
if self.config_path:
|
if self.config_path:
|
||||||
old_file = Path(self.config_path)
|
old_file = Path(self.config_path)
|
||||||
new_name = old_file.parent / 'models.yaml.orig'
|
new_name = old_file.parent / 'models.yaml.orig'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user