Merge branch 'main' into main

This commit is contained in:
Lincoln Stein 2023-11-04 17:05:01 -04:00 committed by GitHub
commit e147379aa7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 255 additions and 90 deletions

View File

@ -108,13 +108,14 @@ class CompelInvocation(BaseInvocation):
print(f'Warn: trigger: "{trigger}" not found') print(f'Warn: trigger: "{trigger}" not found')
with ( with (
ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer, tokenizer,
ti_manager, ti_manager,
), ),
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
text_encoder_info as text_encoder, text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
): ):
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=tokenizer,
@ -229,13 +230,14 @@ class SDXLPromptInvocationBase:
print(f'Warn: trigger: "{trigger}" not found') print(f'Warn: trigger: "{trigger}" not found')
with ( with (
ModelPatcher.apply_lora(text_encoder_info.context.model, _lora_loader(), lora_prefix),
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer, tokenizer,
ti_manager, ti_manager,
), ),
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
text_encoder_info as text_encoder, text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
): ):
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=tokenizer,

View File

@ -710,9 +710,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
) )
with ( with (
ExitStack() as exit_stack, ExitStack() as exit_stack,
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
set_seamless(unet_info.context.model, self.unet.seamless_axes), set_seamless(unet_info.context.model, self.unet.seamless_axes),
unet_info as unet, unet_info as unet,
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
): ):
latents = latents.to(device=unet.device, dtype=unet.dtype) latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None: if noise is not None:

View File

@ -45,6 +45,7 @@ InvokeAI:
ram: 13.5 ram: 13.5
vram: 0.25 vram: 0.25
lazy_offload: true lazy_offload: true
log_memory_usage: false
Device: Device:
device: auto device: auto
precision: auto precision: auto
@ -261,6 +262,7 @@ class InvokeAIAppConfig(InvokeAISettings):
ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, ) vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, ) lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache)
# DEVICE # DEVICE
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device) device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device)

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
import copy import pickle
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
@ -54,24 +54,6 @@ class ModelPatcher:
return (module_key, module) return (module_key, module)
@staticmethod
def _lora_forward_hook(
applied_loras: List[Tuple[LoRAModel, float]],
layer_name: str,
):
def lora_forward(module, input_h, output):
if len(applied_loras) == 0:
return output
for lora, weight in applied_loras:
layer = lora.layers.get(layer_name, None)
if layer is None:
continue
output += layer.forward(module, input_h, weight)
return output
return lora_forward
@classmethod @classmethod
@contextmanager @contextmanager
def apply_lora_unet( def apply_lora_unet(
@ -129,21 +111,40 @@ class ModelPatcher:
if not layer_key.startswith(prefix): if not layer_key.startswith(prefix):
continue continue
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
# should be improved in the following ways:
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
# LoRA model is applied.
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
# weights to have valid keys.
module_key, module = cls._resolve_lora_key(model, layer_key, prefix) module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
if module_key not in original_weights: if module_key not in original_weights:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
# enable autocast to calc fp16 loras on cpu
# with torch.autocast(device_type="cpu"):
layer.to(dtype=torch.float32)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight(original_weights[module_key]) * lora_weight * layer_scale
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device="cpu")
if module.weight.shape != layer_weight.shape: if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris # TODO: debug on lycoris
layer_weight = layer_weight.reshape(module.weight.shape) layer_weight = layer_weight.reshape(module.weight.shape)
module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype) module.weight += layer_weight.to(dtype=dtype)
yield # wait for context manager exit yield # wait for context manager exit
@ -164,7 +165,13 @@ class ModelPatcher:
new_tokens_added = None new_tokens_added = None
try: try:
ti_tokenizer = copy.deepcopy(tokenizer) # HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
# exiting this `apply_ti(...)` context manager.
#
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
ti_manager = TextualInversionManager(ti_tokenizer) ti_manager = TextualInversionManager(ti_tokenizer)
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
@ -196,7 +203,9 @@ class ModelPatcher:
if model_embeddings.weight.data[token_id].shape != embedding.shape: if model_embeddings.weight.data[token_id].shape != embedding.shape:
raise ValueError( raise ValueError(
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}." f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
f" {embedding.shape[0]}, but the current model has token dimension"
f" {model_embeddings.weight.data[token_id].shape[0]}."
) )
model_embeddings.weight.data[token_id] = embedding.to( model_embeddings.weight.data[token_id] = embedding.to(
@ -257,7 +266,8 @@ class TextualInversionModel:
if "string_to_param" in state_dict: if "string_to_param" in state_dict:
if len(state_dict["string_to_param"]) > 1: if len(state_dict["string_to_param"]) > 1:
print( print(
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first token will be used.' f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first'
" token will be used."
) )
result.embedding = next(iter(state_dict["string_to_param"].values())) result.embedding = next(iter(state_dict["string_to_param"].values()))
@ -435,7 +445,13 @@ class ONNXModelPatcher:
orig_embeddings = None orig_embeddings = None
try: try:
ti_tokenizer = copy.deepcopy(tokenizer) # HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
# exiting this `apply_ti(...)` context manager.
#
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
ti_manager = TextualInversionManager(ti_tokenizer) ti_manager = TextualInversionManager(ti_tokenizer)
def _get_trigger(ti_name, index): def _get_trigger(ti_name, index):
@ -470,7 +486,9 @@ class ONNXModelPatcher:
if embeddings[token_id].shape != embedding.shape: if embeddings[token_id].shape != embedding.shape:
raise ValueError( raise ValueError(
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embeddings[token_id].shape[0]}." f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
f" {embedding.shape[0]}, but the current model has token dimension"
f" {embeddings[token_id].shape[0]}."
) )
embeddings[token_id] = embedding embeddings[token_id] = embedding

View File

@ -64,7 +64,7 @@ class MemorySnapshot:
return cls(process_ram, vram, malloc_info) return cls(process_ram, vram, malloc_info)
def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnapshot) -> str: def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: Optional[MemorySnapshot]) -> str:
"""Get a pretty string describing the difference between two `MemorySnapshot`s.""" """Get a pretty string describing the difference between two `MemorySnapshot`s."""
def get_msg_line(prefix: str, val1: int, val2: int): def get_msg_line(prefix: str, val1: int, val2: int):
@ -73,6 +73,9 @@ def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnaps
msg = "" msg = ""
if snapshot_1 is None or snapshot_2 is None:
return msg
msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram) msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram)
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None: if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:

