mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix vae conversion
This commit is contained in:
parent
f312e1448f
commit
82091b9a66
@ -159,17 +159,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
||||
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
||||
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
||||
|
||||
new_item = new_item.replace("q.weight", "query.weight")
|
||||
new_item = new_item.replace("q.bias", "query.bias")
|
||||
new_item = new_item.replace("q.weight", "to_q.weight")
|
||||
new_item = new_item.replace("q.bias", "to_q.bias")
|
||||
|
||||
new_item = new_item.replace("k.weight", "key.weight")
|
||||
new_item = new_item.replace("k.bias", "key.bias")
|
||||
new_item = new_item.replace("k.weight", "to_k.weight")
|
||||
new_item = new_item.replace("k.bias", "to_k.bias")
|
||||
|
||||
new_item = new_item.replace("v.weight", "value.weight")
|
||||
new_item = new_item.replace("v.bias", "value.bias")
|
||||
new_item = new_item.replace("v.weight", "to_v.weight")
|
||||
new_item = new_item.replace("v.bias", "to_v.bias")
|
||||
|
||||
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
||||
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
||||
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
||||
|
||||
new_item = shave_segments(
|
||||
new_item, n_shave_prefix_segments=n_shave_prefix_segments
|
||||
@ -214,9 +214,9 @@ def assign_to_checkpoint(
|
||||
)
|
||||
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
||||
|
||||
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
||||
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
||||
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
||||
checkpoint[path_map["to_q"]] = query.reshape(target_shape)
|
||||
checkpoint[path_map["to_k"]] = key.reshape(target_shape)
|
||||
checkpoint[path_map["to_v"]] = value.reshape(target_shape)
|
||||
|
||||
for path in paths:
|
||||
new_path = path["new"]
|
||||
@ -238,22 +238,23 @@ def assign_to_checkpoint(
|
||||
new_path = new_path.replace(replacement["old"], replacement["new"])
|
||||
|
||||
# proj_attn.weight has to be converted from conv 1D to linear
|
||||
if "proj_attn.weight" in new_path:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
||||
if "to_out.0.weight" in new_path:
|
||||
if checkpoint[new_path].ndim > 2:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0: 0]
|
||||
else:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]]
|
||||
|
||||
|
||||
def conv_attn_to_linear(checkpoint):
|
||||
keys = list(checkpoint.keys())
|
||||
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
||||
attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
|
||||
for key in keys:
|
||||
if ".".join(key.split(".")[-2:]) in attn_keys:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||
elif "proj_attn.weight" in key:
|
||||
elif "to_out.0.weight" in key:
|
||||
if checkpoint[key].ndim > 2:
|
||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||
|
||||
|
||||
def create_unet_diffusers_config(original_config, image_size: int):
|
||||
|
Loading…
Reference in New Issue
Block a user