mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' of github.com:invoke-ai/InvokeAI into feat/controlnet-control-modes
Only "real" conflicts were in: invokeai/frontend/web/src/features/controlNet/components/ControlNet.tsx invokeai/frontend/web/src/features/controlNet/store/controlNetSlice.ts
This commit is contained in:
@ -1,7 +1,6 @@
|
||||
"""
|
||||
Initialization file for the invokeai.backend.stable_diffusion package
|
||||
"""
|
||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||
from .diffusers_pipeline import (
|
||||
ConditioningData,
|
||||
PipelineIntermediateState,
|
||||
@ -10,4 +9,3 @@ from .diffusers_pipeline import (
|
||||
from .diffusion import InvokeAIDiffuserComponent
|
||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from .textual_inversion_manager import TextualInversionManager
|
||||
|
@ -1,275 +0,0 @@
|
||||
"""
|
||||
Query and install embeddings from the HuggingFace SD Concepts Library
|
||||
at https://huggingface.co/sd-concepts-library.
|
||||
|
||||
The interface is through the Concepts() object.
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
from typing import Callable
|
||||
from urllib import error as ul_error
|
||||
from urllib import request
|
||||
|
||||
from huggingface_hub import (
|
||||
HfApi,
|
||||
HfFolder,
|
||||
ModelFilter,
|
||||
hf_hub_url,
|
||||
)
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
logger = InvokeAILogger.getLogger()
|
||||
|
||||
class HuggingFaceConceptsLibrary(object):
|
||||
def __init__(self, root=None):
|
||||
"""
|
||||
Initialize the Concepts object. May optionally pass a root directory.
|
||||
"""
|
||||
self.config = InvokeAIAppConfig.get_config()
|
||||
self.root = root or self.config.root
|
||||
self.hf_api = HfApi()
|
||||
self.local_concepts = dict()
|
||||
self.concept_list = None
|
||||
self.concepts_loaded = dict()
|
||||
self.triggers = dict() # concept name to trigger phrase
|
||||
self.concept_names = dict() # trigger phrase to concept name
|
||||
self.match_trigger = re.compile(
|
||||
"(<[\w\- >]+>)"
|
||||
) # trigger is slightly less restrictive than HF concept name
|
||||
self.match_concept = re.compile(
|
||||
"<([\w\-]+)>"
|
||||
) # HF concept name can only contain A-Za-z0-9_-
|
||||
|
||||
def list_concepts(self) -> list:
|
||||
"""
|
||||
Return a list of all the concepts by name, without the 'sd-concepts-library' part.
|
||||
Also adds local concepts in invokeai/embeddings folder.
|
||||
"""
|
||||
local_concepts_now = self.get_local_concepts(
|
||||
os.path.join(self.root, "embeddings")
|
||||
)
|
||||
local_concepts_to_add = set(local_concepts_now).difference(
|
||||
set(self.local_concepts)
|
||||
)
|
||||
self.local_concepts.update(local_concepts_now)
|
||||
|
||||
if self.concept_list is not None:
|
||||
if local_concepts_to_add:
|
||||
self.concept_list.extend(list(local_concepts_to_add))
|
||||
return self.concept_list
|
||||
return self.concept_list
|
||||
elif self.config.internet_available is True:
|
||||
try:
|
||||
models = self.hf_api.list_models(
|
||||
filter=ModelFilter(model_name="sd-concepts-library/")
|
||||
)
|
||||
self.concept_list = [a.id.split("/")[1] for a in models]
|
||||
# when init, add all in dir. when not init, add only concepts added between init and now
|
||||
self.concept_list.extend(list(local_concepts_to_add))
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Hugging Face textual inversion concepts libraries could not be loaded. The error was {str(e)}."
|
||||
)
|
||||
logger.warning(
|
||||
"You may load .bin and .pt file(s) manually using the --embedding_directory argument."
|
||||
)
|
||||
return self.concept_list
|
||||
else:
|
||||
return self.concept_list
|
||||
|
||||
def get_concept_model_path(self, concept_name: str) -> str:
|
||||
"""
|
||||
Returns the path to the 'learned_embeds.bin' file in
|
||||
the named concept. Returns None if invalid or cannot
|
||||
be downloaded.
|
||||
"""
|
||||
if not concept_name in self.list_concepts():
|
||||
logger.warning(
|
||||
f"{concept_name} is not a local embedding trigger, nor is it a HuggingFace concept. Generation will continue without the concept."
|
||||
)
|
||||
return None
|
||||
return self.get_concept_file(concept_name.lower(), "learned_embeds.bin")
|
||||
|
||||
def concept_to_trigger(self, concept_name: str) -> str:
|
||||
"""
|
||||
Given a concept name returns its trigger by looking in the
|
||||
"token_identifier.txt" file.
|
||||
"""
|
||||
if concept_name in self.triggers:
|
||||
return self.triggers[concept_name]
|
||||
elif self.concept_is_local(concept_name):
|
||||
trigger = f"<{concept_name}>"
|
||||
self.triggers[concept_name] = trigger
|
||||
self.concept_names[trigger] = concept_name
|
||||
return trigger
|
||||
|
||||
file = self.get_concept_file(
|
||||
concept_name, "token_identifier.txt", local_only=True
|
||||
)
|
||||
if not file:
|
||||
return None
|
||||
with open(file, "r") as f:
|
||||
trigger = f.readline()
|
||||
trigger = trigger.strip()
|
||||
self.triggers[concept_name] = trigger
|
||||
self.concept_names[trigger] = concept_name
|
||||
return trigger
|
||||
|
||||
def trigger_to_concept(self, trigger: str) -> str:
|
||||
"""
|
||||
Given a trigger phrase, maps it to the concept library name.
|
||||
Only works if concept_to_trigger() has previously been called
|
||||
on this library. There needs to be a persistent database for
|
||||
this.
|
||||
"""
|
||||
concept = self.concept_names.get(trigger, None)
|
||||
return f"<{concept}>" if concept else f"{trigger}"
|
||||
|
||||
def replace_triggers_with_concepts(self, prompt: str) -> str:
|
||||
"""
|
||||
Given a prompt string that contains <trigger> tags, replace these
|
||||
tags with the concept name. The reason for this is so that the
|
||||
concept names get stored in the prompt metadata. There is no
|
||||
controlling of colliding triggers in the SD library, so it is
|
||||
better to store the concept name (unique) than the concept trigger
|
||||
(not necessarily unique!)
|
||||
"""
|
||||
if not prompt:
|
||||
return prompt
|
||||
triggers = self.match_trigger.findall(prompt)
|
||||
if not triggers:
|
||||
return prompt
|
||||
|
||||
def do_replace(match) -> str:
|
||||
return self.trigger_to_concept(match.group(1)) or f"<{match.group(1)}>"
|
||||
|
||||
return self.match_trigger.sub(do_replace, prompt)
|
||||
|
||||
def replace_concepts_with_triggers(
|
||||
self,
|
||||
prompt: str,
|
||||
load_concepts_callback: Callable[[list], any],
|
||||
excluded_tokens: list[str],
|
||||
) -> str:
|
||||
"""
|
||||
Given a prompt string that contains `<concept_name>` tags, replace
|
||||
these tags with the appropriate trigger.
|
||||
|
||||
If any `<concept_name>` tags are found, `load_concepts_callback()` is called with a list
|
||||
of `concepts_name` strings.
|
||||
|
||||
`excluded_tokens` are any tokens that should not be replaced, typically because they
|
||||
are trigger tokens from a locally-loaded embedding.
|
||||
"""
|
||||
concepts = self.match_concept.findall(prompt)
|
||||
if not concepts:
|
||||
return prompt
|
||||
load_concepts_callback(concepts)
|
||||
|
||||
def do_replace(match) -> str:
|
||||
if excluded_tokens and f"<{match.group(1)}>" in excluded_tokens:
|
||||
return f"<{match.group(1)}>"
|
||||
return self.concept_to_trigger(match.group(1)) or f"<{match.group(1)}>"
|
||||
|
||||
return self.match_concept.sub(do_replace, prompt)
|
||||
|
||||
def get_concept_file(
|
||||
self,
|
||||
concept_name: str,
|
||||
file_name: str = "learned_embeds.bin",
|
||||
local_only: bool = False,
|
||||
) -> str:
|
||||
if not (
|
||||
self.concept_is_downloaded(concept_name)
|
||||
or self.concept_is_local(concept_name)
|
||||
or local_only
|
||||
):
|
||||
self.download_concept(concept_name)
|
||||
|
||||
# get local path in invokeai/embeddings if local concept
|
||||
if self.concept_is_local(concept_name):
|
||||
concept_path = self._concept_local_path(concept_name)
|
||||
path = concept_path
|
||||
else:
|
||||
concept_path = self._concept_path(concept_name)
|
||||
path = os.path.join(concept_path, file_name)
|
||||
return path if os.path.exists(path) else None
|
||||
|
||||
def concept_is_local(self, concept_name) -> bool:
|
||||
return concept_name in self.local_concepts
|
||||
|
||||
def concept_is_downloaded(self, concept_name) -> bool:
|
||||
concept_directory = self._concept_path(concept_name)
|
||||
return os.path.exists(concept_directory)
|
||||
|
||||
def download_concept(self, concept_name) -> bool:
|
||||
repo_id = self._concept_id(concept_name)
|
||||
dest = self._concept_path(concept_name)
|
||||
|
||||
access_token = HfFolder.get_token()
|
||||
header = [("Authorization", f"Bearer {access_token}")] if access_token else []
|
||||
opener = request.build_opener()
|
||||
opener.addheaders = header
|
||||
request.install_opener(opener)
|
||||
|
||||
os.makedirs(dest, exist_ok=True)
|
||||
succeeded = True
|
||||
|
||||
bytes = 0
|
||||
|
||||
def tally_download_size(chunk, size, total):
|
||||
nonlocal bytes
|
||||
if chunk == 0:
|
||||
bytes += total
|
||||
|
||||
logger.info(f"Downloading {repo_id}...", end="")
|
||||
try:
|
||||
for file in (
|
||||
"README.md",
|
||||
"learned_embeds.bin",
|
||||
"token_identifier.txt",
|
||||
"type_of_concept.txt",
|
||||
):
|
||||
url = hf_hub_url(repo_id, file)
|
||||
request.urlretrieve(
|
||||
url, os.path.join(dest, file), reporthook=tally_download_size
|
||||
)
|
||||
except ul_error.HTTPError as e:
|
||||
if e.code == 404:
|
||||
logger.warning(
|
||||
f"Concept {concept_name} is not known to the Hugging Face library. Generation will continue without the concept."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Failed to download {concept_name}/{file} ({str(e)}. Generation will continue without the concept.)"
|
||||
)
|
||||
os.rmdir(dest)
|
||||
return False
|
||||
except ul_error.URLError as e:
|
||||
logger.error(
|
||||
f"an error occurred while downloading {concept_name}: {str(e)}. This may reflect a network issue. Generation will continue without the concept."
|
||||
)
|
||||
os.rmdir(dest)
|
||||
return False
|
||||
logger.info("...{:.2f}Kb".format(bytes / 1024))
|
||||
return succeeded
|
||||
|
||||
def _concept_id(self, concept_name: str) -> str:
|
||||
return f"sd-concepts-library/{concept_name}"
|
||||
|
||||
def _concept_path(self, concept_name: str) -> str:
|
||||
return os.path.join(self.root, "models", "sd-concepts-library", concept_name)
|
||||
|
||||
def _concept_local_path(self, concept_name: str) -> str:
|
||||
filename = self.local_concepts[concept_name]
|
||||
return os.path.join(self.root, "embeddings", filename)
|
||||
|
||||
def get_local_concepts(self, loc_dir: str):
|
||||
locs_dic = dict()
|
||||
if os.path.isdir(loc_dir):
|
||||
for file in os.listdir(loc_dir):
|
||||
f = os.path.splitext(file)
|
||||
if f[1] == ".bin" or f[1] == ".pt":
|
||||
locs_dic[f[0]] = file
|
||||
return locs_dic
|
@ -16,7 +16,6 @@ from accelerate.utils import set_seed
|
||||
import psutil
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from compel import EmbeddingsProvider
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel, ControlNetOutput
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
@ -48,7 +47,6 @@ from .diffusion import (
|
||||
PostprocessingSettings,
|
||||
)
|
||||
from .offloading import FullyLoadedModelGroup, LazilyLoadedModelGroup, ModelGroup
|
||||
from .textual_inversion_manager import TextualInversionManager
|
||||
|
||||
@dataclass
|
||||
class PipelineIntermediateState:
|
||||
@ -319,6 +317,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
requires_safety_checker: bool = False,
|
||||
precision: str = "float32",
|
||||
control_model: ControlNetModel = None,
|
||||
execution_device: Optional[torch.device] = None,
|
||||
):
|
||||
super().__init__(
|
||||
vae,
|
||||
@ -343,22 +342,10 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
# control_model=control_model,
|
||||
)
|
||||
self.invokeai_diffuser = InvokeAIDiffuserComponent(
|
||||
self.unet, self._unet_forward, is_running_diffusers=True
|
||||
)
|
||||
use_full_precision = precision == "float32" or precision == "autocast"
|
||||
self.textual_inversion_manager = TextualInversionManager(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
full_precision=use_full_precision,
|
||||
)
|
||||
# InvokeAI's interface for text embeddings and whatnot
|
||||
self.embeddings_provider = EmbeddingsProvider(
|
||||
tokenizer=self.tokenizer,
|
||||
text_encoder=self.text_encoder,
|
||||
textual_inversion_manager=self.textual_inversion_manager,
|
||||
self.unet, self._unet_forward
|
||||
)
|
||||
|
||||
self._model_group = FullyLoadedModelGroup(self.unet.device)
|
||||
self._model_group = FullyLoadedModelGroup(execution_device or self.unet.device)
|
||||
self._model_group.install(*self._submodels)
|
||||
self.control_model = control_model
|
||||
|
||||
@ -406,50 +393,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
else:
|
||||
self.disable_attention_slicing()
|
||||
|
||||
def enable_offload_submodels(self, device: torch.device):
|
||||
"""
|
||||
Offload each submodel when it's not in use.
|
||||
|
||||
Useful for low-vRAM situations where the size of the model in memory is a big chunk of
|
||||
the total available resource, and you want to free up as much for inference as possible.
|
||||
|
||||
This requires more moving parts and may add some delay as the U-Net is swapped out for the
|
||||
VAE and vice-versa.
|
||||
"""
|
||||
models = self._submodels
|
||||
if self._model_group is not None:
|
||||
self._model_group.uninstall(*models)
|
||||
group = LazilyLoadedModelGroup(device)
|
||||
group.install(*models)
|
||||
self._model_group = group
|
||||
|
||||
def disable_offload_submodels(self):
|
||||
"""
|
||||
Leave all submodels loaded.
|
||||
|
||||
Appropriate for cases where the size of the model in memory is small compared to the memory
|
||||
required for inference. Avoids the delay and complexity of shuffling the submodels to and
|
||||
from the GPU.
|
||||
"""
|
||||
models = self._submodels
|
||||
if self._model_group is not None:
|
||||
self._model_group.uninstall(*models)
|
||||
group = FullyLoadedModelGroup(self._model_group.execution_device)
|
||||
group.install(*models)
|
||||
self._model_group = group
|
||||
|
||||
def offload_all(self):
|
||||
"""Offload all this pipeline's models to CPU."""
|
||||
self._model_group.offload_current()
|
||||
|
||||
def ready(self):
|
||||
"""
|
||||
Ready this pipeline's models.
|
||||
|
||||
i.e. preload them to the GPU if appropriate.
|
||||
"""
|
||||
self._model_group.ready()
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings=False):
|
||||
# overridden method; types match the superclass.
|
||||
if torch_device is None:
|
||||
@ -1013,25 +956,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
device = self._model_group.device_for(self.safety_checker)
|
||||
return super().run_safety_checker(image, device, dtype)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_learned_conditioning(
|
||||
self, c: List[List[str]], *, return_tokens=True, fragment_weights=None
|
||||
):
|
||||
"""
|
||||
Compatibility function for invokeai.models.diffusion.ddpm.LatentDiffusion.
|
||||
"""
|
||||
return self.embeddings_provider.get_embeddings_for_weighted_prompt_fragments(
|
||||
text_batch=c,
|
||||
fragment_weights_batch=fragment_weights,
|
||||
should_return_tokens=return_tokens,
|
||||
device=self._model_group.device_for(self.unet),
|
||||
)
|
||||
|
||||
@property
|
||||
def channels(self) -> int:
|
||||
"""Compatible with DiffusionWrapper"""
|
||||
return self.unet.config.in_channels
|
||||
|
||||
def decode_latents(self, latents):
|
||||
# Explicit call to get the vae loaded, since `decode` isn't the forward method.
|
||||
self._model_group.load(self.vae)
|
||||
@ -1048,8 +972,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
|
||||
# Copied from diffusers pipeline_stable_diffusion_controlnet.py
|
||||
# Returns torch.Tensor of shape (batch_size, 3, height, width)
|
||||
@staticmethod
|
||||
def prepare_control_image(
|
||||
self,
|
||||
image,
|
||||
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions?
|
||||
# latents,
|
||||
|
@ -18,7 +18,6 @@ from .cross_attention_control import (
|
||||
CrossAttentionType,
|
||||
SwapCrossAttnContext,
|
||||
get_cross_attention_modules,
|
||||
restore_default_cross_attention,
|
||||
setup_cross_attention_control_attention_processors,
|
||||
)
|
||||
from .cross_attention_map_saving import AttentionMapSaver
|
||||
@ -66,7 +65,6 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
model,
|
||||
model_forward_callback: ModelForwardCallback,
|
||||
is_running_diffusers: bool = False,
|
||||
):
|
||||
"""
|
||||
:param model: the unet model to pass through to cross attention control
|
||||
@ -75,7 +73,6 @@ class InvokeAIDiffuserComponent:
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
self.conditioning = None
|
||||
self.model = model
|
||||
self.is_running_diffusers = is_running_diffusers
|
||||
self.model_forward_callback = model_forward_callback
|
||||
self.cross_attention_control_context = None
|
||||
self.sequential_guidance = config.sequential_guidance
|
||||
@ -112,37 +109,6 @@ class InvokeAIDiffuserComponent:
|
||||
# TODO resuscitate attention map saving
|
||||
# self.remove_attention_map_saving()
|
||||
|
||||
# apparently unused code
|
||||
# TODO: delete
|
||||
# def override_cross_attention(
|
||||
# self, conditioning: ExtraConditioningInfo, step_count: int
|
||||
# ) -> Dict[str, AttentionProcessor]:
|
||||
# """
|
||||
# setup cross attention .swap control. for diffusers this replaces the attention processor, so
|
||||
# the previous attention processor is returned so that the caller can restore it later.
|
||||
# """
|
||||
# self.conditioning = conditioning
|
||||
# self.cross_attention_control_context = Context(
|
||||
# arguments=self.conditioning.cross_attention_control_args,
|
||||
# step_count=step_count,
|
||||
# )
|
||||
# return override_cross_attention(
|
||||
# self.model,
|
||||
# self.cross_attention_control_context,
|
||||
# is_running_diffusers=self.is_running_diffusers,
|
||||
# )
|
||||
|
||||
def restore_default_cross_attention(
|
||||
self, restore_attention_processor: Optional["AttentionProcessor"] = None
|
||||
):
|
||||
self.conditioning = None
|
||||
self.cross_attention_control_context = None
|
||||
restore_default_cross_attention(
|
||||
self.model,
|
||||
is_running_diffusers=self.is_running_diffusers,
|
||||
restore_attention_processor=restore_attention_processor,
|
||||
)
|
||||
|
||||
def setup_attention_map_saving(self, saver: AttentionMapSaver):
|
||||
def callback(slice, dim, offset, slice_size, key):
|
||||
if dim is not None:
|
||||
@ -204,9 +170,7 @@ class InvokeAIDiffuserComponent:
|
||||
cross_attention_control_types_to_do = []
|
||||
context: Context = self.cross_attention_control_context
|
||||
if self.cross_attention_control_context is not None:
|
||||
percent_through = self.calculate_percent_through(
|
||||
sigma, step_index, total_step_count
|
||||
)
|
||||
percent_through = step_index / total_step_count
|
||||
cross_attention_control_types_to_do = (
|
||||
context.get_active_cross_attention_control_types_for_step(
|
||||
percent_through
|
||||
@ -264,9 +228,7 @@ class InvokeAIDiffuserComponent:
|
||||
total_step_count,
|
||||
) -> torch.Tensor:
|
||||
if postprocessing_settings is not None:
|
||||
percent_through = self.calculate_percent_through(
|
||||
sigma, step_index, total_step_count
|
||||
)
|
||||
percent_through = step_index / total_step_count
|
||||
latents = self.apply_threshold(
|
||||
postprocessing_settings, latents, percent_through
|
||||
)
|
||||
@ -275,22 +237,6 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
return latents
|
||||
|
||||
def calculate_percent_through(self, sigma, step_index, total_step_count):
|
||||
if step_index is not None and total_step_count is not None:
|
||||
# 🧨diffusers codepath
|
||||
percent_through = (
|
||||
step_index / total_step_count
|
||||
) # will never reach 1.0 - this is deliberate
|
||||
else:
|
||||
# legacy compvis codepath
|
||||
# TODO remove when compvis codepath support is dropped
|
||||
if step_index is None and sigma is None:
|
||||
raise ValueError(
|
||||
"Either step_index or sigma is required when doing cross attention control, but both are None."
|
||||
)
|
||||
percent_through = self.estimate_percent_through(step_index, sigma)
|
||||
return percent_through
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def _apply_standard_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||
@ -323,6 +269,7 @@ class InvokeAIDiffuserComponent:
|
||||
conditioned_next_x = conditioned_next_x.clone()
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
# TODO: looks unused
|
||||
def _apply_hybrid_conditioning(self, x, sigma, unconditioning, conditioning, **kwargs):
|
||||
assert isinstance(conditioning, dict)
|
||||
assert isinstance(unconditioning, dict)
|
||||
@ -350,34 +297,6 @@ class InvokeAIDiffuserComponent:
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
):
|
||||
if self.is_running_diffusers:
|
||||
return self._apply_cross_attention_controlled_conditioning__diffusers(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
return self._apply_cross_attention_controlled_conditioning__compvis(
|
||||
x,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning__diffusers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
):
|
||||
context: Context = self.cross_attention_control_context
|
||||
|
||||
@ -409,54 +328,6 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def _apply_cross_attention_controlled_conditioning__compvis(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
unconditioning,
|
||||
conditioning,
|
||||
cross_attention_control_types_to_do,
|
||||
**kwargs,
|
||||
):
|
||||
# print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do)
|
||||
# slower non-batched path (20% slower on mac MPS)
|
||||
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
|
||||
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
|
||||
# This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8)
|
||||
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
|
||||
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
|
||||
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
|
||||
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
|
||||
context: Context = self.cross_attention_control_context
|
||||
|
||||
try:
|
||||
unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning, **kwargs)
|
||||
|
||||
# process x using the original prompt, saving the attention maps
|
||||
# print("saving attention maps for", cross_attention_control_types_to_do)
|
||||
for ca_type in cross_attention_control_types_to_do:
|
||||
context.request_save_attention_maps(ca_type)
|
||||
_ = self.model_forward_callback(x, sigma, conditioning, **kwargs,)
|
||||
context.clear_requests(cleanup=False)
|
||||
|
||||
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
|
||||
# print("applying saved attention maps for", cross_attention_control_types_to_do)
|
||||
for ca_type in cross_attention_control_types_to_do:
|
||||
context.request_apply_saved_attention_maps(ca_type)
|
||||
edited_conditioning = (
|
||||
self.conditioning.cross_attention_control_args.edited_conditioning
|
||||
)
|
||||
conditioned_next_x = self.model_forward_callback(
|
||||
x, sigma, edited_conditioning, **kwargs,
|
||||
)
|
||||
context.clear_requests(cleanup=True)
|
||||
|
||||
except:
|
||||
context.clear_requests(cleanup=True)
|
||||
raise
|
||||
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
|
||||
def _combine(self, unconditioned_next_x, conditioned_next_x, guidance_scale):
|
||||
# to scale how much effect conditioning has, calculate the changes it does and then scale that
|
||||
scaled_delta = (conditioned_next_x - unconditioned_next_x) * guidance_scale
|
||||
|
@ -157,7 +157,7 @@ class LazilyLoadedModelGroup(ModelGroup):
|
||||
def offload_current(self):
|
||||
module = self._current_model_ref()
|
||||
if module is not NO_MODEL:
|
||||
module.to(device=OFFLOAD_DEVICE)
|
||||
module.to(OFFLOAD_DEVICE)
|
||||
self.clear_current_model()
|
||||
|
||||
def _load(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
@ -228,7 +228,7 @@ class FullyLoadedModelGroup(ModelGroup):
|
||||
def install(self, *models: torch.nn.Module):
|
||||
for model in models:
|
||||
self._models.add(model)
|
||||
model.to(device=self.execution_device)
|
||||
model.to(self.execution_device)
|
||||
|
||||
def uninstall(self, *models: torch.nn.Module):
|
||||
for model in models:
|
||||
@ -238,11 +238,11 @@ class FullyLoadedModelGroup(ModelGroup):
|
||||
self.uninstall(*self._models)
|
||||
|
||||
def load(self, model):
|
||||
model.to(device=self.execution_device)
|
||||
model.to(self.execution_device)
|
||||
|
||||
def offload_current(self):
|
||||
for model in self._models:
|
||||
model.to(device=OFFLOAD_DEVICE)
|
||||
model.to(OFFLOAD_DEVICE)
|
||||
|
||||
def ready(self):
|
||||
for model in self._models:
|
||||
@ -252,7 +252,7 @@ class FullyLoadedModelGroup(ModelGroup):
|
||||
self.execution_device = device
|
||||
for model in self._models:
|
||||
if model.device != OFFLOAD_DEVICE:
|
||||
model.to(device=device)
|
||||
model.to(device)
|
||||
|
||||
def device_for(self, model):
|
||||
if model not in self:
|
||||
|
@ -1,13 +1,14 @@
|
||||
from diffusers import DDIMScheduler, DPMSolverMultistepScheduler, KDPM2DiscreteScheduler, \
|
||||
KDPM2AncestralDiscreteScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler, \
|
||||
HeunDiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, \
|
||||
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler
|
||||
DPMSolverSinglestepScheduler, DEISMultistepScheduler, DDPMScheduler, DPMSolverSDEScheduler
|
||||
|
||||
SCHEDULER_MAP = dict(
|
||||
ddim=(DDIMScheduler, dict()),
|
||||
ddpm=(DDPMScheduler, dict()),
|
||||
deis=(DEISMultistepScheduler, dict()),
|
||||
lms=(LMSDiscreteScheduler, dict()),
|
||||
lms=(LMSDiscreteScheduler, dict(use_karras_sigmas=False)),
|
||||
lms_k=(LMSDiscreteScheduler, dict(use_karras_sigmas=True)),
|
||||
pndm=(PNDMScheduler, dict()),
|
||||
heun=(HeunDiscreteScheduler, dict(use_karras_sigmas=False)),
|
||||
heun_k=(HeunDiscreteScheduler, dict(use_karras_sigmas=True)),
|
||||
@ -16,8 +17,13 @@ SCHEDULER_MAP = dict(
|
||||
euler_a=(EulerAncestralDiscreteScheduler, dict()),
|
||||
kdpm_2=(KDPM2DiscreteScheduler, dict()),
|
||||
kdpm_2_a=(KDPM2AncestralDiscreteScheduler, dict()),
|
||||
dpmpp_2s=(DPMSolverSinglestepScheduler, dict()),
|
||||
dpmpp_2s=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=False)),
|
||||
dpmpp_2s_k=(DPMSolverSinglestepScheduler, dict(use_karras_sigmas=True)),
|
||||
dpmpp_2m=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False)),
|
||||
dpmpp_2m_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True)),
|
||||
dpmpp_2m_sde=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=False, algorithm_type='sde-dpmsolver++')),
|
||||
dpmpp_2m_sde_k=(DPMSolverMultistepScheduler, dict(use_karras_sigmas=True, algorithm_type='sde-dpmsolver++')),
|
||||
dpmpp_sde=(DPMSolverSDEScheduler, dict(use_karras_sigmas=False, noise_sampler_seed=0)),
|
||||
dpmpp_sde_k=(DPMSolverSDEScheduler, dict(use_karras_sigmas=True, noise_sampler_seed=0)),
|
||||
unipc=(UniPCMultistepScheduler, dict(cpu_only=True))
|
||||
)
|
||||
|
@ -1,429 +0,0 @@
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union, List
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
from picklescan.scanner import scan_file_path
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from .concepts_lib import HuggingFaceConceptsLibrary
|
||||
|
||||
@dataclass
|
||||
class EmbeddingInfo:
|
||||
name: str
|
||||
embedding: torch.Tensor
|
||||
num_vectors_per_token: int
|
||||
token_dim: int
|
||||
trained_steps: int = None
|
||||
trained_model_name: str = None
|
||||
trained_model_checksum: str = None
|
||||
|
||||
@dataclass
|
||||
class TextualInversion:
|
||||
trigger_string: str
|
||||
embedding: torch.Tensor
|
||||
trigger_token_id: Optional[int] = None
|
||||
pad_token_ids: Optional[list[int]] = None
|
||||
|
||||
@property
|
||||
def embedding_vector_length(self) -> int:
|
||||
return self.embedding.shape[0]
|
||||
|
||||
|
||||
class TextualInversionManager(BaseTextualInversionManager):
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: CLIPTextModel,
|
||||
full_precision: bool = True,
|
||||
):
|
||||
self.tokenizer = tokenizer
|
||||
self.text_encoder = text_encoder
|
||||
self.full_precision = full_precision
|
||||
self.hf_concepts_library = HuggingFaceConceptsLibrary()
|
||||
self.trigger_to_sourcefile = dict()
|
||||
default_textual_inversions: list[TextualInversion] = []
|
||||
self.textual_inversions = default_textual_inversions
|
||||
|
||||
def load_huggingface_concepts(self, concepts: list[str]):
|
||||
for concept_name in concepts:
|
||||
if concept_name in self.hf_concepts_library.concepts_loaded:
|
||||
continue
|
||||
trigger = self.hf_concepts_library.concept_to_trigger(concept_name)
|
||||
if (
|
||||
self.has_textual_inversion_for_trigger_string(trigger)
|
||||
or self.has_textual_inversion_for_trigger_string(concept_name)
|
||||
or self.has_textual_inversion_for_trigger_string(f"<{concept_name}>")
|
||||
): # in case a token with literal angle brackets encountered
|
||||
logger.info(f"Loaded local embedding for trigger {concept_name}")
|
||||
continue
|
||||
bin_file = self.hf_concepts_library.get_concept_model_path(concept_name)
|
||||
if not bin_file:
|
||||
continue
|
||||
logger.info(f"Loaded remote embedding for trigger {concept_name}")
|
||||
self.load_textual_inversion(bin_file)
|
||||
self.hf_concepts_library.concepts_loaded[concept_name] = True
|
||||
|
||||
def get_all_trigger_strings(self) -> list[str]:
|
||||
return [ti.trigger_string for ti in self.textual_inversions]
|
||||
|
||||
def load_textual_inversion(
|
||||
self, ckpt_path: Union[str, Path], defer_injecting_tokens: bool = False
|
||||
):
|
||||
ckpt_path = Path(ckpt_path)
|
||||
|
||||
if not ckpt_path.is_file():
|
||||
return
|
||||
|
||||
if str(ckpt_path).endswith(".DS_Store"):
|
||||
return
|
||||
|
||||
embedding_list = self._parse_embedding(str(ckpt_path))
|
||||
for embedding_info in embedding_list:
|
||||
if (self.text_encoder.get_input_embeddings().weight.data[0].shape[0] != embedding_info.token_dim):
|
||||
logger.warning(
|
||||
f"Notice: {ckpt_path.parents[0].name}/{ckpt_path.name} was trained on a model with an incompatible token dimension: {self.text_encoder.get_input_embeddings().weight.data[0].shape[0]} vs {embedding_info.token_dim}."
|
||||
)
|
||||
continue
|
||||
|
||||
# Resolve the situation in which an earlier embedding has claimed the same
|
||||
# trigger string. We replace the trigger with '<source_file>', as we used to.
|
||||
trigger_str = embedding_info.name
|
||||
sourcefile = (
|
||||
f"{ckpt_path.parent.name}/{ckpt_path.name}"
|
||||
if ckpt_path.name == "learned_embeds.bin"
|
||||
else ckpt_path.name
|
||||
)
|
||||
|
||||
if trigger_str in self.trigger_to_sourcefile:
|
||||
replacement_trigger_str = (
|
||||
f"<{ckpt_path.parent.name}>"
|
||||
if ckpt_path.name == "learned_embeds.bin"
|
||||
else f"<{ckpt_path.stem}>"
|
||||
)
|
||||
logger.info(
|
||||
f"{sourcefile}: Trigger token '{trigger_str}' is already claimed by '{self.trigger_to_sourcefile[trigger_str]}'. Trigger this concept with {replacement_trigger_str}"
|
||||
)
|
||||
trigger_str = replacement_trigger_str
|
||||
|
||||
try:
|
||||
self._add_textual_inversion(
|
||||
trigger_str,
|
||||
embedding_info.embedding,
|
||||
defer_injecting_tokens=defer_injecting_tokens,
|
||||
)
|
||||
# remember which source file claims this trigger
|
||||
self.trigger_to_sourcefile[trigger_str] = sourcefile
|
||||
|
||||
except ValueError as e:
|
||||
logger.debug(f'Ignoring incompatible embedding {embedding_info["name"]}')
|
||||
logger.debug(f"The error was {str(e)}")
|
||||
|
||||
def _add_textual_inversion(
|
||||
self, trigger_str, embedding, defer_injecting_tokens=False
|
||||
) -> Optional[TextualInversion]:
|
||||
"""
|
||||
Add a textual inversion to be recognised.
|
||||
:param trigger_str: The trigger text in the prompt that activates this textual inversion. If unknown to the embedder's tokenizer, will be added.
|
||||
:param embedding: The actual embedding data that will be inserted into the conditioning at the point where the token_str appears.
|
||||
:return: The token id for the added embedding, either existing or newly-added.
|
||||
"""
|
||||
if trigger_str in [ti.trigger_string for ti in self.textual_inversions]:
|
||||
logger.warning(
|
||||
f"TextualInversionManager refusing to overwrite already-loaded token '{trigger_str}'"
|
||||
)
|
||||
return
|
||||
if not self.full_precision:
|
||||
embedding = embedding.half()
|
||||
if len(embedding.shape) == 1:
|
||||
embedding = embedding.unsqueeze(0)
|
||||
elif len(embedding.shape) > 2:
|
||||
raise ValueError(
|
||||
f"** TextualInversionManager cannot add {trigger_str} because the embedding shape {embedding.shape} is incorrect. The embedding must have shape [token_dim] or [V, token_dim] where V is vector length and token_dim is 768 for SD1 or 1280 for SD2."
|
||||
)
|
||||
|
||||
try:
|
||||
ti = TextualInversion(trigger_string=trigger_str, embedding=embedding)
|
||||
if not defer_injecting_tokens:
|
||||
self._inject_tokens_and_assign_embeddings(ti)
|
||||
self.textual_inversions.append(ti)
|
||||
return ti
|
||||
|
||||
except ValueError as e:
|
||||
if str(e).startswith("Warning"):
|
||||
logger.warning(f"{str(e)}")
|
||||
else:
|
||||
traceback.print_exc()
|
||||
logger.error(
|
||||
f"TextualInversionManager was unable to add a textual inversion with trigger string {trigger_str}."
|
||||
)
|
||||
raise
|
||||
|
||||
def _inject_tokens_and_assign_embeddings(self, ti: TextualInversion) -> int:
|
||||
if ti.trigger_token_id is not None:
|
||||
raise ValueError(
|
||||
f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'"
|
||||
)
|
||||
|
||||
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(
|
||||
ti.trigger_string, ti.embedding[0]
|
||||
)
|
||||
|
||||
if ti.embedding_vector_length > 1:
|
||||
# for embeddings with vector length > 1
|
||||
pad_token_strings = [
|
||||
ti.trigger_string + "-!pad-" + str(pad_index)
|
||||
for pad_index in range(1, ti.embedding_vector_length)
|
||||
]
|
||||
# todo: batched UI for faster loading when vector length >2
|
||||
pad_token_ids = [
|
||||
self._get_or_create_token_id_and_assign_embedding(
|
||||
pad_token_str, ti.embedding[1 + i]
|
||||
)
|
||||
for (i, pad_token_str) in enumerate(pad_token_strings)
|
||||
]
|
||||
else:
|
||||
pad_token_ids = []
|
||||
|
||||
ti.trigger_token_id = trigger_token_id
|
||||
ti.pad_token_ids = pad_token_ids
|
||||
return ti.trigger_token_id
|
||||
|
||||
def has_textual_inversion_for_trigger_string(self, trigger_string: str) -> bool:
|
||||
try:
|
||||
ti = self.get_textual_inversion_for_trigger_string(trigger_string)
|
||||
return ti is not None
|
||||
except StopIteration:
|
||||
return False
|
||||
|
||||
def get_textual_inversion_for_trigger_string(
|
||||
self, trigger_string: str
|
||||
) -> TextualInversion:
|
||||
return next(
|
||||
ti for ti in self.textual_inversions if ti.trigger_string == trigger_string
|
||||
)
|
||||
|
||||
def get_textual_inversion_for_token_id(self, token_id: int) -> TextualInversion:
|
||||
return next(
|
||||
ti for ti in self.textual_inversions if ti.trigger_token_id == token_id
|
||||
)
|
||||
|
||||
def create_deferred_token_ids_for_any_trigger_terms(
|
||||
self, prompt_string: str
|
||||
) -> list[int]:
|
||||
injected_token_ids = []
|
||||
for ti in self.textual_inversions:
|
||||
if ti.trigger_token_id is None and ti.trigger_string in prompt_string:
|
||||
if ti.embedding_vector_length > 1:
|
||||
logger.info(
|
||||
f"Preparing tokens for textual inversion {ti.trigger_string}..."
|
||||
)
|
||||
try:
|
||||
self._inject_tokens_and_assign_embeddings(ti)
|
||||
except ValueError as e:
|
||||
logger.debug(
|
||||
f"Ignoring incompatible embedding trigger {ti.trigger_string}"
|
||||
)
|
||||
logger.debug(f"The error was {str(e)}")
|
||||
continue
|
||||
injected_token_ids.append(ti.trigger_token_id)
|
||||
injected_token_ids.extend(ti.pad_token_ids)
|
||||
return injected_token_ids
|
||||
|
||||
def expand_textual_inversion_token_ids_if_necessary(
|
||||
self, prompt_token_ids: list[int]
|
||||
) -> list[int]:
|
||||
"""
|
||||
Insert padding tokens as necessary into the passed-in list of token ids to match any textual inversions it includes.
|
||||
|
||||
:param prompt_token_ids: The prompt as a list of token ids (`int`s). Should not include bos and eos markers.
|
||||
:return: The prompt token ids with any necessary padding to account for textual inversions inserted. May be too
|
||||
long - caller is responsible for prepending/appending eos and bos token ids, and truncating if necessary.
|
||||
"""
|
||||
if len(prompt_token_ids) == 0:
|
||||
return prompt_token_ids
|
||||
|
||||
if prompt_token_ids[0] == self.tokenizer.bos_token_id:
|
||||
raise ValueError("prompt_token_ids must not start with bos_token_id")
|
||||
if prompt_token_ids[-1] == self.tokenizer.eos_token_id:
|
||||
raise ValueError("prompt_token_ids must not end with eos_token_id")
|
||||
textual_inversion_trigger_token_ids = [
|
||||
ti.trigger_token_id for ti in self.textual_inversions
|
||||
]
|
||||
prompt_token_ids = prompt_token_ids.copy()
|
||||
for i, token_id in reversed(list(enumerate(prompt_token_ids))):
|
||||
if token_id in textual_inversion_trigger_token_ids:
|
||||
textual_inversion = next(
|
||||
ti
|
||||
for ti in self.textual_inversions
|
||||
if ti.trigger_token_id == token_id
|
||||
)
|
||||
for pad_idx in range(0, textual_inversion.embedding_vector_length - 1):
|
||||
prompt_token_ids.insert(
|
||||
i + pad_idx + 1, textual_inversion.pad_token_ids[pad_idx]
|
||||
)
|
||||
|
||||
return prompt_token_ids
|
||||
|
||||
def _get_or_create_token_id_and_assign_embedding(
|
||||
self, token_str: str, embedding: torch.Tensor
|
||||
) -> int:
|
||||
if len(embedding.shape) != 1:
|
||||
raise ValueError(
|
||||
"Embedding has incorrect shape - must be [token_dim] where token_dim is 768 for SD1 or 1280 for SD2"
|
||||
)
|
||||
existing_token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
||||
if existing_token_id == self.tokenizer.unk_token_id:
|
||||
num_tokens_added = self.tokenizer.add_tokens(token_str)
|
||||
current_embeddings = self.text_encoder.resize_token_embeddings(None)
|
||||
current_token_count = current_embeddings.num_embeddings
|
||||
new_token_count = current_token_count + num_tokens_added
|
||||
# the following call is slow - todo make batched for better performance with vector length >1
|
||||
self.text_encoder.resize_token_embeddings(new_token_count)
|
||||
|
||||
token_id = self.tokenizer.convert_tokens_to_ids(token_str)
|
||||
if token_id == self.tokenizer.unk_token_id:
|
||||
raise RuntimeError(f"Unable to find token id for token '{token_str}'")
|
||||
if (
|
||||
self.text_encoder.get_input_embeddings().weight.data[token_id].shape
|
||||
!= embedding.shape
|
||||
):
|
||||
raise ValueError(
|
||||
f"Warning. Cannot load embedding for {token_str}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {self.text_encoder.get_input_embeddings().weight.data[token_id].shape[0]}."
|
||||
)
|
||||
self.text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
||||
|
||||
return token_id
|
||||
|
||||
|
||||
def _parse_embedding(self, embedding_file: str)->List[EmbeddingInfo]:
|
||||
suffix = Path(embedding_file).suffix
|
||||
try:
|
||||
if suffix in [".pt",".ckpt",".bin"]:
|
||||
scan_result = scan_file_path(embedding_file)
|
||||
if scan_result.infected_files > 0:
|
||||
logger.critical(
|
||||
f"Security Issues Found in Model: {scan_result.issues_count}"
|
||||
)
|
||||
logger.critical("For your safety, InvokeAI will not load this embed.")
|
||||
return list()
|
||||
ckpt = torch.load(embedding_file,map_location="cpu")
|
||||
else:
|
||||
ckpt = safetensors.torch.load_file(embedding_file)
|
||||
except Exception as e:
|
||||
logger.warning(f"Notice: unrecognized embedding file format: {embedding_file}: {e}")
|
||||
return list()
|
||||
|
||||
# try to figure out what kind of embedding file it is and parse accordingly
|
||||
keys = list(ckpt.keys())
|
||||
if all(x in keys for x in ['string_to_token','string_to_param','name','step']):
|
||||
return self._parse_embedding_v1(ckpt, embedding_file) # example rem_rezero.pt
|
||||
|
||||
elif all(x in keys for x in ['string_to_token','string_to_param']):
|
||||
return self._parse_embedding_v2(ckpt, embedding_file) # example midj-strong.pt
|
||||
|
||||
elif 'emb_params' in keys:
|
||||
return self._parse_embedding_v3(ckpt, embedding_file) # example easynegative.safetensors
|
||||
|
||||
else:
|
||||
return self._parse_embedding_v4(ckpt, embedding_file) # usually a '.bin' file
|
||||
|
||||
def _parse_embedding_v1(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||
basename = Path(file_path).stem
|
||||
logger.debug(f'Loading v1 embedding file: {basename}')
|
||||
|
||||
embeddings = list()
|
||||
token_counter = -1
|
||||
for token,embedding in embedding_ckpt["string_to_param"].items():
|
||||
if token_counter < 0:
|
||||
trigger = embedding_ckpt["name"]
|
||||
elif token_counter == 0:
|
||||
trigger = '<basename>'
|
||||
else:
|
||||
trigger = f'<{basename}-{int(token_counter:=token_counter)}>'
|
||||
token_counter += 1
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = trigger,
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = embedding.size()[0],
|
||||
token_dim = embedding.size()[1],
|
||||
trained_steps = embedding_ckpt["step"],
|
||||
trained_model_name = embedding_ckpt["sd_checkpoint_name"],
|
||||
trained_model_checksum = embedding_ckpt["sd_checkpoint"]
|
||||
)
|
||||
embeddings.append(embedding_info)
|
||||
return embeddings
|
||||
|
||||
def _parse_embedding_v2 (
|
||||
self, embedding_ckpt: dict, file_path: str
|
||||
) -> List[EmbeddingInfo]:
|
||||
"""
|
||||
This handles embedding .pt file variant #2.
|
||||
"""
|
||||
basename = Path(file_path).stem
|
||||
logger.debug(f'Loading v2 embedding file: {basename}')
|
||||
embeddings = list()
|
||||
|
||||
if isinstance(
|
||||
list(embedding_ckpt["string_to_token"].values())[0], torch.Tensor
|
||||
):
|
||||
token_counter = 0
|
||||
for token,embedding in embedding_ckpt["string_to_param"].items():
|
||||
trigger = token if token != '*' \
|
||||
else f'<{basename}>' if token_counter == 0 \
|
||||
else f'<{basename}-{int(token_counter:=token_counter+1)}>'
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = trigger,
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = embedding.size()[0],
|
||||
token_dim = embedding.size()[1],
|
||||
)
|
||||
embeddings.append(embedding_info)
|
||||
else:
|
||||
logger.warning(f"{basename}: Unrecognized embedding format")
|
||||
|
||||
return embeddings
|
||||
|
||||
def _parse_embedding_v3(self, embedding_ckpt: dict, file_path: str)->List[EmbeddingInfo]:
|
||||
"""
|
||||
Parse 'version 3' of the .pt textual inversion embedding files.
|
||||
"""
|
||||
basename = Path(file_path).stem
|
||||
logger.debug(f'Loading v3 embedding file: {basename}')
|
||||
embedding = embedding_ckpt['emb_params']
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = f'<{basename}>',
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = embedding.size()[0],
|
||||
token_dim = embedding.size()[1],
|
||||
)
|
||||
return [embedding_info]
|
||||
|
||||
def _parse_embedding_v4(self, embedding_ckpt: dict, filepath: str)->List[EmbeddingInfo]:
|
||||
"""
|
||||
Parse 'version 4' of the textual inversion embedding files. This one
|
||||
is usually associated with .bin files trained by HuggingFace diffusers.
|
||||
"""
|
||||
basename = Path(filepath).stem
|
||||
short_path = Path(filepath).parents[0].name+'/'+Path(filepath).name
|
||||
|
||||
logger.debug(f'Loading v4 embedding file: {short_path}')
|
||||
|
||||
embeddings = list()
|
||||
if list(embedding_ckpt.keys()) == 0:
|
||||
logger.warning(f"Invalid embeddings file: {short_path}")
|
||||
else:
|
||||
for token,embedding in embedding_ckpt.items():
|
||||
embedding_info = EmbeddingInfo(
|
||||
name = token or f"<{basename}>",
|
||||
embedding = embedding,
|
||||
num_vectors_per_token = 1, # All Concepts seem to default to 1
|
||||
token_dim = embedding.size()[0],
|
||||
)
|
||||
embeddings.append(embedding_info)
|
||||
return embeddings
|
Reference in New Issue
Block a user