This commit is contained in:
Brandon Rising
2024-03-08 12:24:35 -05:00
committed by Brandon
parent df12e12e09
commit 8ba4b2a150
2 changed files with 4 additions and 5 deletions

View File

@ -858,9 +858,9 @@ def do_textual_inversion_training(
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
index_no_updates
] = orig_embeds_params[index_no_updates]
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
orig_embeds_params[index_no_updates]
)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: