mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix ckpt conversion
This commit is contained in:
parent
9140e2c0f2
commit
5aaaaf64a1
@ -184,7 +184,6 @@ def assign_to_checkpoint(
|
|||||||
paths,
|
paths,
|
||||||
checkpoint,
|
checkpoint,
|
||||||
old_checkpoint,
|
old_checkpoint,
|
||||||
attention_paths_to_split=None,
|
|
||||||
additional_replacements=None,
|
additional_replacements=None,
|
||||||
config=None,
|
config=None,
|
||||||
):
|
):
|
||||||
@ -199,35 +198,9 @@ def assign_to_checkpoint(
|
|||||||
paths, list
|
paths, list
|
||||||
), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
||||||
|
|
||||||
# Splits the attention layers into three variables.
|
|
||||||
if attention_paths_to_split is not None:
|
|
||||||
for path, path_map in attention_paths_to_split.items():
|
|
||||||
old_tensor = old_checkpoint[path]
|
|
||||||
channels = old_tensor.shape[0] // 3
|
|
||||||
|
|
||||||
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
|
||||||
|
|
||||||
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
|
||||||
|
|
||||||
old_tensor = old_tensor.reshape(
|
|
||||||
(num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]
|
|
||||||
)
|
|
||||||
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
|
||||||
|
|
||||||
checkpoint[path_map["to_q"]] = query.reshape(target_shape)
|
|
||||||
checkpoint[path_map["to_k"]] = key.reshape(target_shape)
|
|
||||||
checkpoint[path_map["to_v"]] = value.reshape(target_shape)
|
|
||||||
|
|
||||||
for path in paths:
|
for path in paths:
|
||||||
new_path = path["new"]
|
new_path = path["new"]
|
||||||
|
|
||||||
# These have already been assigned
|
|
||||||
if (
|
|
||||||
attention_paths_to_split is not None
|
|
||||||
and new_path in attention_paths_to_split
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Global renaming happens here
|
# Global renaming happens here
|
||||||
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
||||||
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
||||||
@ -238,9 +211,8 @@ def assign_to_checkpoint(
|
|||||||
new_path = new_path.replace(replacement["old"], replacement["new"])
|
new_path = new_path.replace(replacement["old"], replacement["new"])
|
||||||
|
|
||||||
# proj_attn.weight has to be converted from conv 1D to linear
|
# proj_attn.weight has to be converted from conv 1D to linear
|
||||||
if "to_out.0.weight" in new_path:
|
if "proj_attn.weight" in new_path:
|
||||||
if checkpoint[new_path].ndim > 2:
|
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
||||||
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0: 0]
|
|
||||||
else:
|
else:
|
||||||
checkpoint[new_path] = old_checkpoint[path["old"]]
|
checkpoint[new_path] = old_checkpoint[path["old"]]
|
||||||
|
|
||||||
|
@ -158,7 +158,6 @@ def _convert_vae_ckpt_and_cache(
|
|||||||
checkpoint = checkpoint,
|
checkpoint = checkpoint,
|
||||||
vae_config = config,
|
vae_config = config,
|
||||||
image_size = image_size,
|
image_size = image_size,
|
||||||
model_root = app_config.models_path,
|
|
||||||
)
|
)
|
||||||
vae_model.save_pretrained(
|
vae_model.save_pretrained(
|
||||||
output_path,
|
output_path,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user