mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'main' into main
This commit is contained in:
commit
e147379aa7
@ -108,13 +108,14 @@ class CompelInvocation(BaseInvocation):
|
|||||||
print(f'Warn: trigger: "{trigger}" not found')
|
print(f'Warn: trigger: "{trigger}" not found')
|
||||||
|
|
||||||
with (
|
with (
|
||||||
ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),
|
|
||||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||||
tokenizer,
|
tokenizer,
|
||||||
ti_manager,
|
ti_manager,
|
||||||
),
|
),
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
|
||||||
text_encoder_info as text_encoder,
|
text_encoder_info as text_encoder,
|
||||||
|
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||||
|
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
||||||
):
|
):
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
@ -229,13 +230,14 @@ class SDXLPromptInvocationBase:
|
|||||||
print(f'Warn: trigger: "{trigger}" not found')
|
print(f'Warn: trigger: "{trigger}" not found')
|
||||||
|
|
||||||
with (
|
with (
|
||||||
ModelPatcher.apply_lora(text_encoder_info.context.model, _lora_loader(), lora_prefix),
|
|
||||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||||
tokenizer,
|
tokenizer,
|
||||||
ti_manager,
|
ti_manager,
|
||||||
),
|
),
|
||||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
|
||||||
text_encoder_info as text_encoder,
|
text_encoder_info as text_encoder,
|
||||||
|
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
|
||||||
|
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
||||||
):
|
):
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -710,9 +710,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
with (
|
with (
|
||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
|
|
||||||
set_seamless(unet_info.context.model, self.unet.seamless_axes),
|
set_seamless(unet_info.context.model, self.unet.seamless_axes),
|
||||||
unet_info as unet,
|
unet_info as unet,
|
||||||
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
|
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
||||||
):
|
):
|
||||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
if noise is not None:
|
if noise is not None:
|
||||||
|
@ -45,6 +45,7 @@ InvokeAI:
|
|||||||
ram: 13.5
|
ram: 13.5
|
||||||
vram: 0.25
|
vram: 0.25
|
||||||
lazy_offload: true
|
lazy_offload: true
|
||||||
|
log_memory_usage: false
|
||||||
Device:
|
Device:
|
||||||
device: auto
|
device: auto
|
||||||
precision: auto
|
precision: auto
|
||||||
@ -261,6 +262,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
|||||||
ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||||
vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
|
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
|
||||||
|
log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache)
|
||||||
|
|
||||||
# DEVICE
|
# DEVICE
|
||||||
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device)
|
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import pickle
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
@ -54,24 +54,6 @@ class ModelPatcher:
|
|||||||
|
|
||||||
return (module_key, module)
|
return (module_key, module)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _lora_forward_hook(
|
|
||||||
applied_loras: List[Tuple[LoRAModel, float]],
|
|
||||||
layer_name: str,
|
|
||||||
):
|
|
||||||
def lora_forward(module, input_h, output):
|
|
||||||
if len(applied_loras) == 0:
|
|
||||||
return output
|
|
||||||
|
|
||||||
for lora, weight in applied_loras:
|
|
||||||
layer = lora.layers.get(layer_name, None)
|
|
||||||
if layer is None:
|
|
||||||
continue
|
|
||||||
output += layer.forward(module, input_h, weight)
|
|
||||||
return output
|
|
||||||
|
|
||||||
return lora_forward
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def apply_lora_unet(
|
def apply_lora_unet(
|
||||||
@ -129,21 +111,40 @@ class ModelPatcher:
|
|||||||
if not layer_key.startswith(prefix):
|
if not layer_key.startswith(prefix):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
|
||||||
|
# should be improved in the following ways:
|
||||||
|
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
|
||||||
|
# LoRA model is applied.
|
||||||
|
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
|
||||||
|
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
|
||||||
|
# weights to have valid keys.
|
||||||
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
|
||||||
|
|
||||||
|
# All of the LoRA weight calculations will be done on the same device as the module weight.
|
||||||
|
# (Performance will be best if this is a CUDA device.)
|
||||||
|
device = module.weight.device
|
||||||
|
dtype = module.weight.dtype
|
||||||
|
|
||||||
if module_key not in original_weights:
|
if module_key not in original_weights:
|
||||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||||
|
|
||||||
# enable autocast to calc fp16 loras on cpu
|
|
||||||
# with torch.autocast(device_type="cpu"):
|
|
||||||
layer.to(dtype=torch.float32)
|
|
||||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||||
layer_weight = layer.get_weight(original_weights[module_key]) * lora_weight * layer_scale
|
|
||||||
|
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||||
|
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||||
|
# same thing in a single call to '.to(...)'.
|
||||||
|
layer.to(device=device)
|
||||||
|
layer.to(dtype=torch.float32)
|
||||||
|
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||||
|
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||||
|
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
||||||
|
layer.to(device="cpu")
|
||||||
|
|
||||||
if module.weight.shape != layer_weight.shape:
|
if module.weight.shape != layer_weight.shape:
|
||||||
# TODO: debug on lycoris
|
# TODO: debug on lycoris
|
||||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||||
|
|
||||||
module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype)
|
module.weight += layer_weight.to(dtype=dtype)
|
||||||
|
|
||||||
yield # wait for context manager exit
|
yield # wait for context manager exit
|
||||||
|
|
||||||
@ -164,7 +165,13 @@ class ModelPatcher:
|
|||||||
new_tokens_added = None
|
new_tokens_added = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
|
||||||
|
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
|
||||||
|
# exiting this `apply_ti(...)` context manager.
|
||||||
|
#
|
||||||
|
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
|
||||||
|
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
|
||||||
|
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
|
||||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||||
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
|
||||||
|
|
||||||
@ -196,7 +203,9 @@ class ModelPatcher:
|
|||||||
|
|
||||||
if model_embeddings.weight.data[token_id].shape != embedding.shape:
|
if model_embeddings.weight.data[token_id].shape != embedding.shape:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
|
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
|
||||||
|
f" {embedding.shape[0]}, but the current model has token dimension"
|
||||||
|
f" {model_embeddings.weight.data[token_id].shape[0]}."
|
||||||
)
|
)
|
||||||
|
|
||||||
model_embeddings.weight.data[token_id] = embedding.to(
|
model_embeddings.weight.data[token_id] = embedding.to(
|
||||||
@ -257,7 +266,8 @@ class TextualInversionModel:
|
|||||||
if "string_to_param" in state_dict:
|
if "string_to_param" in state_dict:
|
||||||
if len(state_dict["string_to_param"]) > 1:
|
if len(state_dict["string_to_param"]) > 1:
|
||||||
print(
|
print(
|
||||||
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first token will be used.'
|
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first'
|
||||||
|
" token will be used."
|
||||||
)
|
)
|
||||||
|
|
||||||
result.embedding = next(iter(state_dict["string_to_param"].values()))
|
result.embedding = next(iter(state_dict["string_to_param"].values()))
|
||||||
@ -435,7 +445,13 @@ class ONNXModelPatcher:
|
|||||||
orig_embeddings = None
|
orig_embeddings = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
|
||||||
|
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
|
||||||
|
# exiting this `apply_ti(...)` context manager.
|
||||||
|
#
|
||||||
|
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
|
||||||
|
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
|
||||||
|
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
|
||||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||||
|
|
||||||
def _get_trigger(ti_name, index):
|
def _get_trigger(ti_name, index):
|
||||||
@ -470,7 +486,9 @@ class ONNXModelPatcher:
|
|||||||
|
|
||||||
if embeddings[token_id].shape != embedding.shape:
|
if embeddings[token_id].shape != embedding.shape:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embeddings[token_id].shape[0]}."
|
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
|
||||||
|
f" {embedding.shape[0]}, but the current model has token dimension"
|
||||||
|
f" {embeddings[token_id].shape[0]}."
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings[token_id] = embedding
|
embeddings[token_id] = embedding
|
||||||
|
@ -64,7 +64,7 @@ class MemorySnapshot:
|
|||||||
return cls(process_ram, vram, malloc_info)
|
return cls(process_ram, vram, malloc_info)
|
||||||
|
|
||||||
|
|
||||||
def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnapshot) -> str:
|
def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: Optional[MemorySnapshot]) -> str:
|
||||||
"""Get a pretty string describing the difference between two `MemorySnapshot`s."""
|
"""Get a pretty string describing the difference between two `MemorySnapshot`s."""
|
||||||
|
|
||||||
def get_msg_line(prefix: str, val1: int, val2: int):
|
def get_msg_line(prefix: str, val1: int, val2: int):
|
||||||
@ -73,6 +73,9 @@ def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnaps
|
|||||||
|
|
||||||
msg = ""
|
msg = ""
|
||||||
|
|
||||||
|
if snapshot_1 is None or snapshot_2 is None:
|
||||||
|
return msg
|
||||||
|
|
||||||
msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram)
|
msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram)
|
||||||
|
|
||||||
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:
|
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:
|
||||||
|
@ -117,6 +117,7 @@ class ModelCache(object):
|
|||||||
lazy_offloading: bool = True,
|
lazy_offloading: bool = True,
|
||||||
sha_chunksize: int = 16777216,
|
sha_chunksize: int = 16777216,
|
||||||
logger: types.ModuleType = logger,
|
logger: types.ModuleType = logger,
|
||||||
|
log_memory_usage: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||||
@ -126,6 +127,10 @@ class ModelCache(object):
|
|||||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||||
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
||||||
|
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||||
|
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||||
|
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||||
|
behaviour.
|
||||||
"""
|
"""
|
||||||
self.model_infos: Dict[str, ModelBase] = dict()
|
self.model_infos: Dict[str, ModelBase] = dict()
|
||||||
# allow lazy offloading only when vram cache enabled
|
# allow lazy offloading only when vram cache enabled
|
||||||
@ -137,6 +142,7 @@ class ModelCache(object):
|
|||||||
self.storage_device: torch.device = storage_device
|
self.storage_device: torch.device = storage_device
|
||||||
self.sha_chunksize = sha_chunksize
|
self.sha_chunksize = sha_chunksize
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
self._log_memory_usage = log_memory_usage
|
||||||
|
|
||||||
# used for stats collection
|
# used for stats collection
|
||||||
self.stats = None
|
self.stats = None
|
||||||
@ -144,6 +150,11 @@ class ModelCache(object):
|
|||||||
self._cached_models = dict()
|
self._cached_models = dict()
|
||||||
self._cache_stack = list()
|
self._cache_stack = list()
|
||||||
|
|
||||||
|
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||||
|
if self._log_memory_usage:
|
||||||
|
return MemorySnapshot.capture()
|
||||||
|
return None
|
||||||
|
|
||||||
def get_key(
|
def get_key(
|
||||||
self,
|
self,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
@ -223,10 +234,10 @@ class ModelCache(object):
|
|||||||
|
|
||||||
# Load the model from disk and capture a memory snapshot before/after.
|
# Load the model from disk and capture a memory snapshot before/after.
|
||||||
start_load_time = time.time()
|
start_load_time = time.time()
|
||||||
snapshot_before = MemorySnapshot.capture()
|
snapshot_before = self._capture_memory_snapshot()
|
||||||
with skip_torch_weight_init():
|
with skip_torch_weight_init():
|
||||||
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
||||||
snapshot_after = MemorySnapshot.capture()
|
snapshot_after = self._capture_memory_snapshot()
|
||||||
end_load_time = time.time()
|
end_load_time = time.time()
|
||||||
|
|
||||||
self_reported_model_size_after_load = model_info.get_size(submodel)
|
self_reported_model_size_after_load = model_info.get_size(submodel)
|
||||||
@ -275,9 +286,9 @@ class ModelCache(object):
|
|||||||
return
|
return
|
||||||
|
|
||||||
start_model_to_time = time.time()
|
start_model_to_time = time.time()
|
||||||
snapshot_before = MemorySnapshot.capture()
|
snapshot_before = self._capture_memory_snapshot()
|
||||||
cache_entry.model.to(target_device)
|
cache_entry.model.to(target_device)
|
||||||
snapshot_after = MemorySnapshot.capture()
|
snapshot_after = self._capture_memory_snapshot()
|
||||||
end_model_to_time = time.time()
|
end_model_to_time = time.time()
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Moved model '{key}' from {source_device} to"
|
f"Moved model '{key}' from {source_device} to"
|
||||||
@ -286,7 +297,12 @@ class ModelCache(object):
|
|||||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if snapshot_before.vram is not None and snapshot_after.vram is not None:
|
if (
|
||||||
|
snapshot_before is not None
|
||||||
|
and snapshot_after is not None
|
||||||
|
and snapshot_before.vram is not None
|
||||||
|
and snapshot_after.vram is not None
|
||||||
|
):
|
||||||
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
||||||
|
|
||||||
# If the estimated model size does not match the change in VRAM, log a warning.
|
# If the estimated model size does not match the change in VRAM, log a warning.
|
||||||
@ -422,12 +438,17 @@ class ModelCache(object):
|
|||||||
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
|
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
|
||||||
|
|
||||||
pos = 0
|
pos = 0
|
||||||
|
models_cleared = 0
|
||||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||||
model_key = self._cache_stack[pos]
|
model_key = self._cache_stack[pos]
|
||||||
cache_entry = self._cached_models[model_key]
|
cache_entry = self._cached_models[model_key]
|
||||||
|
|
||||||
refs = sys.getrefcount(cache_entry.model)
|
refs = sys.getrefcount(cache_entry.model)
|
||||||
|
|
||||||
|
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
|
||||||
|
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
|
||||||
|
# https://docs.python.org/3/library/gc.html#gc.get_referrers
|
||||||
|
|
||||||
# manualy clear local variable references of just finished function calls
|
# manualy clear local variable references of just finished function calls
|
||||||
# for some reason python don't want to collect it even by gc.collect() immidiately
|
# for some reason python don't want to collect it even by gc.collect() immidiately
|
||||||
if refs > 2:
|
if refs > 2:
|
||||||
@ -453,15 +474,16 @@ class ModelCache(object):
|
|||||||
f" refs: {refs}"
|
f" refs: {refs}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2 refs:
|
# Expected refs:
|
||||||
# 1 from cache_entry
|
# 1 from cache_entry
|
||||||
# 1 from getrefcount function
|
# 1 from getrefcount function
|
||||||
# 1 from onnx runtime object
|
# 1 from onnx runtime object
|
||||||
if not cache_entry.locked and refs <= 3 if "onnx" in model_key else 2:
|
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||||
)
|
)
|
||||||
current_size -= cache_entry.size
|
current_size -= cache_entry.size
|
||||||
|
models_cleared += 1
|
||||||
if self.stats:
|
if self.stats:
|
||||||
self.stats.cleared += 1
|
self.stats.cleared += 1
|
||||||
del self._cache_stack[pos]
|
del self._cache_stack[pos]
|
||||||
@ -471,7 +493,20 @@ class ModelCache(object):
|
|||||||
else:
|
else:
|
||||||
pos += 1
|
pos += 1
|
||||||
|
|
||||||
|
if models_cleared > 0:
|
||||||
|
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
|
||||||
|
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
|
||||||
|
# is high even if no garbage gets collected.)
|
||||||
|
#
|
||||||
|
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
|
||||||
|
# - If models had to be cleared, it's a signal that we are close to our memory limit.
|
||||||
|
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
|
||||||
|
# collected.
|
||||||
|
#
|
||||||
|
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
|
||||||
|
# immediately when their reference count hits 0.
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
if choose_torch_device() == torch.device("mps"):
|
||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
@ -491,7 +526,6 @@ class ModelCache(object):
|
|||||||
vram_in_use = torch.cuda.memory_allocated()
|
vram_in_use = torch.cuda.memory_allocated()
|
||||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
||||||
|
|
||||||
gc.collect()
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
if choose_torch_device() == torch.device("mps"):
|
if choose_torch_device() == torch.device("mps"):
|
||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
@ -17,7 +17,7 @@ def skip_torch_weight_init():
|
|||||||
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
|
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
|
||||||
monkey-patches common torch layers to skip the weight initialization step.
|
monkey-patches common torch layers to skip the weight initialization step.
|
||||||
"""
|
"""
|
||||||
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd]
|
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
|
||||||
saved_functions = [m.reset_parameters for m in torch_modules]
|
saved_functions = [m.reset_parameters for m in torch_modules]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -351,6 +351,7 @@ class ModelManager(object):
|
|||||||
precision=precision,
|
precision=precision,
|
||||||
sequential_offload=sequential_offload,
|
sequential_offload=sequential_offload,
|
||||||
logger=logger,
|
logger=logger,
|
||||||
|
log_memory_usage=self.app_config.log_memory_usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._read_models(config)
|
self._read_models(config)
|
||||||
|
@ -440,33 +440,19 @@ class IA3Layer(LoRALayerBase):
|
|||||||
class LoRAModelRaw: # (torch.nn.Module):
|
class LoRAModelRaw: # (torch.nn.Module):
|
||||||
_name: str
|
_name: str
|
||||||
layers: Dict[str, LoRALayer]
|
layers: Dict[str, LoRALayer]
|
||||||
_device: torch.device
|
|
||||||
_dtype: torch.dtype
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
layers: Dict[str, LoRALayer],
|
layers: Dict[str, LoRALayer],
|
||||||
device: torch.device,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
):
|
):
|
||||||
self._name = name
|
self._name = name
|
||||||
self._device = device or torch.cpu
|
|
||||||
self._dtype = dtype or torch.float32
|
|
||||||
self.layers = layers
|
self.layers = layers
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
return self._device
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self._dtype
|
|
||||||
|
|
||||||
def to(
|
def to(
|
||||||
self,
|
self,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
@ -475,8 +461,6 @@ class LoRAModelRaw: # (torch.nn.Module):
|
|||||||
# TODO: try revert if exception?
|
# TODO: try revert if exception?
|
||||||
for key, layer in self.layers.items():
|
for key, layer in self.layers.items():
|
||||||
layer.to(device=device, dtype=dtype)
|
layer.to(device=device, dtype=dtype)
|
||||||
self._device = device
|
|
||||||
self._dtype = dtype
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
model_size = 0
|
model_size = 0
|
||||||
@ -557,8 +541,6 @@ class LoRAModelRaw: # (torch.nn.Module):
|
|||||||
file_path = Path(file_path)
|
file_path = Path(file_path)
|
||||||
|
|
||||||
model = cls(
|
model = cls(
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
name=file_path.stem, # TODO:
|
name=file_path.stem, # TODO:
|
||||||
layers=dict(),
|
layers=dict(),
|
||||||
)
|
)
|
||||||
|
@ -274,9 +274,10 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
|||||||
else:
|
else:
|
||||||
interp = self.interpolations[self.merge_method.value[0]]
|
interp = self.interpolations[self.merge_method.value[0]]
|
||||||
|
|
||||||
|
bases = ["sd-1", "sd-2", "sdxl"]
|
||||||
args = dict(
|
args = dict(
|
||||||
model_names=models,
|
model_names=models,
|
||||||
base_model=tuple(BaseModelType)[self.base_select.value[0]],
|
base_model=BaseModelType(bases[self.base_select.value[0]]),
|
||||||
alpha=self.alpha.value,
|
alpha=self.alpha.value,
|
||||||
interp=interp,
|
interp=interp,
|
||||||
force=self.force.value,
|
force=self.force.value,
|
||||||
|
@ -722,7 +722,9 @@
|
|||||||
"noMatchingModels": "No matching Models",
|
"noMatchingModels": "No matching Models",
|
||||||
"noModelsAvailable": "No models available",
|
"noModelsAvailable": "No models available",
|
||||||
"selectLoRA": "Select a LoRA",
|
"selectLoRA": "Select a LoRA",
|
||||||
"selectModel": "Select a Model"
|
"selectModel": "Select a Model",
|
||||||
|
"noLoRAsInstalled": "No LoRAs installed",
|
||||||
|
"noRefinerModelsInstalled": "No SDXL Refiner models installed"
|
||||||
},
|
},
|
||||||
"nodes": {
|
"nodes": {
|
||||||
"addNode": "Add Node",
|
"addNode": "Add Node",
|
||||||
|
@ -16,15 +16,13 @@ const ParamDynamicPromptsCollapse = () => {
|
|||||||
() =>
|
() =>
|
||||||
createSelector(stateSelector, ({ dynamicPrompts }) => {
|
createSelector(stateSelector, ({ dynamicPrompts }) => {
|
||||||
const count = dynamicPrompts.prompts.length;
|
const count = dynamicPrompts.prompts.length;
|
||||||
if (count === 1) {
|
if (count > 1) {
|
||||||
return t('dynamicPrompts.promptsWithCount_one', {
|
|
||||||
count,
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
return t('dynamicPrompts.promptsWithCount_other', {
|
return t('dynamicPrompts.promptsWithCount_other', {
|
||||||
count,
|
count,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
}),
|
}),
|
||||||
[t]
|
[t]
|
||||||
);
|
);
|
||||||
|
@ -10,6 +10,7 @@ import { loraAdded } from 'features/lora/store/loraSlice';
|
|||||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||||
import { forEach } from 'lodash-es';
|
import { forEach } from 'lodash-es';
|
||||||
import { memo, useCallback, useMemo } from 'react';
|
import { memo, useCallback, useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
@ -24,7 +25,7 @@ const ParamLoRASelect = () => {
|
|||||||
const dispatch = useAppDispatch();
|
const dispatch = useAppDispatch();
|
||||||
const { loras } = useAppSelector(selector);
|
const { loras } = useAppSelector(selector);
|
||||||
const { data: loraModels } = useGetLoRAModelsQuery();
|
const { data: loraModels } = useGetLoRAModelsQuery();
|
||||||
|
const { t } = useTranslation();
|
||||||
const currentMainModel = useAppSelector(
|
const currentMainModel = useAppSelector(
|
||||||
(state: RootState) => state.generation.model
|
(state: RootState) => state.generation.model
|
||||||
);
|
);
|
||||||
@ -79,7 +80,7 @@ const ParamLoRASelect = () => {
|
|||||||
return (
|
return (
|
||||||
<Flex sx={{ justifyContent: 'center', p: 2 }}>
|
<Flex sx={{ justifyContent: 'center', p: 2 }}>
|
||||||
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
|
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
|
||||||
No LoRAs Loaded
|
{t('models.noLoRAsInstalled')}
|
||||||
</Text>
|
</Text>
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
|
@ -28,9 +28,7 @@ export default function ParamAdvancedCollapse() {
|
|||||||
const activeLabel = useMemo(() => {
|
const activeLabel = useMemo(() => {
|
||||||
const activeLabel: string[] = [];
|
const activeLabel: string[] = [];
|
||||||
|
|
||||||
if (shouldUseCpuNoise) {
|
if (!shouldUseCpuNoise) {
|
||||||
activeLabel.push(t('parameters.cpuNoise'));
|
|
||||||
} else {
|
|
||||||
activeLabel.push(t('parameters.gpuNoise'));
|
activeLabel.push(t('parameters.gpuNoise'));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,12 +4,13 @@ import { RootState, stateSelector } from 'app/store/store';
|
|||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
|
||||||
import IAICollapse from 'common/components/IAICollapse';
|
import IAICollapse from 'common/components/IAICollapse';
|
||||||
|
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
||||||
import { useMemo } from 'react';
|
import { useMemo } from 'react';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import ParamHrfHeight from './ParamHrfHeight';
|
||||||
import ParamHrfStrength from './ParamHrfStrength';
|
import ParamHrfStrength from './ParamHrfStrength';
|
||||||
import ParamHrfToggle from './ParamHrfToggle';
|
import ParamHrfToggle from './ParamHrfToggle';
|
||||||
import ParamHrfWidth from './ParamHrfWidth';
|
import ParamHrfWidth from './ParamHrfWidth';
|
||||||
import ParamHrfHeight from './ParamHrfHeight';
|
|
||||||
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
|
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
@ -22,15 +23,14 @@ const selector = createSelector(
|
|||||||
);
|
);
|
||||||
|
|
||||||
export default function ParamHrfCollapse() {
|
export default function ParamHrfCollapse() {
|
||||||
|
const { t } = useTranslation();
|
||||||
const isHRFFeatureEnabled = useFeatureStatus('hrf').isFeatureEnabled;
|
const isHRFFeatureEnabled = useFeatureStatus('hrf').isFeatureEnabled;
|
||||||
const { hrfEnabled } = useAppSelector(selector);
|
const { hrfEnabled } = useAppSelector(selector);
|
||||||
const activeLabel = useMemo(() => {
|
const activeLabel = useMemo(() => {
|
||||||
if (hrfEnabled) {
|
if (hrfEnabled) {
|
||||||
return 'On';
|
return t('common.on');
|
||||||
} else {
|
|
||||||
return 'Off';
|
|
||||||
}
|
}
|
||||||
}, [hrfEnabled]);
|
}, [t, hrfEnabled]);
|
||||||
|
|
||||||
if (!isHRFFeatureEnabled) {
|
if (!isHRFFeatureEnabled) {
|
||||||
return null;
|
return null;
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
import { Flex } from '@chakra-ui/react';
|
import { Flex, Text } from '@chakra-ui/react';
|
||||||
import { createSelector } from '@reduxjs/toolkit';
|
import { createSelector } from '@reduxjs/toolkit';
|
||||||
import { stateSelector } from 'app/store/store';
|
import { stateSelector } from 'app/store/store';
|
||||||
import { useAppSelector } from 'app/store/storeHooks';
|
import { useAppSelector } from 'app/store/storeHooks';
|
||||||
@ -14,6 +14,7 @@ import ParamSDXLRefinerStart from './SDXLRefiner/ParamSDXLRefinerStart';
|
|||||||
import ParamSDXLRefinerSteps from './SDXLRefiner/ParamSDXLRefinerSteps';
|
import ParamSDXLRefinerSteps from './SDXLRefiner/ParamSDXLRefinerSteps';
|
||||||
import ParamUseSDXLRefiner from './SDXLRefiner/ParamUseSDXLRefiner';
|
import ParamUseSDXLRefiner from './SDXLRefiner/ParamUseSDXLRefiner';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||||
|
|
||||||
const selector = createSelector(
|
const selector = createSelector(
|
||||||
stateSelector,
|
stateSelector,
|
||||||
@ -31,6 +32,19 @@ const selector = createSelector(
|
|||||||
const ParamSDXLRefinerCollapse = () => {
|
const ParamSDXLRefinerCollapse = () => {
|
||||||
const { activeLabel, shouldUseSliders } = useAppSelector(selector);
|
const { activeLabel, shouldUseSliders } = useAppSelector(selector);
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
const isRefinerAvailable = useIsRefinerAvailable();
|
||||||
|
|
||||||
|
if (!isRefinerAvailable) {
|
||||||
|
return (
|
||||||
|
<IAICollapse label={t('sdxl.refiner')} activeLabel={activeLabel}>
|
||||||
|
<Flex sx={{ justifyContent: 'center', p: 2 }}>
|
||||||
|
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
|
||||||
|
{t('models.noRefinerModelsInstalled')}
|
||||||
|
</Text>
|
||||||
|
</Flex>
|
||||||
|
</IAICollapse>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<IAICollapse label={t('sdxl.refiner')} activeLabel={activeLabel}>
|
<IAICollapse label={t('sdxl.refiner')} activeLabel={activeLabel}>
|
||||||
|
102
tests/backend/model_management/test_lora.py
Normal file
102
tests/backend/model_management/test_lora.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
# test that if the model's device changes while the lora is applied, the weights can still be restored
|
||||||
|
|
||||||
|
# test that LoRA patching works on both CPU and CUDA
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.model_management.lora import ModelPatcher
|
||||||
|
from invokeai.backend.model_management.models.lora import LoRALayer, LoRAModelRaw
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"device",
|
||||||
|
[
|
||||||
|
"cpu",
|
||||||
|
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_apply_lora(device):
|
||||||
|
"""Test the basic behavior of ModelPatcher.apply_lora(...). Check that patching and unpatching produce the correct
|
||||||
|
result, and that model/LoRA tensors are moved between devices as expected.
|
||||||
|
"""
|
||||||
|
|
||||||
|
linear_in_features = 4
|
||||||
|
linear_out_features = 8
|
||||||
|
lora_dim = 2
|
||||||
|
model = torch.nn.ModuleDict(
|
||||||
|
{"linear_layer_1": torch.nn.Linear(linear_in_features, linear_out_features, device=device, dtype=torch.float16)}
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_layers = {
|
||||||
|
"linear_layer_1": LoRALayer(
|
||||||
|
layer_key="linear_layer_1",
|
||||||
|
values={
|
||||||
|
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
|
||||||
|
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
lora = LoRAModelRaw("lora_name", lora_layers)
|
||||||
|
|
||||||
|
lora_weight = 0.5
|
||||||
|
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
|
||||||
|
expected_patched_linear_weight = orig_linear_weight + (lora_dim * lora_weight)
|
||||||
|
|
||||||
|
with ModelPatcher.apply_lora(model, [(lora, lora_weight)], prefix=""):
|
||||||
|
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||||
|
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
|
||||||
|
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
|
||||||
|
|
||||||
|
# After patching, the patched model should still be on its original device.
|
||||||
|
assert model["linear_layer_1"].weight.data.device.type == device
|
||||||
|
|
||||||
|
torch.testing.assert_close(model["linear_layer_1"].weight.data, expected_patched_linear_weight)
|
||||||
|
|
||||||
|
# After unpatching, the original model weights should have been restored on the original device.
|
||||||
|
assert model["linear_layer_1"].weight.data.device.type == device
|
||||||
|
torch.testing.assert_close(model["linear_layer_1"].weight.data, orig_linear_weight)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
|
||||||
|
@torch.no_grad()
|
||||||
|
def test_apply_lora_change_device():
|
||||||
|
"""Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching
|
||||||
|
still behaves correctly.
|
||||||
|
"""
|
||||||
|
linear_in_features = 4
|
||||||
|
linear_out_features = 8
|
||||||
|
lora_dim = 2
|
||||||
|
# Initialize the model on the CPU.
|
||||||
|
model = torch.nn.ModuleDict(
|
||||||
|
{"linear_layer_1": torch.nn.Linear(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)}
|
||||||
|
)
|
||||||
|
|
||||||
|
lora_layers = {
|
||||||
|
"linear_layer_1": LoRALayer(
|
||||||
|
layer_key="linear_layer_1",
|
||||||
|
values={
|
||||||
|
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
|
||||||
|
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
lora = LoRAModelRaw("lora_name", lora_layers)
|
||||||
|
|
||||||
|
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
|
||||||
|
|
||||||
|
with ModelPatcher.apply_lora(model, [(lora, 0.5)], prefix=""):
|
||||||
|
# After patching, all LoRA layer weights should have been moved back to the cpu.
|
||||||
|
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
|
||||||
|
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
|
||||||
|
|
||||||
|
# After patching, the patched model should still be on the CPU.
|
||||||
|
assert model["linear_layer_1"].weight.data.device.type == "cpu"
|
||||||
|
|
||||||
|
# Move the model to the GPU.
|
||||||
|
assert model.to("cuda")
|
||||||
|
|
||||||
|
# After unpatching, the original model weights should have been restored on the GPU.
|
||||||
|
assert model["linear_layer_1"].weight.data.device.type == "cuda"
|
||||||
|
torch.testing.assert_close(model["linear_layer_1"].weight.data, orig_linear_weight, check_device=False)
|
@ -13,10 +13,11 @@ def test_memory_snapshot_capture():
|
|||||||
|
|
||||||
|
|
||||||
snapshots = [
|
snapshots = [
|
||||||
MemorySnapshot(process_ram=1.0, vram=2.0, malloc_info=Struct_mallinfo2()),
|
MemorySnapshot(process_ram=1, vram=2, malloc_info=Struct_mallinfo2()),
|
||||||
MemorySnapshot(process_ram=1.0, vram=2.0, malloc_info=None),
|
MemorySnapshot(process_ram=1, vram=2, malloc_info=None),
|
||||||
MemorySnapshot(process_ram=1.0, vram=None, malloc_info=Struct_mallinfo2()),
|
MemorySnapshot(process_ram=1, vram=None, malloc_info=Struct_mallinfo2()),
|
||||||
MemorySnapshot(process_ram=1.0, vram=None, malloc_info=None),
|
MemorySnapshot(process_ram=1, vram=None, malloc_info=None),
|
||||||
|
None,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -26,7 +27,9 @@ def test_get_pretty_snapshot_diff(snapshot_1, snapshot_2):
|
|||||||
"""Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields."""
|
"""Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields."""
|
||||||
msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2)
|
msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2)
|
||||||
|
|
||||||
expected_lines = 1
|
expected_lines = 0
|
||||||
|
if snapshot_1 is not None and snapshot_2 is not None:
|
||||||
|
expected_lines += 1
|
||||||
if snapshot_1.vram is not None and snapshot_2.vram is not None:
|
if snapshot_1.vram is not None and snapshot_2.vram is not None:
|
||||||
expected_lines += 1
|
expected_lines += 1
|
||||||
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:
|
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:
|
||||||
|
@ -11,6 +11,7 @@ from invokeai.backend.model_management.model_load_optimizations import _no_op, s
|
|||||||
(torch.nn.Conv1d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
(torch.nn.Conv1d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
||||||
(torch.nn.Conv2d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
(torch.nn.Conv2d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
||||||
(torch.nn.Conv3d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
(torch.nn.Conv3d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
||||||
|
(torch.nn.Embedding, {"num_embeddings": 10, "embedding_dim": 10}),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_skip_torch_weight_init_linear(torch_module, layer_args):
|
def test_skip_torch_weight_init_linear(torch_module, layer_args):
|
||||||
@ -36,11 +37,13 @@ def test_skip_torch_weight_init_linear(torch_module, layer_args):
|
|||||||
# Check that reset_parameters is skipped while `skip_torch_weight_init()` is active.
|
# Check that reset_parameters is skipped while `skip_torch_weight_init()` is active.
|
||||||
assert reset_params_fn_during == _no_op
|
assert reset_params_fn_during == _no_op
|
||||||
assert not torch.allclose(layer_before.weight, layer_during.weight)
|
assert not torch.allclose(layer_before.weight, layer_during.weight)
|
||||||
|
if hasattr(layer_before, "bias"):
|
||||||
assert not torch.allclose(layer_before.bias, layer_during.bias)
|
assert not torch.allclose(layer_before.bias, layer_during.bias)
|
||||||
|
|
||||||
# Check that the original behavior is restored after `skip_torch_weight_init()` ends.
|
# Check that the original behavior is restored after `skip_torch_weight_init()` ends.
|
||||||
assert reset_params_fn_before is reset_params_fn_after
|
assert reset_params_fn_before is reset_params_fn_after
|
||||||
assert torch.allclose(layer_before.weight, layer_after.weight)
|
assert torch.allclose(layer_before.weight, layer_after.weight)
|
||||||
|
if hasattr(layer_before, "bias"):
|
||||||
assert torch.allclose(layer_before.bias, layer_after.bias)
|
assert torch.allclose(layer_before.bias, layer_after.bias)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user