Local embeddings support (CLI autocomplete) (#2211)

* integrate local embeds with HF embeds

* Update concepts_lib.py

* Update concepts_lib.py

Co-authored-by: BuildTools <unconfigured@null.spigotmc.org>
Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
This commit is contained in:
Minjune Song 2023-01-04 01:22:10 -05:00 committed by GitHub
parent 6c6e534c1a
commit 21bf512056
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 13 deletions

1
.gitignore vendored
View File

@ -1,4 +1,5 @@
# ignore default image save location and model symbolic link
embeddings/
outputs/
models/ldm/stable-diffusion-v1/model.ckpt
**/restoration/codeformer/weights

View File

@ -19,6 +19,7 @@ class Concepts(object):
'''
self.root = root or Globals.root
self.hf_api = HfApi()
self.local_concepts = dict()
self.concept_list = None
self.concepts_loaded = dict()
self.triggers = dict() # concept name to trigger phrase
@ -28,17 +29,28 @@ class Concepts(object):
def list_concepts(self)->list:
'''
Return a list of all the concepts by name, without the 'sd-concepts-library' part.
Return a list of all the concepts by name, without the 'sd-concepts-library' part.
Also adds local concepts in invokeai/embeddings folder.
'''
local_concepts_now = self.get_local_concepts(os.path.join(self.root, 'embeddings'))
local_concepts_to_add = set(local_concepts_now).difference(set(self.local_concepts))
self.local_concepts.update(local_concepts_now)
if self.concept_list is not None:
if local_concepts_to_add:
self.concept_list.extend(list(local_concepts_to_add))
return self.concept_list
return self.concept_list
else:
try:
models = self.hf_api.list_models(filter=ModelFilter(model_name='sd-concepts-library/'))
self.concept_list = [a.id.split('/')[1] for a in models]
# when init, add all in dir. when not init, add only concepts added between init and now
self.concept_list.extend(list(local_concepts_to_add))
except Exception as e:
print(f' ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}.')
print(' ** You may load .bin and .pt file(s) manually using the --embedding_directory argument.')
return self.concept_list
try:
models = self.hf_api.list_models(filter=ModelFilter(model_name='sd-concepts-library/'))
self.concept_list = [a.id.split('/')[1] for a in models]
except Exception as e:
print(f' ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}.')
print(' ** You may load .bin and .pt file(s) manually using the --embedding_directory argument.')
return self.concept_list
def get_concept_model_path(self, concept_name:str)->str:
'''
@ -58,6 +70,12 @@ class Concepts(object):
'''
if concept_name in self.triggers:
return self.triggers[concept_name]
elif self.concept_is_local(concept_name):
trigger = f'<{concept_name}>'
self.triggers[concept_name] = trigger
self.concept_names[trigger] = concept_name
return trigger
file = self.get_concept_file(concept_name, 'token_identifier.txt', local_only=True)
if not file:
return None
@ -115,10 +133,20 @@ class Concepts(object):
return self.match_concept.sub(do_replace, prompt)
def get_concept_file(self, concept_name:str, file_name:str='learned_embeds.bin' , local_only:bool=False)->str:
if not self.concept_is_downloaded(concept_name) and not local_only:
if not (self.concept_is_downloaded(concept_name) or self.concept_is_local(concept_name) or local_only):
self.download_concept(concept_name)
path = os.path.join(self._concept_path(concept_name), file_name)
# get local path in invokeai/embeddings if local concept
if self.concept_is_local(concept_name):
concept_path = self._concept_local_path(concept_name)
path = concept_path
else:
concept_path = self._concept_path(concept_name)
path = os.path.join(concept_path, file_name)
return path if os.path.exists(path) else None
def concept_is_local(self, concept_name)->bool:
return concept_name in self.local_concepts
def concept_is_downloaded(self, concept_name)->bool:
concept_directory = self._concept_path(concept_name)
@ -167,3 +195,16 @@ class Concepts(object):
def _concept_path(self, concept_name:str)->str:
return os.path.join(self.root,'models','sd-concepts-library',concept_name)
def _concept_local_path(self, concept_name:str)->str:
filename = self.local_concepts[concept_name]
return os.path.join(self.root,'embeddings',filename)
def get_local_concepts(self, loc_dir:str):
locs_dic = dict()
if os.path.isdir(loc_dir):
for file in os.listdir(loc_dir):
f = os.path.splitext(file)
if f[1] == '.bin' or f[1] == '.pt':
locs_dic[f[0]] = file
return locs_dic

View File

@ -126,6 +126,7 @@ class Completer(object):
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
self.matches= self._seed_completions(text,state)
# looking for an embedding concept
elif re.search('<[\w-]*$',buffer):
self.matches= self._concept_completions(text,state)
@ -272,12 +273,15 @@ class Completer(object):
def add_embedding_terms(self, terms:list[str]):
self.embedding_terms = set(terms)
if self.concepts:
self.embedding_terms.update(self.concepts)
self.embedding_terms.update(set(self.concepts.list_concepts()))
def _concept_completions(self, text, state):
if self.concepts is None:
self.concepts = set(Concepts().list_concepts())
self.embedding_terms.update(self.concepts)
# cache Concepts() instance so we can check for updates in concepts_list during runtime.
self.concepts = Concepts()
self.embedding_terms.update(set(self.concepts.list_concepts()))
else:
self.embedding_terms.update(set(self.concepts.list_concepts()))
partial = text[1:] # this removes the leading '<'
if len(partial) == 0: