From f2a6f0cf2185ebe222ad7b60e80ef91a3d9c8006 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 23 Jul 2023 09:31:14 -0400 Subject: [PATCH] SDXL & SDXL-refiner models convert correctly --- .../model_management/convert_ckpt_to_diffusers.py | 10 +++++----- invokeai/backend/model_management/model_probe.py | 6 ++++-- invokeai/backend/model_management/models/sdxl.py | 11 +++++++++-- .../model_management/models/stable_diffusion.py | 6 ++++++ invokeai/backend/util/logging.py | 2 +- 5 files changed, 25 insertions(+), 10 deletions(-) diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index aa4dc0ecbc..0124da7f56 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -1121,7 +1121,7 @@ def download_from_original_stable_diffusion_ckpt( prediction_type: str = None, model_type: str = None, extract_ema: bool = False, - precision: torch.dtype = torch.float16, + precision: torch.dtype = torch.float32, scheduler_type: str = "pndm", num_in_channels: Optional[int] = None, upcast_attention: Optional[bool] = None, @@ -1250,7 +1250,7 @@ def download_from_original_stable_diffusion_ckpt( while "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] - print(f'DEBUG: model_type = {model_type}; original_config_file = {original_config_file}') + logger.debug(f'model_type = {model_type}; original_config_file = {original_config_file}') if original_config_file is None: key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" @@ -1624,7 +1624,7 @@ def download_controlnet_from_original_ckpt( original_config_file: str, image_size: int = 512, extract_ema: bool = False, - precision: torch.dtype = torch.float16, + precision: torch.dtype = torch.float32, num_in_channels: Optional[int] = None, upcast_attention: Optional[bool] = None, device: str = None, @@ -1702,7 +1702,7 @@ def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size: def convert_ckpt_to_diffusers( checkpoint_path: Union[str, Path], dump_path: Union[str, Path], - no_safetensors: bool = False, + use_safetensors: bool=True, **kwargs, ): """ @@ -1714,7 +1714,7 @@ def convert_ckpt_to_diffusers( pipe.save_pretrained( dump_path, - safe_serialization=is_safetensors_available() and not no_safetensors, + safe_serialization=use_safetensors and is_safetensors_available(), ) def convert_controlnet_to_diffusers( diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index d2f20bdef7..f768417506 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -253,11 +253,13 @@ class PipelineCheckpointProbe(CheckpointProbeBase): return BaseModelType.StableDiffusion1 if key_name in state_dict and state_dict[key_name].shape[-1] == 1024: return BaseModelType.StableDiffusion2 - # TODO: This is just a guess based on N=1 key_name = 'model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight' if key_name in state_dict and state_dict[key_name].shape[-1] == 2048: return BaseModelType.StableDiffusionXL - raise InvalidModelException("Cannot determine base type") + elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280: + return BaseModelType.StableDiffusionXLRefiner + else: + raise InvalidModelException("Cannot determine base type") def get_scheduler_prediction_type(self)->SchedulerPredictionType: type = self.get_base_type() diff --git a/invokeai/backend/model_management/models/sdxl.py b/invokeai/backend/model_management/models/sdxl.py index f66aa5a87a..ef0dd4c720 100644 --- a/invokeai/backend/model_management/models/sdxl.py +++ b/invokeai/backend/model_management/models/sdxl.py @@ -1,5 +1,6 @@ import os import json +import invokeai.backend.util.logging as logger from enum import Enum from pydantic import Field from typing import Literal, Optional @@ -108,14 +109,20 @@ class StableDiffusionXLModel(DiffusersModel): config: ModelConfigBase, base_model: BaseModelType, ) -> str: + # The convert script adapted from the diffusers package uses + # strings for the base model type. To avoid making too many + # source code changes, we simply translate here + model_base_to_model_type = {BaseModelType.StableDiffusionXL: 'SDXL', + BaseModelType.StableDiffusionXLRefiner: 'SDXL-Refiner', + } if isinstance(config, cls.CheckpointConfig): from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache return _convert_ckpt_and_cache( version=base_model, model_config=config, output_path=output_path, - model_type='SDXL', - no_safetensors=True, # giving errors for some reason + model_type=model_base_to_model_type[base_model], + use_safetensors=False, # corrupts sdxl models for some reason ) else: return model_path diff --git a/invokeai/backend/model_management/models/stable_diffusion.py b/invokeai/backend/model_management/models/stable_diffusion.py index cf34686d6f..735e2e6bfb 100644 --- a/invokeai/backend/model_management/models/stable_diffusion.py +++ b/invokeai/backend/model_management/models/stable_diffusion.py @@ -16,9 +16,11 @@ from .base import ( InvalidModelException, ) from .sdxl import StableDiffusionXLModel +import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig from omegaconf import OmegaConf + class StableDiffusion1ModelFormat(str, Enum): Checkpoint = "checkpoint" Diffusers = "diffusers" @@ -265,6 +267,9 @@ def _convert_ckpt_and_cache( # to avoid circular import errors from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers + from ...util.devices import choose_torch_device, torch_dtype + + logger.info(f'Converting {weights} to diffusers format') with SilenceWarnings(): convert_ckpt_to_diffusers( weights, @@ -275,6 +280,7 @@ def _convert_ckpt_and_cache( extract_ema=True, scan_needed=True, from_safetensors = weights.suffix == ".safetensors", + precision = torch_dtype(choose_torch_device()), **kwargs, ) return output_path diff --git a/invokeai/backend/util/logging.py b/invokeai/backend/util/logging.py index 09ae600633..d06c036506 100644 --- a/invokeai/backend/util/logging.py +++ b/invokeai/backend/util/logging.py @@ -1,7 +1,7 @@ # Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team """ -invokeai.util.logging +invokeai.backend.util.logging Logging class for InvokeAI that produces console messages