mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Run ruff
This commit is contained in:
parent
ee38fbe89c
commit
df12e12e09
@ -17,7 +17,8 @@ class MigrateCallback(Protocol):
|
|||||||
See :class:`Migration` for an example.
|
See :class:`Migration` for an example.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
|
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class MigrationError(RuntimeError):
|
class MigrationError(RuntimeError):
|
||||||
|
@ -858,9 +858,9 @@ def do_textual_inversion_training(
|
|||||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||||
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
|
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
|
||||||
orig_embeds_params[index_no_updates]
|
index_no_updates
|
||||||
)
|
] = orig_embeds_params[index_no_updates]
|
||||||
|
|
||||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||||
if accelerator.sync_gradients:
|
if accelerator.sync_gradients:
|
||||||
|
@ -173,9 +173,7 @@ def test_inplace_install(
|
|||||||
assert Path(job.config_out.path) == embedding_file
|
assert Path(job.config_out.path) == embedding_file
|
||||||
|
|
||||||
|
|
||||||
def test_delete_install(
|
def test_delete_install(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
|
||||||
mm2_installer: ModelInstallServiceBase, embedding_file: Path
|
|
||||||
) -> None:
|
|
||||||
store = mm2_installer.record_store
|
store = mm2_installer.record_store
|
||||||
key = mm2_installer.install_path(embedding_file)
|
key = mm2_installer.install_path(embedding_file)
|
||||||
model_record = store.get_model(key)
|
model_record = store.get_model(key)
|
||||||
|
Loading…
Reference in New Issue
Block a user