SDXL & SDXL-refiner models convert correctly

This commit is contained in:
Lincoln Stein 2023-07-23 09:31:14 -04:00
parent 5e59edfaf1
commit f2a6f0cf21
5 changed files with 25 additions and 10 deletions

View File

@ -1121,7 +1121,7 @@ def download_from_original_stable_diffusion_ckpt(
prediction_type: str = None,
model_type: str = None,
extract_ema: bool = False,
precision: torch.dtype = torch.float16,
precision: torch.dtype = torch.float32,
scheduler_type: str = "pndm",
num_in_channels: Optional[int] = None,
upcast_attention: Optional[bool] = None,
@ -1250,7 +1250,7 @@ def download_from_original_stable_diffusion_ckpt(
while "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
print(f'DEBUG: model_type = {model_type}; original_config_file = {original_config_file}')
logger.debug(f'model_type = {model_type}; original_config_file = {original_config_file}')
if original_config_file is None:
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
@ -1624,7 +1624,7 @@ def download_controlnet_from_original_ckpt(
original_config_file: str,
image_size: int = 512,
extract_ema: bool = False,
precision: torch.dtype = torch.float16,
precision: torch.dtype = torch.float32,
num_in_channels: Optional[int] = None,
upcast_attention: Optional[bool] = None,
device: str = None,
@ -1702,7 +1702,7 @@ def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size:
def convert_ckpt_to_diffusers(
checkpoint_path: Union[str, Path],
dump_path: Union[str, Path],
no_safetensors: bool = False,
use_safetensors: bool=True,
**kwargs,
):
"""
@ -1714,7 +1714,7 @@ def convert_ckpt_to_diffusers(
pipe.save_pretrained(
dump_path,
safe_serialization=is_safetensors_available() and not no_safetensors,
safe_serialization=use_safetensors and is_safetensors_available(),
)
def convert_controlnet_to_diffusers(

View File

@ -253,11 +253,13 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion1
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
return BaseModelType.StableDiffusion2
# TODO: This is just a guess based on N=1
key_name = 'model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL
raise InvalidModelException("Cannot determine base type")
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
return BaseModelType.StableDiffusionXLRefiner
else:
raise InvalidModelException("Cannot determine base type")
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
type = self.get_base_type()

View File

@ -1,5 +1,6 @@
import os
import json
import invokeai.backend.util.logging as logger
from enum import Enum
from pydantic import Field
from typing import Literal, Optional
@ -108,14 +109,20 @@ class StableDiffusionXLModel(DiffusersModel):
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
# The convert script adapted from the diffusers package uses
# strings for the base model type. To avoid making too many
# source code changes, we simply translate here
model_base_to_model_type = {BaseModelType.StableDiffusionXL: 'SDXL',
BaseModelType.StableDiffusionXLRefiner: 'SDXL-Refiner',
}
if isinstance(config, cls.CheckpointConfig):
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
return _convert_ckpt_and_cache(
version=base_model,
model_config=config,
output_path=output_path,
model_type='SDXL',
no_safetensors=True, # giving errors for some reason
model_type=model_base_to_model_type[base_model],
use_safetensors=False, # corrupts sdxl models for some reason
)
else:
return model_path

View File

@ -16,9 +16,11 @@ from .base import (
InvalidModelException,
)
from .sdxl import StableDiffusionXLModel
import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig
from omegaconf import OmegaConf
class StableDiffusion1ModelFormat(str, Enum):
Checkpoint = "checkpoint"
Diffusers = "diffusers"
@ -265,6 +267,9 @@ def _convert_ckpt_and_cache(
# to avoid circular import errors
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from ...util.devices import choose_torch_device, torch_dtype
logger.info(f'Converting {weights} to diffusers format')
with SilenceWarnings():
convert_ckpt_to_diffusers(
weights,
@ -275,6 +280,7 @@ def _convert_ckpt_and_cache(
extract_ema=True,
scan_needed=True,
from_safetensors = weights.suffix == ".safetensors",
precision = torch_dtype(choose_torch_device()),
**kwargs,
)
return output_path

View File

@ -1,7 +1,7 @@
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
"""
invokeai.util.logging
invokeai.backend.util.logging
Logging class for InvokeAI that produces console messages