InvokeAI/invokeai/backend/model_management/models/sdxl.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

141 lines
4.9 KiB
Python
Raw Normal View History

import json
2023-08-18 15:13:28 +00:00
import os
from enum import Enum
from typing import Literal, Optional
2023-08-18 15:13:28 +00:00
from omegaconf import OmegaConf
from pydantic import Field
from .base import (
BaseModelType,
2023-08-18 15:13:28 +00:00
DiffusersModel,
InvalidModelException,
ModelConfigBase,
ModelType,
ModelVariantType,
classproperty,
2023-08-18 15:13:28 +00:00
read_checkpoint_meta,
)
2023-07-27 14:54:01 +00:00
class StableDiffusionXLModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
2023-07-27 14:54:01 +00:00
class StableDiffusionXLModel(DiffusersModel):
# TODO: check that configs overwriten properly
class DiffusersConfig(ModelConfigBase):
model_format: Literal[StableDiffusionXLModelFormat.Diffusers]
vae: Optional[str] = Field(None)
variant: ModelVariantType
class CheckpointConfig(ModelConfigBase):
model_format: Literal[StableDiffusionXLModelFormat.Checkpoint]
vae: Optional[str] = Field(None)
config: str
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}
assert model_type == ModelType.Main
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusionXL,
model_type=ModelType.Main,
)
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
ckpt_config_path = kwargs.get("config", None)
if model_format == StableDiffusionXLModelFormat.Checkpoint:
if ckpt_config_path:
ckpt_config = OmegaConf.load(ckpt_config_path)
in_channels = ckpt_config["model"]["params"]["unet_config"]["params"]["in_channels"]
else:
checkpoint = read_checkpoint_meta(path)
checkpoint = checkpoint.get("state_dict", checkpoint)
in_channels = checkpoint["model.diffusion_model.input_blocks.0.0.weight"].shape[1]
elif model_format == StableDiffusionXLModelFormat.Diffusers:
unet_config_path = os.path.join(path, "unet", "config.json")
if os.path.exists(unet_config_path):
with open(unet_config_path, "r") as f:
unet_config = json.loads(f.read())
in_channels = unet_config["in_channels"]
else:
raise InvalidModelException(f"{path} is not a recognized Stable Diffusion diffusers model")
else:
raise NotImplementedError(f"Unknown stable diffusion 2.* format: {model_format}")
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 5:
variant = ModelVariantType.Depth
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 2.* model format")
if ckpt_config_path is None:
# avoid circular import
from .stable_diffusion import _select_ckpt_config
ckpt_config_path = _select_ckpt_config(kwargs.get("model_base", BaseModelType.StableDiffusionXL), variant)
2023-07-27 14:54:01 +00:00
return cls.create_config(
path=path,
model_format=model_format,
config=ckpt_config_path,
variant=variant,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
if os.path.isdir(model_path):
return StableDiffusionXLModelFormat.Diffusers
else:
return StableDiffusionXLModelFormat.Checkpoint
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
# The convert script adapted from the diffusers package uses
# strings for the base model type. To avoid making too many
# source code changes, we simply translate here
if isinstance(config, cls.CheckpointConfig):
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
2023-07-27 14:54:01 +00:00
# Hack in VAE-fp16 fix - If model sdxl-vae-fp16-fix is installed,
# then we bake it into the converted model.
from invokeai.app.services.config import InvokeAIAppConfig
kwargs = dict()
app_config = InvokeAIAppConfig.get_config()
vae_path = app_config.models_path / 'sdxl/vae/sdxl-vae-fp16-fix'
if vae_path.exists():
kwargs['vae_path'] = vae_path
return _convert_ckpt_and_cache(
version=base_model,
model_config=config,
output_path=output_path,
use_safetensors=False, # corrupts sdxl models for some reason
**kwargs,
)
else:
return model_path