InvokeAI/invokeai/backend/model_management/models/stable_diffusion.py

260 lines
9.1 KiB
Python

import os
import json
import torch
import safetensors.torch
from pydantic import Field
from typing import Literal, Optional
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
ModelVariantType,
DiffusersModel,
)
from invokeai.app.services.config import InvokeAIAppConfig
# TODO: how to name properly
class StableDiffusion15Model(DiffusersModel):
# TODO: str -> Path?
class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1_5
assert model_type == ModelType.Pipeline
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion1_5,
model_type=ModelType.Pipeline,
)
@staticmethod
def _fast_safetensors_reader(path: str):
checkpoint = dict()
device = torch.device("meta")
with open(path, "rb") as f:
definition_len = int.from_bytes(f.read(8), 'little')
definition_json = f.read(definition_len)
definition = json.loads(definition_json)
if "__metadata__" in definition and definition["__metadata__"].get("format", "pt") not in {"pt", "torch", "pytorch"}:
raise Exception("Supported only pytorch safetensors files")
definition.pop("__metadata__", None)
for key, info in definition.items():
dtype = {
"I8": torch.int8,
"I16": torch.int16,
"I32": torch.int32,
"I64": torch.int64,
"F16": torch.float16,
"F32": torch.float32,
"F64": torch.float64,
}[info["dtype"]]
checkpoint[key] = torch.empty(info["shape"], dtype=dtype, device=device)
return checkpoint
@classmethod
def read_checkpoint_meta(cls, path: str):
if path.endswith(".safetensors"):
try:
checkpoint = cls._fast_safetensors_reader(path)
except:
checkpoint = safetensors.torch.load_file(path, device="cpu") # TODO: create issue for support "meta"?
else:
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint
@classmethod
def build_config(cls, **kwargs):
if "format" not in kwargs:
kwargs["format"] = cls.detect_format(kwargs["path"])
if "variant" not in kwargs:
if kwargs["format"] == "checkpoint":
if "config" in kwargs:
ckpt_config = OmegaConf.load(kwargs["config"])
in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = cls.read_checkpoint_meta(kwargs["path"])
checkpoint = checkpoint.get('state_dict', checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif kwargs["format"] == "diffusers":
unet_config_path = os.path.join(kwargs["path"], "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config['in_channels']
else:
raise Exception("Not supported stable diffusion diffusers format(possibly onnx?)")
else:
raise NotImplementedError(f"Unknown stable diffusion format: {kwargs['format']}")
if in_channels == 9:
kwargs["variant"] = ModelVariantType.Inpaint
elif in_channels == 5:
kwargs["variant"] = ModelVariantType.Depth
elif in_channels == 4:
kwargs["variant"] = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion model format")
return super().build_config(**kwargs)
@classmethod
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
return "diffusers"
else:
return "checkpoint"
@classmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion1_5,
config=config.dict(),
in_path=model_path,
out_path=dst_cache_path,
) # TODO: args
else:
return model_path
# all same
class StableDiffusion2BaseModel(StableDiffusion15Model):
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
# skip StableDiffusion15Model __init__
assert base_model == BaseModelType.StableDiffusion2Base
assert model_type == ModelType.Pipeline
super(StableDiffusion15Model, self).__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion2Base,
model_type=ModelType.Pipeline,
)
@classmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2Base,
config=config.dict(),
in_path=model_path,
out_path=dst_cache_path,
) # TODO: args
else:
return model_path
class StableDiffusion2Model(StableDiffusion15Model):
# TODO: str -> Path?
# TODO: check that configs overwriten
class DiffusersConfig(ModelConfigBase):
format: Literal["diffusers"]
vae: Optional[str] = Field(None)
attention_upscale: bool = Field(True)
class CheckpointConfig(ModelConfigBase):
format: Literal["checkpoint"]
vae: Optional[str] = Field(None)
config: Optional[str] = Field(None)
attention_upscale: bool = Field(True)
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
# skip StableDiffusion15Model __init__
assert base_model == BaseModelType.StableDiffusion2
assert model_type == ModelType.Pipeline
# skip StableDiffusion15Model __init__
super(StableDiffusion15Model, self).__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion2,
model_type=ModelType.Pipeline,
)
@classmethod
def convert_if_required(cls, model_path: str, dst_cache_path: str, config: ModelConfigBase) -> str:
if isinstance(config, cls.CheckpointConfig):
return _convert_ckpt_and_cache(
version=BaseModelType.StableDiffusion2,
config=config.dict(),
in_path=model_path,
out_path=dst_cache_path,
) # TODO: args
else:
return model_path
# TODO: rework
DictConfig = dict
def _convert_ckpt_and_cache(
self,
version: BaseModelType,
mconfig: dict, # TODO:
in_path: str,
out_path: str,
) -> str:
"""
Convert the checkpoint model indicated in mconfig into a
diffusers, cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
raise NotImplementedError()
app_config = InvokeAIAppConfig.get_config()
#if "config" not in mconfig:
# if version == BaseModelType.StableDiffusion1_5:
#if
#mconfig["config"] = app_config.config_dir / "stable-diffusion" / "v1-inference.yaml"
weights = app_config.root_dir / mconfig.path
config_file = app_config.root_dir / mconfig.config
diffusers_path = app_config.converted_ckpts_dir / weights.stem
# return cached version if it exists
if diffusers_path.exists():
return diffusers_path
# TODO: I think that it more correctly to convert with embedded vae
# as if user will delete custom vae he will got not embedded but also custom vae
#vae_ckpt_path, vae_model = self._get_vae_for_conversion(weights, mconfig)
vae_ckpt_path, vae_model = None, None
# to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
with SilenceWarnings():
convert_ckpt_to_diffusers(
weights,
diffusers_path,
extract_ema=True,
original_config_file=config_file,
vae=vae_model,
vae_path=str(app_config.root_dir / vae_ckpt_path) if vae_ckpt_path else None,
scan_needed=True,
)
return diffusers_path