Fallback to using filename to trigger embeddings (#2752)

Lots of earlier embeds use a common trigger token such as * or the
hebrew letter shan. Previously, the textual inversion manager would
refuse to load the second and subsequent embeddings that used a
previously-claimed trigger. Now, when this case is encountered, the
trigger token is replaced by <filename> and the user is informed of the
fact.
This commit is contained in:
Lincoln Stein 2023-02-21 21:58:11 -05:00 committed by GitHub
commit ab018ccdfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 47 additions and 25 deletions

View File

@ -975,7 +975,7 @@ class Generate:
ti_path, defer_injecting_tokens=True ti_path, defer_injecting_tokens=True
) )
print( print(
f'>> Textual inversion triggers: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}' f'>> Textual inversion triggers: {", ".join(sorted(self.model.textual_inversion_manager.get_all_trigger_strings()))}'
) )
self.model_name = model_name self.model_name = model_name

View File

@ -34,6 +34,7 @@ class TextualInversionManager:
self.text_encoder = text_encoder self.text_encoder = text_encoder
self.full_precision = full_precision self.full_precision = full_precision
self.hf_concepts_library = HuggingFaceConceptsLibrary() self.hf_concepts_library = HuggingFaceConceptsLibrary()
self.trigger_to_sourcefile = dict()
default_textual_inversions: list[TextualInversion] = [] default_textual_inversions: list[TextualInversion] = []
self.textual_inversions = default_textual_inversions self.textual_inversions = default_textual_inversions
@ -59,7 +60,9 @@ class TextualInversionManager:
def get_all_trigger_strings(self) -> list[str]: def get_all_trigger_strings(self) -> list[str]:
return [ti.trigger_string for ti in self.textual_inversions] return [ti.trigger_string for ti in self.textual_inversions]
def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False): def load_textual_inversion(
self, ckpt_path: Union[str, Path], defer_injecting_tokens: bool = False
):
ckpt_path = Path(ckpt_path) ckpt_path = Path(ckpt_path)
if not ckpt_path.is_file(): if not ckpt_path.is_file():
@ -89,27 +92,45 @@ class TextualInversionManager:
return return
elif ( elif (
self.text_encoder.get_input_embeddings().weight.data[0].shape[0] self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
!= embedding_info['token_dim'] != embedding_info["token_dim"]
): ):
print( print(
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}." f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
) )
return return
if embedding_info: # Resolve the situation in which an earlier embedding has claimed the same
# trigger string. We replace the trigger with '<source_file>', as we used to.
trigger_str = embedding_info["name"]
sourcefile = (
f"{ckpt_path.parent.name}/{ckpt_path.name}"
if ckpt_path.name == "learned_embeds.bin"
else ckpt_path.name
)
if trigger_str in self.trigger_to_sourcefile:
replacement_trigger_str = (
f"<{ckpt_path.parent.name}>"
if ckpt_path.name == "learned_embeds.bin"
else f"<{ckpt_path.stem}>"
)
print(
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
)
trigger_str = replacement_trigger_str
try: try:
self._add_textual_inversion( self._add_textual_inversion(
embedding_info["name"], trigger_str,
embedding_info["embedding"], embedding_info["embedding"],
defer_injecting_tokens=defer_injecting_tokens, defer_injecting_tokens=defer_injecting_tokens,
) )
# remember which source file claims this trigger
self.trigger_to_sourcefile[trigger_str] = sourcefile
except ValueError as e: except ValueError as e:
print(f' | Ignoring incompatible embedding {embedding_info["name"]}') print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
print(f" | The error was {str(e)}") print(f" | The error was {str(e)}")
else:
print(
f">> Failed to load embedding located at {str(ckpt_path)}. Unsupported file."
)
def _add_textual_inversion( def _add_textual_inversion(
self, trigger_str, embedding, defer_injecting_tokens=False self, trigger_str, embedding, defer_injecting_tokens=False
@ -122,7 +143,7 @@ class TextualInversionManager:
""" """
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]: if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
print( print(
f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'" f"** TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
) )
return return
if not self.full_precision: if not self.full_precision:
@ -131,7 +152,7 @@ class TextualInversionManager:
embedding = embedding.unsqueeze(0) embedding = embedding.unsqueeze(0)
elif len(embedding.shape) > 2: elif len(embedding.shape) > 2:
raise ValueError( raise ValueError(
f"TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2." f"** TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2."
) )
try: try:
@ -147,7 +168,7 @@ class TextualInversionManager:
else: else:
traceback.print_exc() traceback.print_exc()
print( print(
f">> TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}." f"** TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
) )
raise raise
@ -294,7 +315,7 @@ class TextualInversionManager:
elif file_type == "bin": elif file_type == "bin":
return self._parse_embedding_bin(embedding_file) return self._parse_embedding_bin(embedding_file)
else: else:
print(f">> Not a recognized embedding file: {embedding_file}") print(f"** Notice: unrecognized embedding file format: {embedding_file}")
return None return None
def _parse_embedding_pt(self, embedding_file): def _parse_embedding_pt(self, embedding_file):
@ -355,8 +376,9 @@ class TextualInversionManager:
embedding_info = None embedding_info = None
else: else:
for token in list(embedding_ckpt.keys()): for token in list(embedding_ckpt.keys()):
embedding_info["name"] = token or os.path.basename( embedding_info["name"] = (
os.path.splitext(embedding_file)[0] token
or f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
) )
embedding_info["embedding"] = embedding_ckpt[token] embedding_info["embedding"] = embedding_ckpt[token]
embedding_info[ embedding_info[
@ -380,7 +402,7 @@ class TextualInversionManager:
embedding_info["name"] = ( embedding_info["name"] = (
token token
if token != "*" if token != "*"
else os.path.basename(os.path.splitext(embedding_file)[0]) else f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
) )
embedding_info["embedding"] = embedding_ckpt[ embedding_info["embedding"] = embedding_ckpt[
"string_to_param" "string_to_param"