mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
added method to convert vaes
This commit is contained in:
parent
fd82763412
commit
b1a99d772c
@ -33,6 +33,7 @@ from .model_cache import ModelCache
|
||||
|
||||
try:
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"OmegaConf is required to convert the LDM checkpoints. Please install it with `pip install OmegaConf`."
|
||||
@ -614,16 +615,29 @@ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False
|
||||
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
# extract state dict for VAE
|
||||
vae_state_dict = {}
|
||||
vae_key = "first_stage_model."
|
||||
keys = list(checkpoint.keys())
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
# Extract state dict for VAE. Works both with burnt-in
|
||||
# VAEs, and with standalone VAEs.
|
||||
|
||||
# checkpoint can either be a all-in-one stable diffusion
|
||||
# model, or an isolated vae .ckpt. This tests for
|
||||
# a key that will be present in the all-in-one model
|
||||
# that isn't present in the isolated ckpt.
|
||||
probe_key = "first_stage_model.encoder.conv_in.weight"
|
||||
if probe_key in checkpoint:
|
||||
vae_state_dict = {}
|
||||
vae_key = "first_stage_model."
|
||||
keys = list(checkpoint.keys())
|
||||
for key in keys:
|
||||
if key.startswith(vae_key):
|
||||
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
||||
else:
|
||||
vae_state_dict = checkpoint
|
||||
|
||||
new_checkpoint = convert_ldm_vae_state_dict(vae_state_dict,config)
|
||||
return new_checkpoint
|
||||
|
||||
def convert_ldm_vae_state_dict(vae_state_dict, config):
|
||||
new_checkpoint = {}
|
||||
|
||||
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
||||
@ -1049,6 +1063,19 @@ def replace_checkpoint_vae(checkpoint, vae_path:str):
|
||||
new_key = f'first_stage_model.{vae_key}'
|
||||
checkpoint[new_key] = state_dict[vae_key]
|
||||
|
||||
def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: int)->AutoencoderKL:
|
||||
vae_config = create_vae_diffusers_config(
|
||||
vae_config, image_size=image_size
|
||||
)
|
||||
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
||||
checkpoint, vae_config
|
||||
)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
return vae
|
||||
|
||||
def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path: str,
|
||||
original_config_file: str = None,
|
||||
@ -1244,15 +1271,10 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
if vae:
|
||||
logger.debug("Using replacement diffusers VAE")
|
||||
else: # convert the original or replacement VAE
|
||||
vae_config = create_vae_diffusers_config(
|
||||
original_config, image_size=image_size
|
||||
)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(
|
||||
checkpoint, vae_config
|
||||
)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
vae = convert_ldm_vae_to_diffusers(
|
||||
checkpoint,
|
||||
original_config,
|
||||
image_size)
|
||||
|
||||
# Convert the text model.
|
||||
model_type = pipeline_type
|
||||
|
@ -32,7 +32,7 @@ import safetensors.torch
|
||||
|
||||
from diffusers import DiffusionPipeline, SchedulerMixin, ConfigMixin
|
||||
from diffusers import logging as diffusers_logging
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub import HfApi, scan_cache_dir
|
||||
from picklescan.scanner import scan_file_path
|
||||
from pydantic import BaseModel
|
||||
from transformers import logging as transformers_logging
|
||||
|
@ -131,9 +131,7 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import textwrap
|
||||
from contextlib import suppress
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from packaging import version
|
||||
@ -145,6 +143,7 @@ import safetensors
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.utils import is_safetensors_available
|
||||
from huggingface_hub import scan_cache_dir
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
@ -157,7 +156,6 @@ from invokeai.backend.util import download_with_resume
|
||||
from ..util import CUDA_DEVICE
|
||||
from .model_cache import (ModelCache, ModelLocker, SDModelType,
|
||||
SilenceWarnings)
|
||||
|
||||
# We are only starting to number the config file with release 3.
|
||||
# The config file version doesn't have to start at release version, but it will help
|
||||
# reduce confusion.
|
||||
@ -307,17 +305,12 @@ class ModelManager(object):
|
||||
|
||||
"""
|
||||
|
||||
# This is a temporary workaround for callers that use "type/name" as the model name
|
||||
# Commented-out workaround for callers that use "type/name" as the model name
|
||||
# because they haven't adjusted to the new return format of `list_models()`
|
||||
if "/" in model_name:
|
||||
model_key = model_name
|
||||
else:
|
||||
model_key = self.create_key(model_name, model_type)
|
||||
|
||||
# TODO: delete default model or add check that this stable diffusion model
|
||||
# if not model_name:
|
||||
# model_name = self.default_model()
|
||||
|
||||
# if "/" in model_name:
|
||||
# model_key = model_name
|
||||
# else:
|
||||
model_key = self.create_key(model_name, model_type)
|
||||
if model_key not in self.config:
|
||||
raise InvalidModelError(
|
||||
f'"{model_key}" is not a known model name. Please check your models.yaml file'
|
||||
@ -358,7 +351,6 @@ class ModelManager(object):
|
||||
if model_type == SDModelType.Vae and "vae" in mconfig:
|
||||
print("NOT_IMPLEMENTED - RETURN CUSTOM VAE")
|
||||
|
||||
|
||||
model_context = self.cache.get_model(
|
||||
location,
|
||||
model_type = model_type,
|
||||
@ -414,7 +406,7 @@ class ModelManager(object):
|
||||
"""
|
||||
Given a model name returns the OmegaConf (dict-like) object describing it.
|
||||
"""
|
||||
if not self.exists(model_name, model_type):
|
||||
if not self.model_exists(model_name, model_type):
|
||||
return None
|
||||
return self.config[self.create_key(model_name, model_type)]
|
||||
|
||||
@ -962,6 +954,44 @@ class ModelManager(object):
|
||||
)
|
||||
return diffusers_path
|
||||
|
||||
def convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> Path:
|
||||
"""
|
||||
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
|
||||
object, cache it to disk, and return Path to converted
|
||||
file. If already on disk then just returns Path.
|
||||
"""
|
||||
weights_file = global_resolve_path(mconfig.weights)
|
||||
config_file = global_resolve_path(mconfig.config)
|
||||
diffusers_path = global_resolve_path(Path('models',Globals.converted_ckpts_dir)) / weights_file.stem
|
||||
image_size = mconfig.get('width') or mconfig.get('height') or 512
|
||||
|
||||
# return cached version if it exists
|
||||
if diffusers_path.exists():
|
||||
return diffusers_path
|
||||
|
||||
# this avoids circular import error
|
||||
from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
checkpoint = torch.load(weights_file, map_location="cpu")\
|
||||
if weights_file.suffix in ['.ckpt','.pt'] \
|
||||
else safetensors.torch.load_file(weights_file)
|
||||
|
||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
config = OmegaConf.load(config_file)
|
||||
|
||||
vae_model = convert_ldm_vae_to_diffusers(
|
||||
checkpoint = checkpoint,
|
||||
vae_config = config,
|
||||
image_size = image_size
|
||||
)
|
||||
vae_model.save_pretrained(
|
||||
diffusers_path,
|
||||
safe_serialization=is_safetensors_available()
|
||||
)
|
||||
return diffusers_path
|
||||
|
||||
def _get_vae_for_conversion(
|
||||
self,
|
||||
weights: Path,
|
||||
@ -1130,7 +1160,7 @@ class ModelManager(object):
|
||||
|
||||
@classmethod
|
||||
def _delete_model_from_cache(cls,repo_id):
|
||||
cache_info = scan_cache_dir(global_cache_dir("hub"))
|
||||
cache_info = scan_cache_dir(global_cache_dir("hub"))
|
||||
|
||||
# I'm sure there is a way to do this with comprehensions
|
||||
# but the code quickly became incomprehensible!
|
||||
@ -1178,7 +1208,7 @@ class ModelManager(object):
|
||||
current_version = self.config.get("_version","1.0.0")
|
||||
if version.parse(current_version) < version.parse(CONFIG_FILE_VERSION):
|
||||
self.logger.warning(f'models.yaml version {current_version} detected. Updating to {CONFIG_FILE_VERSION}')
|
||||
self.logger.warning(f'The original file will be renamed models.yaml.orig')
|
||||
self.logger.warning('The original file will be renamed models.yaml.orig')
|
||||
if self.config_path:
|
||||
old_file = Path(self.config_path)
|
||||
new_name = old_file.parent / 'models.yaml.orig'
|
||||
|
Loading…
Reference in New Issue
Block a user