View File

@ -117,6 +117,7 @@ class ModelCache(object):
lazy_offloading: bool = True, lazy_offloading: bool = True,
sha_chunksize: int = 16777216, sha_chunksize: int = 16777216,
logger: types.ModuleType = logger, logger: types.ModuleType = logger,
log_memory_usage: bool = False,
): ):
""" """
:param max_cache_size: Maximum size of the RAM cache [6.0 GB] :param max_cache_size: Maximum size of the RAM cache [6.0 GB]
@ -126,6 +127,10 @@ class ModelCache(object):
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded :param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param sha_chunksize: Chunksize to use when calculating sha256 model hash :param sha_chunksize: Chunksize to use when calculating sha256 model hash
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
""" """
self.model_infos: Dict[str, ModelBase] = dict() self.model_infos: Dict[str, ModelBase] = dict()
# allow lazy offloading only when vram cache enabled # allow lazy offloading only when vram cache enabled
@ -137,6 +142,7 @@ class ModelCache(object):
self.storage_device: torch.device = storage_device self.storage_device: torch.device = storage_device
self.sha_chunksize = sha_chunksize self.sha_chunksize = sha_chunksize
self.logger = logger self.logger = logger
self._log_memory_usage = log_memory_usage
# used for stats collection # used for stats collection
self.stats = None self.stats = None
@ -144,6 +150,11 @@ class ModelCache(object):
self._cached_models = dict() self._cached_models = dict()
self._cache_stack = list() self._cache_stack = list()
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
return MemorySnapshot.capture()
return None
def get_key( def get_key(
self, self,
model_path: str, model_path: str,
@ -223,10 +234,10 @@ class ModelCache(object):
# Load the model from disk and capture a memory snapshot before/after. # Load the model from disk and capture a memory snapshot before/after.
start_load_time = time.time() start_load_time = time.time()
snapshot_before = MemorySnapshot.capture() snapshot_before = self._capture_memory_snapshot()
with skip_torch_weight_init(): with skip_torch_weight_init():
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision) model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
snapshot_after = MemorySnapshot.capture() snapshot_after = self._capture_memory_snapshot()
end_load_time = time.time() end_load_time = time.time()
self_reported_model_size_after_load = model_info.get_size(submodel) self_reported_model_size_after_load = model_info.get_size(submodel)
@ -275,9 +286,9 @@ class ModelCache(object):
return return
start_model_to_time = time.time() start_model_to_time = time.time()
snapshot_before = MemorySnapshot.capture() snapshot_before = self._capture_memory_snapshot()
cache_entry.model.to(target_device) cache_entry.model.to(target_device)
snapshot_after = MemorySnapshot.capture() snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time() end_model_to_time = time.time()
self.logger.debug( self.logger.debug(
f"Moved model '{key}' from {source_device} to" f"Moved model '{key}' from {source_device} to"
@ -286,7 +297,12 @@ class ModelCache(object):
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}" f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
) )
if snapshot_before.vram is not None and snapshot_after.vram is not None: if (
snapshot_before is not None
and snapshot_after is not None
and snapshot_before.vram is not None
and snapshot_after.vram is not None
):
vram_change = abs(snapshot_before.vram - snapshot_after.vram) vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# If the estimated model size does not match the change in VRAM, log a warning. # If the estimated model size does not match the change in VRAM, log a warning.
@ -422,12 +438,17 @@ class ModelCache(object):
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}") self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
pos = 0 pos = 0
models_cleared = 0
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack): while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos] model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key] cache_entry = self._cached_models[model_key]
refs = sys.getrefcount(cache_entry.model) refs = sys.getrefcount(cache_entry.model)
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
# https://docs.python.org/3/library/gc.html#gc.get_referrers
# manualy clear local variable references of just finished function calls # manualy clear local variable references of just finished function calls
# for some reason python don't want to collect it even by gc.collect() immidiately # for some reason python don't want to collect it even by gc.collect() immidiately
if refs > 2: if refs > 2:
@ -453,15 +474,16 @@ class ModelCache(object):
f" refs: {refs}" f" refs: {refs}"
) )
# 2 refs: # Expected refs:
# 1 from cache_entry # 1 from cache_entry
# 1 from getrefcount function # 1 from getrefcount function
# 1 from onnx runtime object # 1 from onnx runtime object
if not cache_entry.locked and refs <= 3 if "onnx" in model_key else 2: if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
self.logger.debug( self.logger.debug(
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)" f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
) )
current_size -= cache_entry.size current_size -= cache_entry.size
models_cleared += 1
if self.stats: if self.stats:
self.stats.cleared += 1 self.stats.cleared += 1
del self._cache_stack[pos] del self._cache_stack[pos]
@ -471,7 +493,20 @@ class ModelCache(object):
else: else:
pos += 1 pos += 1
if models_cleared > 0:
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
# is high even if no garbage gets collected.)
#
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
# - If models had to be cleared, it's a signal that we are close to our memory limit.
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
# collected.
#
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
# immediately when their reference count hits 0.
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"): if choose_torch_device() == torch.device("mps"):
mps.empty_cache() mps.empty_cache()
@ -491,7 +526,6 @@ class ModelCache(object):
vram_in_use = torch.cuda.memory_allocated() vram_in_use = torch.cuda.memory_allocated()
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB") self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"): if choose_torch_device() == torch.device("mps"):
mps.empty_cache() mps.empty_cache()

