mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
improve status reporting when loading local and remote embeddings
- During trigger token processing, emit better status messages indicating which triggers were found. - Suppress message "<token> is not known to HuggingFace library, when token is in fact a local embed.
This commit is contained in:
parent
ce17051b28
commit
3c3d893b9d
@ -445,7 +445,11 @@ class Generate:
|
|||||||
self._set_sampler()
|
self._set_sampler()
|
||||||
|
|
||||||
# apply the concepts library to the prompt
|
# apply the concepts library to the prompt
|
||||||
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(prompt, lambda concepts: self.load_huggingface_concepts(concepts))
|
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(
|
||||||
|
prompt,
|
||||||
|
lambda concepts: self.load_huggingface_concepts(concepts),
|
||||||
|
self.model.textual_inversion_manager.get_all_trigger_strings()
|
||||||
|
)
|
||||||
|
|
||||||
# bit of a hack to change the cached sampler's karras threshold to
|
# bit of a hack to change the cached sampler's karras threshold to
|
||||||
# whatever the user asked for
|
# whatever the user asked for
|
||||||
|
@ -115,13 +115,19 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>'
|
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>'
|
||||||
return self.match_trigger.sub(do_replace, prompt)
|
return self.match_trigger.sub(do_replace, prompt)
|
||||||
|
|
||||||
def replace_concepts_with_triggers(self, prompt:str, load_concepts_callback: Callable[[list], any])->str:
|
def replace_concepts_with_triggers(self,
|
||||||
|
prompt:str,
|
||||||
|
load_concepts_callback: Callable[[list], any],
|
||||||
|
excluded_tokens:list[str])->str:
|
||||||
'''
|
'''
|
||||||
Given a prompt string that contains `<concept_name>` tags, replace
|
Given a prompt string that contains `<concept_name>` tags, replace
|
||||||
these tags with the appropriate trigger.
|
these tags with the appropriate trigger.
|
||||||
|
|
||||||
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
|
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
|
||||||
of `concepts_name` strings.
|
of `concepts_name` strings.
|
||||||
|
|
||||||
|
`excluded_tokens` are any tokens that should not be replaced, typically because they
|
||||||
|
are trigger tokens from a locally-loaded embedding.
|
||||||
'''
|
'''
|
||||||
concepts = self.match_concept.findall(prompt)
|
concepts = self.match_concept.findall(prompt)
|
||||||
if not concepts:
|
if not concepts:
|
||||||
@ -129,6 +135,8 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
load_concepts_callback(concepts)
|
load_concepts_callback(concepts)
|
||||||
|
|
||||||
def do_replace(match)->str:
|
def do_replace(match)->str:
|
||||||
|
if excluded_tokens and f'<{match.group(1)}>' in excluded_tokens:
|
||||||
|
return f'<{match.group(1)}>'
|
||||||
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>'
|
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>'
|
||||||
return self.match_concept.sub(do_replace, prompt)
|
return self.match_concept.sub(do_replace, prompt)
|
||||||
|
|
||||||
|
@ -38,11 +38,15 @@ class TextualInversionManager():
|
|||||||
if concept_name in self.hf_concepts_library.concepts_loaded:
|
if concept_name in self.hf_concepts_library.concepts_loaded:
|
||||||
continue
|
continue
|
||||||
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
|
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
|
||||||
if self.has_textual_inversion_for_trigger_string(trigger):
|
if self.has_textual_inversion_for_trigger_string(trigger) \
|
||||||
|
or self.has_textual_inversion_for_trigger_string(concept_name) \
|
||||||
|
or self.has_textual_inversion_for_trigger_string(f'<{concept_name}>'): # in case a token with literal angle brackets encountered
|
||||||
|
print(f'>> Loaded local embedding for trigger {concept_name}')
|
||||||
continue
|
continue
|
||||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||||
if not bin_file:
|
if not bin_file:
|
||||||
continue
|
continue
|
||||||
|
print(f'>> Loaded remote embedding for trigger {concept_name}')
|
||||||
self.load_textual_inversion(bin_file)
|
self.load_textual_inversion(bin_file)
|
||||||
self.hf_concepts_library.concepts_loaded[concept_name]=True
|
self.hf_concepts_library.concepts_loaded[concept_name]=True
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user