mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
SDXL & SDXL-refiner models convert correctly
This commit is contained in:
parent
5e59edfaf1
commit
f2a6f0cf21
@ -1121,7 +1121,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
prediction_type: str = None,
|
prediction_type: str = None,
|
||||||
model_type: str = None,
|
model_type: str = None,
|
||||||
extract_ema: bool = False,
|
extract_ema: bool = False,
|
||||||
precision: torch.dtype = torch.float16,
|
precision: torch.dtype = torch.float32,
|
||||||
scheduler_type: str = "pndm",
|
scheduler_type: str = "pndm",
|
||||||
num_in_channels: Optional[int] = None,
|
num_in_channels: Optional[int] = None,
|
||||||
upcast_attention: Optional[bool] = None,
|
upcast_attention: Optional[bool] = None,
|
||||||
@ -1250,7 +1250,7 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
while "state_dict" in checkpoint:
|
while "state_dict" in checkpoint:
|
||||||
checkpoint = checkpoint["state_dict"]
|
checkpoint = checkpoint["state_dict"]
|
||||||
|
|
||||||
print(f'DEBUG: model_type = {model_type}; original_config_file = {original_config_file}')
|
logger.debug(f'model_type = {model_type}; original_config_file = {original_config_file}')
|
||||||
|
|
||||||
if original_config_file is None:
|
if original_config_file is None:
|
||||||
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
@ -1624,7 +1624,7 @@ def download_controlnet_from_original_ckpt(
|
|||||||
original_config_file: str,
|
original_config_file: str,
|
||||||
image_size: int = 512,
|
image_size: int = 512,
|
||||||
extract_ema: bool = False,
|
extract_ema: bool = False,
|
||||||
precision: torch.dtype = torch.float16,
|
precision: torch.dtype = torch.float32,
|
||||||
num_in_channels: Optional[int] = None,
|
num_in_channels: Optional[int] = None,
|
||||||
upcast_attention: Optional[bool] = None,
|
upcast_attention: Optional[bool] = None,
|
||||||
device: str = None,
|
device: str = None,
|
||||||
@ -1702,7 +1702,7 @@ def convert_ldm_vae_to_diffusers(checkpoint, vae_config: DictConfig, image_size:
|
|||||||
def convert_ckpt_to_diffusers(
|
def convert_ckpt_to_diffusers(
|
||||||
checkpoint_path: Union[str, Path],
|
checkpoint_path: Union[str, Path],
|
||||||
dump_path: Union[str, Path],
|
dump_path: Union[str, Path],
|
||||||
no_safetensors: bool = False,
|
use_safetensors: bool=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@ -1714,7 +1714,7 @@ def convert_ckpt_to_diffusers(
|
|||||||
|
|
||||||
pipe.save_pretrained(
|
pipe.save_pretrained(
|
||||||
dump_path,
|
dump_path,
|
||||||
safe_serialization=is_safetensors_available() and not no_safetensors,
|
safe_serialization=use_safetensors and is_safetensors_available(),
|
||||||
)
|
)
|
||||||
|
|
||||||
def convert_controlnet_to_diffusers(
|
def convert_controlnet_to_diffusers(
|
||||||
|
@ -253,10 +253,12 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
|||||||
return BaseModelType.StableDiffusion1
|
return BaseModelType.StableDiffusion1
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 1024:
|
||||||
return BaseModelType.StableDiffusion2
|
return BaseModelType.StableDiffusion2
|
||||||
# TODO: This is just a guess based on N=1
|
|
||||||
key_name = 'model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'
|
key_name = 'model.diffusion_model.input_blocks.4.1.transformer_blocks.0.attn2.to_k.weight'
|
||||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
||||||
return BaseModelType.StableDiffusionXL
|
return BaseModelType.StableDiffusionXL
|
||||||
|
elif key_name in state_dict and state_dict[key_name].shape[-1] == 1280:
|
||||||
|
return BaseModelType.StableDiffusionXLRefiner
|
||||||
|
else:
|
||||||
raise InvalidModelException("Cannot determine base type")
|
raise InvalidModelException("Cannot determine base type")
|
||||||
|
|
||||||
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from typing import Literal, Optional
|
from typing import Literal, Optional
|
||||||
@ -108,14 +109,20 @@ class StableDiffusionXLModel(DiffusersModel):
|
|||||||
config: ModelConfigBase,
|
config: ModelConfigBase,
|
||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
# The convert script adapted from the diffusers package uses
|
||||||
|
# strings for the base model type. To avoid making too many
|
||||||
|
# source code changes, we simply translate here
|
||||||
|
model_base_to_model_type = {BaseModelType.StableDiffusionXL: 'SDXL',
|
||||||
|
BaseModelType.StableDiffusionXLRefiner: 'SDXL-Refiner',
|
||||||
|
}
|
||||||
if isinstance(config, cls.CheckpointConfig):
|
if isinstance(config, cls.CheckpointConfig):
|
||||||
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
|
from invokeai.backend.model_management.models.stable_diffusion import _convert_ckpt_and_cache
|
||||||
return _convert_ckpt_and_cache(
|
return _convert_ckpt_and_cache(
|
||||||
version=base_model,
|
version=base_model,
|
||||||
model_config=config,
|
model_config=config,
|
||||||
output_path=output_path,
|
output_path=output_path,
|
||||||
model_type='SDXL',
|
model_type=model_base_to_model_type[base_model],
|
||||||
no_safetensors=True, # giving errors for some reason
|
use_safetensors=False, # corrupts sdxl models for some reason
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return model_path
|
return model_path
|
||||||
|
@ -16,9 +16,11 @@ from .base import (
|
|||||||
InvalidModelException,
|
InvalidModelException,
|
||||||
)
|
)
|
||||||
from .sdxl import StableDiffusionXLModel
|
from .sdxl import StableDiffusionXLModel
|
||||||
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion1ModelFormat(str, Enum):
|
class StableDiffusion1ModelFormat(str, Enum):
|
||||||
Checkpoint = "checkpoint"
|
Checkpoint = "checkpoint"
|
||||||
Diffusers = "diffusers"
|
Diffusers = "diffusers"
|
||||||
@ -265,6 +267,9 @@ def _convert_ckpt_and_cache(
|
|||||||
|
|
||||||
# to avoid circular import errors
|
# to avoid circular import errors
|
||||||
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
from ..convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
|
||||||
|
from ...util.devices import choose_torch_device, torch_dtype
|
||||||
|
|
||||||
|
logger.info(f'Converting {weights} to diffusers format')
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
convert_ckpt_to_diffusers(
|
convert_ckpt_to_diffusers(
|
||||||
weights,
|
weights,
|
||||||
@ -275,6 +280,7 @@ def _convert_ckpt_and_cache(
|
|||||||
extract_ema=True,
|
extract_ema=True,
|
||||||
scan_needed=True,
|
scan_needed=True,
|
||||||
from_safetensors = weights.suffix == ".safetensors",
|
from_safetensors = weights.suffix == ".safetensors",
|
||||||
|
precision = torch_dtype(choose_torch_device()),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
return output_path
|
return output_path
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
|
# Copyright (c) 2023 Lincoln D. Stein and The InvokeAI Development Team
|
||||||
|
|
||||||
"""
|
"""
|
||||||
invokeai.util.logging
|
invokeai.backend.util.logging
|
||||||
|
|
||||||
Logging class for InvokeAI that produces console messages
|
Logging class for InvokeAI that produces console messages
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user