mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Local embeddings support (CLI autocomplete) (#2211)
* integrate local embeds with HF embeds * Update concepts_lib.py * Update concepts_lib.py Co-authored-by: BuildTools <unconfigured@null.spigotmc.org> Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
This commit is contained in:
parent
6c6e534c1a
commit
21bf512056
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,4 +1,5 @@
|
|||||||
# ignore default image save location and model symbolic link
|
# ignore default image save location and model symbolic link
|
||||||
|
embeddings/
|
||||||
outputs/
|
outputs/
|
||||||
models/ldm/stable-diffusion-v1/model.ckpt
|
models/ldm/stable-diffusion-v1/model.ckpt
|
||||||
**/restoration/codeformer/weights
|
**/restoration/codeformer/weights
|
||||||
|
@ -19,6 +19,7 @@ class Concepts(object):
|
|||||||
'''
|
'''
|
||||||
self.root = root or Globals.root
|
self.root = root or Globals.root
|
||||||
self.hf_api = HfApi()
|
self.hf_api = HfApi()
|
||||||
|
self.local_concepts = dict()
|
||||||
self.concept_list = None
|
self.concept_list = None
|
||||||
self.concepts_loaded = dict()
|
self.concepts_loaded = dict()
|
||||||
self.triggers = dict() # concept name to trigger phrase
|
self.triggers = dict() # concept name to trigger phrase
|
||||||
@ -28,17 +29,28 @@ class Concepts(object):
|
|||||||
|
|
||||||
def list_concepts(self)->list:
|
def list_concepts(self)->list:
|
||||||
'''
|
'''
|
||||||
Return a list of all the concepts by name, without the 'sd-concepts-library' part.
|
Return a list of all the concepts by name, without the 'sd-concepts-library' part.
|
||||||
|
Also adds local concepts in invokeai/embeddings folder.
|
||||||
'''
|
'''
|
||||||
|
local_concepts_now = self.get_local_concepts(os.path.join(self.root, 'embeddings'))
|
||||||
|
local_concepts_to_add = set(local_concepts_now).difference(set(self.local_concepts))
|
||||||
|
self.local_concepts.update(local_concepts_now)
|
||||||
|
|
||||||
if self.concept_list is not None:
|
if self.concept_list is not None:
|
||||||
|
if local_concepts_to_add:
|
||||||
|
self.concept_list.extend(list(local_concepts_to_add))
|
||||||
|
return self.concept_list
|
||||||
|
return self.concept_list
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
models = self.hf_api.list_models(filter=ModelFilter(model_name='sd-concepts-library/'))
|
||||||
|
self.concept_list = [a.id.split('/')[1] for a in models]
|
||||||
|
# when init, add all in dir. when not init, add only concepts added between init and now
|
||||||
|
self.concept_list.extend(list(local_concepts_to_add))
|
||||||
|
except Exception as e:
|
||||||
|
print(f' ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}.')
|
||||||
|
print(' ** You may load .bin and .pt file(s) manually using the --embedding_directory argument.')
|
||||||
return self.concept_list
|
return self.concept_list
|
||||||
try:
|
|
||||||
models = self.hf_api.list_models(filter=ModelFilter(model_name='sd-concepts-library/'))
|
|
||||||
self.concept_list = [a.id.split('/')[1] for a in models]
|
|
||||||
except Exception as e:
|
|
||||||
print(f' ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}.')
|
|
||||||
print(' ** You may load .bin and .pt file(s) manually using the --embedding_directory argument.')
|
|
||||||
return self.concept_list
|
|
||||||
|
|
||||||
def get_concept_model_path(self, concept_name:str)->str:
|
def get_concept_model_path(self, concept_name:str)->str:
|
||||||
'''
|
'''
|
||||||
@ -58,6 +70,12 @@ class Concepts(object):
|
|||||||
'''
|
'''
|
||||||
if concept_name in self.triggers:
|
if concept_name in self.triggers:
|
||||||
return self.triggers[concept_name]
|
return self.triggers[concept_name]
|
||||||
|
elif self.concept_is_local(concept_name):
|
||||||
|
trigger = f'<{concept_name}>'
|
||||||
|
self.triggers[concept_name] = trigger
|
||||||
|
self.concept_names[trigger] = concept_name
|
||||||
|
return trigger
|
||||||
|
|
||||||
file = self.get_concept_file(concept_name, 'token_identifier.txt', local_only=True)
|
file = self.get_concept_file(concept_name, 'token_identifier.txt', local_only=True)
|
||||||
if not file:
|
if not file:
|
||||||
return None
|
return None
|
||||||
@ -115,10 +133,20 @@ class Concepts(object):
|
|||||||
return self.match_concept.sub(do_replace, prompt)
|
return self.match_concept.sub(do_replace, prompt)
|
||||||
|
|
||||||
def get_concept_file(self, concept_name:str, file_name:str='learned_embeds.bin' , local_only:bool=False)->str:
|
def get_concept_file(self, concept_name:str, file_name:str='learned_embeds.bin' , local_only:bool=False)->str:
|
||||||
if not self.concept_is_downloaded(concept_name) and not local_only:
|
if not (self.concept_is_downloaded(concept_name) or self.concept_is_local(concept_name) or local_only):
|
||||||
self.download_concept(concept_name)
|
self.download_concept(concept_name)
|
||||||
path = os.path.join(self._concept_path(concept_name), file_name)
|
|
||||||
|
# get local path in invokeai/embeddings if local concept
|
||||||
|
if self.concept_is_local(concept_name):
|
||||||
|
concept_path = self._concept_local_path(concept_name)
|
||||||
|
path = concept_path
|
||||||
|
else:
|
||||||
|
concept_path = self._concept_path(concept_name)
|
||||||
|
path = os.path.join(concept_path, file_name)
|
||||||
return path if os.path.exists(path) else None
|
return path if os.path.exists(path) else None
|
||||||
|
|
||||||
|
def concept_is_local(self, concept_name)->bool:
|
||||||
|
return concept_name in self.local_concepts
|
||||||
|
|
||||||
def concept_is_downloaded(self, concept_name)->bool:
|
def concept_is_downloaded(self, concept_name)->bool:
|
||||||
concept_directory = self._concept_path(concept_name)
|
concept_directory = self._concept_path(concept_name)
|
||||||
@ -167,3 +195,16 @@ class Concepts(object):
|
|||||||
|
|
||||||
def _concept_path(self, concept_name:str)->str:
|
def _concept_path(self, concept_name:str)->str:
|
||||||
return os.path.join(self.root,'models','sd-concepts-library',concept_name)
|
return os.path.join(self.root,'models','sd-concepts-library',concept_name)
|
||||||
|
|
||||||
|
def _concept_local_path(self, concept_name:str)->str:
|
||||||
|
filename = self.local_concepts[concept_name]
|
||||||
|
return os.path.join(self.root,'embeddings',filename)
|
||||||
|
|
||||||
|
def get_local_concepts(self, loc_dir:str):
|
||||||
|
locs_dic = dict()
|
||||||
|
if os.path.isdir(loc_dir):
|
||||||
|
for file in os.listdir(loc_dir):
|
||||||
|
f = os.path.splitext(file)
|
||||||
|
if f[1] == '.bin' or f[1] == '.pt':
|
||||||
|
locs_dic[f[0]] = file
|
||||||
|
return locs_dic
|
||||||
|
@ -126,6 +126,7 @@ class Completer(object):
|
|||||||
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
|
elif re.search('(-S\s*|--seed[=\s])\d*$',buffer):
|
||||||
self.matches= self._seed_completions(text,state)
|
self.matches= self._seed_completions(text,state)
|
||||||
|
|
||||||
|
# looking for an embedding concept
|
||||||
elif re.search('<[\w-]*$',buffer):
|
elif re.search('<[\w-]*$',buffer):
|
||||||
self.matches= self._concept_completions(text,state)
|
self.matches= self._concept_completions(text,state)
|
||||||
|
|
||||||
@ -272,12 +273,15 @@ class Completer(object):
|
|||||||
def add_embedding_terms(self, terms:list[str]):
|
def add_embedding_terms(self, terms:list[str]):
|
||||||
self.embedding_terms = set(terms)
|
self.embedding_terms = set(terms)
|
||||||
if self.concepts:
|
if self.concepts:
|
||||||
self.embedding_terms.update(self.concepts)
|
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||||
|
|
||||||
def _concept_completions(self, text, state):
|
def _concept_completions(self, text, state):
|
||||||
if self.concepts is None:
|
if self.concepts is None:
|
||||||
self.concepts = set(Concepts().list_concepts())
|
# cache Concepts() instance so we can check for updates in concepts_list during runtime.
|
||||||
self.embedding_terms.update(self.concepts)
|
self.concepts = Concepts()
|
||||||
|
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||||
|
else:
|
||||||
|
self.embedding_terms.update(set(self.concepts.list_concepts()))
|
||||||
|
|
||||||
partial = text[1:] # this removes the leading '<'
|
partial = text[1:] # this removes the leading '<'
|
||||||
if len(partial) == 0:
|
if len(partial) == 0:
|
||||||
|
Loading…
Reference in New Issue
Block a user