[feat] Better status reporting when loading embeds and concepts (#2372)

This PR improves the console reporting of the process of recognizing
trigger tokens and loading their embeds.

1. Do not report "concept is not known to HuggingFace" if the trigger
term is in fact a local embedding trigger.
2. When a trigger term is first recognized during a session, report the
fact.
This should help debug embedding issues in the future.

Note that the local embeddings produced by the new InvokeAI TI training
script default to the format <trigger> with literal angle brackets. This
sets them off from the rest of the text well and will enable
autocomplete at some point in the future. However, this means that they
supersede like-named HuggingFace concepts, and may cause problems for
people uploading them to the HuggingFace repository (although that
problem already exists).
This commit is contained in:
Lincoln Stein 2023-01-24 09:35:53 -05:00 committed by GitHub
commit f687d90bca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 22 additions and 4 deletions

View File

@ -445,7 +445,11 @@ class Generate:
self._set_sampler()
# apply the concepts library to the prompt
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(prompt, lambda concepts: self.load_huggingface_concepts(concepts))
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(
prompt,
lambda concepts: self.load_huggingface_concepts(concepts),
self.model.textual_inversion_manager.get_all_trigger_strings()
)
# bit of a hack to change the cached sampler's karras threshold to
# whatever the user asked for

View File

@ -59,7 +59,7 @@ class HuggingFaceConceptsLibrary(object):
be downloaded.
'''
if not concept_name in self.list_concepts():
print(f'This concept is not known to the Hugging Face library. Generation will continue without the concept.')
print(f'This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept.')
return None
return self.get_concept_file(concept_name.lower(),'learned_embeds.bin')
@ -115,13 +115,19 @@ class HuggingFaceConceptsLibrary(object):
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>'
return self.match_trigger.sub(do_replace, prompt)
def replace_concepts_with_triggers(self, prompt:str, load_concepts_callback: Callable[[list], any])->str:
def replace_concepts_with_triggers(self,
prompt:str,
load_concepts_callback: Callable[[list], any],
excluded_tokens:list[str])->str:
'''
Given a prompt string that contains `<concept_name>` tags, replace
these tags with the appropriate trigger.
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
of `concepts_name` strings.
`excluded_tokens` are any tokens that should not be replaced, typically because they
are trigger tokens from a locally-loaded embedding.
'''
concepts = self.match_concept.findall(prompt)
if not concepts:
@ -129,6 +135,8 @@ class HuggingFaceConceptsLibrary(object):
load_concepts_callback(concepts)
def do_replace(match)->str:
if excluded_tokens and f'<{match.group(1)}>' in excluded_tokens:
return f'<{match.group(1)}>'
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>'
return self.match_concept.sub(do_replace, prompt)

View File

@ -38,11 +38,15 @@ class TextualInversionManager():
if concept_name in self.hf_concepts_library.concepts_loaded:
continue
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
if self.has_textual_inversion_for_trigger_string(trigger):
if self.has_textual_inversion_for_trigger_string(trigger) \
or self.has_textual_inversion_for_trigger_string(concept_name) \
or self.has_textual_inversion_for_trigger_string(f'<{concept_name}>'): # in case a token with literal angle brackets encountered
print(f'>> Loaded local embedding for trigger {concept_name}')
continue
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
if not bin_file:
continue
print(f'>> Loaded remote embedding for trigger {concept_name}')
self.load_textual_inversion(bin_file)
self.hf_concepts_library.concepts_loaded[concept_name]=True
@ -50,6 +54,8 @@ class TextualInversionManager():
return [ti.trigger_string for ti in self.textual_inversions]
def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool=False):
if str(ckpt_path).endswith('.DS_Store'):
return
try:
scan_result = scan_file_path(ckpt_path)
if scan_result.infected_files == 1: