mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
f1748d7017
Before making a concept download request to HuggingFace, the concepts library module now checks the concept name against a downloaded list of all the concepts currently known to HuggingFace. If the requested concept is not on the list, then no download request is made.
170 lines
7.1 KiB
Python
170 lines
7.1 KiB
Python
"""
|
|
Query and install embeddings from the HuggingFace SD Concepts Library
|
|
at https://huggingface.co/sd-concepts-library.
|
|
|
|
The interface is through the Concepts() object.
|
|
"""
|
|
import os
|
|
import re
|
|
import traceback
|
|
from typing import Callable
|
|
from urllib import request, error as ul_error
|
|
from huggingface_hub import HfFolder, hf_hub_url, ModelSearchArguments, ModelFilter, HfApi
|
|
from ldm.invoke.globals import Globals
|
|
|
|
class Concepts(object):
|
|
def __init__(self, root=None):
|
|
'''
|
|
Initialize the Concepts object. May optionally pass a root directory.
|
|
'''
|
|
self.root = root or Globals.root
|
|
self.hf_api = HfApi()
|
|
self.concept_list = None
|
|
self.concepts_loaded = dict()
|
|
self.triggers = dict() # concept name to trigger phrase
|
|
self.concept_names = dict() # trigger phrase to concept name
|
|
self.match_trigger = re.compile('(<[\w\- >]+>)') # trigger is slightly less restrictive than HF concept name
|
|
self.match_concept = re.compile('<([\w\-]+)>') # HF concept name can only contain A-Za-z0-9_-
|
|
|
|
def list_concepts(self)->list:
|
|
'''
|
|
Return a list of all the concepts by name, without the 'sd-concepts-library' part.
|
|
'''
|
|
if self.concept_list is not None:
|
|
return self.concept_list
|
|
try:
|
|
models = self.hf_api.list_models(filter=ModelFilter(model_name='sd-concepts-library/'))
|
|
self.concept_list = [a.id.split('/')[1] for a in models]
|
|
except Exception as e:
|
|
print(f' ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}.')
|
|
print(' ** You may load .bin and .pt file(s) manually using the --embedding_directory argument.')
|
|
return self.concept_list
|
|
|
|
def get_concept_model_path(self, concept_name:str)->str:
|
|
'''
|
|
Returns the path to the 'learned_embeds.bin' file in
|
|
the named concept. Returns None if invalid or cannot
|
|
be downloaded.
|
|
'''
|
|
if not concept_name in self.list_concepts():
|
|
print(f'This concept is not known to the Hugging Face library. Generation will continue without the concept.')
|
|
return None
|
|
return self.get_concept_file(concept_name.lower(),'learned_embeds.bin')
|
|
|
|
def concept_to_trigger(self, concept_name:str)->str:
|
|
'''
|
|
Given a concept name returns its trigger by looking in the
|
|
"token_identifier.txt" file.
|
|
'''
|
|
if concept_name in self.triggers:
|
|
return self.triggers[concept_name]
|
|
file = self.get_concept_file(concept_name, 'token_identifier.txt', local_only=True)
|
|
if not file:
|
|
return None
|
|
with open(file,'r') as f:
|
|
trigger = f.readline()
|
|
trigger = trigger.strip()
|
|
self.triggers[concept_name] = trigger
|
|
self.concept_names[trigger] = concept_name
|
|
return trigger
|
|
|
|
def trigger_to_concept(self, trigger:str)->str:
|
|
'''
|
|
Given a trigger phrase, maps it to the concept library name.
|
|
Only works if concept_to_trigger() has previously been called
|
|
on this library. There needs to be a persistent database for
|
|
this.
|
|
'''
|
|
concept = self.concept_names.get(trigger,None)
|
|
return f'<{concept}>' if concept else f'{trigger}'
|
|
|
|
def replace_triggers_with_concepts(self, prompt:str)->str:
|
|
'''
|
|
Given a prompt string that contains <trigger> tags, replace these
|
|
tags with the concept name. The reason for this is so that the
|
|
concept names get stored in the prompt metadata. There is no
|
|
controlling of colliding triggers in the SD library, so it is
|
|
better to store the concept name (unique) than the concept trigger
|
|
(not necessarily unique!)
|
|
'''
|
|
if not prompt:
|
|
return prompt
|
|
triggers = self.match_trigger.findall(prompt)
|
|
if not triggers:
|
|
return prompt
|
|
|
|
def do_replace(match)->str:
|
|
return self.trigger_to_concept(match.group(1)) or f'<{match.group(1)}>'
|
|
return self.match_trigger.sub(do_replace, prompt)
|
|
|
|
def replace_concepts_with_triggers(self, prompt:str, load_concepts_callback: Callable[[list], any])->str:
|
|
'''
|
|
Given a prompt string that contains `<concept_name>` tags, replace
|
|
these tags with the appropriate trigger.
|
|
|
|
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
|
|
of `concepts_name` strings.
|
|
'''
|
|
concepts = self.match_concept.findall(prompt)
|
|
if not concepts:
|
|
return prompt
|
|
load_concepts_callback(concepts)
|
|
|
|
def do_replace(match)->str:
|
|
return self.concept_to_trigger(match.group(1)) or f'<{match.group(1)}>'
|
|
return self.match_concept.sub(do_replace, prompt)
|
|
|
|
def get_concept_file(self, concept_name:str, file_name:str='learned_embeds.bin' , local_only:bool=False)->str:
|
|
if not self.concept_is_downloaded(concept_name) and not local_only:
|
|
self.download_concept(concept_name)
|
|
path = os.path.join(self._concept_path(concept_name), file_name)
|
|
return path if os.path.exists(path) else None
|
|
|
|
def concept_is_downloaded(self, concept_name)->bool:
|
|
concept_directory = self._concept_path(concept_name)
|
|
return os.path.exists(concept_directory)
|
|
|
|
def download_concept(self,concept_name)->bool:
|
|
repo_id = self._concept_id(concept_name)
|
|
dest = self._concept_path(concept_name)
|
|
|
|
access_token = HfFolder.get_token()
|
|
header = [("Authorization", f'Bearer {access_token}')] if access_token else []
|
|
opener = request.build_opener()
|
|
opener.addheaders = header
|
|
request.install_opener(opener)
|
|
|
|
os.makedirs(dest, exist_ok=True)
|
|
succeeded = True
|
|
|
|
bytes = 0
|
|
def tally_download_size(chunk, size, total):
|
|
nonlocal bytes
|
|
if chunk==0:
|
|
bytes += total
|
|
|
|
print(f'>> Downloading {repo_id}...',end='')
|
|
try:
|
|
for file in ('README.md','learned_embeds.bin','token_identifier.txt','type_of_concept.txt'):
|
|
url = hf_hub_url(repo_id, file)
|
|
request.urlretrieve(url, os.path.join(dest,file),reporthook=tally_download_size)
|
|
except ul_error.HTTPError as e:
|
|
if e.code==404:
|
|
print(f'This concept is not known to the Hugging Face library. Generation will continue without the concept.')
|
|
else:
|
|
print(f'Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)')
|
|
os.rmdir(dest)
|
|
return False
|
|
except ul_error.URLError as e:
|
|
print(f'ERROR: {str(e)}. This may reflect a network issue. Generation will continue without the concept.')
|
|
os.rmdir(dest)
|
|
return False
|
|
print('...{:.2f}Kb'.format(bytes/1024))
|
|
return succeeded
|
|
|
|
def _concept_id(self, concept_name:str)->str:
|
|
return f'sd-concepts-library/{concept_name}'
|
|
|
|
def _concept_path(self, concept_name:str)->str:
|
|
return os.path.join(self.root,'models','sd-concepts-library',concept_name)
|