2023-06-11 01:49:09 +00:00
|
|
|
import os
|
2023-06-12 02:52:30 +00:00
|
|
|
import json
|
2023-06-20 00:30:09 +00:00
|
|
|
from enum import Enum
|
2023-06-11 01:49:09 +00:00
|
|
|
from pydantic import Field
|
2023-06-13 15:05:12 +00:00
|
|
|
from pathlib import Path
|
|
|
|
from typing import Literal, Optional, Union
|
2023-06-11 01:49:09 +00:00
|
|
|
from .base import (
|
|
|
|
ModelConfigBase,
|
|
|
|
BaseModelType,
|
|
|
|
ModelType,
|
2023-06-12 20:07:04 +00:00
|
|
|
ModelVariantType,
|
2023-06-11 01:49:09 +00:00
|
|
|
DiffusersModel,
|
2023-06-13 15:05:12 +00:00
|
|
|
SilenceWarnings,
|
|
|
|
read_checkpoint_meta,
|
2023-06-14 00:12:12 +00:00
|
|
|
classproperty,
|
2023-07-08 01:09:10 +00:00
|
|
|
InvalidModelException,
|
2023-07-26 19:02:32 +00:00
|
|
|
ModelNotFoundException,
|
2023-06-11 01:49:09 +00:00
|
|
|
)
|
2023-07-23 00:12:16 +00:00
|
|
|
from .sdxl import StableDiffusionXLModel
|
2023-07-23 13:31:14 +00:00
|
|
|
import invokeai.backend.util.logging as logger
|
2023-06-11 01:49:09 +00:00
|
|
|
from invokeai.app.services.config import InvokeAIAppConfig
|
2023-06-13 15:05:12 +00:00
|
|
|
from omegaconf import OmegaConf
|
2023-06-11 01:49:09 +00:00
|
|
|
|
2023-07-23 13:31:14 +00:00
|
|
|
|
2023-06-20 00:30:09 +00:00
|
|
|
class StableDiffusion1ModelFormat(str, Enum):
|
|
|
|
Checkpoint = "checkpoint"
|
|
|
|
Diffusers = "diffusers"
|
2023-06-11 01:49:09 +00:00
|
|
|
|
2023-06-13 15:05:12 +00:00
|
|
|
|
|
|
|
class StableDiffusion1Model(DiffusersModel):
|
2023-06-17 14:15:36 +00:00
|
|
|
class DiffusersConfig(ModelConfigBase):
|
2023-06-20 00:30:09 +00:00
|
|
|
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
|
2023-06-11 01:49:09 +00:00
|
|
|
vae: Optional[str] = Field(None)
|
2023-06-12 02:52:30 +00:00
|
|
|
variant: ModelVariantType
|
2023-06-11 01:49:09 +00:00
|
|
|
|
2023-06-17 14:15:36 +00:00
|
|
|
class CheckpointConfig(ModelConfigBase):
|
2023-06-20 00:30:09 +00:00
|
|
|
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
|
2023-06-11 01:49:09 +00:00
|
|
|
vae: Optional[str] = Field(None)
|
2023-06-25 18:06:22 +00:00
|
|
|
config: str
|
2023-06-12 02:52:30 +00:00
|
|
|
variant: ModelVariantType
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-06-11 01:49:09 +00:00
|
|
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
2023-06-13 15:05:12 +00:00
|
|
|
assert base_model == BaseModelType.StableDiffusion1
|
2023-06-24 15:45:49 +00:00
|
|
|
assert model_type == ModelType.Main
|
2023-06-11 01:49:09 +00:00
|
|
|
super().__init__(
|
|
|
|
model_path=model_path,
|
2023-06-13 15:05:12 +00:00
|
|
|
base_model=BaseModelType.StableDiffusion1,
|
2023-06-24 15:45:49 +00:00
|
|
|
model_type=ModelType.Main,
|
2023-06-11 01:49:09 +00:00
|
|
|
)
|
|
|
|
|
2023-06-13 15:05:12 +00:00
|
|
|
@classmethod
|
|
|
|
def probe_config(cls, path: str, **kwargs):
|
|
|
|
model_format = cls.detect_format(path)
|
|
|
|
ckpt_config_path = kwargs.get("config", None)
|
2023-06-20 00:30:09 +00:00
|
|
|
if model_format == StableDiffusion1ModelFormat.Checkpoint:
|
2023-06-13 15:05:12 +00:00
|
|
|
if ckpt_config_path:
|
|
|
|
ckpt_config = OmegaConf.load(ckpt_config_path)
|
|
|
|
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
2023-06-12 02:52:30 +00:00
|
|
|
|
2023-06-13 15:05:12 +00:00
|
|
|
else:
|
|
|
|
checkpoint = read_checkpoint_meta(path)
|
|
|
|
checkpoint = checkpoint.get("state_dict", checkpoint)
|
|
|
|
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
2023-06-12 02:52:30 +00:00
|
|
|
|
2023-06-20 00:30:09 +00:00
|
|
|
elif model_format == StableDiffusion1ModelFormat.Diffusers:
|
2023-06-13 15:05:12 +00:00
|
|
|
unet_config_path = os.path.join(path, "unet", "config.json")
|
|
|
|
if os.path.exists(unet_config_path):
|
|
|
|
with open(unet_config_path, "r") as f:
|
|
|
|
unet_config = json.loads(f.read())
|
|
|
|
in_channels = unet_config["in_channels"]
|
2023-06-12 02:52:30 +00:00
|
|
|
|
2023-06-13 15:05:12 +00:00
|
|
|
else:
|
2023-06-28 19:26:42 +00:00
|
|
|
raise NotImplementedError(f"{path} is not a supported stable diffusion diffusers format")
|
2023-06-12 02:52:30 +00:00
|
|
|
|
|
|
|
else:
|
2023-06-13 15:05:12 +00:00
|
|
|
raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}")
|
2023-06-12 02:52:30 +00:00
|
|
|
|
2023-06-13 15:05:12 +00:00
|
|
|
if in_channels == 9:
|
|
|
|
variant = ModelVariantType.Inpaint
|
|
|
|
elif in_channels == 4:
|
|
|
|
variant = ModelVariantType.Normal
|
|
|
|
else:
|
|
|
|
raise Exception("Unkown stable diffusion 1.* model format")
|
2023-06-12 02:52:30 +00:00
|
|
|
|
2023-06-25 18:06:22 +00:00
|
|
|
if ckpt_config_path is None:
|
|
|
|
ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion1, variant)
|
2023-06-12 02:52:30 +00:00
|
|
|
|
2023-06-13 15:05:12 +00:00
|
|
|
return cls.create_config(
|
|
|
|
path=path,
|
2023-06-20 00:25:08 +00:00
|
|
|
model_format=model_format,
|
2023-06-13 15:05:12 +00:00
|
|
|
config=ckpt_config_path,
|
|
|
|
variant=variant,
|
|
|
|
)
|
2023-06-12 02:52:30 +00:00
|
|
|
|
2023-06-14 00:12:12 +00:00
|
|
|
@classproperty
|
2023-06-11 01:49:09 +00:00
|
|
|
def save_to_config(cls) -> bool:
|
|
|
|
return True
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def detect_format(cls, model_path: str):
|
2023-07-08 01:09:10 +00:00
|
|
|
if not os.path.exists(model_path):
|
|
|
|
raise ModelNotFoundException()
|
|
|
|
|
2023-06-11 01:49:09 +00:00
|
|
|
if os.path.isdir(model_path):
|
2023-07-08 01:09:10 +00:00
|
|
|
if os.path.exists(os.path.join(model_path, "model_index.json")):
|
|
|
|
return StableDiffusion1ModelFormat.Diffusers
|
|
|
|
|
|
|
|
if os.path.isfile(model_path):
|
|
|
|
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
|
|
|
return StableDiffusion1ModelFormat.Checkpoint
|
|
|
|
|
|
|
|
raise InvalidModelException(f"Not a valid model: {model_path}")
|
2023-06-11 01:49:09 +00:00
|
|
|
|
|
|
|
@classmethod
|
2023-06-13 15:05:12 +00:00
|
|
|
def convert_if_required(
|
|
|
|
cls,
|
|
|
|
model_path: str,
|
|
|
|
output_path: str,
|
|
|
|
config: ModelConfigBase,
|
|
|
|
base_model: BaseModelType,
|
|
|
|
) -> str:
|
2023-06-17 14:15:36 +00:00
|
|
|
if isinstance(config, cls.CheckpointConfig):
|
2023-06-12 13:14:09 +00:00
|
|
|
return _convert_ckpt_and_cache(
|
2023-06-13 15:05:12 +00:00
|
|
|
version=BaseModelType.StableDiffusion1,
|
|
|
|
model_config=config,
|
|
|
|
output_path=output_path,
|
2023-07-05 13:05:05 +00:00
|
|
|
)
|
2023-06-12 13:14:09 +00:00
|
|
|
else:
|
|
|
|
return model_path
|
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-06-20 00:30:09 +00:00
|
|
|
class StableDiffusion2ModelFormat(str, Enum):
|
|
|
|
Checkpoint = "checkpoint"
|
|
|
|
Diffusers = "diffusers"
|
2023-06-11 01:49:09 +00:00
|
|
|
|
2023-06-13 15:05:12 +00:00
|
|
|
|
|
|
|
class StableDiffusion2Model(DiffusersModel):
|
|
|
|
# TODO: check that configs overwriten properly
|
2023-06-17 14:15:36 +00:00
|
|
|
class DiffusersConfig(ModelConfigBase):
|
2023-06-20 00:30:09 +00:00
|
|
|
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
|
2023-06-11 01:49:09 +00:00
|
|
|
vae: Optional[str] = Field(None)
|
2023-06-15 18:30:15 +00:00
|
|
|
variant: ModelVariantType
|
2023-06-11 01:49:09 +00:00
|
|
|
|
2023-06-17 14:15:36 +00:00
|
|
|
class CheckpointConfig(ModelConfigBase):
|
2023-06-20 00:30:09 +00:00
|
|
|
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
|
2023-06-11 01:49:09 +00:00
|
|
|
vae: Optional[str] = Field(None)
|
2023-06-23 20:54:52 +00:00
|
|
|
config: str
|
2023-06-15 18:30:15 +00:00
|
|
|
variant: ModelVariantType
|
2023-06-11 01:49:09 +00:00
|
|
|
|
|
|
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
|
|
|
assert base_model == BaseModelType.StableDiffusion2
|
2023-06-24 15:45:49 +00:00
|
|
|
assert model_type == ModelType.Main
|
2023-06-13 15:05:12 +00:00
|
|
|
super().__init__(
|
2023-06-11 01:49:09 +00:00
|
|
|
model_path=model_path,
|
|
|
|
base_model=BaseModelType.StableDiffusion2,
|
2023-06-24 15:45:49 +00:00
|
|
|
model_type=ModelType.Main,
|
2023-06-11 01:49:09 +00:00
|
|
|
)
|
|
|
|
|
2023-06-12 13:14:09 +00:00
|
|
|
@classmethod
|
2023-06-13 15:05:12 +00:00
|
|
|
def probe_config(cls, path: str, **kwargs):
|
|
|
|
model_format = cls.detect_format(path)
|
|
|
|
ckpt_config_path = kwargs.get("config", None)
|
2023-06-20 00:30:09 +00:00
|
|
|
if model_format == StableDiffusion2ModelFormat.Checkpoint:
|
2023-06-13 15:05:12 +00:00
|
|
|
if ckpt_config_path:
|
|
|
|
ckpt_config = OmegaConf.load(ckpt_config_path)
|
|
|
|
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
|
|
|
|
|
|
|
|
else:
|
|
|
|
checkpoint = read_checkpoint_meta(path)
|
|
|
|
checkpoint = checkpoint.get("state_dict", checkpoint)
|
|
|
|
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
|
|
|
|
|
2023-06-20 00:30:09 +00:00
|
|
|
elif model_format == StableDiffusion2ModelFormat.Diffusers:
|
2023-06-13 15:05:12 +00:00
|
|
|
unet_config_path = os.path.join(path, "unet", "config.json")
|
|
|
|
if os.path.exists(unet_config_path):
|
|
|
|
with open(unet_config_path, "r") as f:
|
|
|
|
unet_config = json.loads(f.read())
|
|
|
|
in_channels = unet_config["in_channels"]
|
|
|
|
|
|
|
|
else:
|
|
|
|
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
|
|
|
|
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
|
|
|
|
|
|
|
|
if in_channels == 9:
|
|
|
|
variant = ModelVariantType.Inpaint
|
|
|
|
elif in_channels == 5:
|
|
|
|
variant = ModelVariantType.Depth
|
|
|
|
elif in_channels == 4:
|
|
|
|
variant = ModelVariantType.Normal
|
|
|
|
else:
|
|
|
|
raise Exception("Unkown stable diffusion 2.* model format")
|
|
|
|
|
2023-06-25 18:06:22 +00:00
|
|
|
if ckpt_config_path is None:
|
|
|
|
ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion2, variant)
|
2023-06-13 15:05:12 +00:00
|
|
|
|
|
|
|
return cls.create_config(
|
|
|
|
path=path,
|
2023-06-20 00:25:08 +00:00
|
|
|
model_format=model_format,
|
2023-06-13 15:05:12 +00:00
|
|
|
config=ckpt_config_path,
|
|
|
|
variant=variant,
|
|
|
|
)
|
|
|
|
|
2023-06-14 00:12:12 +00:00
|
|
|
@classproperty
|
2023-06-13 15:05:12 +00:00
|
|
|
def save_to_config(cls) -> bool:
|
|
|
|
return True
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def detect_format(cls, model_path: str):
|
2023-07-08 01:09:10 +00:00
|
|
|
if not os.path.exists(model_path):
|
|
|
|
raise ModelNotFoundException()
|
|
|
|
|
2023-06-13 15:05:12 +00:00
|
|
|
if os.path.isdir(model_path):
|
2023-07-08 01:09:10 +00:00
|
|
|
if os.path.exists(os.path.join(model_path, "model_index.json")):
|
|
|
|
return StableDiffusion2ModelFormat.Diffusers
|
|
|
|
|
|
|
|
if os.path.isfile(model_path):
|
|
|
|
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
|
|
|
|
return StableDiffusion2ModelFormat.Checkpoint
|
|
|
|
|
|
|
|
raise InvalidModelException(f"Not a valid model: {model_path}")
|
2023-06-13 15:05:12 +00:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def convert_if_required(
|
|
|
|
cls,
|
|
|
|
model_path: str,
|
|
|
|
output_path: str,
|
|
|
|
config: ModelConfigBase,
|
|
|
|
base_model: BaseModelType,
|
|
|
|
) -> str:
|
2023-06-17 14:15:36 +00:00
|
|
|
if isinstance(config, cls.CheckpointConfig):
|
2023-06-12 13:14:09 +00:00
|
|
|
return _convert_ckpt_and_cache(
|
|
|
|
version=BaseModelType.StableDiffusion2,
|
2023-06-13 15:05:12 +00:00
|
|
|
model_config=config,
|
|
|
|
output_path=output_path,
|
2023-06-25 18:06:22 +00:00
|
|
|
)
|
2023-06-12 13:14:09 +00:00
|
|
|
else:
|
|
|
|
return model_path
|
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-06-11 01:49:09 +00:00
|
|
|
# TODO: rework
|
2023-07-23 04:00:31 +00:00
|
|
|
# pass precision - currently defaulting to fp16
|
2023-06-12 13:14:09 +00:00
|
|
|
def _convert_ckpt_and_cache(
|
2023-07-23 00:12:16 +00:00
|
|
|
version: BaseModelType,
|
|
|
|
model_config: Union[
|
|
|
|
StableDiffusion1Model.CheckpointConfig,
|
|
|
|
StableDiffusion2Model.CheckpointConfig,
|
|
|
|
StableDiffusionXLModel.CheckpointConfig,
|
|
|
|
],
|
|
|
|
output_path: str,
|
|
|
|
use_save_model: bool = False,
|
|
|
|
**kwargs,
|
2023-06-12 13:14:09 +00:00
|
|
|
) -> str:
|
2023-06-11 01:49:09 +00:00
|
|
|
"""
|
|
|
|
Convert the checkpoint model indicated in mconfig into a
|
|
|
|
diffusers, cache it to disk, and return Path to converted
|
|
|
|
file. If already on disk then just returns Path.
|
|
|
|
"""
|
|
|
|
app_config = InvokeAIAppConfig.get_config()
|
2023-06-12 13:14:09 +00:00
|
|
|
|
2023-07-29 14:30:27 +00:00
|
|
|
weights = app_config.models_path / model_config.path
|
2023-06-26 17:52:25 +00:00
|
|
|
config_file = app_config.root_path / model_config.config
|
2023-06-13 15:05:12 +00:00
|
|
|
output_path = Path(output_path)
|
|
|
|
|
2023-06-11 01:49:09 +00:00
|
|
|
# return cached version if it exists
|
2023-06-13 15:05:12 +00:00
|
|
|
if output_path.exists():
|
|
|
|
return output_path
|
2023-06-11 01:49:09 +00:00
|
|
|
|
|
|
|
# to avoid circular import errors
|
|
|
|
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
2023-07-23 13:31:14 +00:00
|
|
|
from ...util.devices import choose_torch_device, torch_dtype
|
2023-07-26 19:02:32 +00:00
|
|
|
|
|
|
|
model_base_to_model_type = {
|
|
|
|
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
|
|
|
|
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
|
2023-07-27 03:28:58 +00:00
|
|
|
BaseModelType.StableDiffusionXL: "SDXL",
|
|
|
|
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
|
2023-07-26 19:02:32 +00:00
|
|
|
}
|
2023-07-23 13:31:14 +00:00
|
|
|
logger.info(f"Converting {weights} to diffusers format")
|
2023-06-11 01:49:09 +00:00
|
|
|
with SilenceWarnings():
|
|
|
|
convert_ckpt_to_diffusers(
|
|
|
|
weights,
|
2023-06-13 15:05:12 +00:00
|
|
|
output_path,
|
2023-07-26 19:02:32 +00:00
|
|
|
model_type=model_base_to_model_type[version],
|
2023-06-13 15:05:12 +00:00
|
|
|
model_version=version,
|
|
|
|
model_variant=model_config.variant,
|
2023-06-11 01:49:09 +00:00
|
|
|
original_config_file=config_file,
|
2023-06-13 15:05:12 +00:00
|
|
|
extract_ema=True,
|
2023-06-11 01:49:09 +00:00
|
|
|
scan_needed=True,
|
2023-07-23 00:12:16 +00:00
|
|
|
from_safetensors=weights.suffix == ".safetensors",
|
2023-07-23 13:31:14 +00:00
|
|
|
precision=torch_dtype(choose_torch_device()),
|
2023-07-23 00:12:16 +00:00
|
|
|
**kwargs,
|
2023-06-11 01:49:09 +00:00
|
|
|
)
|
2023-06-13 15:05:12 +00:00
|
|
|
return output_path
|
2023-07-23 00:12:16 +00:00
|
|
|
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-07-23 00:12:16 +00:00
|
|
|
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
|
|
|
|
ckpt_configs = {
|
|
|
|
BaseModelType.StableDiffusion1: {
|
|
|
|
ModelVariantType.Normal: "v1-inference.yaml",
|
|
|
|
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
|
|
|
},
|
|
|
|
BaseModelType.StableDiffusion2: {
|
|
|
|
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
|
|
|
|
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
|
|
|
|
ModelVariantType.Depth: "v2-midas-inference.yaml",
|
|
|
|
},
|
|
|
|
BaseModelType.StableDiffusionXL: {
|
|
|
|
ModelVariantType.Normal: "sd_xl_base.yaml",
|
|
|
|
ModelVariantType.Inpaint: None,
|
|
|
|
ModelVariantType.Depth: None,
|
|
|
|
},
|
|
|
|
BaseModelType.StableDiffusionXLRefiner: {
|
|
|
|
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
|
|
|
ModelVariantType.Inpaint: None,
|
|
|
|
ModelVariantType.Depth: None,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
|
|
|
|
app_config = InvokeAIAppConfig.get_config()
|
|
|
|
try:
|
|
|
|
config_path = app_config.legacy_conf_path / ckpt_configs[version][variant]
|
|
|
|
if config_path.is_relative_to(app_config.root_path):
|
|
|
|
config_path = config_path.relative_to(app_config.root_path)
|
|
|
|
return str(config_path)
|
2023-07-27 14:54:01 +00:00
|
|
|
|
2023-07-23 00:12:16 +00:00
|
|
|
except:
|
|
|
|
return None
|