This commit is contained in:
Brandon Rising 2024-03-08 11:57:21 -05:00 committed by Brandon
parent ee38fbe89c
commit df12e12e09
3 changed files with 6 additions and 7 deletions

View File

@ -17,7 +17,8 @@ class MigrateCallback(Protocol):
See :class:`Migration` for an example.
"""
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
def __call__(self, cursor: sqlite3.Cursor) -> None:
...
class MigrationError(RuntimeError):

View File

@ -858,9 +858,9 @@ def do_textual_inversion_training(
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id
with torch.no_grad():
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = (
orig_embeds_params[index_no_updates]
)
accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[
index_no_updates
] = orig_embeds_params[index_no_updates]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:

View File

@ -173,9 +173,7 @@ def test_inplace_install(
assert Path(job.config_out.path) == embedding_file
def test_delete_install(
mm2_installer: ModelInstallServiceBase, embedding_file: Path
) -> None:
def test_delete_install(mm2_installer: ModelInstallServiceBase, embedding_file: Path) -> None:
store = mm2_installer.record_store
key = mm2_installer.install_path(embedding_file)
model_record = store.get_model(key)