Merge branch 'main' into slider-fix

This commit is contained in:
Lincoln Stein 2023-01-24 12:21:39 -05:00 committed by GitHub
commit ce865a8d69
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 34 additions and 10 deletions

View File

@ -445,7 +445,11 @@ class Generate:
self._set_sampler()
# apply the concepts library to the prompt
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(prompt, lambda concepts: self.load_huggingface_concepts(concepts))
prompt = self.huggingface_concepts_library.replace_concepts_with_triggers(
prompt,
lambda concepts: self.load_huggingface_concepts(concepts),
self.model.textual_inversion_manager.get_all_trigger_strings()
)
# bit of a hack to change the cached sampler's karras threshold to
# whatever the user asked for

View File

@ -573,7 +573,7 @@ def import_model(model_path:str, gen, opt, completer):
if model_path.startswith(('http:','https:','ftp:')):
model_name = import_ckpt_model(model_path, gen, opt, completer)
elif os.path.exists(model_path) and model_path.endswith('.ckpt') and os.path.isfile(model_path):
elif os.path.exists(model_path) and model_path.endswith(('.ckpt','.safetensors')) and os.path.isfile(model_path):
model_name = import_ckpt_model(model_path, gen, opt, completer)
elif re.match('^[\w.+-]+/[\w.+-]+$',model_path):
model_name = import_diffuser_model(model_path, gen, opt, completer)
@ -627,9 +627,9 @@ def import_ckpt_model(path_or_url:str, gen, opt, completer)->str:
model_description=default_description
)
config_file = None
default = Path(Globals.root,'configs/stable-diffusion/v1-inference.yaml')
completer.complete_extensions(('.yaml','.yml'))
completer.set_line('configs/stable-diffusion/v1-inference.yaml')
completer.set_line(str(default))
done = False
while not done:
config_file = input('Configuration file for this model: ').strip()

View File

@ -59,7 +59,7 @@ class HuggingFaceConceptsLibrary(object):
be downloaded.
'''
if not concept_name in self.list_concepts():
print(f'This concept is not known to the Hugging Face library. Generation will continue without the concept.')
print(f'This concept is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept.')
return None
return self.get_concept_file(concept_name.lower(),'learned_embeds.bin')
@ -115,13 +115,19 @@ class HuggingFaceConceptsLibrary(object):
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>'
return self.match_trigger.sub(do_replace, prompt)
def replace_concepts_with_triggers(self, prompt:str, load_concepts_callback: Callable[[list], any])->str:
def replace_concepts_with_triggers(self,
prompt:str,
load_concepts_callback: Callable[[list], any],
excluded_tokens:list[str])->str:
'''
Given a prompt string that contains `<concept_name>` tags, replace
these tags with the appropriate trigger.
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
of `concepts_name` strings.
`excluded_tokens` are any tokens that should not be replaced, typically because they
are trigger tokens from a locally-loaded embedding.
'''
concepts = self.match_concept.findall(prompt)
if not concepts:
@ -129,6 +135,8 @@ class HuggingFaceConceptsLibrary(object):
load_concepts_callback(concepts)
def do_replace(match)->str:
if excluded_tokens and f'<{match.group(1)}>' in excluded_tokens:
return f'<{match.group(1)}>'
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>'
return self.match_concept.sub(do_replace, prompt)

View File

@ -147,7 +147,7 @@ class ModelManager(object):
Return true if this is a legacy (.ckpt) model
'''
info = self.model_info(model_name)
if 'weights' in info and info['weights'].endswith('.ckpt'):
if 'weights' in info and info['weights'].endswith(('.ckpt','.safetensors')):
return True
return False
@ -366,8 +366,14 @@ class ModelManager(object):
vae = os.path.normpath(os.path.join(Globals.root,vae))
if os.path.exists(vae):
print(f' | Loading VAE weights from: {vae}')
vae_ckpt = None
vae_dict = None
if vae.endswith('.safetensors'):
vae_ckpt = safetensors.torch.load_file(vae)
vae_dict = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss"}
else:
vae_ckpt = torch.load(vae, map_location="cpu")
vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
vae_dict = {k: v for k, v in vae_ckpt['state_dict'].items() if k[0:4] != "loss"}
model.first_stage_model.load_state_dict(vae_dict, strict=False)
else:
print(f' | VAE file {vae} not found. Skipping.')

View File

@ -38,11 +38,15 @@ class TextualInversionManager():
if concept_name in self.hf_concepts_library.concepts_loaded:
continue
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
if self.has_textual_inversion_for_trigger_string(trigger):
if self.has_textual_inversion_for_trigger_string(trigger) \
or self.has_textual_inversion_for_trigger_string(concept_name) \
or self.has_textual_inversion_for_trigger_string(f'<{concept_name}>'): # in case a token with literal angle brackets encountered
print(f'>> Loaded local embedding for trigger {concept_name}')
continue
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
if not bin_file:
continue
print(f'>> Loaded remote embedding for trigger {concept_name}')
self.load_textual_inversion(bin_file)
self.hf_concepts_library.concepts_loaded[concept_name]=True
@ -50,6 +54,8 @@ class TextualInversionManager():
return [ti.trigger_string for ti in self.textual_inversions]
def load_textual_inversion(self, ckpt_path, defer_injecting_tokens: bool=False):
if str(ckpt_path).endswith('.DS_Store'):
return
try:
scan_result = scan_file_path(ckpt_path)
if scan_result.infected_files == 1: