added method to convert vaes

This commit is contained in:
Lincoln Stein 2023-05-18 13:31:11 -04:00
parent fd82763412
commit b1a99d772c
3 changed files with 87 additions and 35 deletions

View File

@ -33,6 +33,7 @@ from .model_cache import ModelCache
try:
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
except ImportError:
raise ImportError(
"OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
@ -614,16 +615,29 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
return new_checkpoint
def convert_ldm_vae_checkpoint(checkpoint, config):
# extract state dict for VAE
vae_state_dict = {}
vae_key = "first_stage_model."
keys = list(checkpoint.keys())
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
# Extract state dict for VAE. Works both with burnt-in
# VAEs, and with standalone VAEs.
# checkpoint can either be a all-in-one stable diffusion
# model, or an isolated vae .ckpt. This tests for
# a key that will be present in the all-in-one model
# that isn't present in the isolated ckpt.
probe_key = "first_stage_model.encoder.conv_in.weight"
if probe_key in checkpoint:
vae_state_dict = {}
vae_key = "first_stage_model."
keys = list(checkpoint.keys())
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
else:
vae_state_dict = checkpoint
new_checkpoint = convert_ldm_vae_state_dict(vae_state_dict,config)
return new_checkpoint
def convert_ldm_vae_state_dict(vae_state_dict, config):
new_checkpoint = {}
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
@ -1049,6 +1063,19 @@ def replace_checkpoint_vae(checkpoint, vae_path:str):
new_key = f'first_stage_model.{vae_key}'
checkpoint[new_key] = state_dict[vae_key]
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int)->AutoencoderKL:
vae_config = create_vae_diffusers_config(
vae_config, image_size=image_size
)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
checkpoint, vae_config
)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
return vae
def load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path: str,
original_config_file: str = None,
@ -1244,15 +1271,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
if vae:
logger.debug("Using replacement diffusers VAE")
else: # convert the original or replacement VAE
vae_config = create_vae_diffusers_config(
original_config, image_size=image_size
)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
checkpoint, vae_config
)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)
vae = convert_ldm_vae_to_diffusers(
checkpoint,
original_config,
image_size)
# Convert the text model.
model_type = pipeline_type

View File

@ -32,7 +32,7 @@ import safetensors.torch
from diffusers import DiffusionPipeline, SchedulerMixin, ConfigMixin
from diffusers import logging as diffusers_logging
from huggingface_hub import HfApi
from huggingface_hub import HfApi, scan_cache_dir
from picklescan.scanner import scan_file_path
from pydantic import BaseModel
from transformers import logging as transformers_logging

View File

@ -131,9 +131,7 @@ from __future__ import annotations
import os
import re
import sys
import textwrap
from contextlib import suppress
from dataclasses import dataclass
from enum import Enum, auto
from packaging import version
@ -145,6 +143,7 @@ import safetensors
import safetensors.torch
import torch
from diffusers import AutoencoderKL
from diffusers.utils import is_safetensors_available
from huggingface_hub import scan_cache_dir
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
@ -157,7 +156,6 @@ from invokeai.backend.util import download_with_resume
from ..util import CUDA_DEVICE
from .model_cache import (ModelCache, ModelLocker, SDModelType,
SilenceWarnings)
# We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help
# reduce confusion.
@ -307,17 +305,12 @@ class ModelManager(object):
"""
# This is a temporary workaround for callers that use "type/name" as the model name
# Commented-out workaround for callers that use "type/name" as the model name
# because they haven't adjusted to the new return format of `list_models()`
if "/" in model_name:
model_key = model_name
else:
model_key = self.create_key(model_name, model_type)
# TODO: delete default model or add check that this stable diffusion model
# if not model_name:
# model_name = self.default_model()
# if "/" in model_name:
# model_key = model_name
# else:
model_key = self.create_key(model_name, model_type)
if model_key not in self.config:
raise InvalidModelError(
f'"{model_key}" is not a known model name. Please check your models.yaml file'
@ -358,7 +351,6 @@ class ModelManager(object):
if model_type == SDModelType.Vae and "vae" in mconfig:
print("NOT_IMPLEMENTED - RETURN CUSTOM VAE")
model_context = self.cache.get_model(
location,
model_type = model_type,
@ -414,7 +406,7 @@ class ModelManager(object):
"""
Given a model name returns the OmegaConf (dict-like) object describing it.
"""
if not self.exists(model_name, model_type):
if not self.model_exists(model_name, model_type):
return None
return self.config[self.create_key(model_name, model_type)]
@ -962,6 +954,44 @@ class ModelManager(object):
)
return diffusers_path
def convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
"""
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
object, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
weights_file = global_resolve_path(mconfig.weights)
config_file = global_resolve_path(mconfig.config)
diffusers_path = global_resolve_path(Path('models',Globals.converted_ckpts_dir)) / weights_file.stem
image_size = mconfig.get('width') or mconfig.get('height') or 512
# return cached version if it exists
if diffusers_path.exists():
return diffusers_path
# this avoids circular import error
from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
checkpoint = torch.load(weights_file, map_location="cpu")\
if weights_file.suffix in ['.ckpt','.pt'] \
else safetensors.torch.load_file(weights_file)
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
config = OmegaConf.load(config_file)
vae_model = convert_ldm_vae_to_diffusers(
checkpoint = checkpoint,
vae_config = config,
image_size = image_size
)
vae_model.save_pretrained(
diffusers_path,
safe_serialization=is_safetensors_available()
)
return diffusers_path
def _get_vae_for_conversion(
self,
weights: Path,
@ -1130,7 +1160,7 @@ class ModelManager(object):
@classmethod
def _delete_model_from_cache(cls,repo_id):
cache_info = scan_cache_dir(global_cache_dir("hub"))
cache_info = scan_cache_dir(global_cache_dir("hub"))
# I'm sure there is a way to do this with comprehensions
# but the code quickly became incomprehensible!
@ -1178,7 +1208,7 @@ class ModelManager(object):
current_version = self.config.get("_version","1.0.0")
if version.parse(current_version) < version.parse(CONFIG_FILE_VERSION):
self.logger.warning(f'models.yaml version {current_version} detected. Updating to {CONFIG_FILE_VERSION}')
self.logger.warning(f'The original file will be renamed models.yaml.orig')
self.logger.warning('The original file will be renamed models.yaml.orig')
if self.config_path:
old_file = Path(self.config_path)
new_name = old_file.parent / 'models.yaml.orig'