mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fallback to using filename to trigger embeddings (#2752)
Lots of earlier embeds use a common trigger token such as * or the hebrew letter shan. Previously, the textual inversion manager would refuse to load the second and subsequent embeddings that used a previously-claimed trigger. Now, when this case is encountered, the trigger token is replaced by <filename> and the user is informed of the fact.
This commit is contained in:
commit
ab018ccdfe
@ -975,7 +975,7 @@ class Generate:
|
|||||||
ti_path, defer_injecting_tokens=True
|
ti_path, defer_injecting_tokens=True
|
||||||
)
|
)
|
||||||
print(
|
print(
|
||||||
f'>> Textual inversion triggers: {", ".join(self.model.textual_inversion_manager.get_all_trigger_strings())}'
|
f'>> Textual inversion triggers: {", ".join(sorted(self.model.textual_inversion_manager.get_all_trigger_strings()))}'
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
|
@ -34,6 +34,7 @@ class TextualInversionManager:
|
|||||||
self.text_encoder = text_encoder
|
self.text_encoder = text_encoder
|
||||||
self.full_precision = full_precision
|
self.full_precision = full_precision
|
||||||
self.hf_concepts_library = HuggingFaceConceptsLibrary()
|
self.hf_concepts_library = HuggingFaceConceptsLibrary()
|
||||||
|
self.trigger_to_sourcefile = dict()
|
||||||
default_textual_inversions: list[TextualInversion] = []
|
default_textual_inversions: list[TextualInversion] = []
|
||||||
self.textual_inversions = default_textual_inversions
|
self.textual_inversions = default_textual_inversions
|
||||||
|
|
||||||
@ -59,15 +60,17 @@ class TextualInversionManager:
|
|||||||
def get_all_trigger_strings(self) -> list[str]:
|
def get_all_trigger_strings(self) -> list[str]:
|
||||||
return [ti.trigger_string for ti in self.textual_inversions]
|
return [ti.trigger_string for ti in self.textual_inversions]
|
||||||
|
|
||||||
def load_textual_inversion(self, ckpt_path: Union[str,Path], defer_injecting_tokens: bool = False):
|
def load_textual_inversion(
|
||||||
|
self, ckpt_path: Union[str, Path], defer_injecting_tokens: bool = False
|
||||||
|
):
|
||||||
ckpt_path = Path(ckpt_path)
|
ckpt_path = Path(ckpt_path)
|
||||||
|
|
||||||
if not ckpt_path.is_file():
|
if not ckpt_path.is_file():
|
||||||
return
|
return
|
||||||
|
|
||||||
if str(ckpt_path).endswith(".DS_Store"):
|
if str(ckpt_path).endswith(".DS_Store"):
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
scan_result = scan_file_path(str(ckpt_path))
|
scan_result = scan_file_path(str(ckpt_path))
|
||||||
if scan_result.infected_files == 1:
|
if scan_result.infected_files == 1:
|
||||||
@ -89,27 +92,45 @@ class TextualInversionManager:
|
|||||||
return
|
return
|
||||||
elif (
|
elif (
|
||||||
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
|
self.text_encoder.get_input_embeddings().weight.data[0].shape[0]
|
||||||
!= embedding_info['token_dim']
|
!= embedding_info["token_dim"]
|
||||||
):
|
):
|
||||||
print(
|
print(
|
||||||
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
|
f"** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info['token_dim']}."
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
if embedding_info:
|
# Resolve the situation in which an earlier embedding has claimed the same
|
||||||
try:
|
# trigger string. We replace the trigger with '<source_file>', as we used to.
|
||||||
self._add_textual_inversion(
|
trigger_str = embedding_info["name"]
|
||||||
embedding_info["name"],
|
sourcefile = (
|
||||||
embedding_info["embedding"],
|
f"{ckpt_path.parent.name}/{ckpt_path.name}"
|
||||||
defer_injecting_tokens=defer_injecting_tokens,
|
if ckpt_path.name == "learned_embeds.bin"
|
||||||
)
|
else ckpt_path.name
|
||||||
except ValueError as e:
|
)
|
||||||
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
|
|
||||||
print(f" | The error was {str(e)}")
|
if trigger_str in self.trigger_to_sourcefile:
|
||||||
else:
|
replacement_trigger_str = (
|
||||||
print(
|
f"<{ckpt_path.parent.name}>"
|
||||||
f">> Failed to load embedding located at {str(ckpt_path)}. Unsupported file."
|
if ckpt_path.name == "learned_embeds.bin"
|
||||||
|
else f"<{ckpt_path.stem}>"
|
||||||
)
|
)
|
||||||
|
print(
|
||||||
|
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
||||||
|
)
|
||||||
|
trigger_str = replacement_trigger_str
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._add_textual_inversion(
|
||||||
|
trigger_str,
|
||||||
|
embedding_info["embedding"],
|
||||||
|
defer_injecting_tokens=defer_injecting_tokens,
|
||||||
|
)
|
||||||
|
# remember which source file claims this trigger
|
||||||
|
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
|
||||||
|
print(f" | The error was {str(e)}")
|
||||||
|
|
||||||
def _add_textual_inversion(
|
def _add_textual_inversion(
|
||||||
self, trigger_str, embedding, defer_injecting_tokens=False
|
self, trigger_str, embedding, defer_injecting_tokens=False
|
||||||
@ -122,7 +143,7 @@ class TextualInversionManager:
|
|||||||
"""
|
"""
|
||||||
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
||||||
print(
|
print(
|
||||||
f">> TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
f"** TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
if not self.full_precision:
|
if not self.full_precision:
|
||||||
@ -131,7 +152,7 @@ class TextualInversionManager:
|
|||||||
embedding = embedding.unsqueeze(0)
|
embedding = embedding.unsqueeze(0)
|
||||||
elif len(embedding.shape) > 2:
|
elif len(embedding.shape) > 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2."
|
f"** TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2."
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -147,7 +168,7 @@ class TextualInversionManager:
|
|||||||
else:
|
else:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
print(
|
print(
|
||||||
f">> TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
f"** TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -294,7 +315,7 @@ class TextualInversionManager:
|
|||||||
elif file_type == "bin":
|
elif file_type == "bin":
|
||||||
return self._parse_embedding_bin(embedding_file)
|
return self._parse_embedding_bin(embedding_file)
|
||||||
else:
|
else:
|
||||||
print(f">> Not a recognized embedding file: {embedding_file}")
|
print(f"** Notice: unrecognized embedding file format: {embedding_file}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _parse_embedding_pt(self, embedding_file):
|
def _parse_embedding_pt(self, embedding_file):
|
||||||
@ -355,8 +376,9 @@ class TextualInversionManager:
|
|||||||
embedding_info = None
|
embedding_info = None
|
||||||
else:
|
else:
|
||||||
for token in list(embedding_ckpt.keys()):
|
for token in list(embedding_ckpt.keys()):
|
||||||
embedding_info["name"] = token or os.path.basename(
|
embedding_info["name"] = (
|
||||||
os.path.splitext(embedding_file)[0]
|
token
|
||||||
|
or f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
|
||||||
)
|
)
|
||||||
embedding_info["embedding"] = embedding_ckpt[token]
|
embedding_info["embedding"] = embedding_ckpt[token]
|
||||||
embedding_info[
|
embedding_info[
|
||||||
@ -380,7 +402,7 @@ class TextualInversionManager:
|
|||||||
embedding_info["name"] = (
|
embedding_info["name"] = (
|
||||||
token
|
token
|
||||||
if token != "*"
|
if token != "*"
|
||||||
else os.path.basename(os.path.splitext(embedding_file)[0])
|
else f"<{os.path.basename(os.path.splitext(embedding_file)[0])}>"
|
||||||
)
|
)
|
||||||
embedding_info["embedding"] = embedding_ckpt[
|
embedding_info["embedding"] = embedding_ckpt[
|
||||||
"string_to_param"
|
"string_to_param"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user