From 82091b9a669773eeb9a3e38c3e710946f73e23e9 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Sun, 18 Jun 2023 23:46:07 +0300 Subject: [PATCH 1/2] Fix vae conversion --- .../convert_ckpt_to_diffusers.py | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index db099acbb8..0083cfa088 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -159,17 +159,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.bias", "group_norm.bias") - new_item = new_item.replace("q.weight", "query.weight") - new_item = new_item.replace("q.bias", "query.bias") + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") - new_item = new_item.replace("k.weight", "key.weight") - new_item = new_item.replace("k.bias", "key.bias") + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") - new_item = new_item.replace("v.weight", "value.weight") - new_item = new_item.replace("v.bias", "value.bias") + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") - new_item = new_item.replace("proj_out.weight", "proj_attn.weight") - new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") new_item = shave_segments( new_item, n_shave_prefix_segments=n_shave_prefix_segments @@ -214,9 +214,9 @@ def assign_to_checkpoint( ) query, key, value = old_tensor.split(channels // num_heads, dim=1) - checkpoint[path_map["query"]] = query.reshape(target_shape) - checkpoint[path_map["key"]] = key.reshape(target_shape) - checkpoint[path_map["value"]] = value.reshape(target_shape) + checkpoint[path_map["to_q"]] = query.reshape(target_shape) + checkpoint[path_map["to_k"]] = key.reshape(target_shape) + checkpoint[path_map["to_v"]] = value.reshape(target_shape) for path in paths: new_path = path["new"] @@ -238,22 +238,23 @@ def assign_to_checkpoint( new_path = new_path.replace(replacement["old"], replacement["new"]) # proj_attn.weight has to be converted from conv 1D to linear - if "proj_attn.weight" in new_path: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] + if "to_out.0.weight" in new_path: + if checkpoint[new_path].ndim > 2: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0: 0] else: checkpoint[new_path] = old_checkpoint[path["old"]] def conv_attn_to_linear(checkpoint): keys = list(checkpoint.keys()) - attn_keys = ["query.weight", "key.weight", "value.weight"] + attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"] for key in keys: if ".".join(key.split(".")[-2:]) in attn_keys: if checkpoint[key].ndim > 2: checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif "proj_attn.weight" in key: + elif "to_out.0.weight" in key: if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] + checkpoint[key] = checkpoint[key][:, :, 0, 0] def create_unet_diffusers_config(original_config, image_size: int): From 5aaaaf64a19b08b5f02ef0adaa0bac1d7aa0e267 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Fri, 23 Jun 2023 17:29:54 +0300 Subject: [PATCH 2/2] Fix ckpt conversion --- .../convert_ckpt_to_diffusers.py | 32 ++----------------- .../backend/model_management/models/vae.py | 1 - 2 files changed, 2 insertions(+), 31 deletions(-) diff --git a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py index 0083cfa088..5d097f5a4e 100644 --- a/invokeai/backend/model_management/convert_ckpt_to_diffusers.py +++ b/invokeai/backend/model_management/convert_ckpt_to_diffusers.py @@ -184,7 +184,6 @@ def assign_to_checkpoint( paths, checkpoint, old_checkpoint, - attention_paths_to_split=None, additional_replacements=None, config=None, ): @@ -199,35 +198,9 @@ def assign_to_checkpoint( paths, list ), "Paths should be a list of dicts containing 'old' and 'new' keys." - # Splits the attention layers into three variables. - if attention_paths_to_split is not None: - for path, path_map in attention_paths_to_split.items(): - old_tensor = old_checkpoint[path] - channels = old_tensor.shape[0] // 3 - - target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) - - num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - - old_tensor = old_tensor.reshape( - (num_heads, 3 * channels // num_heads) + old_tensor.shape[1:] - ) - query, key, value = old_tensor.split(channels // num_heads, dim=1) - - checkpoint[path_map["to_q"]] = query.reshape(target_shape) - checkpoint[path_map["to_k"]] = key.reshape(target_shape) - checkpoint[path_map["to_v"]] = value.reshape(target_shape) - for path in paths: new_path = path["new"] - # These have already been assigned - if ( - attention_paths_to_split is not None - and new_path in attention_paths_to_split - ): - continue - # Global renaming happens here new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") @@ -238,9 +211,8 @@ def assign_to_checkpoint( new_path = new_path.replace(replacement["old"], replacement["new"]) # proj_attn.weight has to be converted from conv 1D to linear - if "to_out.0.weight" in new_path: - if checkpoint[new_path].ndim > 2: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0: 0] + if "proj_attn.weight" in new_path: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] else: checkpoint[new_path] = old_checkpoint[path["old"]] diff --git a/invokeai/backend/model_management/models/vae.py b/invokeai/backend/model_management/models/vae.py index 76133b074d..b582f16b30 100644 --- a/invokeai/backend/model_management/models/vae.py +++ b/invokeai/backend/model_management/models/vae.py @@ -158,7 +158,6 @@ def _convert_vae_ckpt_and_cache( checkpoint = checkpoint, vae_config = config, image_size = image_size, - model_root = app_config.models_path, ) vae_model.save_pretrained( output_path,