mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into feat/import-with-vae
This commit is contained in:
commit
ce52d0c42b
@ -445,7 +445,11 @@ class Generate:
|
|||||||
self._set_sampler()
|
self._set_sampler()
|
||||||
|
|
||||||
# apply the concepts library to the prompt
|
# apply the concepts library to the prompt
|
||||||
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(prompt, lambda concepts: self.load_huggingface_concepts(concepts))
|
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(
|
||||||
|
prompt,
|
||||||
|
lambda concepts: self.load_huggingface_concepts(concepts),
|
||||||
|
self.model.textual_inversion_manager.get_all_trigger_strings()
|
||||||
|
)
|
||||||
|
|
||||||
# bit of a hack to change the cached sampler's karras threshold to
|
# bit of a hack to change the cached sampler's karras threshold to
|
||||||
# whatever the user asked for
|
# whatever the user asked for
|
||||||
|
@ -59,7 +59,7 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
be downloaded.
|
be downloaded.
|
||||||
'''
|
'''
|
||||||
if not concept_name in self.list_concepts():
|
if not concept_name in self.list_concepts():
|
||||||
print(f'This concept is not known to the Hugging Face library. Generation will continue without the concept.')
|
print(f'This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept.')
|
||||||
return None
|
return None
|
||||||
return self.get_concept_file(concept_name.lower(),'learned_embeds.bin')
|
return self.get_concept_file(concept_name.lower(),'learned_embeds.bin')
|
||||||
|
|
||||||
@ -115,13 +115,19 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>'
|
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>'
|
||||||
return self.match_trigger.sub(do_replace, prompt)
|
return self.match_trigger.sub(do_replace, prompt)
|
||||||
|
|
||||||
def replace_concepts_with_triggers(self, prompt:str, load_concepts_callback: Callable[[list], any])->str:
|
def replace_concepts_with_triggers(self,
|
||||||
|
prompt:str,
|
||||||
|
load_concepts_callback: Callable[[list], any],
|
||||||
|
excluded_tokens:list[str])->str:
|
||||||
'''
|
'''
|
||||||
Given a prompt string that contains `<concept_name>` tags, replace
|
Given a prompt string that contains `<concept_name>` tags, replace
|
||||||
these tags with the appropriate trigger.
|
these tags with the appropriate trigger.
|
||||||
|
|
||||||
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
|
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
|
||||||
of `concepts_name` strings.
|
of `concepts_name` strings.
|
||||||
|
|
||||||
|
`excluded_tokens` are any tokens that should not be replaced, typically because they
|
||||||
|
are trigger tokens from a locally-loaded embedding.
|
||||||
'''
|
'''
|
||||||
concepts = self.match_concept.findall(prompt)
|
concepts = self.match_concept.findall(prompt)
|
||||||
if not concepts:
|
if not concepts:
|
||||||
@ -129,6 +135,8 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
load_concepts_callback(concepts)
|
load_concepts_callback(concepts)
|
||||||
|
|
||||||
def do_replace(match)->str:
|
def do_replace(match)->str:
|
||||||
|
if excluded_tokens and f'<{match.group(1)}>' in excluded_tokens:
|
||||||
|
return f'<{match.group(1)}>'
|
||||||
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>'
|
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>'
|
||||||
return self.match_concept.sub(do_replace, prompt)
|
return self.match_concept.sub(do_replace, prompt)
|
||||||
|
|
||||||
|
@ -38,11 +38,15 @@ class TextualInversionManager():
|
|||||||
if concept_name in self.hf_concepts_library.concepts_loaded:
|
if concept_name in self.hf_concepts_library.concepts_loaded:
|
||||||
continue
|
continue
|
||||||
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
|
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
|
||||||
if self.has_textual_inversion_for_trigger_string(trigger):
|
if self.has_textual_inversion_for_trigger_string(trigger) \
|
||||||
|
or self.has_textual_inversion_for_trigger_string(concept_name) \
|
||||||
|
or self.has_textual_inversion_for_trigger_string(f'<{concept_name}>'): # in case a token with literal angle brackets encountered
|
||||||
|
print(f'>> Loaded local embedding for trigger {concept_name}')
|
||||||
continue
|
continue
|
||||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||||
if not bin_file:
|
if not bin_file:
|
||||||
continue
|
continue
|
||||||
|
print(f'>> Loaded remote embedding for trigger {concept_name}')
|
||||||
self.load_textual_inversion(bin_file)
|
self.load_textual_inversion(bin_file)
|
||||||
self.hf_concepts_library.concepts_loaded[concept_name]=True
|
self.hf_concepts_library.concepts_loaded[concept_name]=True
|
||||||
|
|
||||||
@ -50,6 +54,8 @@ class TextualInversionManager():
|
|||||||
return [ti.trigger_string for ti in self.textual_inversions]
|
return [ti.trigger_string for ti in self.textual_inversions]
|
||||||
|
|
||||||
def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool=False):
|
def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool=False):
|
||||||
|
if str(ckpt_path).endswith('.DS_Store'):
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
scan_result = scan_file_path(ckpt_path)
|
scan_result = scan_file_path(ckpt_path)
|
||||||
if scan_result.infected_files == 1:
|
if scan_result.infected_files == 1:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user