InvokeAI/invokeai/backend/model_management/models/stable_diffusion.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

330 lines
11 KiB
Python
Raw Normal View History

2023-06-11 01:49:09 +00:00
import os
2023-06-12 02:52:30 +00:00
import json
2023-06-20 00:30:09 +00:00
from enum import Enum
2023-06-11 01:49:09 +00:00
from pydantic import Field
from pathlib import Path
from typing import Literal, Optional, Union
2023-06-11 01:49:09 +00:00
from .base import (
ModelConfigBase,
BaseModelType,
ModelType,
ModelVariantType,
2023-06-11 01:49:09 +00:00
DiffusersModel,
SilenceWarnings,
read_checkpoint_meta,
classproperty,
2023-07-08 01:09:10 +00:00
InvalidModelException,
2023-07-26 19:02:32 +00:00
ModelNotFoundException,
2023-06-11 01:49:09 +00:00
)
from .sdxl import StableDiffusionXLModel
import invokeai.backend.util.logging as logger
2023-06-11 01:49:09 +00:00
from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf
2023-06-11 01:49:09 +00:00
2023-06-20 00:30:09 +00:00
class StableDiffusion1ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
2023-06-11 01:49:09 +00:00
class StableDiffusion1Model(DiffusersModel):
2023-06-17 14:15:36 +00:00
class DiffusersConfig(ModelConfigBase):
2023-06-20 00:30:09 +00:00
model_format: Literal[StableDiffusion1ModelFormat.Diffusers]
2023-06-11 01:49:09 +00:00
vae: Optional[str] = Field(None)
2023-06-12 02:52:30 +00:00
variant: ModelVariantType
2023-06-11 01:49:09 +00:00
2023-06-17 14:15:36 +00:00
class CheckpointConfig(ModelConfigBase):
2023-06-20 00:30:09 +00:00
model_format: Literal[StableDiffusion1ModelFormat.Checkpoint]
2023-06-11 01:49:09 +00:00
vae: Optional[str] = Field(None)
config: str
2023-06-12 02:52:30 +00:00
variant: ModelVariantType
2023-07-27 14:54:01 +00:00
2023-06-11 01:49:09 +00:00
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1
assert model_type == ModelType.Main
2023-06-11 01:49:09 +00:00
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.Main,
2023-06-11 01:49:09 +00:00
)
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
2023-06-20 00:30:09 +00:00
if model_format == StableDiffusion1ModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
2023-06-12 02:52:30 +00:00
else:
checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get("state_dict", checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
2023-06-12 02:52:30 +00:00
2023-06-20 00:30:09 +00:00
elif model_format == StableDiffusion1ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config["in_channels"]
2023-06-12 02:52:30 +00:00
else:
raise NotImplementedError(f"{path} is not a supported stable diffusion diffusers format")
2023-06-12 02:52:30 +00:00
else:
raise NotImplementedError(f"Unknown stable diffusion 1.* format: {model_format}")
2023-06-12 02:52:30 +00:00
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 1.* model format")
2023-06-12 02:52:30 +00:00
if ckpt_config_path is None:
ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion1, variant)
2023-06-12 02:52:30 +00:00
return cls.create_config(
path=path,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
)
2023-06-12 02:52:30 +00:00
@classproperty
2023-06-11 01:49:09 +00:00
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
2023-07-08 01:09:10 +00:00
if not os.path.exists(model_path):
raise ModelNotFoundException()
2023-06-11 01:49:09 +00:00
if os.path.isdir(model_path):
2023-07-08 01:09:10 +00:00
if os.path.exists(os.path.join(model_path, "model_index.json")):
return StableDiffusion1ModelFormat.Diffusers
if os.path.isfile(model_path):
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
return StableDiffusion1ModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {model_path}")
2023-06-11 01:49:09 +00:00
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
2023-06-17 14:15:36 +00:00
if isinstance(config, cls.CheckpointConfig):
2023-06-12 13:14:09 +00:00
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion1,
model_config=config,
output_path=output_path,
2023-07-05 13:05:05 +00:00
)
2023-06-12 13:14:09 +00:00
else:
return model_path
2023-07-27 14:54:01 +00:00
2023-06-20 00:30:09 +00:00
class StableDiffusion2ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
2023-06-11 01:49:09 +00:00
class StableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
2023-06-17 14:15:36 +00:00
class DiffusersConfig(ModelConfigBase):
2023-06-20 00:30:09 +00:00
model_format: Literal[StableDiffusion2ModelFormat.Diffusers]
2023-06-11 01:49:09 +00:00
vae: Optional[str] = Field(None)
variant: ModelVariantType
2023-06-11 01:49:09 +00:00
2023-06-17 14:15:36 +00:00
class CheckpointConfig(ModelConfigBase):
2023-06-20 00:30:09 +00:00
model_format: Literal[StableDiffusion2ModelFormat.Checkpoint]
2023-06-11 01:49:09 +00:00
vae: Optional[str] = Field(None)
config: str
variant: ModelVariantType
2023-06-11 01:49:09 +00:00
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion2
assert model_type == ModelType.Main
super().__init__(
2023-06-11 01:49:09 +00:00
model_path=model_path,
base_model=BaseModelType.StableDiffusion2,
model_type=ModelType.Main,
2023-06-11 01:49:09 +00:00
)
2023-06-12 13:14:09 +00:00
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
2023-06-20 00:30:09 +00:00
if model_format == StableDiffusion2ModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get("state_dict", checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
2023-06-20 00:30:09 +00:00
elif model_format == StableDiffusion2ModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config["in_channels"]
else:
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
else:
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 5:
variant = ModelVariantType.Depth
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 2.* model format")
if ckpt_config_path is None:
ckpt_config_path = _select_ckpt_config(BaseModelType.StableDiffusion2, variant)
return cls.create_config(
path=path,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
2023-07-08 01:09:10 +00:00
if not os.path.exists(model_path):
raise ModelNotFoundException()
if os.path.isdir(model_path):
2023-07-08 01:09:10 +00:00
if os.path.exists(os.path.join(model_path, "model_index.json")):
return StableDiffusion2ModelFormat.Diffusers
if os.path.isfile(model_path):
if any([model_path.endswith(f".{ext}") for ext in ["safetensors", "ckpt", "pt"]]):
return StableDiffusion2ModelFormat.Checkpoint
raise InvalidModelException(f"Not a valid model: {model_path}")
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
2023-06-17 14:15:36 +00:00
if isinstance(config, cls.CheckpointConfig):
2023-06-12 13:14:09 +00:00
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2,
model_config=config,
output_path=output_path,
)
2023-06-12 13:14:09 +00:00
else:
return model_path
2023-07-27 14:54:01 +00:00
2023-06-11 01:49:09 +00:00
# TODO: rework
# pass precision - currently defaulting to fp16
2023-06-12 13:14:09 +00:00
def _convert_ckpt_and_cache(
version: BaseModelType,
model_config: Union[
StableDiffusion1Model.CheckpointConfig,
StableDiffusion2Model.CheckpointConfig,
StableDiffusionXLModel.CheckpointConfig,
],
output_path: str,
use_save_model: bool = False,
**kwargs,
2023-06-12 13:14:09 +00:00
) -> str:
2023-06-11 01:49:09 +00:00
"""
Convert the checkpoint model indicated in mconfig into a
diffusers, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
app_config = InvokeAIAppConfig.get_config()
2023-06-12 13:14:09 +00:00
weights = app_config.models_path / model_config.path
config_file = app_config.root_path / model_config.config
output_path = Path(output_path)
2023-06-11 01:49:09 +00:00
# return cached version if it exists
if output_path.exists():
return output_path
2023-06-11 01:49:09 +00:00
# to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from ...util.devices import choose_torch_device, torch_dtype
2023-07-26 19:02:32 +00:00
model_base_to_model_type = {
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
BaseModelType.StableDiffusionXL: "SDXL",
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
2023-07-26 19:02:32 +00:00
}
logger.info(f"Converting {weights} to diffusers format")
2023-06-11 01:49:09 +00:00
with SilenceWarnings():
convert_ckpt_to_diffusers(
weights,
output_path,
2023-07-26 19:02:32 +00:00
model_type=model_base_to_model_type[version],
model_version=version,
model_variant=model_config.variant,
2023-06-11 01:49:09 +00:00
original_config_file=config_file,
extract_ema=True,
2023-06-11 01:49:09 +00:00
scan_needed=True,
from_safetensors=weights.suffix == ".safetensors",
precision=torch_dtype(choose_torch_device()),
**kwargs,
2023-06-11 01:49:09 +00:00
)
return output_path
2023-07-27 14:54:01 +00:00
def _select_ckpt_config(version: BaseModelType, variant: ModelVariantType):
ckpt_configs = {
BaseModelType.StableDiffusion1: {
ModelVariantType.Normal: "v1-inference.yaml",
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
},
BaseModelType.StableDiffusion2: {
ModelVariantType.Normal: "v2-inference-v.yaml", # best guess, as we can't differentiate with base(512)
ModelVariantType.Inpaint: "v2-inpainting-inference.yaml",
ModelVariantType.Depth: "v2-midas-inference.yaml",
},
BaseModelType.StableDiffusionXL: {
ModelVariantType.Normal: "sd_xl_base.yaml",
ModelVariantType.Inpaint: None,
ModelVariantType.Depth: None,
},
BaseModelType.StableDiffusionXLRefiner: {
ModelVariantType.Normal: "sd_xl_refiner.yaml",
ModelVariantType.Inpaint: None,
ModelVariantType.Depth: None,
},
}
app_config = InvokeAIAppConfig.get_config()
try:
config_path = app_config.legacy_conf_path / ckpt_configs[version][variant]
if config_path.is_relative_to(app_config.root_path):
config_path = config_path.relative_to(app_config.root_path)
return str(config_path)
2023-07-27 14:54:01 +00:00
except:
return None