From 04229f4a21e1bb064d993f52c4d0c4a1c732f618 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Thu, 7 Mar 2024 12:17:25 -0500 Subject: [PATCH] Run ruff --- invokeai/app/services/model_install/model_install_base.py | 4 +--- .../shared/sqlite_migrator/sqlite_migrator_common.py | 3 ++- invokeai/backend/training/textual_inversion_training.py | 6 +++--- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index a1ad0e1a87..b7385495e5 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -131,9 +131,7 @@ class URLModelSource(StringLikeSource): return str(self.url) -ModelSource = Annotated[ - Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type") -] +ModelSource = Annotated[Union[LocalModelSource, HFModelSource, URLModelSource], Field(discriminator="type")] MODEL_SOURCE_TO_TYPE_MAP = { URLModelSource: ModelSourceType.Url, diff --git a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py index 9b2444dae4..47ed5da505 100644 --- a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py +++ b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py @@ -17,7 +17,8 @@ class MigrateCallback(Protocol): See :class:`Migration` for an example. """ - def __call__(self, cursor: sqlite3.Cursor) -> None: ... + def __call__(self, cursor: sqlite3.Cursor) -> None: + ... class MigrationError(RuntimeError): diff --git a/invokeai/backend/training/textual_inversion_training.py b/invokeai/backend/training/textual_inversion_training.py index 7ddcf14367..9a38c006a5 100644 --- a/invokeai/backend/training/textual_inversion_training.py +++ b/invokeai/backend/training/textual_inversion_training.py @@ -858,9 +858,9 @@ def do_textual_inversion_training( # Let's make sure we don't update any embedding weights besides the newly added token index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id with torch.no_grad(): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = ( - orig_embeds_params[index_no_updates] - ) + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ + index_no_updates + ] = orig_embeds_params[index_no_updates] # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: