SDXL & SDXL-refiner models convert correctly

This commit is contained in:
Lincoln Stein 2023-07-23 09:31:14 -04:00
parent 5e59edfaf1
commit f2a6f0cf21
5 changed files with 25 additions and 10 deletions

View File

@ -1121,7 +1121,7 @@ def download_from_original_stable_diffusion_ckpt(
prediction_type: str = None, prediction_type: str = None,
model_type: str = None, model_type: str = None,
extract_ema: bool = False, extract_ema: bool = False,
precision: torch.dtype = torch.float16, precision: torch.dtype = torch.float32,
scheduler_type: str = "pndm", scheduler_type: str = "pndm",
num_in_channels: Optional[int] = None, num_in_channels: Optional[int] = None,
upcast_attention: Optional[bool] = None, upcast_attention: Optional[bool] = None,
@ -1250,7 +1250,7 @@ def download_from_original_stable_diffusion_ckpt(
while "state_dict" in checkpoint: while "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"] checkpoint = checkpoint["state_dict"]
print(f'DEBUG: model_type = {model_type}; original_config_file = {original_config_file}') logger.debug(f'model_type = {model_type}; original_config_file = {original_config_file}')
if original_config_file is None: if original_config_file is None:
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
@ -1624,7 +1624,7 @@ def download_controlnet_from_original_ckpt(
original_config_file: str, original_config_file: str,
image_size: int = 512, image_size: int = 512,
extract_ema: bool = False, extract_ema: bool = False,
precision: torch.dtype = torch.float16, precision: torch.dtype = torch.float32,
num_in_channels: Optional[int] = None, num_in_channels: Optional[int] = None,
upcast_attention: Optional[bool] = None, upcast_attention: Optional[bool] = None,
device: str = None, device: str = None,
@ -1702,7 +1702,7 @@ def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size:
def convert_ckpt_to_diffusers( def convert_ckpt_to_diffusers(
checkpoint_path: Union[str, Path], checkpoint_path: Union[str, Path],
dump_path: Union[str, Path], dump_path: Union[str, Path],
no_safetensors: bool = False, use_safetensors: bool=True,
**kwargs, **kwargs,
): ):
""" """
@ -1714,7 +1714,7 @@ def convert_ckpt_to_diffusers(
pipe.save_pretrained( pipe.save_pretrained(
dump_path, dump_path,
safe_serialization=is_safetensors_available() and not no_safetensors, safe_serialization=use_safetensors and is_safetensors_available(),
) )
def convert_controlnet_to_diffusers( def convert_controlnet_to_diffusers(

View File

@ -253,10 +253,12 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion1 return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2 return BaseModelType.StableDiffusion2
# TODO: This is just a guess based on N=1
key_name = 'model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight' key_name = 'model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048: if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL return BaseModelType.StableDiffusionXL
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
return BaseModelType.StableDiffusionXLRefiner
else:
raise InvalidModelException("Cannot determine base type") raise InvalidModelException("Cannot determine base type")
def get_scheduler_prediction_type(self)->SchedulerPredictionType: def get_scheduler_prediction_type(self)->SchedulerPredictionType:

View File

@ -1,5 +1,6 @@
import os import os
import json import json
import invokeai.backend.util.logging as logger
from enum import Enum from enum import Enum
from pydantic import Field from pydantic import Field
from typing import Literal, Optional from typing import Literal, Optional
@ -108,14 +109,20 @@ class StableDiffusionXLModel(DiffusersModel):
config: ModelConfigBase, config: ModelConfigBase,
base_model: BaseModelType, base_model: BaseModelType,
) -> str: ) -> str:
# The convert script adapted from the diffusers package uses
# strings for the base model type. To avoid making too many
# source code changes, we simply translate here
model_base_to_model_type = {BaseModelType.StableDiffusionXL: 'SDXL',
BaseModelType.StableDiffusionXLRefiner: 'SDXL-Refiner',
}
if isinstance(config, cls.CheckpointConfig): if isinstance(config, cls.CheckpointConfig):
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
return _convert_ckpt_and_cache( return _convert_ckpt_and_cache(
version=base_model, version=base_model,
model_config=config, model_config=config,
output_path=output_path, output_path=output_path,
model_type='SDXL', model_type=model_base_to_model_type[base_model],
no_safetensors=True, # giving errors for some reason use_safetensors=False, # corrupts sdxl models for some reason
) )
else: else:
return model_path return model_path

View File

@ -16,9 +16,11 @@ from .base import (
InvalidModelException, InvalidModelException,
) )
from .sdxl import StableDiffusionXLModel from .sdxl import StableDiffusionXLModel
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf from omegaconf import OmegaConf
class StableDiffusion1ModelFormat(str, Enum): class StableDiffusion1ModelFormat(str, Enum):
Checkpoint = "checkpoint" Checkpoint = "checkpoint"
Diffusers = "diffusers" Diffusers = "diffusers"
@ -265,6 +267,9 @@ def _convert_ckpt_and_cache(
# to avoid circular import errors # to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from ...util.devices import choose_torch_device, torch_dtype
logger.info(f'Converting {weights} to diffusers format')
with SilenceWarnings(): with SilenceWarnings():
convert_ckpt_to_diffusers( convert_ckpt_to_diffusers(
weights, weights,
@ -275,6 +280,7 @@ def _convert_ckpt_and_cache(
extract_ema=True, extract_ema=True,
scan_needed=True, scan_needed=True,
from_safetensors = weights.suffix == ".safetensors", from_safetensors = weights.suffix == ".safetensors",
precision = torch_dtype(choose_torch_device()),
**kwargs, **kwargs,
) )
return output_path return output_path

View File

@ -1,7 +1,7 @@
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team # Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
""" """
invokeai.util.logging invokeai.backend.util.logging
Logging class for InvokeAI that produces console messages Logging class for InvokeAI that produces console messages