Merge branch 'main' into feat/nodes/freeu

This commit is contained in:
Kent Keirsey
2023-11-06 05:39:58 -08:00
committed by GitHub
206 changed files with 11705 additions and 8252 deletions

View File

@ -20,12 +20,12 @@ class InvisibleWatermark:
"""
@classmethod
def invisible_watermark_available(self) -> bool:
def invisible_watermark_available(cls) -> bool:
return config.invisible_watermark
@classmethod
def add_watermark(self, image: Image, watermark_text: str) -> Image:
if not self.invisible_watermark_available():
def add_watermark(cls, image: Image.Image, watermark_text: str) -> Image.Image:
if not cls.invisible_watermark_available():
return image
logger.debug(f'Applying invisible watermark "{watermark_text}"')
bgr = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)

View File

@ -26,8 +26,8 @@ class SafetyChecker:
tried_load: bool = False
@classmethod
def _load_safety_checker(self):
if self.tried_load:
def _load_safety_checker(cls):
if cls.tried_load:
return
if config.nsfw_checker:
@ -35,31 +35,31 @@ class SafetyChecker:
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
cls.safety_checker = StableDiffusionSafetyChecker.from_pretrained(config.models_path / CHECKER_PATH)
cls.feature_extractor = AutoFeatureExtractor.from_pretrained(config.models_path / CHECKER_PATH)
logger.info("NSFW checker initialized")
except Exception as e:
logger.warning(f"Could not load NSFW checker: {str(e)}")
else:
logger.info("NSFW checker loading disabled")
self.tried_load = True
cls.tried_load = True
@classmethod
def safety_checker_available(self) -> bool:
self._load_safety_checker()
return self.safety_checker is not None
def safety_checker_available(cls) -> bool:
cls._load_safety_checker()
return cls.safety_checker is not None
@classmethod
def has_nsfw_concept(self, image: Image) -> bool:
if not self.safety_checker_available():
def has_nsfw_concept(cls, image: Image.Image) -> bool:
if not cls.safety_checker_available():
return False
device = choose_torch_device()
features = self.feature_extractor([image], return_tensors="pt")
features = cls.feature_extractor([image], return_tensors="pt")
features.to(device)
self.safety_checker.to(device)
cls.safety_checker.to(device)
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
with SilenceWarnings():
checked_image, has_nsfw_concept = self.safety_checker(images=x_image, clip_input=features.pixel_values)
checked_image, has_nsfw_concept = cls.safety_checker(images=x_image, clip_input=features.pixel_values)
return has_nsfw_concept[0]

View File

@ -460,6 +460,12 @@ class ModelInstall(object):
possible_conf = path.with_suffix(".yaml")
if possible_conf.exists():
legacy_conf = str(self.relative_to_root(possible_conf))
else:
legacy_conf = Path(
self.config.root_path,
"configs/controlnet",
("cldm_v15.yaml" if info.base_type == BaseModelType("sd-1") else "cldm_v21.yaml"),
)
if legacy_conf:
attributes.update(dict(config=str(legacy_conf)))

View File

@ -67,6 +67,12 @@ class IPAttnProcessor2_0(torch.nn.Module):
temb=None,
ip_adapter_image_prompt_embeds=None,
):
"""Apply IP-Adapter attention.
Args:
ip_adapter_image_prompt_embeds (torch.Tensor): The image prompt embeddings.
Shape: (batch_size, num_ip_images, seq_len, ip_embedding_len).
"""
residual = hidden_states
if attn.spatial_norm is not None:
@ -127,26 +133,35 @@ class IPAttnProcessor2_0(torch.nn.Module):
for ipa_embed, ipa_weights, scale in zip(ip_adapter_image_prompt_embeds, self._weights, self._scales):
# The batch dimensions should match.
assert ipa_embed.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ipa_embed.shape[2] == encoder_hidden_states.shape[2]
# The token_len dimensions should match.
assert ipa_embed.shape[-1] == encoder_hidden_states.shape[-1]
ip_hidden_states = ipa_embed
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# The output of sdpa has shape: (batch, num_heads, seq_len, head_dim)
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + scale * ip_hidden_states
# linear proj

View File

@ -1,6 +1,6 @@
from __future__ import annotations
import copy
import pickle
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
@ -56,24 +56,6 @@ class ModelPatcher:
return (module_key, module)
@staticmethod
def _lora_forward_hook(
applied_loras: List[Tuple[LoRAModel, float]],
layer_name: str,
):
def lora_forward(module, input_h, output):
if len(applied_loras) == 0:
return output
for lora, weight in applied_loras:
layer = lora.layers.get(layer_name, None)
if layer is None:
continue
output += layer.forward(module, input_h, weight)
return output
return lora_forward
@classmethod
@contextmanager
def apply_lora_unet(
@ -131,21 +113,40 @@ class ModelPatcher:
if not layer_key.startswith(prefix):
continue
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
# should be improved in the following ways:
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
# LoRA model is applied.
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
# weights to have valid keys.
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
if module_key not in original_weights:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
# enable autocast to calc fp16 loras on cpu
# with torch.autocast(device_type="cpu"):
layer.to(dtype=torch.float32)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight(original_weights[module_key]) * lora_weight * layer_scale
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device="cpu")
if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris
layer_weight = layer_weight.reshape(module.weight.shape)
module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype)
module.weight += layer_weight.to(dtype=dtype)
yield # wait for context manager exit
@ -166,7 +167,13 @@ class ModelPatcher:
new_tokens_added = None
try:
ti_tokenizer = copy.deepcopy(tokenizer)
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
# exiting this `apply_ti(...)` context manager.
#
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
ti_manager = TextualInversionManager(ti_tokenizer)
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
@ -198,7 +205,9 @@ class ModelPatcher:
if model_embeddings.weight.data[token_id].shape != embedding.shape:
raise ValueError(
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
f" {embedding.shape[0]}, but the current model has token dimension"
f" {model_embeddings.weight.data[token_id].shape[0]}."
)
model_embeddings.weight.data[token_id] = embedding.to(
@ -278,7 +287,8 @@ class TextualInversionModel:
if "string_to_param" in state_dict:
if len(state_dict["string_to_param"]) > 1:
print(
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first token will be used.'
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first'
" token will be used."
)
result.embedding = next(iter(state_dict["string_to_param"].values()))
@ -456,7 +466,13 @@ class ONNXModelPatcher:
orig_embeddings = None
try:
ti_tokenizer = copy.deepcopy(tokenizer)
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
# exiting this `apply_ti(...)` context manager.
#
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
ti_manager = TextualInversionManager(ti_tokenizer)
def _get_trigger(ti_name, index):
@ -491,7 +507,9 @@ class ONNXModelPatcher:
if embeddings[token_id].shape != embedding.shape:
raise ValueError(
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embeddings[token_id].shape[0]}."
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
f" {embedding.shape[0]}, but the current model has token dimension"
f" {embeddings[token_id].shape[0]}."
)
embeddings[token_id] = embedding

View File

@ -64,7 +64,7 @@ class MemorySnapshot:
return cls(process_ram, vram, malloc_info)
def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnapshot) -> str:
def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: Optional[MemorySnapshot]) -> str:
"""Get a pretty string describing the difference between two `MemorySnapshot`s."""
def get_msg_line(prefix: str, val1: int, val2: int):
@ -73,6 +73,9 @@ def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnaps
msg = ""
if snapshot_1 is None or snapshot_2 is None:
return msg
msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram)
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:

View File

@ -117,6 +117,7 @@ class ModelCache(object):
lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
logger: types.ModuleType = logger,
log_memory_usage: bool = False,
):
"""
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
@ -126,6 +127,10 @@ class ModelCache(object):
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
"""
self.model_infos: Dict[str, ModelBase] = dict()
# allow lazy offloading only when vram cache enabled
@ -137,6 +142,7 @@ class ModelCache(object):
self.storage_device: torch.device = storage_device
self.sha_chunksize = sha_chunksize
self.logger = logger
self._log_memory_usage = log_memory_usage
# used for stats collection
self.stats = None
@ -144,6 +150,11 @@ class ModelCache(object):
self._cached_models = dict()
self._cache_stack = list()
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
return MemorySnapshot.capture()
return None
def get_key(
self,
model_path: str,
@ -223,10 +234,10 @@ class ModelCache(object):
# Load the model from disk and capture a memory snapshot before/after.
start_load_time = time.time()
snapshot_before = MemorySnapshot.capture()
snapshot_before = self._capture_memory_snapshot()
with skip_torch_weight_init():
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
snapshot_after = MemorySnapshot.capture()
snapshot_after = self._capture_memory_snapshot()
end_load_time = time.time()
self_reported_model_size_after_load = model_info.get_size(submodel)
@ -275,9 +286,9 @@ class ModelCache(object):
return
start_model_to_time = time.time()
snapshot_before = MemorySnapshot.capture()
snapshot_before = self._capture_memory_snapshot()
cache_entry.model.to(target_device)
snapshot_after = MemorySnapshot.capture()
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self.logger.debug(
f"Moved model '{key}' from {source_device} to"
@ -286,7 +297,12 @@ class ModelCache(object):
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
if snapshot_before.vram is not None and snapshot_after.vram is not None:
if (
snapshot_before is not None
and snapshot_after is not None
and snapshot_before.vram is not None
and snapshot_after.vram is not None
):
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# If the estimated model size does not match the change in VRAM, log a warning.
@ -422,12 +438,17 @@ class ModelCache(object):
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
pos = 0
models_cleared = 0
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
refs = sys.getrefcount(cache_entry.model)
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
# https://docs.python.org/3/library/gc.html#gc.get_referrers
# manualy clear local variable references of just finished function calls
# for some reason python don't want to collect it even by gc.collect() immidiately
if refs > 2:
@ -453,15 +474,16 @@ class ModelCache(object):
f" refs: {refs}"
)
# 2 refs:
# Expected refs:
# 1 from cache_entry
# 1 from getrefcount function
# 1 from onnx runtime object
if not cache_entry.locked and refs <= 3 if "onnx" in model_key else 2:
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
self.logger.debug(
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
current_size -= cache_entry.size
models_cleared += 1
if self.stats:
self.stats.cleared += 1
del self._cache_stack[pos]
@ -471,7 +493,20 @@ class ModelCache(object):
else:
pos += 1
gc.collect()
if models_cleared > 0:
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
# is high even if no garbage gets collected.)
#
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
# - If models had to be cleared, it's a signal that we are close to our memory limit.
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
# collected.
#
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
# immediately when their reference count hits 0.
gc.collect()
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
@ -491,7 +526,6 @@ class ModelCache(object):
vram_in_use = torch.cuda.memory_allocated()
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
gc.collect()
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()

View File

@ -17,7 +17,7 @@ def skip_torch_weight_init():
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
monkey-patches common torch layers to skip the weight initialization step.
"""
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd]
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
saved_functions = [m.reset_parameters for m in torch_modules]
try:

