mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix vae conversion (#3555)
Unsure at which moment it broke, but now I can't convert vae(and model as vae it's part) without this fix. Need further research - maybe it's breaking change in `transformers`?
This commit is contained in:
commit
9de54b2266
@ -159,17 +159,17 @@ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
|||||||
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
||||||
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
||||||
|
|
||||||
new_item = new_item.replace("q.weight", "query.weight")
|
new_item = new_item.replace("q.weight", "to_q.weight")
|
||||||
new_item = new_item.replace("q.bias", "query.bias")
|
new_item = new_item.replace("q.bias", "to_q.bias")
|
||||||
|
|
||||||
new_item = new_item.replace("k.weight", "key.weight")
|
new_item = new_item.replace("k.weight", "to_k.weight")
|
||||||
new_item = new_item.replace("k.bias", "key.bias")
|
new_item = new_item.replace("k.bias", "to_k.bias")
|
||||||
|
|
||||||
new_item = new_item.replace("v.weight", "value.weight")
|
new_item = new_item.replace("v.weight", "to_v.weight")
|
||||||
new_item = new_item.replace("v.bias", "value.bias")
|
new_item = new_item.replace("v.bias", "to_v.bias")
|
||||||
|
|
||||||
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
|
||||||
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
|
||||||
|
|
||||||
new_item = shave_segments(
|
new_item = shave_segments(
|
||||||
new_item, n_shave_prefix_segments=n_shave_prefix_segments
|
new_item, n_shave_prefix_segments=n_shave_prefix_segments
|
||||||
@ -184,7 +184,6 @@ def assign_to_checkpoint(
|
|||||||
paths,
|
paths,
|
||||||
checkpoint,
|
checkpoint,
|
||||||
old_checkpoint,
|
old_checkpoint,
|
||||||
attention_paths_to_split=None,
|
|
||||||
additional_replacements=None,
|
additional_replacements=None,
|
||||||
config=None,
|
config=None,
|
||||||
):
|
):
|
||||||
@ -199,35 +198,9 @@ def assign_to_checkpoint(
|
|||||||
paths, list
|
paths, list
|
||||||
), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
||||||
|
|
||||||
# Splits the attention layers into three variables.
|
|
||||||
if attention_paths_to_split is not None:
|
|
||||||
for path, path_map in attention_paths_to_split.items():
|
|
||||||
old_tensor = old_checkpoint[path]
|
|
||||||
channels = old_tensor.shape[0] // 3
|
|
||||||
|
|
||||||
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
|
||||||
|
|
||||||
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
|
||||||
|
|
||||||
old_tensor = old_tensor.reshape(
|
|
||||||
(num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]
|
|
||||||
)
|
|
||||||
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
|
||||||
|
|
||||||
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
|
||||||
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
|
||||||
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
|
||||||
|
|
||||||
for path in paths:
|
for path in paths:
|
||||||
new_path = path["new"]
|
new_path = path["new"]
|
||||||
|
|
||||||
# These have already been assigned
|
|
||||||
if (
|
|
||||||
attention_paths_to_split is not None
|
|
||||||
and new_path in attention_paths_to_split
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Global renaming happens here
|
# Global renaming happens here
|
||||||
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
||||||
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
||||||
@ -246,14 +219,14 @@ def assign_to_checkpoint(
|
|||||||
|
|
||||||
def conv_attn_to_linear(checkpoint):
|
def conv_attn_to_linear(checkpoint):
|
||||||
keys = list(checkpoint.keys())
|
keys = list(checkpoint.keys())
|
||||||
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
|
||||||
for key in keys:
|
for key in keys:
|
||||||
if ".".join(key.split(".")[-2:]) in attn_keys:
|
if ".".join(key.split(".")[-2:]) in attn_keys:
|
||||||
if checkpoint[key].ndim > 2:
|
if checkpoint[key].ndim > 2:
|
||||||
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||||
elif "proj_attn.weight" in key:
|
elif "to_out.0.weight" in key:
|
||||||
if checkpoint[key].ndim > 2:
|
if checkpoint[key].ndim > 2:
|
||||||
checkpoint[key] = checkpoint[key][:, :, 0]
|
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
||||||
|
|
||||||
|
|
||||||
def create_unet_diffusers_config(original_config, image_size: int):
|
def create_unet_diffusers_config(original_config, image_size: int):
|
||||||
|
@ -158,7 +158,6 @@ def _convert_vae_ckpt_and_cache(
|
|||||||
checkpoint = checkpoint,
|
checkpoint = checkpoint,
|
||||||
vae_config = config,
|
vae_config = config,
|
||||||
image_size = image_size,
|
image_size = image_size,
|
||||||
model_root = app_config.models_path,
|
|
||||||
)
|
)
|
||||||
vae_model.save_pretrained(
|
vae_model.save_pretrained(
|
||||||
output_path,
|
output_path,
|
||||||
|
Loading…
Reference in New Issue
Block a user