Merge branch 'main' into lstein/new-model-manager

This commit is contained in:
Lincoln Stein
2023-05-03 18:09:29 -04:00
committed by GitHub
40 changed files with 652 additions and 536 deletions

View File

@ -17,6 +17,7 @@ from huggingface_hub import (
hf_hub_url,
)
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
@ -66,11 +67,11 @@ class HuggingFaceConceptsLibrary(object):
# when init, add all in dir. when not init, add only concepts added between init and now
self.concept_list.extend(list(local_concepts_to_add))
except Exception as e:
print(
f" ** WARNING: Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
logger.warning(
f"Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
)
print(
" ** You may load .bin and .pt file(s) manually using the --embedding_directory argument."
logger.warning(
"You may load .bin and .pt file(s) manually using the --embedding_directory argument."
)
return self.concept_list
else:
@ -83,7 +84,7 @@ class HuggingFaceConceptsLibrary(object):
be downloaded.
"""
if not concept_name in self.list_concepts():
print(
logger.warning(
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
)
return None
@ -221,7 +222,7 @@ class HuggingFaceConceptsLibrary(object):
if chunk == 0:
bytes += total
print(f">> Downloading {repo_id}...", end="")
logger.info(f"Downloading {repo_id}...", end="")
try:
for file in (
"README.md",
@ -235,22 +236,22 @@ class HuggingFaceConceptsLibrary(object):
)
except ul_error.HTTPError as e:
if e.code == 404:
print(
logger.warning(
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
)
else:
print(
logger.warning(
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
)
os.rmdir(dest)
return False
except ul_error.URLError as e:
print(
f"ERROR while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
logger.error(
f"an error occurred while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
)
os.rmdir(dest)
return False
print("...{:.2f}Kb".format(bytes / 1024))
logger.info("...{:.2f}Kb".format(bytes / 1024))
return succeeded
def _concept_id(self, concept_name: str) -> str:

View File

@ -13,9 +13,9 @@ from compel.cross_attention_control import Arguments
from diffusers.models.attention_processor import AttentionProcessor
from torch import nn
import invokeai.backend.util.logging as logger
from ...util import torch_dtype
class CrossAttentionType(enum.Enum):
SELF = 1
TOKENS = 2
@ -421,7 +421,7 @@ def get_cross_attention_modules(
expected_count = 16
if cross_attention_modules_in_model_count != expected_count:
# non-fatal error but .swap() won't work.
print(
logger.error(
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "

View File

@ -8,6 +8,7 @@ import torch
from diffusers.models.attention_processor import AttentionProcessor
from typing_extensions import TypeAlias
import invokeai.backend.util.logging as logger
from invokeai.backend.globals import Globals
from .cross_attention_control import (
@ -466,10 +467,14 @@ class InvokeAIDiffuserComponent:
outside = torch.count_nonzero(
(latents < -current_threshold) | (latents > current_threshold)
)
print(
f"\nThreshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})\n"
f" | min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}\n"
f" | {outside / latents.numel() * 100:.2f}% values outside threshold"
logger.info(
f"Threshold: %={percent_through} threshold={current_threshold:.3f} (of {threshold:.3f})"
)
logger.debug(
f"min, mean, max = {minval:.3f}, {mean:.3f}, {maxval:.3f}\tstd={std}"
)
logger.debug(
f"{outside / latents.numel() * 100:.2f}% values outside threshold"
)
if maxval < current_threshold and minval > -current_threshold:
@ -496,9 +501,11 @@ class InvokeAIDiffuserComponent:
)
if self.debug_thresholding:
print(
f" | min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})\n"
f" | {num_altered / latents.numel() * 100:.2f}% values altered"
logger.debug(
f"min, , max = {minval:.3f}, , {maxval:.3f}\t(scaled by {scale})"
)
logger.debug(
f"{num_altered / latents.numel() * 100:.2f}% values altered"
)
return latents

View File

@ -10,7 +10,7 @@ from torchvision.utils import make_grid
# import matplotlib.pyplot as plt # TODO: check with Dominik, also bsrgan.py vs bsrgan_light.py
import invokeai.backend.util.logging as logger
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
@ -191,7 +191,7 @@ def mkdirs(paths):
def mkdir_and_rename(path):
if os.path.exists(path):
new_name = path + "_archived_" + get_timestamp()
print("Path already exists. Rename it to [{:s}]".format(new_name))
logger.error("Path already exists. Rename it to [{:s}]".format(new_name))
os.replace(path, new_name)
os.makedirs(path)

View File

@ -10,6 +10,7 @@ from compel.embeddings_provider import BaseTextualInversionManager
from picklescan.scanner import scan_file_path
from transformers import CLIPTextModel, CLIPTokenizer
import invokeai.backend.util.logging as logger
from .concepts_lib import HuggingFaceConceptsLibrary
@dataclass
@ -59,12 +60,12 @@ class TextualInversionManager(BaseTextualInversionManager):
or self.has_textual_inversion_for_trigger_string(concept_name)
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
): # in case a token with literal angle brackets encountered
print(f">> Loaded local embedding for trigger {concept_name}")
logger.info(f"Loaded local embedding for trigger {concept_name}")
continue
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
if not bin_file:
continue
print(f">> Loaded remote embedding for trigger {concept_name}")
logger.info(f"Loaded remote embedding for trigger {concept_name}")
self.load_textual_inversion(bin_file)
self.hf_concepts_library.concepts_loaded[concept_name] = True
@ -85,8 +86,8 @@ class TextualInversionManager(BaseTextualInversionManager):
embedding_list = self._parse_embedding(str(ckpt_path))
for embedding_info in embedding_list:
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
print(
f" ** Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
logger.warning(
f"Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
)
continue
@ -105,8 +106,8 @@ class TextualInversionManager(BaseTextualInversionManager):
if ckpt_path.name == "learned_embeds.bin"
else f"<{ckpt_path.stem}>"
)
print(
f">> {sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
logger.info(
f"{sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
)
trigger_str = replacement_trigger_str
@ -120,8 +121,8 @@ class TextualInversionManager(BaseTextualInversionManager):
self.trigger_to_sourcefile[trigger_str] = sourcefile
except ValueError as e:
print(f' | Ignoring incompatible embedding {embedding_info["name"]}')
print(f" | The error was {str(e)}")
logger.debug(f'Ignoring incompatible embedding {embedding_info["name"]}')
logger.debug(f"The error was {str(e)}")
def _add_textual_inversion(
self, trigger_str, embedding, defer_injecting_tokens=False
@ -133,8 +134,8 @@ class TextualInversionManager(BaseTextualInversionManager):
:return: The token id for the added embedding, either existing or newly-added.
"""
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
print(
f"** TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
logger.warning(
f"TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
)
return
if not self.full_precision:
@ -155,11 +156,11 @@ class TextualInversionManager(BaseTextualInversionManager):
except ValueError as e:
if str(e).startswith("Warning"):
print(f">> {str(e)}")
logger.warning(f"{str(e)}")
else:
traceback.print_exc()
print(
f"** TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
logger.error(
f"TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
)
raise
@ -219,16 +220,16 @@ class TextualInversionManager(BaseTextualInversionManager):
for ti in self.textual_inversions:
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
if ti.embedding_vector_length > 1:
print(
f">> Preparing tokens for textual inversion {ti.trigger_string}..."
logger.info(
f"Preparing tokens for textual inversion {ti.trigger_string}..."
)
try:
self._inject_tokens_and_assign_embeddings(ti)
except ValueError as e:
print(
f" | Ignoring incompatible embedding trigger {ti.trigger_string}"
logger.debug(
f"Ignoring incompatible embedding trigger {ti.trigger_string}"
)
print(f" | The error was {str(e)}")
logger.debug(f"The error was {str(e)}")
continue
injected_token_ids.append(ti.trigger_token_id)
injected_token_ids.extend(ti.pad_token_ids)
@ -306,16 +307,16 @@ class TextualInversionManager(BaseTextualInversionManager):
if suffix in [".pt",".ckpt",".bin"]:
scan_result = scan_file_path(embedding_file)
if scan_result.infected_files > 0:
print(
f" ** Security Issues Found in Model: {scan_result.issues_count}"
logger.critical(
f"Security Issues Found in Model: {scan_result.issues_count}"
)
print(" ** For your safety, InvokeAI will not load this embed.")
logger.critical("For your safety, InvokeAI will not load this embed.")
return list()
ckpt = torch.load(embedding_file,map_location="cpu")
else:
ckpt = safetensors.torch.load_file(embedding_file)
except Exception as e:
print(f" ** Notice: unrecognized embedding file format: {embedding_file}: {e}")
logger.warning(f"Notice: unrecognized embedding file format: {embedding_file}: {e}")
return list()
# try to figure out what kind of embedding file it is and parse accordingly
@ -334,7 +335,7 @@ class TextualInversionManager(BaseTextualInversionManager):
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
basename = Path(file_path).stem
print(f' | Loading v1 embedding file: {basename}')
logger.debug(f'Loading v1 embedding file: {basename}')
embeddings = list()
token_counter = -1
@ -342,7 +343,7 @@ class TextualInversionManager(BaseTextualInversionManager):
if token_counter < 0:
trigger = embedding_ckpt["name"]
elif token_counter == 0:
trigger = f'<basename>'
trigger = '<basename>'
else:
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
token_counter += 1
@ -365,7 +366,7 @@ class TextualInversionManager(BaseTextualInversionManager):
This handles embedding .pt file variant #2.
"""
basename = Path(file_path).stem
print(f' | Loading v2 embedding file: {basename}')
logger.debug(f'Loading v2 embedding file: {basename}')
embeddings = list()
if isinstance(
@ -384,7 +385,7 @@ class TextualInversionManager(BaseTextualInversionManager):
)
embeddings.append(embedding_info)
else:
print(f" ** {basename}: Unrecognized embedding format")
logger.warning(f"{basename}: Unrecognized embedding format")
return embeddings
@ -393,7 +394,7 @@ class TextualInversionManager(BaseTextualInversionManager):
Parse 'version 3' of the .pt textual inversion embedding files.
"""
basename = Path(file_path).stem
print(f' | Loading v3 embedding file: {basename}')
logger.debug(f'Loading v3 embedding file: {basename}')
embedding = embedding_ckpt['emb_params']
embedding_info = EmbeddingInfo(
name = f'<{basename}>',
@ -411,11 +412,11 @@ class TextualInversionManager(BaseTextualInversionManager):
basename = Path(filepath).stem
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
print(f' | Loading v4 embedding file: {short_path}')
logger.debug(f'Loading v4 embedding file: {short_path}')
embeddings = list()
if list(embedding_ckpt.keys()) == 0:
print(f" ** Invalid embeddings file: {short_path}")
logger.warning(f"Invalid embeddings file: {short_path}")
else:
for token,embedding in embedding_ckpt.items():
embedding_info = EmbeddingInfo(