From b6065d63281a62a5887197cb6f5765f547042401 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Thu, 7 Mar 2024 12:19:51 -0500 Subject: [PATCH] Run ruff with newest version of ruff --- .../shared/sqlite_migrator/sqlite_migrator_common.py | 3 +-- invokeai/backend/training/textual_inversion_training.py | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py index 47ed5da505..9b2444dae4 100644 --- a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py +++ b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py @@ -17,8 +17,7 @@ class MigrateCallback(Protocol): See :class:`Migration` for an example. """ - def __call__(self, cursor: sqlite3.Cursor) -> None: - ... + def __call__(self, cursor: sqlite3.Cursor) -> None: ... class MigrationError(RuntimeError): diff --git a/invokeai/backend/training/textual_inversion_training.py b/invokeai/backend/training/textual_inversion_training.py index 9a38c006a5..7ddcf14367 100644 --- a/invokeai/backend/training/textual_inversion_training.py +++ b/invokeai/backend/training/textual_inversion_training.py @@ -858,9 +858,9 @@ def do_textual_inversion_training( # Let's make sure we don't update any embedding weights besides the newly added token index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id with torch.no_grad(): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ - index_no_updates - ] = orig_embeds_params[index_no_updates] + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = ( + orig_embeds_params[index_no_updates] + ) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: