added method to convert vaes

This commit is contained in:
Lincoln Stein 2023-05-18 13:31:11 -04:00
parent fd82763412
commit b1a99d772c
3 changed files with 87 additions and 35 deletions

View File

@ -33,6 +33,7 @@ from .model_cache import ModelCache
try: try:
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`." "OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
@ -614,16 +615,29 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
return new_checkpoint return new_checkpoint
def convert_ldm_vae_checkpoint(checkpoint, config): def convert_ldm_vae_checkpoint(checkpoint, config):
# extract state dict for VAE # Extract state dict for VAE. Works both with burnt-in
# VAEs, and with standalone VAEs.
# checkpoint can either be a all-in-one stable diffusion
# model, or an isolated vae .ckpt. This tests for
# a key that will be present in the all-in-one model
# that isn't present in the isolated ckpt.
probe_key = "first_stage_model.encoder.conv_in.weight"
if probe_key in checkpoint:
vae_state_dict = {} vae_state_dict = {}
vae_key = "first_stage_model." vae_key = "first_stage_model."
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
for key in keys: for key in keys:
if key.startswith(vae_key): if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
else:
vae_state_dict = checkpoint
new_checkpoint = convert_ldm_vae_state_dict(vae_state_dict,config)
return new_checkpoint
def convert_ldm_vae_state_dict(vae_state_dict, config):
new_checkpoint = {} new_checkpoint = {}
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
@ -1049,6 +1063,19 @@ def replace_checkpoint_vae(checkpoint, vae_path:str):
new_key = f'first_stage_model.{vae_key}' new_key = f'first_stage_model.{vae_key}'
checkpoint[new_key] = state_dict[vae_key] checkpoint[new_key] = state_dict[vae_key]
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int)->AutoencoderKL:
vae_config = create_vae_diffusers_config(
vae_config, image_size=image_size
)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
checkpoint, vae_config
)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
return vae
def load_pipeline_from_original_stable_diffusion_ckpt( def load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path: str, checkpoint_path: str,
original_config_file: str = None, original_config_file: str = None,
@ -1244,15 +1271,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
if vae: if vae:
logger.debug("Using replacement diffusers VAE") logger.debug("Using replacement diffusers VAE")
else: # convert the original or replacement VAE else: # convert the original or replacement VAE
vae_config = create_vae_diffusers_config( vae = convert_ldm_vae_to_diffusers(
original_config, image_size=image_size checkpoint,
) original_config,
converted_vae_checkpoint = convert_ldm_vae_checkpoint( image_size)
checkpoint, vae_config
)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
# Convert the text model. # Convert the text model.
model_type = pipeline_type model_type = pipeline_type

View File

@ -32,7 +32,7 @@ import safetensors.torch
from diffusers import DiffusionPipeline, SchedulerMixin, ConfigMixin from diffusers import DiffusionPipeline, SchedulerMixin, ConfigMixin
from diffusers import logging as diffusers_logging from diffusers import logging as diffusers_logging
from huggingface_hub import HfApi from huggingface_hub import HfApi, scan_cache_dir
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
from pydantic import BaseModel from pydantic import BaseModel
from transformers import logging as transformers_logging from transformers import logging as transformers_logging

View File

@ -131,9 +131,7 @@ from __future__ import annotations
import os import os
import re import re
import sys
import textwrap import textwrap
from contextlib import suppress
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from packaging import version from packaging import version
@ -145,6 +143,7 @@ import safetensors
import safetensors.torch import safetensors.torch
import torch import torch
from diffusers import AutoencoderKL from diffusers import AutoencoderKL
from diffusers.utils import is_safetensors_available
from huggingface_hub import scan_cache_dir from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig from omegaconf.dictconfig import DictConfig
@ -157,7 +156,6 @@ from invokeai.backend.util import download_with_resume
from ..util import CUDA_DEVICE from ..util import CUDA_DEVICE
from .model_cache import (ModelCache, ModelLocker, SDModelType, from .model_cache import (ModelCache, ModelLocker, SDModelType,
SilenceWarnings) SilenceWarnings)
# We are only starting to number the config file with release 3. # We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help # The config file version doesn't have to start at release version, but it will help
# reduce confusion. # reduce confusion.
@ -307,17 +305,12 @@ class ModelManager(object):
""" """
# This is a temporary workaround for callers that use "type/name" as the model name # Commented-out workaround for callers that use "type/name" as the model name
# because they haven't adjusted to the new return format of `list_models()` # because they haven't adjusted to the new return format of `list_models()`
if "/" in model_name: # if "/" in model_name:
model_key = model_name # model_key = model_name
else: # else:
model_key = self.create_key(model_name, model_type) model_key = self.create_key(model_name, model_type)
# TODO: delete default model or add check that this stable diffusion model
# if not model_name:
# model_name = self.default_model()
if model_key not in self.config: if model_key not in self.config:
raise InvalidModelError( raise InvalidModelError(
f'"{model_key}" is not a known model name. Please check your models.yaml file' f'"{model_key}" is not a known model name. Please check your models.yaml file'
@ -358,7 +351,6 @@ class ModelManager(object):
if model_type == SDModelType.Vae and "vae" in mconfig: if model_type == SDModelType.Vae and "vae" in mconfig:
print("NOT_IMPLEMENTED - RETURN CUSTOM VAE") print("NOT_IMPLEMENTED - RETURN CUSTOM VAE")
model_context = self.cache.get_model( model_context = self.cache.get_model(
location, location,
model_type = model_type, model_type = model_type,
@ -414,7 +406,7 @@ class ModelManager(object):
""" """
Given a model name returns the OmegaConf (dict-like) object describing it. Given a model name returns the OmegaConf (dict-like) object describing it.
""" """
if not self.exists(model_name, model_type): if not self.model_exists(model_name, model_type):
return None return None
return self.config[self.create_key(model_name, model_type)] return self.config[self.create_key(model_name, model_type)]
@ -962,6 +954,44 @@ class ModelManager(object):
) )
return diffusers_path return diffusers_path
def convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
"""
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
object, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
weights_file = global_resolve_path(mconfig.weights)
config_file = global_resolve_path(mconfig.config)
diffusers_path = global_resolve_path(Path('models',Globals.converted_ckpts_dir)) / weights_file.stem
image_size = mconfig.get('width') or mconfig.get('height') or 512
# return cached version if it exists
if diffusers_path.exists():
return diffusers_path
# this avoids circular import error
from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
checkpoint = torch.load(weights_file, map_location="cpu")\
if weights_file.suffix in ['.ckpt','.pt'] \
else safetensors.torch.load_file(weights_file)
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
config = OmegaConf.load(config_file)
vae_model = convert_ldm_vae_to_diffusers(
checkpoint = checkpoint,
vae_config = config,
image_size = image_size
)
vae_model.save_pretrained(
diffusers_path,
safe_serialization=is_safetensors_available()
)
return diffusers_path
def _get_vae_for_conversion( def _get_vae_for_conversion(
self, self,
weights: Path, weights: Path,
@ -1178,7 +1208,7 @@ class ModelManager(object):
current_version = self.config.get("_version","1.0.0") current_version = self.config.get("_version","1.0.0")
if version.parse(current_version) < version.parse(CONFIG_FILE_VERSION): if version.parse(current_version) < version.parse(CONFIG_FILE_VERSION):
self.logger.warning(f'models.yaml version {current_version} detected. Updating to {CONFIG_FILE_VERSION}') self.logger.warning(f'models.yaml version {current_version} detected. Updating to {CONFIG_FILE_VERSION}')
self.logger.warning(f'The original file will be renamed models.yaml.orig') self.logger.warning('The original file will be renamed models.yaml.orig')
if self.config_path: if self.config_path:
old_file = Path(self.config_path) old_file = Path(self.config_path)
new_name = old_file.parent / 'models.yaml.orig' new_name = old_file.parent / 'models.yaml.orig'