InvokeAI/ldm/invoke/concepts_lib.py
Lincoln Stein 1e1f871ee1
Embedding merging (#1526)
* add whole <style token> to vocab for concept library embeddings

* add ability to load multiple concept .bin files

* make --log_tokenization respect custom tokens

* start working on concept downloading system

* preliminary support for dynamic loading and merging of multiple embedded models

- The embedding_manager is now enhanced with ldm.invoke.concepts_lib,
  which handles dynamic downloading and caching of embedded models from
  the Hugging Face concepts library (https://huggingface.co/sd-concepts-library)

- Downloading of a embedded model is triggered by the presence of one or more
  <concept> tags in the prompt.

- Once the embedded model is downloaded, its trigger phrase will be loaded
  into the embedding manager and the prompt's <concept> tag will be replaced
  with the <trigger_phrase>

- The downloaded model stays on disk for fast loading later.

- The CLI autocomplete will complete partial <concept> tags for you. Type a
  '<' and hit tab to get all ~700 concepts.

BUGS AND LIMITATIONS:

- MODEL NAME VS TRIGGER PHRASE

  You must use the name of the concept embed model from the SD
  library, and not the trigger phrase itself. Usually these are the
  same, but not always. For example, the model named "hoi4-leaders"
  corresponds to the trigger "<HOI4-Leader>"

  One reason for this design choice is that there is no apparent
  constraint on the uniqueness of the trigger phrases and one trigger
  phrase may map onto multiple models. So we use the model name
  instead.

  The second reason is that there is no way I know of to search
  Hugging Face for models with certain trigger phrases. So we'd have
  to download all 700 models to index the phrases.

  The problem this presents is that this may confuse users, who will
  want to reuse prompts from distributions that use the trigger phrase
  directly. Usually this will work, but not always.

- WON'T WORK ON A FIREWALLED SYSTEM

  If the host running IAI has no internet connection, it can't
  download the concept libraries. I will add a script that allows
  users to preload a list of concept models.

- BUG IN PROMPT REPLACEMENT WHEN MODEL NOT FOUND

  There's a small bug that occurs when the user provides an invalid
  model name. The <concept> gets replaced with <None> in the prompt.

* fix loading .pt embeddings; allow multi-vector embeddings; warn on dupes

* simplify replacement logic and remove cuda assumption

* download list of concepts from hugging face

* remove misleading customization of '*' placeholder

the existing code as-is did not do anything; unclear what it was supposed to do.

the obvious alternative -- setting using 'placeholder_strings' instead of
'placeholder_tokens' to match model.params.personalization_config.params.placeholder_strings --
caused a crash. i think this is because the passed string also needed to be handed over
on init of the PersonalizedBase as the 'placeholder_token' argument.
this is weird config dict magic and i don't want to touch it. put a
breakpoint in personalzied.py line 116 (top of PersonalizedBase.__init__) if
you want to have a crack at it yourself.

* address all the issues raised by damian0815 in review of PR #1526

* actually resize the token_embeddings

* multiple improvements to the concept loader based on code reviews

1. Activated the --embedding_directory option (alias --embedding_path)
   to load a single embedding or an entire directory of embeddings at
   startup time.

2. Can turn off automatic loading of embeddings using --no-embeddings.

3. Embedding checkpoints are scanned with the pickle scanner.

4. More informative error messages when a concept can't be loaded due
   either to a 404 not found error or a network error.

* autocomplete terms end with ">" now

* fix startup error and network unreachable

1. If the .invokeai file does not contain the --root and --outdir options,
  invoke.py will now fix it.

2. Catch and handle network problems when downloading hugging face textual
   inversion concepts.

* fix misformatted error string

Co-authored-by: Damian Stewart <d@damianstewart.com>
2022-11-28 02:40:24 -05:00

152 lines
6.3 KiB
Python

"""
Query and install embeddings from the HuggingFace SD Concepts Library
at https://huggingface.co/sd-concepts-library.
The interface is through the Concepts() object.
"""
import os
import re
import traceback
from urllib import request, error as ul_error
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
from ldm.invoke.globals import Globals
class Concepts(object):
def __init__(self, root=None):
'''
Initialize the Concepts object. May optionally pass a root directory.
'''
self.root = root or Globals.root
self.hf_api = HfApi()
self.concept_list = None
self.concepts_loaded = dict()
self.triggers = dict() # concept name to trigger phrase
self.concept_names = dict() # trigger phrase to concept name
self.match_trigger = re.compile('(<[\w\-]+>)')
self.match_concept = re.compile('<([\w\-]+)>')
def list_concepts(self)->list:
'''
Return a list of all the concepts by name, without the 'sd-concepts-library' part.
'''
if self.concept_list is not None:
return self.concept_list
try:
models = self.hf_api.list_models(filter=ModelFilter(model_name='sd-concepts-library/'))
self.concept_list = [a.id.split('/')[1] for a in models]
except Exception as e:
print(' ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}.')
print(' ** You may load .bin and .pt file(s) manually using the --embedding_directory argument.')
return self.concept_list
def get_concept_model_path(self, concept_name:str)->str:
'''
Returns the path to the 'learned_embeds.bin' file in
the named concept. Returns None if invalid or cannot
be downloaded.
'''
return self.get_concept_file(concept_name.lower(),'learned_embeds.bin')
def concept_to_trigger(self, concept_name:str)->str:
'''
Given a concept name returns its trigger by looking in the
"token_identifier.txt" file.
'''
if concept_name in self.triggers:
return self.triggers[concept_name]
file = self.get_concept_file(concept_name, 'token_identifier.txt', local_only=True)
if not file:
return None
with open(file,'r') as f:
trigger = f.readline()
trigger = trigger.strip()
self.triggers[concept_name] = trigger
self.concept_names[trigger] = concept_name
return trigger
def trigger_to_concept(self, trigger:str)->str:
'''
Given a trigger phrase, maps it to the concept library name.
Only works if concept_to_trigger() has previously been called
on this library. There needs to be a persistent database for
this.
'''
concept = self.concept_names.get(trigger,None)
return f'<{concept}>' if concept else f'{trigger}'
def replace_triggers_with_concepts(self, prompt:str)->str:
'''
Given a prompt string that contains <trigger> tags, replace these
tags with the concept name. The reason for this is so that the
concept names get stored in the prompt metadata. There is no
controlling of colliding triggers in the SD library, so it is
better to store the concept name (unique) than the concept trigger
(not necessarily unique!)
'''
def do_replace(match)->str:
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>'
return self.match_trigger.sub(do_replace, prompt)
def replace_concepts_with_triggers(self, prompt:str)->str:
'''
Given a prompt string that contains <concept_name> tags, replace
these tags with the appropriate trigger.
'''
def do_replace(match)->str:
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>'
return self.match_concept.sub(do_replace, prompt)
def get_concept_file(self, concept_name:str, file_name:str='learned_embeds.bin' , local_only:bool=False)->str:
if not self.concept_is_downloaded(concept_name) and not local_only:
self.download_concept(concept_name)
path = os.path.join(self._concept_path(concept_name), file_name)
return path if os.path.exists(path) else None
def concept_is_downloaded(self, concept_name)->bool:
concept_directory = self._concept_path(concept_name)
return os.path.exists(concept_directory)
def download_concept(self,concept_name)->bool:
repo_id = self._concept_id(concept_name)
dest = self._concept_path(concept_name)
access_token = HfFolder.get_token()
header = [("Authorization", f'Bearer {access_token}')] if access_token else []
opener = request.build_opener()
opener.addheaders = header
request.install_opener(opener)
os.makedirs(dest, exist_ok=True)
succeeded = True
bytes = 0
def tally_download_size(chunk, size, total):
nonlocal bytes
if chunk==0:
bytes += total
print(f'>> Downloading {repo_id}...',end='')
try:
for file in ('README.md','learned_embeds.bin','token_identifier.txt','type_of_concept.txt'):
url = hf_hub_url(repo_id, file)
request.urlretrieve(url, os.path.join(dest,file),reporthook=tally_download_size)
except ul_error.HTTPError as e:
if e.code==404:
print(f'This concept is not known to the Hugging Face library. Generation will continue without the concept.')
else:
print(f'Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)')
os.rmdir(dest)
return False
except ul_error.URLError as e:
print(f'ERROR: {str(e)}. This may reflect a network issue. Generation will continue without the concept.')
os.rmdir(dest)
return False
print('...{:.2f}Kb'.format(bytes/1024))
return succeeded
def _concept_id(self, concept_name:str)->str:
return f'sd-concepts-library/{concept_name}'
def _concept_path(self, concept_name:str)->str:
return os.path.join(self.root,'models','sd-concepts-library',concept_name)