Fix vae conversion (#3555)

Unsure at which moment it broke, but now I can't convert vae(and model
as vae it's part) without this fix.
Need further research - maybe it's breaking change in `transformers`?
This commit is contained in:
Lincoln Stein 2023-06-23 15:55:26 +01:00 committed by GitHub
commit 9de54b2266
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 11 additions and 39 deletions

View File

@ -159,17 +159,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias") new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "query.weight") new_item = new_item.replace("q.weight", "to_q.weight")
new_item = new_item.replace("q.bias", "query.bias") new_item = new_item.replace("q.bias", "to_q.bias")
new_item = new_item.replace("k.weight", "key.weight") new_item = new_item.replace("k.weight", "to_k.weight")
new_item = new_item.replace("k.bias", "key.bias") new_item = new_item.replace("k.bias", "to_k.bias")
new_item = new_item.replace("v.weight", "value.weight") new_item = new_item.replace("v.weight", "to_v.weight")
new_item = new_item.replace("v.bias", "value.bias") new_item = new_item.replace("v.bias", "to_v.bias")
new_item = new_item.replace("proj_out.weight", "proj_attn.weight") new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias") new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
new_item = shave_segments( new_item = shave_segments(
new_item, n_shave_prefix_segments=n_shave_prefix_segments new_item, n_shave_prefix_segments=n_shave_prefix_segments
@ -184,7 +184,6 @@ def assign_to_checkpoint(
paths, paths,
checkpoint, checkpoint,
old_checkpoint, old_checkpoint,
attention_paths_to_split=None,
additional_replacements=None, additional_replacements=None,
config=None, config=None,
): ):
@ -199,35 +198,9 @@ def assign_to_checkpoint(
paths, list paths, list
), "Paths should be a list of dicts containing 'old' and 'new' keys." ), "Paths should be a list of dicts containing 'old' and 'new' keys."
# Splits the attention layers into three variables.
if attention_paths_to_split is not None:
for path, path_map in attention_paths_to_split.items():
old_tensor = old_checkpoint[path]
channels = old_tensor.shape[0] // 3
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
old_tensor = old_tensor.reshape(
(num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]
)
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map["query"]] = query.reshape(target_shape)
checkpoint[path_map["key"]] = key.reshape(target_shape)
checkpoint[path_map["value"]] = value.reshape(target_shape)
for path in paths: for path in paths:
new_path = path["new"] new_path = path["new"]
# These have already been assigned
if (
attention_paths_to_split is not None
and new_path in attention_paths_to_split
):
continue
# Global renaming happens here # Global renaming happens here
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
@ -246,14 +219,14 @@ def assign_to_checkpoint(
def conv_attn_to_linear(checkpoint): def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"] attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
for key in keys: for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys: if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2: if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0] checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key: elif "to_out.0.weight" in key:
if checkpoint[key].ndim > 2: if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0] checkpoint[key] = checkpoint[key][:, :, 0, 0]
def create_unet_diffusers_config(original_config, image_size: int): def create_unet_diffusers_config(original_config, image_size: int):

View File

@ -158,7 +158,6 @@ def _convert_vae_ckpt_and_cache(
checkpoint = checkpoint, checkpoint = checkpoint,
vae_config = config, vae_config = config,
image_size = image_size, image_size = image_size,
model_root = app_config.models_path,
) )
vae_model.save_pretrained( vae_model.save_pretrained(
output_path, output_path,