replace repeated triggers with <filename>

This commit is contained in:
Lincoln Stein 2023-02-20 22:33:13 -05:00
parent 694d5aa2e8
commit 91f7abb398

View File

@ -102,6 +102,7 @@ class TextualInversionManager:
embedding_info["name"],
embedding_info["embedding"],
defer_injecting_tokens=defer_injecting_tokens,
source_file=ckpt_path,
)
except ValueError as e:
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
@ -112,7 +113,7 @@ class TextualInversionManager:
)
def _add_textual_inversion(
self, trigger_str, embedding, defer_injecting_tokens=False
self, trigger_str, embedding, defer_injecting_tokens=False, source_file=Path
) -> TextualInversion:
"""
Add a textual inversion to be recognised.
@ -120,11 +121,13 @@ class TextualInversionManager:
:param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
:return: The token id for the added embedding, either existing or newly-added.
"""
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
triggers = [ti.trigger_string for ti in self.textual_inversions]
if trigger_str in triggers:
new_trigger_str = f'<{source_file.stem}>'
print(
f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
f">> {source_file.parents[0].name}/{source_file.name}: Trigger token '{trigger_str}' already in use. Trigger with {new_trigger_str}"
)
return
trigger_str = new_trigger_str
if not self.full_precision:
embedding = embedding.half()
if len(embedding.shape) == 1: