mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix ckpt conversion
This commit is contained in:
parent
9140e2c0f2
commit
5aaaaf64a1
@ -184,7 +184,6 @@ def assign_to_checkpoint(
|
||||
paths,
|
||||
checkpoint,
|
||||
old_checkpoint,
|
||||
attention_paths_to_split=None,
|
||||
additional_replacements=None,
|
||||
config=None,
|
||||
):
|
||||
@ -199,35 +198,9 @@ def assign_to_checkpoint(
|
||||
paths, list
|
||||
), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
||||
|
||||
# Splits the attention layers into three variables.
|
||||
if attention_paths_to_split is not None:
|
||||
for path, path_map in attention_paths_to_split.items():
|
||||
old_tensor = old_checkpoint[path]
|
||||
channels = old_tensor.shape[0] // 3
|
||||
|
||||
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
||||
|
||||
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
||||
|
||||
old_tensor = old_tensor.reshape(
|
||||
(num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]
|
||||
)
|
||||
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
||||
|
||||
checkpoint[path_map["to_q"]] = query.reshape(target_shape)
|
||||
checkpoint[path_map["to_k"]] = key.reshape(target_shape)
|
||||
checkpoint[path_map["to_v"]] = value.reshape(target_shape)
|
||||
|
||||
for path in paths:
|
||||
new_path = path["new"]
|
||||
|
||||
# These have already been assigned
|
||||
if (
|
||||
attention_paths_to_split is not None
|
||||
and new_path in attention_paths_to_split
|
||||
):
|
||||
continue
|
||||
|
||||
# Global renaming happens here
|
||||
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
||||
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
||||
@ -238,9 +211,8 @@ def assign_to_checkpoint(
|
||||
new_path = new_path.replace(replacement["old"], replacement["new"])
|
||||
|
||||
# proj_attn.weight has to be converted from conv 1D to linear
|
||||
if "to_out.0.weight" in new_path:
|
||||
if checkpoint[new_path].ndim > 2:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0: 0]
|
||||
if "proj_attn.weight" in new_path:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
||||
else:
|
||||
checkpoint[new_path] = old_checkpoint[path["old"]]
|
||||
|
||||
|
@ -158,7 +158,6 @@ def _convert_vae_ckpt_and_cache(
|
||||
checkpoint = checkpoint,
|
||||
vae_config = config,
|
||||
image_size = image_size,
|
||||
model_root = app_config.models_path,
|
||||
)
|
||||
vae_model.save_pretrained(
|
||||
output_path,
|
||||
|
Loading…
Reference in New Issue
Block a user