diff --git a/ldm/generate.py b/ldm/generate.py index 80ebfef227..1afb642721 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -445,7 +445,11 @@ class Generate: self._set_sampler() # apply the concepts library to the prompt - prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(prompt, lambda concepts: self.load_huggingface_concepts(concepts)) + prompt = self.huggingface_concepts_library.replace_concepts_with_triggers( + prompt, + lambda concepts: self.load_huggingface_concepts(concepts), + self.model.textual_inversion_manager.get_all_trigger_strings() + ) # bit of a hack to change the cached sampler's karras threshold to # whatever the user asked for diff --git a/ldm/invoke/concepts_lib.py b/ldm/invoke/concepts_lib.py index 246dea362a..c774f29674 100644 --- a/ldm/invoke/concepts_lib.py +++ b/ldm/invoke/concepts_lib.py @@ -59,7 +59,7 @@ class HuggingFaceConceptsLibrary(object): be downloaded. ''' if not concept_name in self.list_concepts(): - print(f'This concept is not known to the Hugging Face library. Generation will continue without the concept.') + print(f'This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept.') return None return self.get_concept_file(concept_name.lower(),'learned_embeds.bin') @@ -115,13 +115,19 @@ class HuggingFaceConceptsLibrary(object): return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>' return self.match_trigger.sub(do_replace, prompt) - def replace_concepts_with_triggers(self, prompt:str, load_concepts_callback: Callable[[list], any])->str: + def replace_concepts_with_triggers(self, + prompt:str, + load_concepts_callback: Callable[[list], any], + excluded_tokens:list[str])->str: ''' Given a prompt string that contains `` tags, replace these tags with the appropriate trigger. If any `` tags are found, `load_concepts_callback()` is called with a list of `concepts_name` strings. + + `excluded_tokens` are any tokens that should not be replaced, typically because they + are trigger tokens from a locally-loaded embedding. ''' concepts = self.match_concept.findall(prompt) if not concepts: @@ -129,6 +135,8 @@ class HuggingFaceConceptsLibrary(object): load_concepts_callback(concepts) def do_replace(match)->str: + if excluded_tokens and f'<{match.group(1)}>' in excluded_tokens: + return f'<{match.group(1)}>' return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>' return self.match_concept.sub(do_replace, prompt) diff --git a/ldm/modules/textual_inversion_manager.py b/ldm/modules/textual_inversion_manager.py index f7ced79a52..cf28cf8c7a 100644 --- a/ldm/modules/textual_inversion_manager.py +++ b/ldm/modules/textual_inversion_manager.py @@ -38,11 +38,15 @@ class TextualInversionManager(): if concept_name in self.hf_concepts_library.concepts_loaded: continue trigger = self.hf_concepts_library.concept_to_trigger(concept_name) - if self.has_textual_inversion_for_trigger_string(trigger): + if self.has_textual_inversion_for_trigger_string(trigger) \ + or self.has_textual_inversion_for_trigger_string(concept_name) \ + or self.has_textual_inversion_for_trigger_string(f'<{concept_name}>'): # in case a token with literal angle brackets encountered + print(f'>> Loaded local embedding for trigger {concept_name}') continue bin_file = self.hf_concepts_library.get_concept_model_path(concept_name) if not bin_file: continue + print(f'>> Loaded remote embedding for trigger {concept_name}') self.load_textual_inversion(bin_file) self.hf_concepts_library.concepts_loaded[concept_name]=True @@ -50,6 +54,8 @@ class TextualInversionManager(): return [ti.trigger_string for ti in self.textual_inversions] def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool=False): + if str(ckpt_path).endswith('.DS_Store'): + return try: scan_result = scan_file_path(ckpt_path) if scan_result.infected_files == 1: