mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
275 lines
10 KiB
Python
275 lines
10 KiB
Python
"""
|
|
Query and install embeddings from the HuggingFace SD Concepts Library
|
|
at https://huggingface.co/sd-concepts-library.
|
|
|
|
The interface is through the Concepts() object.
|
|
"""
|
|
import os
|
|
import re
|
|
from typing import Callable
|
|
from urllib import error as ul_error
|
|
from urllib import request
|
|
|
|
from huggingface_hub import (
|
|
HfApi,
|
|
HfFolder,
|
|
ModelFilter,
|
|
hf_hub_url,
|
|
)
|
|
|
|
import invokeai.backend.util.logging as logger
|
|
from invokeai.app.services.config import get_invokeai_config
|
|
config = get_invokeai_config()
|
|
|
|
class HuggingFaceConceptsLibrary(object):
|
|
def __init__(self, root=None):
|
|
"""
|
|
Initialize the Concepts object. May optionally pass a root directory.
|
|
"""
|
|
self.root = root or config.root
|
|
self.hf_api = HfApi()
|
|
self.local_concepts = dict()
|
|
self.concept_list = None
|
|
self.concepts_loaded = dict()
|
|
self.triggers = dict() # concept name to trigger phrase
|
|
self.concept_names = dict() # trigger phrase to concept name
|
|
self.match_trigger = re.compile(
|
|
"(<[\w\- >]+>)"
|
|
) # trigger is slightly less restrictive than HF concept name
|
|
self.match_concept = re.compile(
|
|
"<([\w\-]+)>"
|
|
) # HF concept name can only contain A-Za-z0-9_-
|
|
|
|
def list_concepts(self) -> list:
|
|
"""
|
|
Return a list of all the concepts by name, without the 'sd-concepts-library' part.
|
|
Also adds local concepts in invokeai/embeddings folder.
|
|
"""
|
|
local_concepts_now = self.get_local_concepts(
|
|
os.path.join(self.root, "embeddings")
|
|
)
|
|
local_concepts_to_add = set(local_concepts_now).difference(
|
|
set(self.local_concepts)
|
|
)
|
|
self.local_concepts.update(local_concepts_now)
|
|
|
|
if self.concept_list is not None:
|
|
if local_concepts_to_add:
|
|
self.concept_list.extend(list(local_concepts_to_add))
|
|
return self.concept_list
|
|
return self.concept_list
|
|
elif config.internet_available is True:
|
|
try:
|
|
models = self.hf_api.list_models(
|
|
filter=ModelFilter(model_name="sd-concepts-library/")
|
|
)
|
|
self.concept_list = [a.id.split("/")[1] for a in models]
|
|
# when init, add all in dir. when not init, add only concepts added between init and now
|
|
self.concept_list.extend(list(local_concepts_to_add))
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
|
)
|
|
logger.warning(
|
|
"You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
|
)
|
|
return self.concept_list
|
|
else:
|
|
return self.concept_list
|
|
|
|
def get_concept_model_path(self, concept_name: str) -> str:
|
|
"""
|
|
Returns the path to the 'learned_embeds.bin' file in
|
|
the named concept. Returns None if invalid or cannot
|
|
be downloaded.
|
|
"""
|
|
if not concept_name in self.list_concepts():
|
|
logger.warning(
|
|
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
|
)
|
|
return None
|
|
return self.get_concept_file(concept_name.lower(), "learned_embeds.bin")
|
|
|
|
def concept_to_trigger(self, concept_name: str) -> str:
|
|
"""
|
|
Given a concept name returns its trigger by looking in the
|
|
"token_identifier.txt" file.
|
|
"""
|
|
if concept_name in self.triggers:
|
|
return self.triggers[concept_name]
|
|
elif self.concept_is_local(concept_name):
|
|
trigger = f"<{concept_name}>"
|
|
self.triggers[concept_name] = trigger
|
|
self.concept_names[trigger] = concept_name
|
|
return trigger
|
|
|
|
file = self.get_concept_file(
|
|
concept_name, "token_identifier.txt", local_only=True
|
|
)
|
|
if not file:
|
|
return None
|
|
with open(file, "r") as f:
|
|
trigger = f.readline()
|
|
trigger = trigger.strip()
|
|
self.triggers[concept_name] = trigger
|
|
self.concept_names[trigger] = concept_name
|
|
return trigger
|
|
|
|
def trigger_to_concept(self, trigger: str) -> str:
|
|
"""
|
|
Given a trigger phrase, maps it to the concept library name.
|
|
Only works if concept_to_trigger() has previously been called
|
|
on this library. There needs to be a persistent database for
|
|
this.
|
|
"""
|
|
concept = self.concept_names.get(trigger, None)
|
|
return f"<{concept}>" if concept else f"{trigger}"
|
|
|
|
def replace_triggers_with_concepts(self, prompt: str) -> str:
|
|
"""
|
|
Given a prompt string that contains <trigger> tags, replace these
|
|
tags with the concept name. The reason for this is so that the
|
|
concept names get stored in the prompt metadata. There is no
|
|
controlling of colliding triggers in the SD library, so it is
|
|
better to store the concept name (unique) than the concept trigger
|
|
(not necessarily unique!)
|
|
"""
|
|
if not prompt:
|
|
return prompt
|
|
triggers = self.match_trigger.findall(prompt)
|
|
if not triggers:
|
|
return prompt
|
|
|
|
def do_replace(match) -> str:
|
|
return self.trigger_to_concept(match.group(1)) or f"<{match.group(1)}>"
|
|
|
|
return self.match_trigger.sub(do_replace, prompt)
|
|
|
|
def replace_concepts_with_triggers(
|
|
self,
|
|
prompt: str,
|
|
load_concepts_callback: Callable[[list], any],
|
|
excluded_tokens: list[str],
|
|
) -> str:
|
|
"""
|
|
Given a prompt string that contains `<concept_name>` tags, replace
|
|
these tags with the appropriate trigger.
|
|
|
|
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
|
|
of `concepts_name` strings.
|
|
|
|
`excluded_tokens` are any tokens that should not be replaced, typically because they
|
|
are trigger tokens from a locally-loaded embedding.
|
|
"""
|
|
concepts = self.match_concept.findall(prompt)
|
|
if not concepts:
|
|
return prompt
|
|
load_concepts_callback(concepts)
|
|
|
|
def do_replace(match) -> str:
|
|
if excluded_tokens and f"<{match.group(1)}>" in excluded_tokens:
|
|
return f"<{match.group(1)}>"
|
|
return self.concept_to_trigger(match.group(1)) or f"<{match.group(1)}>"
|
|
|
|
return self.match_concept.sub(do_replace, prompt)
|
|
|
|
def get_concept_file(
|
|
self,
|
|
concept_name: str,
|
|
file_name: str = "learned_embeds.bin",
|
|
local_only: bool = False,
|
|
) -> str:
|
|
if not (
|
|
self.concept_is_downloaded(concept_name)
|
|
or self.concept_is_local(concept_name)
|
|
or local_only
|
|
):
|
|
self.download_concept(concept_name)
|
|
|
|
# get local path in invokeai/embeddings if local concept
|
|
if self.concept_is_local(concept_name):
|
|
concept_path = self._concept_local_path(concept_name)
|
|
path = concept_path
|
|
else:
|
|
concept_path = self._concept_path(concept_name)
|
|
path = os.path.join(concept_path, file_name)
|
|
return path if os.path.exists(path) else None
|
|
|
|
def concept_is_local(self, concept_name) -> bool:
|
|
return concept_name in self.local_concepts
|
|
|
|
def concept_is_downloaded(self, concept_name) -> bool:
|
|
concept_directory = self._concept_path(concept_name)
|
|
return os.path.exists(concept_directory)
|
|
|
|
def download_concept(self, concept_name) -> bool:
|
|
repo_id = self._concept_id(concept_name)
|
|
dest = self._concept_path(concept_name)
|
|
|
|
access_token = HfFolder.get_token()
|
|
header = [("Authorization", f"Bearer {access_token}")] if access_token else []
|
|
opener = request.build_opener()
|
|
opener.addheaders = header
|
|
request.install_opener(opener)
|
|
|
|
os.makedirs(dest, exist_ok=True)
|
|
succeeded = True
|
|
|
|
bytes = 0
|
|
|
|
def tally_download_size(chunk, size, total):
|
|
nonlocal bytes
|
|
if chunk == 0:
|
|
bytes += total
|
|
|
|
logger.info(f"Downloading {repo_id}...", end="")
|
|
try:
|
|
for file in (
|
|
"README.md",
|
|
"learned_embeds.bin",
|
|
"token_identifier.txt",
|
|
"type_of_concept.txt",
|
|
):
|
|
url = hf_hub_url(repo_id, file)
|
|
request.urlretrieve(
|
|
url, os.path.join(dest, file), reporthook=tally_download_size
|
|
)
|
|
except ul_error.HTTPError as e:
|
|
if e.code == 404:
|
|
logger.warning(
|
|
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
|
|
)
|
|
os.rmdir(dest)
|
|
return False
|
|
except ul_error.URLError as e:
|
|
logger.error(
|
|
f"an error occurred while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
|
)
|
|
os.rmdir(dest)
|
|
return False
|
|
logger.info("...{:.2f}Kb".format(bytes / 1024))
|
|
return succeeded
|
|
|
|
def _concept_id(self, concept_name: str) -> str:
|
|
return f"sd-concepts-library/{concept_name}"
|
|
|
|
def _concept_path(self, concept_name: str) -> str:
|
|
return os.path.join(self.root, "models", "sd-concepts-library", concept_name)
|
|
|
|
def _concept_local_path(self, concept_name: str) -> str:
|
|
filename = self.local_concepts[concept_name]
|
|
return os.path.join(self.root, "embeddings", filename)
|
|
|
|
def get_local_concepts(self, loc_dir: str):
|
|
locs_dic = dict()
|
|
if os.path.isdir(loc_dir):
|
|
for file in os.listdir(loc_dir):
|
|
f = os.path.splitext(file)
|
|
if f[1] == ".bin" or f[1] == ".pt":
|
|
locs_dic[f[0]] = file
|
|
return locs_dic
|