mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into slider-fix
This commit is contained in:
commit
ce865a8d69
@ -445,7 +445,11 @@ class Generate:
|
|||||||
self._set_sampler()
|
self._set_sampler()
|
||||||
|
|
||||||
# apply the concepts library to the prompt
|
# apply the concepts library to the prompt
|
||||||
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(prompt, lambda concepts: self.load_huggingface_concepts(concepts))
|
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(
|
||||||
|
prompt,
|
||||||
|
lambda concepts: self.load_huggingface_concepts(concepts),
|
||||||
|
self.model.textual_inversion_manager.get_all_trigger_strings()
|
||||||
|
)
|
||||||
|
|
||||||
# bit of a hack to change the cached sampler's karras threshold to
|
# bit of a hack to change the cached sampler's karras threshold to
|
||||||
# whatever the user asked for
|
# whatever the user asked for
|
||||||
|
@ -573,7 +573,7 @@ def import_model(model_path:str, gen, opt, completer):
|
|||||||
|
|
||||||
if model_path.startswith(('http:','https:','ftp:')):
|
if model_path.startswith(('http:','https:','ftp:')):
|
||||||
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||||
elif os.path.exists(model_path) and model_path.endswith('.ckpt') and os.path.isfile(model_path):
|
elif os.path.exists(model_path) and model_path.endswith(('.ckpt','.safetensors')) and os.path.isfile(model_path):
|
||||||
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
model_name = import_ckpt_model(model_path, gen, opt, completer)
|
||||||
elif re.match('^[\w.+-]+/[\w.+-]+$',model_path):
|
elif re.match('^[\w.+-]+/[\w.+-]+$',model_path):
|
||||||
model_name = import_diffuser_model(model_path, gen, opt, completer)
|
model_name = import_diffuser_model(model_path, gen, opt, completer)
|
||||||
@ -627,9 +627,9 @@ def import_ckpt_model(path_or_url:str, gen, opt, completer)->str:
|
|||||||
model_description=default_description
|
model_description=default_description
|
||||||
)
|
)
|
||||||
config_file = None
|
config_file = None
|
||||||
|
default = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
|
||||||
completer.complete_extensions(('.yaml','.yml'))
|
completer.complete_extensions(('.yaml','.yml'))
|
||||||
completer.set_line('configs/stable-diffusion/v1-inference.yaml')
|
completer.set_line(str(default))
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
config_file = input('Configuration file for this model: ').strip()
|
config_file = input('Configuration file for this model: ').strip()
|
||||||
|
@ -59,7 +59,7 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
be downloaded.
|
be downloaded.
|
||||||
'''
|
'''
|
||||||
if not concept_name in self.list_concepts():
|
if not concept_name in self.list_concepts():
|
||||||
print(f'This concept is not known to the Hugging Face library. Generation will continue without the concept.')
|
print(f'This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept.')
|
||||||
return None
|
return None
|
||||||
return self.get_concept_file(concept_name.lower(),'learned_embeds.bin')
|
return self.get_concept_file(concept_name.lower(),'learned_embeds.bin')
|
||||||
|
|
||||||
@ -115,13 +115,19 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>'
|
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>'
|
||||||
return self.match_trigger.sub(do_replace, prompt)
|
return self.match_trigger.sub(do_replace, prompt)
|
||||||
|
|
||||||
def replace_concepts_with_triggers(self, prompt:str, load_concepts_callback: Callable[[list], any])->str:
|
def replace_concepts_with_triggers(self,
|
||||||
|
prompt:str,
|
||||||
|
load_concepts_callback: Callable[[list], any],
|
||||||
|
excluded_tokens:list[str])->str:
|
||||||
'''
|
'''
|
||||||
Given a prompt string that contains `<concept_name>` tags, replace
|
Given a prompt string that contains `<concept_name>` tags, replace
|
||||||
these tags with the appropriate trigger.
|
these tags with the appropriate trigger.
|
||||||
|
|
||||||
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
|
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
|
||||||
of `concepts_name` strings.
|
of `concepts_name` strings.
|
||||||
|
|
||||||
|
`excluded_tokens` are any tokens that should not be replaced, typically because they
|
||||||
|
are trigger tokens from a locally-loaded embedding.
|
||||||
'''
|
'''
|
||||||
concepts = self.match_concept.findall(prompt)
|
concepts = self.match_concept.findall(prompt)
|
||||||
if not concepts:
|
if not concepts:
|
||||||
@ -129,6 +135,8 @@ class HuggingFaceConceptsLibrary(object):
|
|||||||
load_concepts_callback(concepts)
|
load_concepts_callback(concepts)
|
||||||
|
|
||||||
def do_replace(match)->str:
|
def do_replace(match)->str:
|
||||||
|
if excluded_tokens and f'<{match.group(1)}>' in excluded_tokens:
|
||||||
|
return f'<{match.group(1)}>'
|
||||||
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>'
|
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>'
|
||||||
return self.match_concept.sub(do_replace, prompt)
|
return self.match_concept.sub(do_replace, prompt)
|
||||||
|
|
||||||
|
@ -147,7 +147,7 @@ class ModelManager(object):
|
|||||||
Return true if this is a legacy (.ckpt) model
|
Return true if this is a legacy (.ckpt) model
|
||||||
'''
|
'''
|
||||||
info = self.model_info(model_name)
|
info = self.model_info(model_name)
|
||||||
if 'weights' in info and info['weights'].endswith('.ckpt'):
|
if 'weights' in info and info['weights'].endswith(('.ckpt','.safetensors')):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -366,8 +366,14 @@ class ModelManager(object):
|
|||||||
vae = os.path.normpath(os.path.join(Globals.root,vae))
|
vae = os.path.normpath(os.path.join(Globals.root,vae))
|
||||||
if os.path.exists(vae):
|
if os.path.exists(vae):
|
||||||
print(f' | Loading VAE weights from: {vae}')
|
print(f' | Loading VAE weights from: {vae}')
|
||||||
|
vae_ckpt = None
|
||||||
|
vae_dict = None
|
||||||
|
if vae.endswith('.safetensors'):
|
||||||
|
vae_ckpt = safetensors.torch.load_file(vae)
|
||||||
|
vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"}
|
||||||
|
else:
|
||||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||||
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
|
vae_dict = {k: v for k, v in vae_ckpt['state_dict'].items() if k[0:4] != "loss"}
|
||||||
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
model.first_stage_model.load_state_dict(vae_dict, strict=False)
|
||||||
else:
|
else:
|
||||||
print(f' | VAE file {vae} not found. Skipping.')
|
print(f' | VAE file {vae} not found. Skipping.')
|
||||||
|
@ -38,11 +38,15 @@ class TextualInversionManager():
|
|||||||
if concept_name in self.hf_concepts_library.concepts_loaded:
|
if concept_name in self.hf_concepts_library.concepts_loaded:
|
||||||
continue
|
continue
|
||||||
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
|
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
|
||||||
if self.has_textual_inversion_for_trigger_string(trigger):
|
if self.has_textual_inversion_for_trigger_string(trigger) \
|
||||||
|
or self.has_textual_inversion_for_trigger_string(concept_name) \
|
||||||
|
or self.has_textual_inversion_for_trigger_string(f'<{concept_name}>'): # in case a token with literal angle brackets encountered
|
||||||
|
print(f'>> Loaded local embedding for trigger {concept_name}')
|
||||||
continue
|
continue
|
||||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||||
if not bin_file:
|
if not bin_file:
|
||||||
continue
|
continue
|
||||||
|
print(f'>> Loaded remote embedding for trigger {concept_name}')
|
||||||
self.load_textual_inversion(bin_file)
|
self.load_textual_inversion(bin_file)
|
||||||
self.hf_concepts_library.concepts_loaded[concept_name]=True
|
self.hf_concepts_library.concepts_loaded[concept_name]=True
|
||||||
|
|
||||||
@ -50,6 +54,8 @@ class TextualInversionManager():
|
|||||||
return [ti.trigger_string for ti in self.textual_inversions]
|
return [ti.trigger_string for ti in self.textual_inversions]
|
||||||
|
|
||||||
def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool=False):
|
def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool=False):
|
||||||
|
if str(ckpt_path).endswith('.DS_Store'):
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
scan_result = scan_file_path(ckpt_path)
|
scan_result = scan_file_path(ckpt_path)
|
||||||
if scan_result.infected_files == 1:
|
if scan_result.infected_files == 1:
|
||||||
|
Loading…
Reference in New Issue
Block a user