mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix ckpt and vae conversion, migrate script, remove sd2-base
This commit is contained in:
parent
a6af7e8824
commit
e7db6d8120
@ -43,10 +43,10 @@ class ModelLoaderOutput(BaseInvocationOutput):
|
||||
#fmt: on
|
||||
|
||||
|
||||
class ModelLoaderInvocation(BaseInvocation):
|
||||
class SD1ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loading submodels of selected model."""
|
||||
|
||||
type: Literal["model_loader"] = "model_loader"
|
||||
type: Literal["sd1_model_loader"] = "sd1_model_loader"
|
||||
|
||||
model_name: str = Field(default="", description="Model to load")
|
||||
# TODO: precision?
|
||||
@ -64,7 +64,110 @@ class ModelLoaderInvocation(BaseInvocation):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
|
||||
base_model = BaseModelType.StableDiffusion1_5 # TODO:
|
||||
base_model = BaseModelType.StableDiffusion1 # TODO:
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
):
|
||||
raise Exception(f"Unkown model name: {self.model_name}!")
|
||||
|
||||
"""
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.Tokenizer,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.TextEncoder,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
|
||||
if not context.services.model_manager.model_exists(
|
||||
model_name=self.model_name,
|
||||
model_type=SDModelType.Diffusers,
|
||||
submodel=SDModelType.UNet,
|
||||
):
|
||||
raise Exception(
|
||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
||||
)
|
||||
"""
|
||||
|
||||
|
||||
return ModelLoaderOutput(
|
||||
unet=UNetField(
|
||||
unet=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SubModelType.UNet,
|
||||
),
|
||||
scheduler=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SubModelType.Scheduler,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
clip=ClipField(
|
||||
tokenizer=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SubModelType.Tokenizer,
|
||||
),
|
||||
text_encoder=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SubModelType.TextEncoder,
|
||||
),
|
||||
loras=[],
|
||||
),
|
||||
vae=VaeField(
|
||||
vae=ModelInfo(
|
||||
model_name=self.model_name,
|
||||
base_model=base_model,
|
||||
model_type=ModelType.Pipeline,
|
||||
submodel=SubModelType.Vae,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# TODO: optimize(less code copy)
|
||||
class SD2ModelLoaderInvocation(BaseInvocation):
|
||||
"""Loading submodels of selected model."""
|
||||
|
||||
type: Literal["sd2_model_loader"] = "sd2_model_loader"
|
||||
|
||||
model_name: str = Field(default="", description="Model to load")
|
||||
# TODO: precision?
|
||||
|
||||
# Schema customisation
|
||||
class Config(InvocationConfig):
|
||||
schema_extra = {
|
||||
"ui": {
|
||||
"tags": ["model", "loader"],
|
||||
"type_hints": {
|
||||
"model_name": "model" # TODO: rename to model_name?
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
||||
|
||||
base_model = BaseModelType.StableDiffusion2 # TODO:
|
||||
|
||||
# TODO: not found exceptions
|
||||
if not context.services.model_manager.model_exists(
|
||||
|
@ -28,8 +28,9 @@ from safetensors.torch import load_file
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
from .model_manager import ModelManager, SDLegacyType
|
||||
from .model_manager import ModelManager
|
||||
from .model_cache import ModelCache
|
||||
from .models import SchedulerPredictionType, BaseModelType, ModelVariantType
|
||||
|
||||
try:
|
||||
from omegaconf import OmegaConf
|
||||
@ -58,10 +59,6 @@ from diffusers.pipelines.latent_diffusion.pipeline_latent_diffusion import (
|
||||
LDMBertConfig,
|
||||
LDMBertModel,
|
||||
)
|
||||
from diffusers.pipelines.paint_by_example import (
|
||||
PaintByExampleImageEncoder,
|
||||
PaintByExamplePipeline,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
@ -911,8 +908,10 @@ textenc_pattern = re.compile("|".join(protected.keys()))
|
||||
|
||||
|
||||
def convert_open_clip_checkpoint(checkpoint):
|
||||
cache_dir = InvokeAIAppConfig.get_config().cache_dir
|
||||
text_model = CLIPTextModel.from_pretrained(MODEL_ROOT / 'stable-diffusion-2-text_encoder')
|
||||
text_model = CLIPTextModel.from_pretrained(
|
||||
MODEL_ROOT / 'stable-diffusion-2-clip',
|
||||
subfolder='text_encoder',
|
||||
)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
@ -1002,20 +1001,15 @@ def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size:
|
||||
|
||||
def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path: str,
|
||||
original_config_file: str = None,
|
||||
num_in_channels: int = None,
|
||||
scheduler_type: str = "pndm",
|
||||
pipeline_type: str = None,
|
||||
image_size: int = None,
|
||||
prediction_type: str = None,
|
||||
model_version: BaseModelType,
|
||||
model_variant: ModelVariantType,
|
||||
original_config_file: str,
|
||||
extract_ema: bool = True,
|
||||
upcast_attn: bool = False,
|
||||
vae: AutoencoderKL = None,
|
||||
vae_path: str = None,
|
||||
precision: torch.dtype = torch.float32,
|
||||
return_generator_pipeline: bool = False,
|
||||
scan_needed:bool=True,
|
||||
) -> Union[StableDiffusionPipeline, StableDiffusionGeneratorPipeline]:
|
||||
upcast_attention: bool = False,
|
||||
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon,
|
||||
scan_needed: bool = True,
|
||||
) -> StableDiffusionPipeline:
|
||||
"""
|
||||
Load a Stable Diffusion pipeline object from a CompVis-style `.ckpt`/`.safetensors` file and (ideally) a `.yaml`
|
||||
config file.
|
||||
@ -1027,148 +1021,68 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
:param checkpoint_path: Path to `.ckpt` file.
|
||||
:param original_config_file: Path to `.yaml` config file corresponding to the original architecture.
|
||||
If `None`, will be automatically inferred by looking for a key that only exists in SD2.0 models.
|
||||
:param image_size: The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Diffusion v2
|
||||
Base. Use 768 for Stable Diffusion v2.
|
||||
:param prediction_type: The prediction type that the model was trained on. Use `'epsilon'` for Stable Diffusion
|
||||
v1.X and Stable Diffusion v2 Base. Use `'v-prediction'` for Stable Diffusion v2.
|
||||
:param num_in_channels: The number of input channels. If `None` number of input channels will be automatically
|
||||
inferred.
|
||||
:param scheduler_type: Type of scheduler to use. Should be one of `["pndm", "lms", "heun", "euler",
|
||||
"euler-ancestral", "dpm", "ddim"]`. :param model_type: The pipeline type. `None` to automatically infer, or one of
|
||||
`["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder", "PaintByExample"]`. :param extract_ema: Only relevant for
|
||||
`["FrozenOpenCLIPEmbedder", "FrozenCLIPEmbedder"]`. :param extract_ema: Only relevant for
|
||||
checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights
|
||||
or not. Defaults to `False`. Pass `True` to extract the EMA weights. EMA weights usually yield higher
|
||||
quality images for inference. Non-EMA weights are usually better to continue fine-tuning.
|
||||
:param precision: precision to use - torch.float16, torch.float32 or torch.autocast
|
||||
:param upcast_attention: Whether the attention computation should always be upcasted. This is necessary when
|
||||
running stable diffusion 2.1.
|
||||
:param vae: A diffusers VAE to load into the pipeline.
|
||||
:param vae_path: Path to a checkpoint VAE that will be converted into diffusers and loaded into the pipeline.
|
||||
"""
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
cache_dir = config.cache_dir
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
verbosity = dlogging.get_verbosity()
|
||||
dlogging.set_verbosity_error()
|
||||
|
||||
if Path(checkpoint_path).suffix == '.ckpt':
|
||||
if scan_needed:
|
||||
ModelCache.scan_model(checkpoint_path,checkpoint_path)
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
else:
|
||||
if str(checkpoint_path).endswith(".safetensors"):
|
||||
checkpoint = load_file(checkpoint_path)
|
||||
|
||||
pipeline_class = (
|
||||
StableDiffusionGeneratorPipeline
|
||||
if return_generator_pipeline
|
||||
else StableDiffusionPipeline
|
||||
)
|
||||
|
||||
# Sometimes models don't have the global_step item
|
||||
if "global_step" in checkpoint:
|
||||
global_step = checkpoint["global_step"]
|
||||
else:
|
||||
logger.debug("global_step key not found in model")
|
||||
global_step = None
|
||||
if scan_needed:
|
||||
ModelCache.scan_model(checkpoint_path, checkpoint_path)
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
|
||||
# sometimes there is a state_dict key and sometimes not
|
||||
if "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
upcast_attention = False
|
||||
if original_config_file is None:
|
||||
model_type = ModelManager.probe_model_type(checkpoint)
|
||||
|
||||
if model_type == SDLegacyType.V2_v:
|
||||
original_config_file = (
|
||||
config.legacy_conf_path / "v2-inference-v.yaml"
|
||||
)
|
||||
if global_step == 110000:
|
||||
# v2.1 needs to upcast attention
|
||||
upcast_attention = True
|
||||
elif model_type == SDLegacyType.V2_e:
|
||||
original_config_file = (
|
||||
config.legacy_conf_path / "v2-inference.yaml"
|
||||
)
|
||||
elif model_type == SDLegacyType.V1_INPAINT:
|
||||
original_config_file = (
|
||||
config.legacy_conf_path / "v1-inpainting-inference.yaml"
|
||||
)
|
||||
|
||||
elif model_type == SDLegacyType.V1:
|
||||
original_config_file = (
|
||||
config.legacy_conf_path / "v1-inference.yaml"
|
||||
)
|
||||
|
||||
else:
|
||||
raise Exception("Unknown checkpoint type")
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
|
||||
if num_in_channels is not None:
|
||||
original_config["model"]["params"]["unet_config"]["params"][
|
||||
"in_channels"
|
||||
] = num_in_channels
|
||||
|
||||
if (
|
||||
"parameterization" in original_config["model"]["params"]
|
||||
and original_config["model"]["params"]["parameterization"] == "v"
|
||||
):
|
||||
if prediction_type is None:
|
||||
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
|
||||
# as it relies on a brittle global step parameter here
|
||||
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
|
||||
if image_size is None:
|
||||
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
|
||||
# as it relies on a brittle global step parameter here
|
||||
image_size = 512 if global_step == 875000 else 768
|
||||
if model_version == BaseModelType.StableDiffusion2 and prediction_type == SchedulerPredictionType.VPrediction:
|
||||
image_size = 768
|
||||
else:
|
||||
if prediction_type is None:
|
||||
prediction_type = "epsilon"
|
||||
if image_size is None:
|
||||
image_size = 512
|
||||
image_size = 512
|
||||
|
||||
#
|
||||
# convert scheduler
|
||||
#
|
||||
|
||||
num_train_timesteps = original_config.model.params.timesteps
|
||||
beta_start = original_config.model.params.linear_start
|
||||
beta_end = original_config.model.params.linear_end
|
||||
|
||||
scheduler = DDIMScheduler(
|
||||
scheduler = PNDMScheduler(
|
||||
beta_end=beta_end,
|
||||
beta_schedule="scaled_linear",
|
||||
beta_start=beta_start,
|
||||
num_train_timesteps=num_train_timesteps,
|
||||
steps_offset=1,
|
||||
clip_sample=False,
|
||||
set_alpha_to_one=False,
|
||||
prediction_type=prediction_type,
|
||||
skip_prk_steps=True
|
||||
)
|
||||
# make sure scheduler works correctly with DDIM
|
||||
scheduler.register_to_config(clip_sample=False)
|
||||
|
||||
if scheduler_type == "pndm":
|
||||
config = dict(scheduler.config)
|
||||
config["skip_prk_steps"] = True
|
||||
scheduler = PNDMScheduler.from_config(config)
|
||||
elif scheduler_type == "lms":
|
||||
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "heun":
|
||||
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "euler":
|
||||
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "euler-ancestral":
|
||||
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "dpm":
|
||||
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == 'unipc':
|
||||
scheduler = UniPCMultistepScheduler.from_config(scheduler.config)
|
||||
elif scheduler_type == "ddim":
|
||||
scheduler = scheduler
|
||||
else:
|
||||
raise ValueError(f"Scheduler of type {scheduler_type} doesn't exist!")
|
||||
#
|
||||
# convert unet
|
||||
#
|
||||
|
||||
# Convert the UNet2DConditionModel model.
|
||||
unet_config = create_unet_diffusers_config(
|
||||
original_config, image_size=image_size
|
||||
)
|
||||
@ -1181,35 +1095,25 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
|
||||
unet.load_state_dict(converted_unet_checkpoint)
|
||||
|
||||
# If a replacement VAE path was specified, we'll incorporate that into
|
||||
# the checkpoint model and then convert it
|
||||
if vae_path:
|
||||
logger.debug(f"Converting VAE {vae_path}")
|
||||
replace_checkpoint_vae(checkpoint,vae_path)
|
||||
# otherwise we use the original VAE, provided that
|
||||
# an externally loaded diffusers VAE was not passed
|
||||
elif not vae:
|
||||
logger.debug("Using checkpoint model's original VAE")
|
||||
#
|
||||
# convert vae
|
||||
#
|
||||
|
||||
if vae:
|
||||
logger.debug("Using replacement diffusers VAE")
|
||||
else: # convert the original or replacement VAE
|
||||
vae = convert_ldm_vae_to_diffusers(
|
||||
checkpoint,
|
||||
original_config,
|
||||
image_size)
|
||||
vae = convert_ldm_vae_to_diffusers(
|
||||
checkpoint,
|
||||
original_config,
|
||||
image_size,
|
||||
)
|
||||
|
||||
# Convert the text model.
|
||||
model_type = pipeline_type
|
||||
if model_type is None:
|
||||
model_type = original_config.model.params.cond_stage_config.target.split(
|
||||
"."
|
||||
)[-1]
|
||||
|
||||
model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
|
||||
if model_type == "FrozenOpenCLIPEmbedder":
|
||||
text_model = convert_open_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(MODEL_ROOT / 'stable-diffusion-2-tokenizer')
|
||||
pipe = pipeline_class(
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
MODEL_ROOT / 'stable-diffusion-2-clip',
|
||||
subfolder='tokenizer',
|
||||
)
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae.to(precision),
|
||||
text_encoder=text_model.to(precision),
|
||||
tokenizer=tokenizer,
|
||||
@ -1219,20 +1123,22 @@ def load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
|
||||
elif model_type in ["FrozenCLIPEmbedder", "WeightedFrozenCLIPEmbedder"]:
|
||||
text_model = convert_ldm_clip_checkpoint(checkpoint)
|
||||
tokenizer = CLIPTokenizer.from_pretrained(MODEL_ROOT / 'clip-vit-large-patch14')
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker')
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker-extractor')
|
||||
pipe = pipeline_class(
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ROOT / 'stable-diffusion-safety-checker')
|
||||
pipe = StableDiffusionPipeline(
|
||||
vae=vae.to(precision),
|
||||
text_encoder=text_model.to(precision),
|
||||
tokenizer=tokenizer,
|
||||
unet=unet.to(precision),
|
||||
scheduler=scheduler,
|
||||
safety_checker=None if return_generator_pipeline else safety_checker.to(precision),
|
||||
safety_checker=safety_checker.to(precision),
|
||||
feature_extractor=feature_extractor,
|
||||
)
|
||||
|
||||
else:
|
||||
text_config = create_ldm_bert_config(original_config)
|
||||
text_model = convert_ldm_bert_checkpoint(checkpoint, text_config)
|
||||
|
@ -20,15 +20,12 @@ import gc
|
||||
import os
|
||||
import sys
|
||||
import hashlib
|
||||
import warnings
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union, types, Optional, Type, Any
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import logging as diffusers_logging
|
||||
from transformers import logging as transformers_logging
|
||||
import logging
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import get_invokeai_config
|
||||
@ -382,21 +379,6 @@ class ModelCache(object):
|
||||
f.write(hash)
|
||||
return hash
|
||||
|
||||
class SilenceWarnings(object):
|
||||
def __init__(self):
|
||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
||||
|
||||
def __enter__(self):
|
||||
transformers_logging.set_verbosity_error()
|
||||
diffusers_logging.set_verbosity_error()
|
||||
warnings.simplefilter('ignore')
|
||||
|
||||
def __exit__(self,type,value,traceback):
|
||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||
warnings.simplefilter('default')
|
||||
|
||||
class VRAMUsage(object):
|
||||
def __init__(self):
|
||||
self.vram = None
|
||||
|
@ -148,6 +148,7 @@ into the model when downloaded or converted.
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import hashlib
|
||||
import textwrap
|
||||
from dataclasses import dataclass
|
||||
from packaging import version
|
||||
@ -166,7 +167,7 @@ import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util import CUDA_DEVICE, download_with_resume
|
||||
from .model_cache import ModelCache, ModelLocker
|
||||
from .models import BaseModelType, ModelType, SubModelType, MODEL_CLASSES
|
||||
from .models import BaseModelType, ModelType, SubModelType, ModelError, MODEL_CLASSES
|
||||
|
||||
# We are only starting to number the config file with release 3.
|
||||
# The config file version doesn't have to start at release version, but it will help
|
||||
@ -299,7 +300,7 @@ class ModelManager(object):
|
||||
for model_key, model_config in config.items():
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
self.models[model_key] = model_class.build_config(**model_config)
|
||||
self.models[model_key] = model_class.create_config(**model_config)
|
||||
|
||||
# check config version number and update on disk/RAM if necessary
|
||||
self.globals = InvokeAIAppConfig.get_config()
|
||||
@ -406,9 +407,7 @@ class ModelManager(object):
|
||||
path_mask = f"/models/{base_model}/{model_type}/{model_name}*"
|
||||
if False: # model_path = next(find_by_mask(path_mask)):
|
||||
model_path = None # TODO:
|
||||
model_config = model_class.build_config(
|
||||
path=model_path,
|
||||
)
|
||||
model_config = model_class.probe_config(model_path)
|
||||
self.models[model_key] = model_config
|
||||
else:
|
||||
raise Exception(f"Model not found - {model_key}")
|
||||
@ -437,16 +436,22 @@ class ModelManager(object):
|
||||
|
||||
# vae/movq override
|
||||
# TODO:
|
||||
if submodel_type is not None and submodel_type in model_config:
|
||||
model_path = model_config[submodel_type]
|
||||
model_type = submodel_type
|
||||
submodel_type = None
|
||||
if submodel_type is not None and hasattr(model_config, submodel_type):
|
||||
override_path = getattr(model_config, submodel_type)
|
||||
if override_path:
|
||||
model_path = override_path
|
||||
model_type = submodel_type
|
||||
submodel_type = None
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
|
||||
dst_convert_path = None # TODO:
|
||||
# TODO: path
|
||||
# TODO: is it accurate to use path as id
|
||||
dst_convert_path = self.globals.models_dir / ".cache" / hashlib.md5(model_path.encode()).hexdigest()
|
||||
model_path = model_class.convert_if_required(
|
||||
model_path,
|
||||
dst_convert_path,
|
||||
model_config,
|
||||
base_model=base_model,
|
||||
model_path=model_path,
|
||||
output_path=dst_convert_path,
|
||||
config=model_config,
|
||||
)
|
||||
|
||||
model_context = self.cache.get_model(
|
||||
@ -457,14 +462,14 @@ class ModelManager(object):
|
||||
submodel=submodel_type,
|
||||
)
|
||||
|
||||
hash = "<NO_HASH>" # TODO:
|
||||
model_hash = "<NO_HASH>" # TODO:
|
||||
|
||||
return ModelInfo(
|
||||
context = model_context,
|
||||
name = model_name,
|
||||
base_model = base_model,
|
||||
type = submodel_type or model_type,
|
||||
hash = hash,
|
||||
hash = model_hash,
|
||||
location = model_path, # TODO:
|
||||
precision = self.cache.precision,
|
||||
_cache = self.cache,
|
||||
@ -633,7 +638,7 @@ class ModelManager(object):
|
||||
"""
|
||||
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
model_config = model_class.build_config(**model_attributes)
|
||||
model_config = model_class.create_config(**model_attributes)
|
||||
model_key = self.create_key(model_name, base_model, model_type)
|
||||
|
||||
assert (
|
||||
@ -749,13 +754,15 @@ class ModelManager(object):
|
||||
|
||||
for model_key, model_config in list(self.models.items()):
|
||||
model_name, base_model, model_type = self.parse_key(model_key)
|
||||
if not os.path.exists(model_config.path):
|
||||
model_path = str(self.globals.root / model_config.path)
|
||||
if not os.path.exists(model_path):
|
||||
model_class = MODEL_CLASSES[base_model][model_type]
|
||||
if model_class.save_to_config:
|
||||
model_config.error = ModelError.NotFound
|
||||
else:
|
||||
self.models.pop(model_key, None)
|
||||
else:
|
||||
loaded_files.add(model_config.path)
|
||||
loaded_files.add(model_path)
|
||||
|
||||
for base_model in BaseModelType:
|
||||
for model_type in ModelType:
|
||||
@ -774,7 +781,5 @@ class ModelManager(object):
|
||||
if model_key in self.models:
|
||||
raise Exception(f"Model with key {model_key} added twice")
|
||||
|
||||
model_config: ModelConfigBase = model_class.build_config(
|
||||
path=model_path,
|
||||
)
|
||||
model_config: ModelConfigBase = model_class.probe_config(model_path)
|
||||
self.models[model_key] = model_config
|
||||
|
@ -12,8 +12,7 @@ from typing import Callable, Literal, Union, Dict
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType
|
||||
from .model_cache import SilenceWarnings
|
||||
from .models import BaseModelType, ModelType, ModelVariantType, SchedulerPredictionType, SilenceWarnings
|
||||
|
||||
@dataclass
|
||||
class ModelVariantInfo(object):
|
||||
@ -21,6 +20,7 @@ class ModelVariantInfo(object):
|
||||
base_type: BaseModelType
|
||||
variant_type: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
format: Literal['folder','checkpoint']
|
||||
image_size: int
|
||||
|
||||
@ -95,10 +95,12 @@ class ModelProbe(object):
|
||||
base_type = base_type,
|
||||
variant_type = variant_type,
|
||||
prediction_type = prediction_type,
|
||||
upcast_attention = (base_type==BaseModelType.StableDiffusion2 \
|
||||
and prediction_type==SchedulerPredictionType.VPrediction),
|
||||
format = format,
|
||||
image_size = 768 if (base_type==BaseModelType.StableDiffusion2 \
|
||||
and prediction_type==SchedulerPredictionType.VPrediction \
|
||||
) else 512
|
||||
) else 512,
|
||||
)
|
||||
except Exception as e:
|
||||
return None
|
||||
|
@ -1,5 +1,5 @@
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType
|
||||
from .stable_diffusion import StableDiffusion15Model, StableDiffusion2Model, StableDiffusion2BaseModel
|
||||
from .base import BaseModelType, ModelType, SubModelType, ModelBase, ModelConfigBase, ModelVariantType, SchedulerPredictionType, ModelError, SilenceWarnings
|
||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||
from .vae import VaeModel
|
||||
from .lora import LoRAModel
|
||||
#from .controlnet import ControlNetModel # TODO:
|
||||
@ -11,7 +11,7 @@ class ControlNetModel:
|
||||
|
||||
MODEL_CLASSES = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelType.Pipeline: StableDiffusion15Model,
|
||||
ModelType.Pipeline: StableDiffusion1Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
|
@ -10,13 +10,9 @@ from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
#StableDiffusion1_5 = "stable_diffusion_1_5"
|
||||
#StableDiffusion2 = "stable_diffusion_2"
|
||||
#StableDiffusion2Base = "stable_diffusion_2_base"
|
||||
# TODO: maybe then add sample size(512/768)?
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2" # 768 pixels; this will have v-prediction parameterization
|
||||
#Kandinsky2_1 = "kandinsky_2_1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
#Kandinsky2_1 = "kandinsky-2.1"
|
||||
|
||||
class ModelType(str, Enum):
|
||||
Pipeline = "pipeline"
|
||||
@ -107,7 +103,7 @@ class ModelBase:
|
||||
continue
|
||||
|
||||
fields = inspect.get_annotations(value)
|
||||
if "format" not in fields or typing.get_origin(fields["format"]) != Literal:
|
||||
if "format" not in fields:
|
||||
raise Exception("Invalid config definition - format field not found")
|
||||
|
||||
format_type = typing.get_origin(fields["format"])
|
||||
@ -125,13 +121,20 @@ class ModelBase:
|
||||
return cls.__configs
|
||||
|
||||
@classmethod
|
||||
def build_config(cls, **kwargs):
|
||||
def create_config(cls, **kwargs):
|
||||
if "format" not in kwargs:
|
||||
kwargs["format"] = cls.detect_format(kwargs["path"])
|
||||
raise Exception("Field 'format' not found in model config")
|
||||
|
||||
configs = cls._get_configs()
|
||||
return configs[kwargs["format"]](**kwargs)
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
format=cls.detect_format(path),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str) -> str:
|
||||
raise NotImplementedError()
|
||||
@ -304,3 +307,62 @@ def _calc_model_by_data(model) -> int:
|
||||
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in model.buffers()])
|
||||
mem = mem_params + mem_bufs # in bytes
|
||||
return mem
|
||||
|
||||
|
||||
def _fast_safetensors_reader(path: str):
|
||||
checkpoint = dict()
|
||||
device = torch.device("meta")
|
||||
with open(path, "rb") as f:
|
||||
definition_len = int.from_bytes(f.read(8), 'little')
|
||||
definition_json = f.read(definition_len)
|
||||
definition = json.loads(definition_json)
|
||||
|
||||
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}:
|
||||
raise Exception("Supported only pytorch safetensors files")
|
||||
definition.pop("__metadata__", None)
|
||||
|
||||
for key, info in definition.items():
|
||||
dtype = {
|
||||
"I8": torch.int8,
|
||||
"I16": torch.int16,
|
||||
"I32": torch.int32,
|
||||
"I64": torch.int64,
|
||||
"F16": torch.float16,
|
||||
"F32": torch.float32,
|
||||
"F64": torch.float64,
|
||||
}[info["dtype"]]
|
||||
|
||||
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def read_checkpoint_meta(path: str):
|
||||
if path.endswith(".safetensors"):
|
||||
try:
|
||||
checkpoint = _fast_safetensors_reader(path)
|
||||
except:
|
||||
# TODO: create issue for support "meta"?
|
||||
checkpoint = safetensors.torch.load_file(path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
return checkpoint
|
||||
|
||||
import warnings
|
||||
from diffusers import logging as diffusers_logging
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
class SilenceWarnings(object):
|
||||
def __init__(self):
|
||||
self.transformers_verbosity = transformers_logging.get_verbosity()
|
||||
self.diffusers_verbosity = diffusers_logging.get_verbosity()
|
||||
|
||||
def __enter__(self):
|
||||
transformers_logging.set_verbosity_error()
|
||||
diffusers_logging.set_verbosity_error()
|
||||
warnings.simplefilter('ignore')
|
||||
|
||||
def __exit__(self, type, value, traceback):
|
||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||
warnings.simplefilter('default')
|
||||
|
@ -54,8 +54,14 @@ class LoRAModel(ModelBase):
|
||||
else:
|
||||
return "lycoris"
|
||||
|
||||
@staticmethod
|
||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) == "diffusers":
|
||||
# TODO: add diffusers lora when it stabilizes a bit
|
||||
raise NotImplementedError("Diffusers lora not supported")
|
||||
|
@ -3,7 +3,8 @@ import json
|
||||
import torch
|
||||
import safetensors.torch
|
||||
from pydantic import Field
|
||||
from typing import Literal, Optional
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
@ -12,13 +13,16 @@ from .base import (
|
||||
SubModelType,
|
||||
ModelVariantType,
|
||||
DiffusersModel,
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
read_checkpoint_meta,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
# TODO: how to name properly
|
||||
class StableDiffusion15Model(DiffusersModel):
|
||||
|
||||
# TODO: str -> Path?
|
||||
class StableDiffusion1Model(DiffusersModel):
|
||||
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
format: Literal["diffusers"]
|
||||
vae: Optional[str] = Field(None)
|
||||
@ -32,94 +36,56 @@ class StableDiffusion15Model(DiffusersModel):
|
||||
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion1_5
|
||||
assert base_model == BaseModelType.StableDiffusion1
|
||||
assert model_type == ModelType.Pipeline
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion1_5,
|
||||
base_model=BaseModelType.StableDiffusion1,
|
||||
model_type=ModelType.Pipeline,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _fast_safetensors_reader(path: str):
|
||||
checkpoint = dict()
|
||||
device = torch.device("meta")
|
||||
with open(path, "rb") as f:
|
||||
definition_len = int.from_bytes(f.read(8), 'little')
|
||||
definition_json = f.read(definition_len)
|
||||
definition = json.loads(definition_json)
|
||||
|
||||
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}:
|
||||
raise Exception("Supported only pytorch safetensors files")
|
||||
definition.pop("__metadata__", None)
|
||||
|
||||
for key, info in definition.items():
|
||||
dtype = {
|
||||
"I8": torch.int8,
|
||||
"I16": torch.int16,
|
||||
"I32": torch.int32,
|
||||
"I64": torch.int64,
|
||||
"F16": torch.float16,
|
||||
"F32": torch.float32,
|
||||
"F64": torch.float64,
|
||||
}[info["dtype"]]
|
||||
|
||||
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
@classmethod
|
||||
def read_checkpoint_meta(cls, path: str):
|
||||
if path.endswith(".safetensors"):
|
||||
try:
|
||||
checkpoint = cls._fast_safetensors_reader(path)
|
||||
except:
|
||||
checkpoint = safetensors.torch.load_file(path, device="cpu") # TODO: create issue for support "meta"?
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
ckpt_config_path = kwargs.get("config", None)
|
||||
if model_format == "checkpoint":
|
||||
if ckpt_config_path:
|
||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||
|
||||
else:
|
||||
checkpoint = read_checkpoint_meta(path)
|
||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
|
||||
elif model_format == "diffusers":
|
||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||
if os.path.exists(unet_config_path):
|
||||
with open(unet_config_path, "r") as f:
|
||||
unet_config = json.loads(f.read())
|
||||
in_channels = unet_config['in_channels']
|
||||
|
||||
else:
|
||||
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
||||
|
||||
else:
|
||||
checkpoint = torch.load(path, map_location=torch.device("meta"))
|
||||
return checkpoint
|
||||
raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}")
|
||||
|
||||
@classmethod
|
||||
def build_config(cls, **kwargs):
|
||||
if "format" not in kwargs:
|
||||
kwargs["format"] = cls.detect_format(kwargs["path"])
|
||||
|
||||
if "variant" not in kwargs:
|
||||
if kwargs["format"] == "checkpoint":
|
||||
if "config" in kwargs:
|
||||
ckpt_config = OmegaConf.load(kwargs["config"])
|
||||
in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||
|
||||
else:
|
||||
checkpoint = cls.read_checkpoint_meta(kwargs["path"])
|
||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
|
||||
elif kwargs["format"] == "diffusers":
|
||||
unet_config_path = os.path.join(kwargs["path"], "unet", "config.json")
|
||||
if os.path.exists(unet_config_path):
|
||||
with open(unet_config_path, "r") as f:
|
||||
unet_config = json.loads(f.read())
|
||||
in_channels = unet_config['in_channels']
|
||||
|
||||
else:
|
||||
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown stable diffusion format: {kwargs['format']}")
|
||||
|
||||
if in_channels == 9:
|
||||
kwargs["variant"] = ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
kwargs["variant"] = ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
kwargs["variant"] = ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Unkown stable diffusion model format")
|
||||
if in_channels == 9:
|
||||
variant = ModelVariantType.Inpaint
|
||||
elif in_channels == 4:
|
||||
variant = ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Unkown stable diffusion 1.* model format")
|
||||
|
||||
|
||||
return super().build_config(**kwargs)
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
format=model_format,
|
||||
|
||||
config=ckpt_config_path,
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save_to_config(cls) -> bool:
|
||||
@ -133,128 +99,215 @@ class StableDiffusion15Model(DiffusersModel):
|
||||
return "checkpoint"
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
assert model_path == config.path
|
||||
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
return _convert_ckpt_and_cache(
|
||||
version=BaseModelType.StableDiffusion1_5,
|
||||
config=config.dict(),
|
||||
in_path=model_path,
|
||||
out_path=dst_cache_path,
|
||||
version=BaseModelType.StableDiffusion1,
|
||||
model_config=config,
|
||||
output_path=output_path,
|
||||
) # TODO: args
|
||||
else:
|
||||
return model_path
|
||||
|
||||
# all same
|
||||
class StableDiffusion2BaseModel(StableDiffusion15Model):
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
# skip StableDiffusion15Model __init__
|
||||
assert base_model == BaseModelType.StableDiffusion2Base
|
||||
assert model_type == ModelType.Pipeline
|
||||
super(StableDiffusion15Model, self).__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion2Base,
|
||||
model_type=ModelType.Pipeline,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
return _convert_ckpt_and_cache(
|
||||
version=BaseModelType.StableDiffusion2Base,
|
||||
config=config.dict(),
|
||||
in_path=model_path,
|
||||
out_path=dst_cache_path,
|
||||
) # TODO: args
|
||||
else:
|
||||
return model_path
|
||||
class StableDiffusion2Model(DiffusersModel):
|
||||
|
||||
class StableDiffusion2Model(StableDiffusion15Model):
|
||||
|
||||
# TODO: str -> Path?
|
||||
# TODO: check that configs overwriten
|
||||
# TODO: check that configs overwriten properly
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
format: Literal["diffusers"]
|
||||
vae: Optional[str] = Field(None)
|
||||
attention_upscale: bool = Field(True)
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
format: Literal["checkpoint"]
|
||||
vae: Optional[str] = Field(None)
|
||||
config: Optional[str] = Field(None)
|
||||
attention_upscale: bool = Field(True)
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
# skip StableDiffusion15Model __init__
|
||||
assert base_model == BaseModelType.StableDiffusion2
|
||||
assert model_type == ModelType.Pipeline
|
||||
# skip StableDiffusion15Model __init__
|
||||
super(StableDiffusion15Model, self).__init__(
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion2,
|
||||
model_type=ModelType.Pipeline,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
ckpt_config_path = kwargs.get("config", None)
|
||||
if model_format == "checkpoint":
|
||||
if ckpt_config_path:
|
||||
ckpt_config = OmegaConf.load(ckpt_config_path)
|
||||
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
||||
|
||||
else:
|
||||
checkpoint = read_checkpoint_meta(path)
|
||||
checkpoint = checkpoint.get('state_dict', checkpoint)
|
||||
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
||||
|
||||
elif model_format == "diffusers":
|
||||
unet_config_path = os.path.join(path, "unet", "config.json")
|
||||
if os.path.exists(unet_config_path):
|
||||
with open(unet_config_path, "r") as f:
|
||||
unet_config = json.loads(f.read())
|
||||
in_channels = unet_config['in_channels']
|
||||
|
||||
else:
|
||||
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
|
||||
|
||||
if in_channels == 9:
|
||||
variant = ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
variant = ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
variant = ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Unkown stable diffusion 2.* model format")
|
||||
|
||||
if variant == ModelVariantType.Normal:
|
||||
prediction_type = SchedulerPredictionType.VPrediction
|
||||
upcast_attention = True
|
||||
|
||||
else:
|
||||
prediction_type = SchedulerPredictionType.Epsilon
|
||||
upcast_attention = False
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
format=model_format,
|
||||
|
||||
config=ckpt_config_path,
|
||||
variant=variant,
|
||||
prediction_type=prediction_type,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
if os.path.isdir(model_path):
|
||||
return "diffusers"
|
||||
else:
|
||||
return "checkpoint"
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
assert model_path == config.path
|
||||
|
||||
if isinstance(config, cls.CheckpointConfig):
|
||||
return _convert_ckpt_and_cache(
|
||||
version=BaseModelType.StableDiffusion2,
|
||||
config=config.dict(),
|
||||
in_path=model_path,
|
||||
out_path=dst_cache_path,
|
||||
model_config=config,
|
||||
output_path=output_path,
|
||||
) # TODO: args
|
||||
else:
|
||||
return model_path
|
||||
|
||||
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
||||
ckpt_configs = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelVariantType.Normal: "v1-inference.yaml",
|
||||
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
# code further will manually set upcast_attention and v_prediction
|
||||
ModelVariantType.Normal: "v2-inference.yaml",
|
||||
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
|
||||
ModelVariantType.Depth: "v2-midas-inference.yaml",
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
# TODO: path
|
||||
#model_config.config = app_config.config_dir / "stable-diffusion" / ckpt_configs[version][model_config.variant]
|
||||
#return InvokeAIAppConfig.get_config().legacy_conf_dir / ckpt_configs[version][variant]
|
||||
return InvokeAIAppConfig.get_config().root_dir / "configs" / "stable-diffusion" / ckpt_configs[version][variant]
|
||||
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
# TODO: rework
|
||||
DictConfig = dict
|
||||
def _convert_ckpt_and_cache(
|
||||
self,
|
||||
version: BaseModelType,
|
||||
mconfig: dict, # TODO:
|
||||
in_path: str,
|
||||
out_path: str,
|
||||
model_config: Union[StableDiffusion1Model.CheckpointConfig, StableDiffusion2Model.CheckpointConfig],
|
||||
output_path: str,
|
||||
) -> str:
|
||||
"""
|
||||
Convert the checkpoint model indicated in mconfig into a
|
||||
diffusers, cache it to disk, and return Path to converted
|
||||
file. If already on disk then just returns Path.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
|
||||
#if "config" not in mconfig:
|
||||
# if version == BaseModelType.StableDiffusion1_5:
|
||||
#if
|
||||
#mconfig["config"] = app_config.config_dir / "stable-diffusion" / "v1-inference.yaml"
|
||||
if model_config.config is None:
|
||||
model_config.config = _select_ckpt_config(version, model_config.variant)
|
||||
if model_config.config is None:
|
||||
raise Exception(f"Model variant {model_config.variant} not supported for {version}")
|
||||
|
||||
|
||||
weights = app_config.root_dir / mconfig.path
|
||||
config_file = app_config.root_dir / mconfig.config
|
||||
diffusers_path = app_config.converted_ckpts_dir / weights.stem
|
||||
weights = app_config.root_dir / model_config.path
|
||||
config_file = app_config.root_dir / model_config.config
|
||||
output_path = Path(output_path)
|
||||
|
||||
if version == BaseModelType.StableDiffusion1:
|
||||
upcast_attention = False
|
||||
prediction_type = SchedulerPredictionType.Epsilon
|
||||
|
||||
elif version == BaseModelType.StableDiffusion2:
|
||||
upcast_attention = config.upcast_attention
|
||||
prediction_type = config.prediction_type
|
||||
|
||||
else:
|
||||
raise Exception(f"Unknown model provided: {version}")
|
||||
|
||||
|
||||
# return cached version if it exists
|
||||
if diffusers_path.exists():
|
||||
return diffusers_path
|
||||
if output_path.exists():
|
||||
return output_path
|
||||
|
||||
# TODO: I think that it more correctly to convert with embedded vae
|
||||
# as if user will delete custom vae he will got not embedded but also custom vae
|
||||
#vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
|
||||
vae_ckpt_path, vae_model = None, None
|
||||
|
||||
# to avoid circular import errors
|
||||
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||
with SilenceWarnings():
|
||||
convert_ckpt_to_diffusers(
|
||||
weights,
|
||||
diffusers_path,
|
||||
extract_ema=True,
|
||||
output_path,
|
||||
model_version=version,
|
||||
model_variant=model_config.variant,
|
||||
original_config_file=config_file,
|
||||
vae=vae_model,
|
||||
vae_path=str(app_config.root_dir / vae_ckpt_path) if vae_ckpt_path else None,
|
||||
extract_ema=True,
|
||||
upcast_attention=upcast_attention,
|
||||
prediction_type=prediction_type,
|
||||
scan_needed=True,
|
||||
model_root=app_config.models_path,
|
||||
)
|
||||
return diffusers_path
|
||||
return output_path
|
||||
|
@ -51,6 +51,12 @@ class TextualInversionModel(ModelBase):
|
||||
def detect_format(cls, path: str):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def convert_if_required(model_path: str, cache_path: str, config: Optional[dict]) -> str:
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
return model_path
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import torch
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from .base import (
|
||||
ModelBase,
|
||||
@ -7,11 +8,14 @@ from .base import (
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelVariantType,
|
||||
EmptyConfigLoader,
|
||||
calc_model_size_by_fs,
|
||||
calc_model_size_by_data,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from diffusers.utils import is_safetensors_available
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
class VaeModel(ModelBase):
|
||||
#vae_class: Type
|
||||
@ -70,39 +74,72 @@ class VaeModel(ModelBase):
|
||||
return "checkpoint"
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: Optional[dict]) -> str:
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase, # empty config or config of parent model
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
if cls.detect_format(model_path) != "diffusers":
|
||||
# TODO:
|
||||
#_convert_vae_ckpt_and_cache
|
||||
raise NotImplementedError("TODO: vae convert")
|
||||
return _convert_vae_ckpt_and_cache(
|
||||
weights_path=model_path,
|
||||
output_path=output_path,
|
||||
base_model=base_model,
|
||||
model_config=config,
|
||||
)
|
||||
else:
|
||||
return model_path
|
||||
|
||||
# TODO: rework
|
||||
DictConfig = dict
|
||||
def _convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> str:
|
||||
def _convert_vae_ckpt_and_cache(
|
||||
weights_path: str,
|
||||
output_path: str,
|
||||
base_model: BaseModelType,
|
||||
model_config: ModelConfigBase,
|
||||
) -> str:
|
||||
"""
|
||||
Convert the VAE indicated in mconfig into a diffusers AutoencoderKL
|
||||
object, cache it to disk, and return Path to converted
|
||||
file. If already on disk then just returns Path.
|
||||
"""
|
||||
app_config = InvokeAIAppConfig.get_config()
|
||||
root = app_config.root_dir
|
||||
weights_file = root / mconfig.path
|
||||
config_file = root / mconfig.config
|
||||
diffusers_path = app_config.converted_ckpts_dir / weights_file.stem
|
||||
image_size = mconfig.get('width') or mconfig.get('height') or 512
|
||||
weights_path = app_config.root_dir / weights_path
|
||||
output_path = Path(output_path)
|
||||
|
||||
"""
|
||||
this size used only in when tiling enabled to separate input in tiles
|
||||
sizes in configs from stable diffusion githubs(1 and 2) set to 256
|
||||
on huggingface it:
|
||||
1.5 - 512
|
||||
1.5-inpainting - 256
|
||||
2-inpainting - 512
|
||||
2-depth - 256
|
||||
2-base - 512
|
||||
2 - 768
|
||||
2.1-base - 768
|
||||
2.1 - 768
|
||||
"""
|
||||
image_size = 512
|
||||
|
||||
# return cached version if it exists
|
||||
if diffusers_path.exists():
|
||||
return diffusers_path
|
||||
if output_path.exists():
|
||||
return output_path
|
||||
|
||||
if base_model in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
|
||||
from .stable_diffusion import _select_ckpt_config
|
||||
# all sd models use same vae settings
|
||||
config_file = _select_ckpt_config(base_model, ModelVariantType.Normal)
|
||||
|
||||
else:
|
||||
raise Exception(f"Vae conversion not supported for model type: {base_model}")
|
||||
|
||||
# this avoids circular import error
|
||||
from .convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
if weights_file.suffix == '.safetensors':
|
||||
checkpoint = safetensors.torch.load_file(weights_file)
|
||||
from ..convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
|
||||
if weights_path.suffix == '.safetensors':
|
||||
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(weights_file, map_location="cpu")
|
||||
checkpoint = torch.load(weights_path, map_location="cpu")
|
||||
|
||||
# sometimes weights are hidden under "state_dict", and sometimes not
|
||||
if "state_dict" in checkpoint:
|
||||
@ -117,7 +154,7 @@ def _convert_vae_ckpt_and_cache(self, mconfig: DictConfig) -> str:
|
||||
model_root = app_config.models_path,
|
||||
)
|
||||
vae_model.save_pretrained(
|
||||
diffusers_path,
|
||||
output_path,
|
||||
safe_serialization=is_safetensors_available()
|
||||
)
|
||||
return diffusers_path
|
||||
return output_path
|
||||
|
@ -14,7 +14,7 @@ export const receivedModels = createAppAsyncThunk(
|
||||
const response = await ModelsService.listModels();
|
||||
|
||||
const deserializedModels = reduce(
|
||||
response.models['sd-1.5']['pipeline'],
|
||||
response.models['sd-1']['pipeline'],
|
||||
(modelsAccumulator, model, modelName) => {
|
||||
modelsAccumulator[modelName] = { ...model, name: modelName };
|
||||
|
||||
@ -25,7 +25,7 @@ export const receivedModels = createAppAsyncThunk(
|
||||
|
||||
models.info(
|
||||
{ response },
|
||||
`Received ${size(response.models['sd-1.5']['pipeline'])} models`
|
||||
`Received ${size(response.models['sd-1']['pipeline'])} models`
|
||||
);
|
||||
|
||||
return deserializedModels;
|
||||
|
@ -100,45 +100,45 @@ def migrate_conversion_models(dest_directory: Path):
|
||||
# These are needed for the conversion script
|
||||
kwargs = dict(
|
||||
cache_dir = Path('./models/hub'),
|
||||
local_files_only = True
|
||||
#local_files_only = True
|
||||
)
|
||||
try:
|
||||
logger.info('Migrating core tokenizers and text encoders')
|
||||
target_dir = dest_directory/'core/convert'
|
||||
target_dir = dest_directory / 'core' / 'convert'
|
||||
|
||||
# bert
|
||||
bert = BertTokenizerFast.from_pretrained("bert-base-uncased", **kwargs)
|
||||
bert.save_pretrained(target_dir/'bert-base-uncased', safe_serialization=True)
|
||||
bert.save_pretrained(target_dir / 'bert-base-uncased', safe_serialization=True)
|
||||
|
||||
# sd-1
|
||||
repo_id = 'openai/clip-vit-large-patch14'
|
||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, **kwargs)
|
||||
pipeline.save_pretrained(target_dir/'clip-vit-large-patch14', safe_serialization=True)
|
||||
pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14', safe_serialization=True)
|
||||
|
||||
pipeline = CLIPTextModel.from_pretrained(repo_id, **kwargs)
|
||||
pipeline.save_pretrained(target_dir/'clip-vit-large-patch14', safe_serialization=True)
|
||||
pipeline.save_pretrained(target_dir / 'clip-vit-large-patch14', safe_serialization=True)
|
||||
|
||||
# sd-2
|
||||
repo_id = "stabilityai/stable-diffusion-2"
|
||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder = "tokenizer", **kwargs)
|
||||
pipeline.save_pretrained(target_dir/'stable-diffusion-2-tokenizer', safe_serialization=True)
|
||||
pipeline = CLIPTokenizer.from_pretrained(repo_id, subfolder="tokenizer", **kwargs)
|
||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'tokenizer', safe_serialization=True)
|
||||
|
||||
pipeline = CLIPTextModel.from_pretrained(repo_id,subfolder = "text_encoder", **kwargs)
|
||||
pipeline.save_pretrained(target_dir/'stable-diffusion-2-text_encoder', safe_serialization=True)
|
||||
pipeline = CLIPTextModel.from_pretrained(repo_id, subfolder="text_encoder", **kwargs)
|
||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-2-clip' / 'text_encoder', safe_serialization=True)
|
||||
|
||||
# VAE
|
||||
logger.info('Migrating stable diffusion VAE')
|
||||
vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse', **kwargs)
|
||||
vae.save_pretrained(target_dir/'sd-vae-ft-mse', safe_serialization=True)
|
||||
vae.save_pretrained(target_dir / 'sd-vae-ft-mse', safe_serialization=True)
|
||||
|
||||
# safety checking
|
||||
logger.info('Migrating safety checker')
|
||||
repo_id = "CompVis/stable-diffusion-safety-checker"
|
||||
pipeline = AutoFeatureExtractor.from_pretrained(repo_id,**kwargs)
|
||||
pipeline.save_pretrained(target_dir/'stable-diffusion-safety-checker-extractor', safe_serialization=True)
|
||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
||||
|
||||
pipeline = StableDiffusionSafetyChecker.from_pretrained(repo_id,**kwargs)
|
||||
pipeline.save_pretrained(target_dir/'stable-diffusion-safety-checker', safe_serialization=True)
|
||||
pipeline.save_pretrained(target_dir / 'stable-diffusion-safety-checker', safe_serialization=True)
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
Loading…
Reference in New Issue
Block a user