Fix vae conversion

This commit is contained in:
Sergey Borisov 2023-06-18 23:46:07 +03:00
parent f312e1448f
commit 82091b9a66

View File

@ -159,17 +159,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
new_item = new_item.replace("norm.weight", "group_norm.weight") new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias") new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "query.weight") new_item = new_item.replace("q.weight", "to_q.weight")
new_item = new_item.replace("q.bias", "query.bias") new_item = new_item.replace("q.bias", "to_q.bias")
new_item = new_item.replace("k.weight", "key.weight") new_item = new_item.replace("k.weight", "to_k.weight")
new_item = new_item.replace("k.bias", "key.bias") new_item = new_item.replace("k.bias", "to_k.bias")
new_item = new_item.replace("v.weight", "value.weight") new_item = new_item.replace("v.weight", "to_v.weight")
new_item = new_item.replace("v.bias", "value.bias") new_item = new_item.replace("v.bias", "to_v.bias")
new_item = new_item.replace("proj_out.weight", "proj_attn.weight") new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias") new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
new_item = shave_segments( new_item = shave_segments(
new_item, n_shave_prefix_segments=n_shave_prefix_segments new_item, n_shave_prefix_segments=n_shave_prefix_segments
@ -214,9 +214,9 @@ def assign_to_checkpoint(
) )
query, key, value = old_tensor.split(channels // num_heads, dim=1) query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map["query"]] = query.reshape(target_shape) checkpoint[path_map["to_q"]] = query.reshape(target_shape)
checkpoint[path_map["key"]] = key.reshape(target_shape) checkpoint[path_map["to_k"]] = key.reshape(target_shape)
checkpoint[path_map["value"]] = value.reshape(target_shape) checkpoint[path_map["to_v"]] = value.reshape(target_shape)
for path in paths: for path in paths:
new_path = path["new"] new_path = path["new"]
@ -238,22 +238,23 @@ def assign_to_checkpoint(
new_path = new_path.replace(replacement["old"], replacement["new"]) new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear # proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path: if "to_out.0.weight" in new_path:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] if checkpoint[new_path].ndim > 2:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0: 0]
else: else:
checkpoint[new_path] = old_checkpoint[path["old"]] checkpoint[new_path] = old_checkpoint[path["old"]]
def conv_attn_to_linear(checkpoint): def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys()) keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"] attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
for key in keys: for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys: if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2: if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0] checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key: elif "to_out.0.weight" in key:
if checkpoint[key].ndim > 2: if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0] checkpoint[key] = checkpoint[key][:, :, 0, 0]
def create_unet_diffusers_config(original_config, image_size: int): def create_unet_diffusers_config(original_config, image_size: int):