mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[feat] Better status reporting when loading embeds and concepts (#2372)
This PR improves the console reporting of the process of recognizing trigger tokens and loading their embeds. 1. Do not report "concept is not known to HuggingFace" if the trigger term is in fact a local embedding trigger. 2. When a trigger term is first recognized during a session, report the fact. This should help debug embedding issues in the future. Note that the local embeddings produced by the new InvokeAI TI training script default to the format <trigger> with literal angle brackets. This sets them off from the rest of the text well and will enable autocomplete at some point in the future. However, this means that they supersede like-named HuggingFace concepts, and may cause problems for people uploading them to the HuggingFace repository (although that problem already exists).
This commit is contained in:
commit
f687d90bca
@ -445,7 +445,11 @@ class Generate:
|
||||
self._set_sampler()
|
||||
|
||||
# apply the concepts library to the prompt
|
||||
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(prompt, lambda concepts: self.load_huggingface_concepts(concepts))
|
||||
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(
|
||||
prompt,
|
||||
lambda concepts: self.load_huggingface_concepts(concepts),
|
||||
self.model.textual_inversion_manager.get_all_trigger_strings()
|
||||
)
|
||||
|
||||
# bit of a hack to change the cached sampler's karras threshold to
|
||||
# whatever the user asked for
|
||||
|
@ -59,7 +59,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
be downloaded.
|
||||
'''
|
||||
if not concept_name in self.list_concepts():
|
||||
print(f'This concept is not known to the Hugging Face library. Generation will continue without the concept.')
|
||||
print(f'This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept.')
|
||||
return None
|
||||
return self.get_concept_file(concept_name.lower(),'learned_embeds.bin')
|
||||
|
||||
@ -115,13 +115,19 @@ class HuggingFaceConceptsLibrary(object):
|
||||
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>'
|
||||
return self.match_trigger.sub(do_replace, prompt)
|
||||
|
||||
def replace_concepts_with_triggers(self, prompt:str, load_concepts_callback: Callable[[list], any])->str:
|
||||
def replace_concepts_with_triggers(self,
|
||||
prompt:str,
|
||||
load_concepts_callback: Callable[[list], any],
|
||||
excluded_tokens:list[str])->str:
|
||||
'''
|
||||
Given a prompt string that contains `<concept_name>` tags, replace
|
||||
these tags with the appropriate trigger.
|
||||
|
||||
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
|
||||
of `concepts_name` strings.
|
||||
|
||||
`excluded_tokens` are any tokens that should not be replaced, typically because they
|
||||
are trigger tokens from a locally-loaded embedding.
|
||||
'''
|
||||
concepts = self.match_concept.findall(prompt)
|
||||
if not concepts:
|
||||
@ -129,6 +135,8 @@ class HuggingFaceConceptsLibrary(object):
|
||||
load_concepts_callback(concepts)
|
||||
|
||||
def do_replace(match)->str:
|
||||
if excluded_tokens and f'<{match.group(1)}>' in excluded_tokens:
|
||||
return f'<{match.group(1)}>'
|
||||
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>'
|
||||
return self.match_concept.sub(do_replace, prompt)
|
||||
|
||||
|
@ -38,11 +38,15 @@ class TextualInversionManager():
|
||||
if concept_name in self.hf_concepts_library.concepts_loaded:
|
||||
continue
|
||||
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
|
||||
if self.has_textual_inversion_for_trigger_string(trigger):
|
||||
if self.has_textual_inversion_for_trigger_string(trigger) \
|
||||
or self.has_textual_inversion_for_trigger_string(concept_name) \
|
||||
or self.has_textual_inversion_for_trigger_string(f'<{concept_name}>'): # in case a token with literal angle brackets encountered
|
||||
print(f'>> Loaded local embedding for trigger {concept_name}')
|
||||
continue
|
||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||
if not bin_file:
|
||||
continue
|
||||
print(f'>> Loaded remote embedding for trigger {concept_name}')
|
||||
self.load_textual_inversion(bin_file)
|
||||
self.hf_concepts_library.concepts_loaded[concept_name]=True
|
||||
|
||||
@ -50,6 +54,8 @@ class TextualInversionManager():
|
||||
return [ti.trigger_string for ti in self.textual_inversions]
|
||||
|
||||
def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool=False):
|
||||
if str(ckpt_path).endswith('.DS_Store'):
|
||||
return
|
||||
try:
|
||||
scan_result = scan_file_path(ckpt_path)
|
||||
if scan_result.infected_files == 1:
|
||||
|
Loading…
Reference in New Issue
Block a user