View File

@ -17,7 +17,7 @@ def skip_torch_weight_init():
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
monkey-patches common torch layers to skip the weight initialization step. monkey-patches common torch layers to skip the weight initialization step.
""" """
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd] torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
saved_functions = [m.reset_parameters for m in torch_modules] saved_functions = [m.reset_parameters for m in torch_modules]
try: try:

View File

@ -351,6 +351,7 @@ class ModelManager(object):
precision=precision, precision=precision,
sequential_offload=sequential_offload, sequential_offload=sequential_offload,
logger=logger, logger=logger,
log_memory_usage=self.app_config.log_memory_usage,
) )
self._read_models(config) self._read_models(config)

View File

@ -440,33 +440,19 @@ class IA3Layer(LoRALayerBase):
class LoRAModelRaw: # (torch.nn.Module): class LoRAModelRaw: # (torch.nn.Module):
_name: str _name: str
layers: Dict[str, LoRALayer] layers: Dict[str, LoRALayer]
_device: torch.device
_dtype: torch.dtype
def __init__( def __init__(
self, self,
name: str, name: str,
layers: Dict[str, LoRALayer], layers: Dict[str, LoRALayer],
device: torch.device,
dtype: torch.dtype,
): ):
self._name = name self._name = name
self._device = device or torch.cpu
self._dtype = dtype or torch.float32
self.layers = layers self.layers = layers
@property @property
def name(self): def name(self):
return self._name return self._name
@property
def device(self):
return self._device
@property
def dtype(self):
return self._dtype
def to( def to(
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
@ -475,8 +461,6 @@ class LoRAModelRaw: # (torch.nn.Module):
# TODO: try revert if exception? # TODO: try revert if exception?
for key, layer in self.layers.items(): for key, layer in self.layers.items():
layer.to(device=device, dtype=dtype) layer.to(device=device, dtype=dtype)
self._device = device
self._dtype = dtype
def calc_size(self) -> int: def calc_size(self) -> int:
model_size = 0 model_size = 0
@ -557,8 +541,6 @@ class LoRAModelRaw: # (torch.nn.Module):
file_path = Path(file_path) file_path = Path(file_path)
model = cls( model = cls(
device=device,
dtype=dtype,
name=file_path.stem, # TODO: name=file_path.stem, # TODO:
layers=dict(), layers=dict(),
) )

View File

@ -274,9 +274,10 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
else: else:
interp = self.interpolations[self.merge_method.value[0]] interp = self.interpolations[self.merge_method.value[0]]
bases = ["sd-1", "sd-2", "sdxl"]
args = dict( args = dict(
model_names=models, model_names=models,
base_model=tuple(BaseModelType)[self.base_select.value[0]], base_model=BaseModelType(bases[self.base_select.value[0]]),
alpha=self.alpha.value, alpha=self.alpha.value,
interp=interp, interp=interp,
force=self.force.value, force=self.force.value,

View File

@ -722,7 +722,9 @@
"noMatchingModels": "No matching Models", "noMatchingModels": "No matching Models",
"noModelsAvailable": "No models available", "noModelsAvailable": "No models available",
"selectLoRA": "Select a LoRA", "selectLoRA": "Select a LoRA",
"selectModel": "Select a Model" "selectModel": "Select a Model",
"noLoRAsInstalled": "No LoRAs installed",
"noRefinerModelsInstalled": "No SDXL Refiner models installed"
}, },
"nodes": { "nodes": {
"addNode": "Add Node", "addNode": "Add Node",

View File

@ -16,15 +16,13 @@ const ParamDynamicPromptsCollapse = () => {
() => () =>
createSelector(stateSelector, ({ dynamicPrompts }) => { createSelector(stateSelector, ({ dynamicPrompts }) => {
const count = dynamicPrompts.prompts.length; const count = dynamicPrompts.prompts.length;
if (count === 1) { if (count > 1) {
return t('dynamicPrompts.promptsWithCount_one', {
count,
});
} else {
return t('dynamicPrompts.promptsWithCount_other', { return t('dynamicPrompts.promptsWithCount_other', {
count, count,
}); });
} }
return;
}), }),
[t] [t]
); );

View File

@ -10,6 +10,7 @@ import { loraAdded } from 'features/lora/store/loraSlice';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants'; import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach } from 'lodash-es'; import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models'; import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
const selector = createSelector( const selector = createSelector(
@ -24,7 +25,7 @@ const ParamLoRASelect = () => {
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const { loras } = useAppSelector(selector); const { loras } = useAppSelector(selector);
const { data: loraModels } = useGetLoRAModelsQuery(); const { data: loraModels } = useGetLoRAModelsQuery();
const { t } = useTranslation();
const currentMainModel = useAppSelector( const currentMainModel = useAppSelector(
(state: RootState) => state.generation.model (state: RootState) => state.generation.model
); );
@ -79,7 +80,7 @@ const ParamLoRASelect = () => {
return ( return (
<Flex sx={{ justifyContent: 'center', p: 2 }}> <Flex sx={{ justifyContent: 'center', p: 2 }}>
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}> <Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
No LoRAs Loaded {t('models.noLoRAsInstalled')}
</Text> </Text>
</Flex> </Flex>
); );

View File

@ -28,9 +28,7 @@ export default function ParamAdvancedCollapse() {
const activeLabel = useMemo(() => { const activeLabel = useMemo(() => {
const activeLabel: string[] = []; const activeLabel: string[] = [];
if (shouldUseCpuNoise) { if (!shouldUseCpuNoise) {
activeLabel.push(t('parameters.cpuNoise'));
} else {
activeLabel.push(t('parameters.gpuNoise')); activeLabel.push(t('parameters.gpuNoise'));
} }

View File

@ -4,12 +4,13 @@ import { RootState, stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions'; import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse'; import IAICollapse from 'common/components/IAICollapse';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useMemo } from 'react'; import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import ParamHrfHeight from './ParamHrfHeight';
import ParamHrfStrength from './ParamHrfStrength'; import ParamHrfStrength from './ParamHrfStrength';
import ParamHrfToggle from './ParamHrfToggle'; import ParamHrfToggle from './ParamHrfToggle';
import ParamHrfWidth from './ParamHrfWidth'; import ParamHrfWidth from './ParamHrfWidth';
import ParamHrfHeight from './ParamHrfHeight';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
@ -22,15 +23,14 @@ const selector = createSelector(
); );
export default function ParamHrfCollapse() { export default function ParamHrfCollapse() {
const { t } = useTranslation();
const isHRFFeatureEnabled = useFeatureStatus('hrf').isFeatureEnabled; const isHRFFeatureEnabled = useFeatureStatus('hrf').isFeatureEnabled;
const { hrfEnabled } = useAppSelector(selector); const { hrfEnabled } = useAppSelector(selector);
const activeLabel = useMemo(() => { const activeLabel = useMemo(() => {
if (hrfEnabled) { if (hrfEnabled) {
return 'On'; return t('common.on');
} else {
return 'Off';
} }
}, [hrfEnabled]); }, [t, hrfEnabled]);
if (!isHRFFeatureEnabled) { if (!isHRFFeatureEnabled) {
return null; return null;

View File

@ -1,4 +1,4 @@
import { Flex } from '@chakra-ui/react'; import { Flex, Text } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit'; import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store'; import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks'; import { useAppSelector } from 'app/store/storeHooks';
@ -14,6 +14,7 @@ import ParamSDXLRefinerStart from './SDXLRefiner/ParamSDXLRefinerStart';
import ParamSDXLRefinerSteps from './SDXLRefiner/ParamSDXLRefinerSteps'; import ParamSDXLRefinerSteps from './SDXLRefiner/ParamSDXLRefinerSteps';
import ParamUseSDXLRefiner from './SDXLRefiner/ParamUseSDXLRefiner'; import ParamUseSDXLRefiner from './SDXLRefiner/ParamUseSDXLRefiner';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector( const selector = createSelector(
stateSelector, stateSelector,
@ -31,6 +32,19 @@ const selector = createSelector(
const ParamSDXLRefinerCollapse = () => { const ParamSDXLRefinerCollapse = () => {
const { activeLabel, shouldUseSliders } = useAppSelector(selector); const { activeLabel, shouldUseSliders } = useAppSelector(selector);
const { t } = useTranslation(); const { t } = useTranslation();
const isRefinerAvailable = useIsRefinerAvailable();
if (!isRefinerAvailable) {
return (
<IAICollapse label={t('sdxl.refiner')} activeLabel={activeLabel}>
<Flex sx={{ justifyContent: 'center', p: 2 }}>
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
{t('models.noRefinerModelsInstalled')}
</Text>
</Flex>
</IAICollapse>
);
}
return ( return (
<IAICollapse label={t('sdxl.refiner')} activeLabel={activeLabel}> <IAICollapse label={t('sdxl.refiner')} activeLabel={activeLabel}>

View File

@ -0,0 +1,102 @@
# test that if the model's device changes while the lora is applied, the weights can still be restored
# test that LoRA patching works on both CPU and CUDA
import pytest
import torch
from invokeai.backend.model_management.lora import ModelPatcher
from invokeai.backend.model_management.models.lora import LoRALayer, LoRAModelRaw
@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
],
)
@torch.no_grad()
def test_apply_lora(device):
"""Test the basic behavior of ModelPatcher.apply_lora(...). Check that patching and unpatching produce the correct
result, and that model/LoRA tensors are moved between devices as expected.
"""
linear_in_features = 4
linear_out_features = 8
lora_dim = 2
model = torch.nn.ModuleDict(
{"linear_layer_1": torch.nn.Linear(linear_in_features, linear_out_features, device=device, dtype=torch.float16)}
)
lora_layers = {
"linear_layer_1": LoRALayer(
layer_key="linear_layer_1",
values={
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
},
)
}
lora = LoRAModelRaw("lora_name", lora_layers)
lora_weight = 0.5
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
expected_patched_linear_weight = orig_linear_weight + (lora_dim * lora_weight)
with ModelPatcher.apply_lora(model, [(lora, lora_weight)], prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
# After patching, the patched model should still be on its original device.
assert model["linear_layer_1"].weight.data.device.type == device
torch.testing.assert_close(model["linear_layer_1"].weight.data, expected_patched_linear_weight)
# After unpatching, the original model weights should have been restored on the original device.
assert model["linear_layer_1"].weight.data.device.type == device
torch.testing.assert_close(model["linear_layer_1"].weight.data, orig_linear_weight)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
@torch.no_grad()
def test_apply_lora_change_device():
"""Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching
still behaves correctly.
"""
linear_in_features = 4
linear_out_features = 8
lora_dim = 2
# Initialize the model on the CPU.
model = torch.nn.ModuleDict(
{"linear_layer_1": torch.nn.Linear(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)}
)
lora_layers = {
"linear_layer_1": LoRALayer(
layer_key="linear_layer_1",
values={
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
},
)
}
lora = LoRAModelRaw("lora_name", lora_layers)
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
with ModelPatcher.apply_lora(model, [(lora, 0.5)], prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
# After patching, the patched model should still be on the CPU.
assert model["linear_layer_1"].weight.data.device.type == "cpu"
# Move the model to the GPU.
assert model.to("cuda")
# After unpatching, the original model weights should have been restored on the GPU.
assert model["linear_layer_1"].weight.data.device.type == "cuda"
torch.testing.assert_close(model["linear_layer_1"].weight.data, orig_linear_weight, check_device=False)

View File

@ -13,10 +13,11 @@ def test_memory_snapshot_capture():
snapshots = [ snapshots = [
MemorySnapshot(process_ram=1.0, vram=2.0, malloc_info=Struct_mallinfo2()), MemorySnapshot(process_ram=1, vram=2, malloc_info=Struct_mallinfo2()),
MemorySnapshot(process_ram=1.0, vram=2.0, malloc_info=None), MemorySnapshot(process_ram=1, vram=2, malloc_info=None),
MemorySnapshot(process_ram=1.0, vram=None, malloc_info=Struct_mallinfo2()), MemorySnapshot(process_ram=1, vram=None, malloc_info=Struct_mallinfo2()),
MemorySnapshot(process_ram=1.0, vram=None, malloc_info=None), MemorySnapshot(process_ram=1, vram=None, malloc_info=None),
None,
] ]
@ -26,7 +27,9 @@ def test_get_pretty_snapshot_diff(snapshot_1, snapshot_2):
"""Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields.""" """Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields."""
msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2) msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2)
expected_lines = 1 expected_lines = 0
if snapshot_1 is not None and snapshot_2 is not None:
expected_lines += 1
if snapshot_1.vram is not None and snapshot_2.vram is not None: if snapshot_1.vram is not None and snapshot_2.vram is not None:
expected_lines += 1 expected_lines += 1
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None: if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:

View File

@ -11,6 +11,7 @@ from invokeai.backend.model_management.model_load_optimizations import _no_op, s
(torch.nn.Conv1d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}), (torch.nn.Conv1d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
(torch.nn.Conv2d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}), (torch.nn.Conv2d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
(torch.nn.Conv3d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}), (torch.nn.Conv3d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
(torch.nn.Embedding, {"num_embeddings": 10, "embedding_dim": 10}),
], ],
) )
def test_skip_torch_weight_init_linear(torch_module, layer_args): def test_skip_torch_weight_init_linear(torch_module, layer_args):
@ -36,11 +37,13 @@ def test_skip_torch_weight_init_linear(torch_module, layer_args):
# Check that reset_parameters is skipped while `skip_torch_weight_init()` is active. # Check that reset_parameters is skipped while `skip_torch_weight_init()` is active.
assert reset_params_fn_during == _no_op assert reset_params_fn_during == _no_op
assert not torch.allclose(layer_before.weight, layer_during.weight) assert not torch.allclose(layer_before.weight, layer_during.weight)
if hasattr(layer_before, "bias"):
assert not torch.allclose(layer_before.bias, layer_during.bias) assert not torch.allclose(layer_before.bias, layer_during.bias)
# Check that the original behavior is restored after `skip_torch_weight_init()` ends. # Check that the original behavior is restored after `skip_torch_weight_init()` ends.
assert reset_params_fn_before is reset_params_fn_after assert reset_params_fn_before is reset_params_fn_after
assert torch.allclose(layer_before.weight, layer_after.weight) assert torch.allclose(layer_before.weight, layer_after.weight)
if hasattr(layer_before, "bias"):
assert torch.allclose(layer_before.bias, layer_after.bias) assert torch.allclose(layer_before.bias, layer_after.bias)