mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into slider-fix
This commit is contained in:
commit
ce865a8d69
@ -445,7 +445,11 @@ class Generate:
|
||||
self._set_sampler()
|
||||
|
||||
# apply the concepts library to the prompt
|
||||
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(prompt, lambda concepts: self.load_huggingface_concepts(concepts))
|
||||
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(
|
||||
prompt,
|
||||
lambda concepts: self.load_huggingface_concepts(concepts),
|
||||
self.model.textual_inversion_manager.get_all_trigger_strings()
|
||||
)
|
||||
|
||||
# bit of a hack to change the cached sampler's karras threshold to
|
||||
# whatever the user asked for
|
||||
|
@ -573,7 +573,7 @@ def import_model(model_path:str, gen, opt, completer):
|
||||
|
||||
if model_path.startswith(('http:','https:','ftp:')):
|
||||
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||
elif os.path.exists(model_path) and model_path.endswith('.ckpt') and os.path.isfile(model_path):
|
||||
elif os.path.exists(model_path) and model_path.endswith(('.ckpt','.safetensors')) and os.path.isfile(model_path):
|
||||
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||
elif re.match('^[\w.+-]+/[\w.+-]+$',model_path):
|
||||
model_name = import_diffuser_model(model_path, gen, opt, completer)
|
||||
@ -627,9 +627,9 @@ def import_ckpt_model(path_or_url:str, gen, opt, completer)->str:
|
||||
model_description=default_description
|
||||
)
|
||||
config_file = None
|
||||
|
||||
default = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
|
||||
completer.complete_extensions(('.yaml','.yml'))
|
||||
completer.set_line('configs/stable-diffusion/v1-inference.yaml')
|
||||
completer.set_line(str(default))
|
||||
done = False
|
||||
while not done:
|
||||
config_file = input('Configuration file for this model: ').strip()
|
||||
|
@ -59,7 +59,7 @@ class HuggingFaceConceptsLibrary(object):
|
||||
be downloaded.
|
||||
'''
|
||||
if not concept_name in self.list_concepts():
|
||||
print(f'This concept is not known to the Hugging Face library. Generation will continue without the concept.')
|
||||
print(f'This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept.')
|
||||
return None
|
||||
return self.get_concept_file(concept_name.lower(),'learned_embeds.bin')
|
||||
|
||||
@ -115,13 +115,19 @@ class HuggingFaceConceptsLibrary(object):
|
||||
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>'
|
||||
return self.match_trigger.sub(do_replace, prompt)
|
||||
|
||||
def replace_concepts_with_triggers(self, prompt:str, load_concepts_callback: Callable[[list], any])->str:
|
||||
def replace_concepts_with_triggers(self,
|
||||
prompt:str,
|
||||
load_concepts_callback: Callable[[list], any],
|
||||
excluded_tokens:list[str])->str:
|
||||
'''
|
||||
Given a prompt string that contains `<concept_name>` tags, replace
|
||||
these tags with the appropriate trigger.
|
||||
|
||||
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
|
||||
of `concepts_name` strings.
|
||||
|
||||
`excluded_tokens` are any tokens that should not be replaced, typically because they
|
||||
are trigger tokens from a locally-loaded embedding.
|
||||
'''
|
||||
concepts = self.match_concept.findall(prompt)
|
||||
if not concepts:
|
||||
@ -129,6 +135,8 @@ class HuggingFaceConceptsLibrary(object):
|
||||
load_concepts_callback(concepts)
|
||||
|
||||
def do_replace(match)->str:
|
||||
if excluded_tokens and f'<{match.group(1)}>' in excluded_tokens:
|
||||
return f'<{match.group(1)}>'
|
||||
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>'
|
||||
return self.match_concept.sub(do_replace, prompt)
|
||||
|
||||
|
@ -147,7 +147,7 @@ class ModelManager(object):
|
||||
Return true if this is a legacy (.ckpt) model
|
||||
'''
|
||||
info = self.model_info(model_name)
|
||||
if 'weights' in info and info['weights'].endswith('.ckpt'):
|
||||
if 'weights' in info and info['weights'].endswith(('.ckpt','.safetensors')):
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -366,8 +366,14 @@ class ModelManager(object):
|
||||
vae = os.path.normpath(os.path.join(Globals.root,vae))
|
||||
if os.path.exists(vae):
|
||||
print(f' | Loading VAE weights from: {vae}')
|
||||
vae_ckpt = None
|
||||
vae_dict = None
|
||||
if vae.endswith('.safetensors'):
|
||||
vae_ckpt = safetensors.torch.load_file(vae)
|
||||
vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"}
|
||||
else:
|
||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
||||
vae_dict = {k: v for k, v in vae_ckpt['state_dict'].items() if k[0:4] != "loss"}
|
||||
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||
else:
|
||||
print(f' | VAE file {vae} not found. Skipping.')
|
||||
|
@ -38,11 +38,15 @@ class TextualInversionManager():
|
||||
if concept_name in self.hf_concepts_library.concepts_loaded:
|
||||
continue
|
||||
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
|
||||
if self.has_textual_inversion_for_trigger_string(trigger):
|
||||
if self.has_textual_inversion_for_trigger_string(trigger) \
|
||||
or self.has_textual_inversion_for_trigger_string(concept_name) \
|
||||
or self.has_textual_inversion_for_trigger_string(f'<{concept_name}>'): # in case a token with literal angle brackets encountered
|
||||
print(f'>> Loaded local embedding for trigger {concept_name}')
|
||||
continue
|
||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||
if not bin_file:
|
||||
continue
|
||||
print(f'>> Loaded remote embedding for trigger {concept_name}')
|
||||
self.load_textual_inversion(bin_file)
|
||||
self.hf_concepts_library.concepts_loaded[concept_name]=True
|
||||
|
||||
@ -50,6 +54,8 @@ class TextualInversionManager():
|
||||
return [ti.trigger_string for ti in self.textual_inversions]
|
||||
|
||||
def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool=False):
|
||||
if str(ckpt_path).endswith('.DS_Store'):
|
||||
return
|
||||
try:
|
||||
scan_result = scan_file_path(ckpt_path)
|
||||
if scan_result.infected_files == 1:
|
||||
|
Loading…
Reference in New Issue
Block a user