View File

@ -351,6 +351,7 @@ class ModelManager(object):
precision=precision,
sequential_offload=sequential_offload,
logger=logger,
log_memory_usage=self.app_config.log_memory_usage,
)
self._read_models(config)
@ -1011,6 +1012,8 @@ class ModelManager(object):
self.logger.warning(f"Not a valid model: {model_path}. {e}")
except NotImplementedError as e:
self.logger.warning(e)
except Exception as e:
self.logger.warning(f"Error loading model {model_path}. {e}")
imported_models = self.scan_autoimport_directory()
if (new_models_found or imported_models) and self.config_path:

View File

@ -132,13 +132,14 @@ def _convert_controlnet_ckpt_and_cache(
model_path: str,
output_path: str,
base_model: BaseModelType,
model_config: ControlNetModel.CheckpointConfig,
model_config: str,
) -> str:
"""
Convert the controlnet from checkpoint format to diffusers format,
cache it to disk, and return Path to converted
file. If already on disk then just returns Path.
"""
print(f"DEBUG: controlnet config = {model_config}")
app_config = InvokeAIAppConfig.get_config()
weights = app_config.root_path / model_path
output_path = Path(output_path)

View File

@ -440,33 +440,19 @@ class IA3Layer(LoRALayerBase):
class LoRAModelRaw: # (torch.nn.Module):
_name: str
layers: Dict[str, LoRALayer]
_device: torch.device
_dtype: torch.dtype
def __init__(
self,
name: str,
layers: Dict[str, LoRALayer],
device: torch.device,
dtype: torch.dtype,
):
self._name = name
self._device = device or torch.cpu
self._dtype = dtype or torch.float32
self.layers = layers
@property
def name(self):
return self._name
@property
def device(self):
return self._device
@property
def dtype(self):
return self._dtype
def to(
self,
device: Optional[torch.device] = None,
@ -475,8 +461,6 @@ class LoRAModelRaw: # (torch.nn.Module):
# TODO: try revert if exception?
for key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
self._device = device
self._dtype = dtype
def calc_size(self) -> int:
model_size = 0
@ -557,8 +541,6 @@ class LoRAModelRaw: # (torch.nn.Module):
file_path = Path(file_path)
model = cls(
device=device,
dtype=dtype,
name=file_path.stem, # TODO:
layers=dict(),
)

View File

@ -55,11 +55,11 @@ class PostprocessingSettings:
class IPAdapterConditioningInfo:
cond_image_prompt_embeds: torch.Tensor
"""IP-Adapter image encoder conditioning embeddings.
Shape: (batch_size, num_tokens, encoding_dim).
Shape: (num_images, num_tokens, encoding_dim).
"""
uncond_image_prompt_embeds: torch.Tensor
"""IP-Adapter image encoding embeddings to use for unconditional generation.
Shape: (batch_size, num_tokens, encoding_dim).
Shape: (num_images, num_tokens, encoding_dim).
"""

View File

@ -345,9 +345,12 @@ class InvokeAIDiffuserComponent:
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
torch.cat([ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds])
torch.stack(
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}
@ -415,9 +418,10 @@ class InvokeAIDiffuserComponent:
# Run unconditional UNet denoising.
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
ipa_conditioning.uncond_image_prompt_embeds
torch.unsqueeze(ipa_conditioning.uncond_image_prompt_embeds, dim=0)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}
@ -444,9 +448,10 @@ class InvokeAIDiffuserComponent:
# Run conditional UNet denoising.
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
# Note that we 'unsqueeze' to produce tensors of shape (batch_size=1, num_ip_images, seq_len, token_len).
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": [
ipa_conditioning.cond_image_prompt_embeds
torch.unsqueeze(ipa_conditioning.cond_image_prompt_embeds, dim=0)
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
]
}

View File

@ -41,7 +41,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
# invokeai stuff
from invokeai.app.services.config import InvokeAIAppConfig, PagingArgumentParser
from invokeai.app.services.model_manager_service import ModelManagerService
from invokeai.app.services.model_manager import ModelManagerService
from invokeai.backend.model_management.models import SubModelType
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):