Fix ckpt conversion

This commit is contained in:
Sergey Borisov 2023-06-23 17:29:54 +03:00
parent 9140e2c0f2
commit 5aaaaf64a1
2 changed files with 2 additions and 31 deletions

View File

@ -184,7 +184,6 @@ def assign_to_checkpoint(
paths,
checkpoint,
old_checkpoint,
attention_paths_to_split=None,
additional_replacements=None,
config=None,
):
@ -199,35 +198,9 @@ def assign_to_checkpoint(
paths, list
), "Paths should be a list of dicts containing 'old' and 'new' keys."
# Splits the attention layers into three variables.
if attention_paths_to_split is not None:
for path, path_map in attention_paths_to_split.items():
old_tensor = old_checkpoint[path]
channels = old_tensor.shape[0] // 3
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
old_tensor = old_tensor.reshape(
(num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]
)
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map["to_q"]] = query.reshape(target_shape)
checkpoint[path_map["to_k"]] = key.reshape(target_shape)
checkpoint[path_map["to_v"]] = value.reshape(target_shape)
for path in paths:
new_path = path["new"]
# These have already been assigned
if (
attention_paths_to_split is not None
and new_path in attention_paths_to_split
):
continue
# Global renaming happens here
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
@ -238,9 +211,8 @@ def assign_to_checkpoint(
new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear
if "to_out.0.weight" in new_path:
if checkpoint[new_path].ndim > 2:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0: 0]
if "proj_attn.weight" in new_path:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
else:
checkpoint[new_path] = old_checkpoint[path["old"]]

View File

@ -158,7 +158,6 @@ def _convert_vae_ckpt_and_cache(
checkpoint = checkpoint,
vae_config = config,
image_size = image_size,
model_root = app_config.models_path,
)
vae_model.save_pretrained(
output_path,