Fix ckpt and vae conversion, migrate script, remove sd2-base

This commit is contained in:
Sergey Borisov 2023-06-13 18:05:12 +03:00
parent a6af7e8824
commit e7db6d8120
13 changed files with 543 additions and 381 deletions

View File

@ -43,10 +43,10 @@ class ModelLoaderOutput(BaseInvocationOutput):
#fmt: on #fmt: on
class ModelLoaderInvocation(BaseInvocation): class SD1ModelLoaderInvocation(BaseInvocation):
"""Loading submodels of selected model.""" """Loading submodels of selected model."""
type: Literal["model_loader"] = "model_loader" type: Literal["sd1_model_loader"] = "sd1_model_loader"
model_name: str = Field(default="", description="Model to load") model_name: str = Field(default="", description="Model to load")
# TODO: precision? # TODO: precision?
@ -64,7 +64,110 @@ class ModelLoaderInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ModelLoaderOutput: def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = BaseModelType.StableDiffusion1_5 # TODO: base_model = BaseModelType.StableDiffusion1 # TODO:
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
):
raise Exception(f"Unkown model name: {self.model_name}!")
"""
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.Tokenizer,
):
raise Exception(
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.TextEncoder,
):
raise Exception(
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
)
if not context.services.model_manager.model_exists(
model_name=self.model_name,
model_type=SDModelType.Diffusers,
submodel=SDModelType.UNet,
):
raise Exception(
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
)
"""
return ModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.TextEncoder,
),
loras=[],
),
vae=VaeField(
vae=ModelInfo(
model_name=self.model_name,
base_model=base_model,
model_type=ModelType.Pipeline,
submodel=SubModelType.Vae,
),
)
)
# TODO: optimize(less code copy)
class SD2ModelLoaderInvocation(BaseInvocation):
"""Loading submodels of selected model."""
type: Literal["sd2_model_loader"] = "sd2_model_loader"
model_name: str = Field(default="", description="Model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["model", "loader"],
"type_hints": {
"model_name": "model" # TODO: rename to model_name?
}
},
}
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
base_model = BaseModelType.StableDiffusion2 # TODO:
# TODO: not found exceptions # TODO: not found exceptions
if not context.services.model_manager.model_exists( if not context.services.model_manager.model_exists(

View File

@ -28,8 +28,9 @@ from safetensors.torch import load_file
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from .model_manager import ModelManager, SDLegacyType from .model_manager import ModelManager
from .model_cache import ModelCache from .model_cache import ModelCache
from .models import SchedulerPredictionType, BaseModelType, ModelVariantType
try: try:
from omegaconf import OmegaConf from omegaconf import OmegaConf
@ -58,10 +59,6 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
LDMBertConfig, LDMBertConfig,
LDMBertModel, LDMBertModel,
) )
from diffusers.pipelines.paint_by_example import (
PaintByExampleImageEncoder,
PaintByExamplePipeline,
)
from diffusers.pipelines.stable_diffusion.safety_checker import ( from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker, StableDiffusionSafetyChecker,
) )
@ -911,8 +908,10 @@ textenc_pattern = re.compile("|".join(protected.keys()))
def convert_open_clip_checkpoint(checkpoint): def convert_open_clip_checkpoint(checkpoint):
cache_dir = InvokeAIAppConfig.get_config().cache_dir text_model = CLIPTextModel.from_pretrained(
text_model = CLIPTextModel.from_pretrained(MODEL_ROOT / 'stable-diffusion-2-text_encoder') MODEL_ROOT / 'stable-diffusion-2-clip',
subfolder='text_encoder',
)
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
@ -1002,20 +1001,15 @@ def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size:
def load_pipeline_from_original_stable_diffusion_ckpt( def load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path: str, checkpoint_path: str,
original_config_file: str = None, model_version: BaseModelType,
num_in_channels: int = None, model_variant: ModelVariantType,
scheduler_type: str = "pndm", original_config_file: str,
pipeline_type: str = None,
image_size: int = None,
prediction_type: str = None,
extract_ema: bool = True, extract_ema: bool = True,
upcast_attn: bool = False,
vae: AutoencoderKL = None,
vae_path: str = None,
precision: torch.dtype = torch.float32, precision: torch.dtype = torch.float32,
return_generator_pipeline: bool = False, upcast_attention: bool = False,
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon,
scan_needed: bool = True, scan_needed: bool = True,
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]: ) -> StableDiffusionPipeline:
""" """
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml` Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
config file. config file.
@ -1027,148 +1021,68 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
:param checkpoint_path: Path to `.ckpt` file. :param checkpoint_path: Path to `.ckpt` file.
:param original_config_file: Path to `.yaml` config file corresponding to the original architecture. :param original_config_file: Path to `.yaml` config file corresponding to the original architecture.
If `None`, will be automatically inferred by looking for a key that only exists in SD2.0 models. If `None`, will be automatically inferred by looking for a key that only exists in SD2.0 models.
:param image_size: The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
Base. Use 768 for Stable Diffusion v2.
:param prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion :param prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion
v1.X and Stable Diffusion v2 Base. Use `'v-prediction'` for Stable Diffusion v2. v1.X and Stable Diffusion v2 Base. Use `'v-prediction'` for Stable Diffusion v2.
:param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
inferred.
:param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler", :param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
"euler-ancestral", "dpm", "ddim"]`. :param model_type: The pipeline type. `None` to automatically infer, or one of "euler-ancestral", "dpm", "ddim"]`. :param model_type: The pipeline type. `None` to automatically infer, or one of
`["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder", "PaintByExample"]`. :param extract_ema: Only relevant for `["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder"]`. :param extract_ema: Only relevant for
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights
or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher
quality images for inference. Non-EMA weights are usually better to continue fine-tuning. quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast :param precision: precision to use - torch.float16, torch.float32 or torch.autocast
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when :param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
running stable diffusion 2.1. running stable diffusion 2.1.
:param vae: A diffusers VAE to load into the pipeline.
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
""" """
config = InvokeAIAppConfig.get_config() config = InvokeAIAppConfig.get_config()
cache_dir = config.cache_dir
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
verbosity = dlogging.get_verbosity() verbosity = dlogging.get_verbosity()
dlogging.set_verbosity_error() dlogging.set_verbosity_error()
if Path(checkpoint_path).suffix == '.ckpt': if str(checkpoint_path).endswith(".safetensors"):
checkpoint = load_file(checkpoint_path)
else:
if scan_needed: if scan_needed:
ModelCache.scan_model(checkpoint_path, checkpoint_path) ModelCache.scan_model(checkpoint_path, checkpoint_path)
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path)
else:
checkpoint = load_file(checkpoint_path)
pipeline_class = (
StableDiffusionGeneratorPipeline
if return_generator_pipeline
else StableDiffusionPipeline
)
# Sometimes models don't have the global_step item
if "global_step" in checkpoint:
global_step = checkpoint["global_step"]
else:
logger.debug("global_step key not found in model")
global_step = None
# sometimes there is a state_dict key and sometimes not # sometimes there is a state_dict key and sometimes not
if "state_dict" in checkpoint: if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"] checkpoint = checkpoint["state_dict"]
upcast_attention = False
if original_config_file is None:
model_type = ModelManager.probe_model_type(checkpoint)
if model_type == SDLegacyType.V2_v:
original_config_file = (
config.legacy_conf_path / "v2-inference-v.yaml"
)
if global_step == 110000:
# v2.1 needs to upcast attention
upcast_attention = True
elif model_type == SDLegacyType.V2_e:
original_config_file = (
config.legacy_conf_path / "v2-inference.yaml"
)
elif model_type == SDLegacyType.V1_INPAINT:
original_config_file = (
config.legacy_conf_path / "v1-inpainting-inference.yaml"
)
elif model_type == SDLegacyType.V1:
original_config_file = (
config.legacy_conf_path / "v1-inference.yaml"
)
else:
raise Exception("Unknown checkpoint type")
original_config = OmegaConf.load(original_config_file) original_config = OmegaConf.load(original_config_file)
if num_in_channels is not None: if model_version == BaseModelType.StableDiffusion2 and prediction_type == SchedulerPredictionType.VPrediction:
original_config["model"]["params"]["unet_config"]["params"][ image_size = 768
"in_channels"
] = num_in_channels
if (
"parameterization" in original_config["model"]["params"]
and original_config["model"]["params"]["parameterization"] == "v"
):
if prediction_type is None:
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
# as it relies on a brittle global step parameter here
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
if image_size is None:
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
# as it relies on a brittle global step parameter here
image_size = 512 if global_step == 875000 else 768
else: else:
if prediction_type is None:
prediction_type = "epsilon"
if image_size is None:
image_size = 512 image_size = 512
#
# convert scheduler
#
num_train_timesteps = original_config.model.params.timesteps num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start beta_start = original_config.model.params.linear_start
beta_end = original_config.model.params.linear_end beta_end = original_config.model.params.linear_end
scheduler = DDIMScheduler( scheduler = PNDMScheduler(
beta_end=beta_end, beta_end=beta_end,
beta_schedule="scaled_linear", beta_schedule="scaled_linear",
beta_start=beta_start, beta_start=beta_start,
num_train_timesteps=num_train_timesteps, num_train_timesteps=num_train_timesteps,
steps_offset=1, steps_offset=1,
clip_sample=False,
set_alpha_to_one=False, set_alpha_to_one=False,
prediction_type=prediction_type, prediction_type=prediction_type,
skip_prk_steps=True
) )
# make sure scheduler works correctly with DDIM # make sure scheduler works correctly with DDIM
scheduler.register_to_config(clip_sample=False) scheduler.register_to_config(clip_sample=False)
if scheduler_type == "pndm": #
config = dict(scheduler.config) # convert unet
config["skip_prk_steps"] = True #
scheduler = PNDMScheduler.from_config(config)
elif scheduler_type == "lms":
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "heun":
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "euler":
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "euler-ancestral":
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
elif scheduler_type == "dpm":
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
elif scheduler_type == 'unipc':
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
elif scheduler_type == "ddim":
scheduler = scheduler
else:
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config( unet_config = create_unet_diffusers_config(
original_config, image_size=image_size original_config, image_size=image_size
) )
@ -1181,35 +1095,25 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
unet.load_state_dict(converted_unet_checkpoint) unet.load_state_dict(converted_unet_checkpoint)
# If a replacement VAE path was specified, we'll incorporate that into #
# the checkpoint model and then convert it # convert vae
if vae_path: #
logger.debug(f"Converting VAE {vae_path}")
replace_checkpoint_vae(checkpoint,vae_path)
# otherwise we use the original VAE, provided that
# an externally loaded diffusers VAE was not passed
elif not vae:
logger.debug("Using checkpoint model's original VAE")
if vae:
logger.debug("Using replacement diffusers VAE")
else: # convert the original or replacement VAE
vae = convert_ldm_vae_to_diffusers( vae = convert_ldm_vae_to_diffusers(
checkpoint, checkpoint,
original_config, original_config,
image_size) image_size,
)
# Convert the text model. # Convert the text model.
model_type = pipeline_type model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if model_type is None:
model_type = original_config.model.params.cond_stage_config.target.split(
"."
)[-1]
if model_type == "FrozenOpenCLIPEmbedder": if model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint) text_model = convert_open_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained(MODEL_ROOT / 'stable-diffusion-2-tokenizer') tokenizer = CLIPTokenizer.from_pretrained(
pipe = pipeline_class( MODEL_ROOT / 'stable-diffusion-2-clip',
subfolder='tokenizer',
)
pipe = StableDiffusionPipeline(
vae=vae.to(precision), vae=vae.to(precision),
text_encoder=text_model.to(precision), text_encoder=text_model.to(precision),
tokenizer=tokenizer, tokenizer=tokenizer,
@ -1219,20 +1123,22 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
feature_extractor=None, feature_extractor=None,
requires_safety_checker=False, requires_safety_checker=False,
) )
elif model_type in ["FrozenCLIPEmbedder", "WeightedFrozenCLIPEmbedder"]: elif model_type in ["FrozenCLIPEmbedder", "WeightedFrozenCLIPEmbedder"]:
text_model = convert_ldm_clip_checkpoint(checkpoint) text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained(MODEL_ROOT / 'clip-vit-large-patch14') tokenizer = CLIPTokenizer.from_pretrained(MODEL_ROOT / 'clip-vit-large-patch14')
safety_checker = StableDiffusionSafetyChecker.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker') safety_checker = StableDiffusionSafetyChecker.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker')
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker-extractor') feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker')
pipe = pipeline_class( pipe = StableDiffusionPipeline(
vae=vae.to(precision), vae=vae.to(precision),
text_encoder=text_model.to(precision), text_encoder=text_model.to(precision),
tokenizer=tokenizer, tokenizer=tokenizer,
unet=unet.to(precision), unet=unet.to(precision),
scheduler=scheduler, scheduler=scheduler,
safety_checker=None if return_generator_pipeline else safety_checker.to(precision), safety_checker=safety_checker.to(precision),
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
else: else:
text_config = create_ldm_bert_config(original_config) text_config = create_ldm_bert_config(original_config)
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config) text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)

View File

@ -20,15 +20,12 @@ import gc
import os import os
import sys import sys
import hashlib import hashlib
import warnings
from contextlib import suppress from contextlib import suppress
from pathlib import Path from pathlib import Path
from typing import Dict, Union, types, Optional, Type, Any from typing import Dict, Union, types, Optional, Type, Any
import torch import torch
from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging
import logging import logging
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import get_invokeai_config from invokeai.app.services.config import get_invokeai_config
@ -382,21 +379,6 @@ class ModelCache(object):
f.write(hash) f.write(hash)
return hash return hash
class SilenceWarnings(object):
def __init__(self):
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self):
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter('ignore')
def __exit__(self,type,value,traceback):
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter('default')
class VRAMUsage(object): class VRAMUsage(object):
def __init__(self): def __init__(self):
self.vram = None self.vram = None

View File

@ -148,6 +148,7 @@ into the model when downloaded or converted.
from __future__ import annotations from __future__ import annotations
import os import os
import hashlib
import textwrap import textwrap
from dataclasses import dataclass from dataclasses import dataclass
from packaging import version from packaging import version
@ -166,7 +167,7 @@ import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import CUDA_DEVICE, download_with_resume from invokeai.backend.util import CUDA_DEVICE, download_with_resume
from .model_cache import ModelCache, ModelLocker from .model_cache import ModelCache, ModelLocker
from .models import BaseModelType, ModelType, SubModelType, MODEL_CLASSES from .models import BaseModelType, ModelType, SubModelType, ModelError, MODEL_CLASSES
# We are only starting to number the config file with release 3. # We are only starting to number the config file with release 3.
# The config file version doesn't have to start at release version, but it will help # The config file version doesn't have to start at release version, but it will help
@ -299,7 +300,7 @@ class ModelManager(object):
for model_key, model_config in config.items(): for model_key, model_config in config.items():
model_name, base_model, model_type = self.parse_key(model_key) model_name, base_model, model_type = self.parse_key(model_key)
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
self.models[model_key] = model_class.build_config(**model_config) self.models[model_key] = model_class.create_config(**model_config)
# check config version number and update on disk/RAM if necessary # check config version number and update on disk/RAM if necessary
self.globals = InvokeAIAppConfig.get_config() self.globals = InvokeAIAppConfig.get_config()
@ -406,9 +407,7 @@ class ModelManager(object):
path_mask = f"/models/{base_model}/{model_type}/{model_name}*" path_mask = f"/models/{base_model}/{model_type}/{model_name}*"
if False: # model_path = next(find_by_mask(path_mask)): if False: # model_path = next(find_by_mask(path_mask)):
model_path = None # TODO: model_path = None # TODO:
model_config = model_class.build_config( model_config = model_class.probe_config(model_path)
path=model_path,
)
self.models[model_key] = model_config self.models[model_key] = model_config
else: else:
raise Exception(f"Model not found - {model_key}") raise Exception(f"Model not found - {model_key}")
@ -437,16 +436,22 @@ class ModelManager(object):
# vae/movq override # vae/movq override
# TODO: # TODO:
if submodel_type is not None and submodel_type in model_config: if submodel_type is not None and hasattr(model_config, submodel_type):
model_path = model_config[submodel_type] override_path = getattr(model_config, submodel_type)
if override_path:
model_path = override_path
model_type = submodel_type model_type = submodel_type
submodel_type = None submodel_type = None
model_class = MODEL_CLASSES[base_model][model_type]
dst_convert_path = None # TODO: # TODO: path
# TODO: is it accurate to use path as id
dst_convert_path = self.globals.models_dir / ".cache" / hashlib.md5(model_path.encode()).hexdigest()
model_path = model_class.convert_if_required( model_path = model_class.convert_if_required(
model_path, base_model=base_model,
dst_convert_path, model_path=model_path,
model_config, output_path=dst_convert_path,
config=model_config,
) )
model_context = self.cache.get_model( model_context = self.cache.get_model(
@ -457,14 +462,14 @@ class ModelManager(object):
submodel=submodel_type, submodel=submodel_type,
) )
hash = "<NO_HASH>" # TODO: model_hash = "<NO_HASH>" # TODO:
return ModelInfo( return ModelInfo(
context = model_context, context = model_context,
name = model_name, name = model_name,
base_model = base_model, base_model = base_model,
type = submodel_type or model_type, type = submodel_type or model_type,
hash = hash, hash = model_hash,
location = model_path, # TODO: location = model_path, # TODO:
precision = self.cache.precision, precision = self.cache.precision,
_cache = self.cache, _cache = self.cache,
@ -633,7 +638,7 @@ class ModelManager(object):
""" """
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
model_config = model_class.build_config(**model_attributes) model_config = model_class.create_config(**model_attributes)
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
assert ( assert (
@ -749,13 +754,15 @@ class ModelManager(object):
for model_key, model_config in list(self.models.items()): for model_key, model_config in list(self.models.items()):
model_name, base_model, model_type = self.parse_key(model_key) model_name, base_model, model_type = self.parse_key(model_key)
if not os.path.exists(model_config.path): model_path = str(self.globals.root / model_config.path)
if not os.path.exists(model_path):
model_class = MODEL_CLASSES[base_model][model_type]
if model_class.save_to_config: if model_class.save_to_config:
model_config.error = ModelError.NotFound model_config.error = ModelError.NotFound
else: else:
self.models.pop(model_key, None) self.models.pop(model_key, None)
else: else:
loaded_files.add(model_config.path) loaded_files.add(model_path)
for base_model in BaseModelType: for base_model in BaseModelType:
for model_type in ModelType: for model_type in ModelType:
@ -774,7 +781,5 @@ class ModelManager(object):
if model_key in self.models: if model_key in self.models:
raise Exception(f"Model with key {model_key} added twice") raise Exception(f"Model with key {model_key} added twice")
model_config: ModelConfigBase = model_class.build_config( model_config: ModelConfigBase = model_class.probe_config(model_path)
path=model_path,
)
self.models[model_key] = model_config self.models[model_key] = model_config

View File

@ -12,8 +12,7 @@ from typing import Callable, Literal, Union, Dict
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType, SilenceWarnings
from .model_cache import SilenceWarnings
@dataclass @dataclass
class ModelVariantInfo(object): class ModelVariantInfo(object):
@ -21,6 +20,7 @@ class ModelVariantInfo(object):
base_type: BaseModelType base_type: BaseModelType
variant_type: ModelVariantType variant_type: ModelVariantType
prediction_type: SchedulerPredictionType prediction_type: SchedulerPredictionType
upcast_attention: bool
format: Literal['folder','checkpoint'] format: Literal['folder','checkpoint']
image_size: int image_size: int
@ -95,10 +95,12 @@ class ModelProbe(object):
base_type = base_type, base_type = base_type,
variant_type = variant_type, variant_type = variant_type,
prediction_type = prediction_type, prediction_type = prediction_type,
upcast_attention = (base_type==BaseModelType.StableDiffusion2 \
and prediction_type==SchedulerPredictionType.VPrediction),
format = format, format = format,
image_size = 768 if (base_type==BaseModelType.StableDiffusion2 \ image_size = 768 if (base_type==BaseModelType.StableDiffusion2 \
and prediction_type==SchedulerPredictionType.VPrediction \ and prediction_type==SchedulerPredictionType.VPrediction \
) else 512 ) else 512,
) )
except Exception as e: except Exception as e:
return None return None

View File

@ -1,5 +1,5 @@
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
from .stable_diffusion import StableDiffusion15Model, StableDiffusion2Model, StableDiffusion2BaseModel from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .vae import VaeModel from .vae import VaeModel
from .lora import LoRAModel from .lora import LoRAModel
#from .controlnet import ControlNetModel # TODO: #from .controlnet import ControlNetModel # TODO:
@ -11,7 +11,7 @@ class ControlNetModel:
MODEL_CLASSES = { MODEL_CLASSES = {
BaseModelType.StableDiffusion1: { BaseModelType.StableDiffusion1: {
ModelType.Pipeline: StableDiffusion15Model, ModelType.Pipeline: StableDiffusion1Model,
ModelType.Vae: VaeModel, ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel, ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel, ModelType.ControlNet: ControlNetModel,

View File

@ -10,13 +10,9 @@ from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal from typing import List, Dict, Optional, Type, Literal
class BaseModelType(str, Enum): class BaseModelType(str, Enum):
#StableDiffusion1_5 = "stable_diffusion_1_5"
#StableDiffusion2 = "stable_diffusion_2"
#StableDiffusion2Base = "stable_diffusion_2_base"
# TODO: maybe then add sample size(512/768)?
StableDiffusion1 = "sd-1" StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2" # 768 pixels; this will have v-prediction parameterization StableDiffusion2 = "sd-2"
#Kandinsky2_1 = "kandinsky_2_1" #Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum): class ModelType(str, Enum):
Pipeline = "pipeline" Pipeline = "pipeline"
@ -107,7 +103,7 @@ class ModelBase:
continue continue
fields = inspect.get_annotations(value) fields = inspect.get_annotations(value)
if "format" not in fields or typing.get_origin(fields["format"]) != Literal: if "format" not in fields:
raise Exception("Invalid config definition - format field not found") raise Exception("Invalid config definition - format field not found")
format_type = typing.get_origin(fields["format"]) format_type = typing.get_origin(fields["format"])
@ -125,13 +121,20 @@ class ModelBase:
return cls.__configs return cls.__configs
@classmethod @classmethod
def build_config(cls, **kwargs): def create_config(cls, **kwargs):
if "format" not in kwargs: if "format" not in kwargs:
kwargs["format"] = cls.detect_format(kwargs["path"]) raise Exception("Field 'format' not found in model config")
configs = cls._get_configs() configs = cls._get_configs()
return configs[kwargs["format"]](**kwargs) return configs[kwargs["format"]](**kwargs)
@classmethod
def probe_config(cls, path: str, **kwargs):
return cls.create_config(
path=path,
format=cls.detect_format(path),
)
@classmethod @classmethod
def detect_format(cls, path: str) -> str: def detect_format(cls, path: str) -> str:
raise NotImplementedError() raise NotImplementedError()
@ -304,3 +307,62 @@ def _calc_model_by_data(model) -> int:
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()]) mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
mem = mem_params + mem_bufs # in bytes mem = mem_params + mem_bufs # in bytes
return mem return mem
def _fast_safetensors_reader(path: str):
checkpoint = dict()
device = torch.device("meta")
with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), 'little')
definition_json = f.read(definition_len)
definition = json.loads(definition_json)
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}:
raise Exception("Supported only pytorch safetensors files")
definition.pop("__metadata__", None)
for key, info in definition.items():
dtype = {
"I8": torch.int8,
"I16": torch.int16,
"I32": torch.int32,
"I64": torch.int64,
"F16": torch.float16,
"F32": torch.float32,
"F64": torch.float64,
}[info["dtype"]]
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
return checkpoint
def read_checkpoint_meta(path: str):
if path.endswith(".safetensors"):
try:
checkpoint = _fast_safetensors_reader(path)
except:
# TODO: create issue for support "meta"?
checkpoint = safetensors.torch.load_file(path, device="cpu")
else:
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint
import warnings
from diffusers import logging as diffusers_logging
from transformers import logging as transformers_logging
class SilenceWarnings(object):
def __init__(self):
self.transformers_verbosity = transformers_logging.get_verbosity()
self.diffusers_verbosity = diffusers_logging.get_verbosity()
def __enter__(self):
transformers_logging.set_verbosity_error()
diffusers_logging.set_verbosity_error()
warnings.simplefilter('ignore')
def __exit__(self, type, value, traceback):
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter('default')

View File

@ -54,8 +54,14 @@ class LoRAModel(ModelBase):
else: else:
return "lycoris" return "lycoris"
@staticmethod @classmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str: def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) == "diffusers": if cls.detect_format(model_path) == "diffusers":
# TODO: add diffusers lora when it stabilizes a bit # TODO: add diffusers lora when it stabilizes a bit
raise NotImplementedError("Diffusers lora not supported") raise NotImplementedError("Diffusers lora not supported")

View File

@ -3,7 +3,8 @@ import json
import torch import torch
import safetensors.torch import safetensors.torch
from pydantic import Field from pydantic import Field
from typing import Literal, Optional from pathlib import Path
from typing import Literal, Optional, Union
from .base import ( from .base import (
ModelBase, ModelBase,
ModelConfigBase, ModelConfigBase,
@ -12,13 +13,16 @@ from .base import (
SubModelType, SubModelType,
ModelVariantType, ModelVariantType,
DiffusersModel, DiffusersModel,
SchedulerPredictionType,
SilenceWarnings,
read_checkpoint_meta,
) )
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf
# TODO: how to name properly
class StableDiffusion15Model(DiffusersModel):
# TODO: str -> Path? class StableDiffusion1Model(DiffusersModel):
class DiffusersConfig(ModelConfigBase): class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"] format: Literal["diffusers"]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
@ -32,72 +36,30 @@ class StableDiffusion15Model(DiffusersModel):
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1_5 assert base_model == BaseModelType.StableDiffusion1
assert model_type == ModelType.Pipeline assert model_type == ModelType.Pipeline
super().__init__( super().__init__(
model_path=model_path, model_path=model_path,
base_model=BaseModelType.StableDiffusion1_5, base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.Pipeline, model_type=ModelType.Pipeline,
) )
@staticmethod
def _fast_safetensors_reader(path: str):
checkpoint = dict()
device = torch.device("meta")
with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), 'little')
definition_json = f.read(definition_len)
definition = json.loads(definition_json)
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}:
raise Exception("Supported only pytorch safetensors files")
definition.pop("__metadata__", None)
for key, info in definition.items():
dtype = {
"I8": torch.int8,
"I16": torch.int16,
"I32": torch.int32,
"I64": torch.int64,
"F16": torch.float16,
"F32": torch.float32,
"F64": torch.float64,
}[info["dtype"]]
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
return checkpoint
@classmethod @classmethod
def read_checkpoint_meta(cls, path: str): def probe_config(cls, path: str, **kwargs):
if path.endswith(".safetensors"): model_format = cls.detect_format(path)
try: ckpt_config_path = kwargs.get("config", None)
checkpoint = cls._fast_safetensors_reader(path) if model_format == "checkpoint":
except: if ckpt_config_path:
checkpoint = safetensors.torch.load_file(path, device="cpu") # TODO: create issue for support "meta"? ckpt_config = OmegaConf.load(ckpt_config_path)
else: ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint
@classmethod
def build_config(cls, **kwargs):
if "format" not in kwargs:
kwargs["format"] = cls.detect_format(kwargs["path"])
if "variant" not in kwargs:
if kwargs["format"] == "checkpoint":
if "config" in kwargs:
ckpt_config = OmegaConf.load(kwargs["config"])
in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else: else:
checkpoint = cls.read_checkpoint_meta(kwargs["path"]) checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get('state_dict', checkpoint) checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1] in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif kwargs["format"] == "diffusers": elif model_format == "diffusers":
unet_config_path = os.path.join(kwargs["path"], "unet", "config.json") unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path): if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f: with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read()) unet_config = json.loads(f.read())
@ -107,19 +69,23 @@ class StableDiffusion15Model(DiffusersModel):
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)") raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
else: else:
raise NotImplementedError(f"Unknown stable diffusion format: {kwargs['format']}") raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}")
if in_channels == 9: if in_channels == 9:
kwargs["variant"] = ModelVariantType.Inpaint variant = ModelVariantType.Inpaint
elif in_channels == 5:
kwargs["variant"] = ModelVariantType.Depth
elif in_channels == 4: elif in_channels == 4:
kwargs["variant"] = ModelVariantType.Normal variant = ModelVariantType.Normal
else: else:
raise Exception("Unkown stable diffusion model format") raise Exception("Unkown stable diffusion 1.* model format")
return super().build_config(**kwargs) return cls.create_config(
path=path,
format=model_format,
config=ckpt_config_path,
variant=variant,
)
@classmethod @classmethod
def save_to_config(cls) -> bool: def save_to_config(cls) -> bool:
@ -133,128 +99,215 @@ class StableDiffusion15Model(DiffusersModel):
return "checkpoint" return "checkpoint"
@classmethod @classmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str: def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
assert model_path == config.path
if isinstance(config, cls.CheckpointConfig): if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache( return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion1_5, version=BaseModelType.StableDiffusion1,
config=config.dict(), model_config=config,
in_path=model_path, output_path=output_path,
out_path=dst_cache_path,
) # TODO: args ) # TODO: args
else: else:
return model_path return model_path
# all same
class StableDiffusion2BaseModel(StableDiffusion15Model):
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
# skip StableDiffusion15Model __init__
assert base_model == BaseModelType.StableDiffusion2Base
assert model_type == ModelType.Pipeline
super(StableDiffusion15Model, self).__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion2Base,
model_type=ModelType.Pipeline,
)
@classmethod class StableDiffusion2Model(DiffusersModel):
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2Base,
config=config.dict(),
in_path=model_path,
out_path=dst_cache_path,
) # TODO: args
else:
return model_path
class StableDiffusion2Model(StableDiffusion15Model): # TODO: check that configs overwriten properly
# TODO: str -> Path?
# TODO: check that configs overwriten
class DiffusersConfig(ModelConfigBase): class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"] format: Literal["diffusers"]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
attention_upscale: bool = Field(True) prediction_type: SchedulerPredictionType
upcast_attention: bool
class CheckpointConfig(ModelConfigBase): class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"] format: Literal["checkpoint"]
vae: Optional[str] = Field(None) vae: Optional[str] = Field(None)
config: Optional[str] = Field(None) config: Optional[str] = Field(None)
attention_upscale: bool = Field(True) prediction_type: SchedulerPredictionType
upcast_attention: bool
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
# skip StableDiffusion15Model __init__
assert base_model == BaseModelType.StableDiffusion2 assert base_model == BaseModelType.StableDiffusion2
assert model_type == ModelType.Pipeline assert model_type == ModelType.Pipeline
# skip StableDiffusion15Model __init__ super().__init__(
super(StableDiffusion15Model, self).__init__(
model_path=model_path, model_path=model_path,
base_model=BaseModelType.StableDiffusion2, base_model=BaseModelType.StableDiffusion2,
model_type=ModelType.Pipeline, model_type=ModelType.Pipeline,
) )
@classmethod @classmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str: def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
if model_format == "checkpoint":
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == "diffusers":
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config['in_channels']
else:
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
else:
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 5:
variant = ModelVariantType.Depth
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 2.* model format")
if variant == ModelVariantType.Normal:
prediction_type = SchedulerPredictionType.VPrediction
upcast_attention = True
else:
prediction_type = SchedulerPredictionType.Epsilon
upcast_attention = False
return cls.create_config(
path=path,
format=model_format,
config=ckpt_config_path,
variant=variant,
prediction_type=prediction_type,
upcast_attention=upcast_attention,
)
@classmethod
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
return "diffusers"
else:
return "checkpoint"
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
assert model_path == config.path
if isinstance(config, cls.CheckpointConfig): if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache( return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2, version=BaseModelType.StableDiffusion2,
config=config.dict(), model_config=config,
in_path=model_path, output_path=output_path,
out_path=dst_cache_path,
) # TODO: args ) # TODO: args
else: else:
return model_path return model_path
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
ckpt_configs = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
# code further will manually set upcast_attention and v_prediction
ModelVariantType.Normal: "v2-inference.yaml",
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
ModelVariantType.Depth: "v2-midas-inference.yaml",
}
}
try:
# TODO: path
#model_config.config = app_config.config_dir / "stable-diffusion" / ckpt_configs[version][model_config.variant]
#return InvokeAIAppConfig.get_config().legacy_conf_dir / ckpt_configs[version][variant]
return InvokeAIAppConfig.get_config().root_dir / "configs" / "stable-diffusion" / ckpt_configs[version][variant]
except:
return None
# TODO: rework # TODO: rework
DictConfig = dict
def _convert_ckpt_and_cache( def _convert_ckpt_and_cache(
self,
version: BaseModelType, version: BaseModelType,
mconfig: dict, # TODO: model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
in_path: str, output_path: str,
out_path: str,
) -> str: ) -> str:
""" """
Convert the checkpoint model indicated in mconfig into a Convert the checkpoint model indicated in mconfig into a
diffusers, cache it to disk, and return Path to converted diffusers, cache it to disk, and return Path to converted
file. If already on disk then just returns Path. file. If already on disk then just returns Path.
""" """
raise NotImplementedError()
app_config = InvokeAIAppConfig.get_config() app_config = InvokeAIAppConfig.get_config()
#if "config" not in mconfig: if model_config.config is None:
# if version == BaseModelType.StableDiffusion1_5: model_config.config = _select_ckpt_config(version, model_config.variant)
#if if model_config.config is None:
#mconfig["config"] = app_config.config_dir / "stable-diffusion" / "v1-inference.yaml" raise Exception(f"Model variant {model_config.variant} not supported for {version}")
weights = app_config.root_dir / mconfig.path weights = app_config.root_dir / model_config.path
config_file = app_config.root_dir / mconfig.config config_file = app_config.root_dir / model_config.config
diffusers_path = app_config.converted_ckpts_dir / weights.stem output_path = Path(output_path)
if version == BaseModelType.StableDiffusion1:
upcast_attention = False
prediction_type = SchedulerPredictionType.Epsilon
elif version == BaseModelType.StableDiffusion2:
upcast_attention = config.upcast_attention
prediction_type = config.prediction_type
else:
raise Exception(f"Unknown model provided: {version}")
# return cached version if it exists # return cached version if it exists
if diffusers_path.exists(): if output_path.exists():
return diffusers_path return output_path
# TODO: I think that it more correctly to convert with embedded vae # TODO: I think that it more correctly to convert with embedded vae
# as if user will delete custom vae he will got not embedded but also custom vae # as if user will delete custom vae he will got not embedded but also custom vae
#vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig) #vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
vae_ckpt_path, vae_model = None, None
# to avoid circular import errors # to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
with SilenceWarnings(): with SilenceWarnings():
convert_ckpt_to_diffusers( convert_ckpt_to_diffusers(
weights, weights,
diffusers_path, output_path,
extract_ema=True, model_version=version,
model_variant=model_config.variant,
original_config_file=config_file, original_config_file=config_file,
vae=vae_model, extract_ema=True,
vae_path=str(app_config.root_dir / vae_ckpt_path) if vae_ckpt_path else None, upcast_attention=upcast_attention,
prediction_type=prediction_type,
scan_needed=True, scan_needed=True,
model_root=app_config.models_path, model_root=app_config.models_path,
) )
return diffusers_path return output_path

View File

@ -51,6 +51,12 @@ class TextualInversionModel(ModelBase):
def detect_format(cls, path: str): def detect_format(cls, path: str):
return None return None
@staticmethod @classmethod
def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str: def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
return model_path return model_path

View File

@ -1,5 +1,6 @@
import os import os
import torch import torch
from pathlib import Path
from typing import Optional from typing import Optional
from .base import ( from .base import (
ModelBase, ModelBase,
@ -7,11 +8,14 @@ from .base import (
BaseModelType, BaseModelType,
ModelType, ModelType,
SubModelType, SubModelType,
ModelVariantType,
EmptyConfigLoader, EmptyConfigLoader,
calc_model_size_by_fs, calc_model_size_by_fs,
calc_model_size_by_data, calc_model_size_by_data,
) )
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
class VaeModel(ModelBase): class VaeModel(ModelBase):
#vae_class: Type #vae_class: Type
@ -70,39 +74,72 @@ class VaeModel(ModelBase):
return "checkpoint" return "checkpoint"
@classmethod @classmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str: def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase, # empty config or config of parent model
base_model: BaseModelType,
) -> str:
if cls.detect_format(model_path) != "diffusers": if cls.detect_format(model_path) != "diffusers":
# TODO: return _convert_vae_ckpt_and_cache(
#_convert_vae_ckpt_and_cache weights_path=model_path,
raise NotImplementedError("TODO: vae convert") output_path=output_path,
base_model=base_model,
model_config=config,
)
else: else:
return model_path return model_path
# TODO: rework # TODO: rework
DictConfig = dict def _convert_vae_ckpt_and_cache(
def _convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> str: weights_path: str,
output_path: str,
base_model: BaseModelType,
model_config: ModelConfigBase,
) -> str:
""" """
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
object, cache it to disk, and return Path to converted object, cache it to disk, and return Path to converted
file. If already on disk then just returns Path. file. If already on disk then just returns Path.
""" """
app_config = InvokeAIAppConfig.get_config() app_config = InvokeAIAppConfig.get_config()
root = app_config.root_dir weights_path = app_config.root_dir / weights_path
weights_file = root / mconfig.path output_path = Path(output_path)
config_file = root / mconfig.config
diffusers_path = app_config.converted_ckpts_dir / weights_file.stem """
image_size = mconfig.get('width') or mconfig.get('height') or 512 this size used only in when tiling enabled to separate input in tiles
sizes in configs from stable diffusion githubs(1 and 2) set to 256
on huggingface it:
1.5 - 512
1.5-inpainting - 256
2-inpainting - 512
2-depth - 256
2-base - 512
2 - 768
2.1-base - 768
2.1 - 768
"""
image_size = 512
# return cached version if it exists # return cached version if it exists
if diffusers_path.exists(): if output_path.exists():
return diffusers_path return output_path
if base_model in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
from .stable_diffusion import _select_ckpt_config
# all sd models use same vae settings
config_file = _select_ckpt_config(base_model, ModelVariantType.Normal)
else:
raise Exception(f"Vae conversion not supported for model type: {base_model}")
# this avoids circular import error # this avoids circular import error
from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers from ..convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
if weights_file.suffix == '.safetensors': if weights_path.suffix == '.safetensors':
checkpoint = safetensors.torch.load_file(weights_file) checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
else: else:
checkpoint = torch.load(weights_file, map_location="cpu") checkpoint = torch.load(weights_path, map_location="cpu")
# sometimes weights are hidden under "state_dict", and sometimes not # sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint: if "state_dict" in checkpoint:
@ -117,7 +154,7 @@ def _convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> str:
model_root = app_config.models_path, model_root = app_config.models_path,
) )
vae_model.save_pretrained( vae_model.save_pretrained(
diffusers_path, output_path,
safe_serialization=is_safetensors_available() safe_serialization=is_safetensors_available()
) )
return diffusers_path return output_path

View File

@ -14,7 +14,7 @@ export const receivedModels = createAppAsyncThunk(
const response = await ModelsService.listModels(); const response = await ModelsService.listModels();
const deserializedModels = reduce( const deserializedModels = reduce(
response.models['sd-1.5']['pipeline'], response.models['sd-1']['pipeline'],
(modelsAccumulator, model, modelName) => { (modelsAccumulator, model, modelName) => {
modelsAccumulator[modelName] = { ...model, name: modelName }; modelsAccumulator[modelName] = { ...model, name: modelName };
@ -25,7 +25,7 @@ export const receivedModels = createAppAsyncThunk(
models.info( models.info(
{ response }, { response },
`Received ${size(response.models['sd-1.5']['pipeline'])} models` `Received ${size(response.models['sd-1']['pipeline'])} models`
); );
return deserializedModels; return deserializedModels;

View File

@ -100,11 +100,11 @@ def migrate_conversion_models(dest_directory: Path):
# These are needed for the conversion script # These are needed for the conversion script
kwargs = dict( kwargs = dict(
cache_dir = Path('./models/hub'), cache_dir = Path('./models/hub'),
local_files_only = True #local_files_only = True
) )
try: try:
logger.info('Migrating core tokenizers and text encoders') logger.info('Migrating core tokenizers and text encoders')
target_dir = dest_directory/'core/convert' target_dir = dest_directory / 'core' / 'convert'
# bert # bert
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs) bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
@ -121,10 +121,10 @@ def migrate_conversion_models(dest_directory: Path):
# sd-2 # sd-2
repo_id = "stabilityai/stable-diffusion-2" repo_id = "stabilityai/stable-diffusion-2"
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs) pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
pipeline.save_pretrained(target_dir/'stable-diffusion-2-tokenizer', safe_serialization=True) pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True)
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs) pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
pipeline.save_pretrained(target_dir/'stable-diffusion-2-text_encoder', safe_serialization=True) pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
# VAE # VAE
logger.info('Migrating stable diffusion VAE') logger.info('Migrating stable diffusion VAE')
@ -135,7 +135,7 @@ def migrate_conversion_models(dest_directory: Path):
logger.info('Migrating safety checker') logger.info('Migrating safety checker')
repo_id = "CompVis/stable-diffusion-safety-checker" repo_id = "CompVis/stable-diffusion-safety-checker"
pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs) pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs)
pipeline.save_pretrained(target_dir/'stable-diffusion-safety-checker-extractor', safe_serialization=True) pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs) pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs)
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True) pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)