Merge branch 'main' into main

This commit is contained in:
Lincoln Stein 2023-11-04 17:05:01 -04:00 committed by GitHub
commit e147379aa7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 255 additions and 90 deletions

View File

@ -108,13 +108,14 @@ class CompelInvocation(BaseInvocation):
print(f'Warn: trigger: "{trigger}" not found')
with (
ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
),
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
):
compel = Compel(
tokenizer=tokenizer,
@ -229,13 +230,14 @@ class SDXLPromptInvocationBase:
print(f'Warn: trigger: "{trigger}" not found')
with (
ModelPatcher.apply_lora(text_encoder_info.context.model, _lora_loader(), lora_prefix),
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
),
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
text_encoder_info as text_encoder,
# Apply the LoRA after text_encoder has been moved to its target device for faster patching.
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
):
compel = Compel(
tokenizer=tokenizer,

View File

@ -710,9 +710,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
with (
ExitStack() as exit_stack,
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
set_seamless(unet_info.context.model, self.unet.seamless_axes),
unet_info as unet,
# Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
):
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:

View File

@ -45,6 +45,7 @@ InvokeAI:
ram: 13.5
vram: 0.25
lazy_offload: true
log_memory_usage: false
Device:
device: auto
precision: auto
@ -261,6 +262,7 @@ class InvokeAIAppConfig(InvokeAISettings):
ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache)
# DEVICE
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device)

View File

@ -1,6 +1,6 @@
from __future__ import annotations
import copy
import pickle
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
@ -54,24 +54,6 @@ class ModelPatcher:
return (module_key, module)
@staticmethod
def _lora_forward_hook(
applied_loras: List[Tuple[LoRAModel, float]],
layer_name: str,
):
def lora_forward(module, input_h, output):
if len(applied_loras) == 0:
return output
for lora, weight in applied_loras:
layer = lora.layers.get(layer_name, None)
if layer is None:
continue
output += layer.forward(module, input_h, weight)
return output
return lora_forward
@classmethod
@contextmanager
def apply_lora_unet(
@ -129,21 +111,40 @@ class ModelPatcher:
if not layer_key.startswith(prefix):
continue
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
# should be improved in the following ways:
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
# LoRA model is applied.
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
# weights to have valid keys.
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
# All of the LoRA weight calculations will be done on the same device as the module weight.
# (Performance will be best if this is a CUDA device.)
device = module.weight.device
dtype = module.weight.dtype
if module_key not in original_weights:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
# enable autocast to calc fp16 loras on cpu
# with torch.autocast(device_type="cpu"):
layer.to(dtype=torch.float32)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
layer_weight = layer.get_weight(original_weights[module_key]) * lora_weight * layer_scale
# We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'.
layer.to(device=device)
layer.to(dtype=torch.float32)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device="cpu")
if module.weight.shape != layer_weight.shape:
# TODO: debug on lycoris
layer_weight = layer_weight.reshape(module.weight.shape)
module.weight += layer_weight.to(device=module.weight.device, dtype=module.weight.dtype)
module.weight += layer_weight.to(dtype=dtype)
yield # wait for context manager exit
@ -164,7 +165,13 @@ class ModelPatcher:
new_tokens_added = None
try:
ti_tokenizer = copy.deepcopy(tokenizer)
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
# exiting this `apply_ti(...)` context manager.
#
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
ti_manager = TextualInversionManager(ti_tokenizer)
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
@ -196,7 +203,9 @@ class ModelPatcher:
if model_embeddings.weight.data[token_id].shape != embedding.shape:
raise ValueError(
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {model_embeddings.weight.data[token_id].shape[0]}."
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
f" {embedding.shape[0]}, but the current model has token dimension"
f" {model_embeddings.weight.data[token_id].shape[0]}."
)
model_embeddings.weight.data[token_id] = embedding.to(
@ -257,7 +266,8 @@ class TextualInversionModel:
if "string_to_param" in state_dict:
if len(state_dict["string_to_param"]) > 1:
print(
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first token will be used.'
f'Warn: Embedding "{file_path.name}" contains multiple tokens, which is not supported. The first'
" token will be used."
)
result.embedding = next(iter(state_dict["string_to_param"].values()))
@ -435,7 +445,13 @@ class ONNXModelPatcher:
orig_embeddings = None
try:
ti_tokenizer = copy.deepcopy(tokenizer)
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
# exiting this `apply_ti(...)` context manager.
#
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
ti_manager = TextualInversionManager(ti_tokenizer)
def _get_trigger(ti_name, index):
@ -470,7 +486,9 @@ class ONNXModelPatcher:
if embeddings[token_id].shape != embedding.shape:
raise ValueError(
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embeddings[token_id].shape[0]}."
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension"
f" {embedding.shape[0]}, but the current model has token dimension"
f" {embeddings[token_id].shape[0]}."
)
embeddings[token_id] = embedding

View File

@ -64,7 +64,7 @@ class MemorySnapshot:
return cls(process_ram, vram, malloc_info)
def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnapshot) -> str:
def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: Optional[MemorySnapshot]) -> str:
"""Get a pretty string describing the difference between two `MemorySnapshot`s."""
def get_msg_line(prefix: str, val1: int, val2: int):
@ -73,6 +73,9 @@ def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnaps
msg = ""
if snapshot_1 is None or snapshot_2 is None:
return msg
msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram)
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:

View File

@ -117,6 +117,7 @@ class ModelCache(object):
lazy_offloading: bool = True,
sha_chunksize: int = 16777216,
logger: types.ModuleType = logger,
log_memory_usage: bool = False,
):
"""
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
@ -126,6 +127,10 @@ class ModelCache(object):
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
behaviour.
"""
self.model_infos: Dict[str, ModelBase] = dict()
# allow lazy offloading only when vram cache enabled
@ -137,6 +142,7 @@ class ModelCache(object):
self.storage_device: torch.device = storage_device
self.sha_chunksize = sha_chunksize
self.logger = logger
self._log_memory_usage = log_memory_usage
# used for stats collection
self.stats = None
@ -144,6 +150,11 @@ class ModelCache(object):
self._cached_models = dict()
self._cache_stack = list()
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
if self._log_memory_usage:
return MemorySnapshot.capture()
return None
def get_key(
self,
model_path: str,
@ -223,10 +234,10 @@ class ModelCache(object):
# Load the model from disk and capture a memory snapshot before/after.
start_load_time = time.time()
snapshot_before = MemorySnapshot.capture()
snapshot_before = self._capture_memory_snapshot()
with skip_torch_weight_init():
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
snapshot_after = MemorySnapshot.capture()
snapshot_after = self._capture_memory_snapshot()
end_load_time = time.time()
self_reported_model_size_after_load = model_info.get_size(submodel)
@ -275,9 +286,9 @@ class ModelCache(object):
return
start_model_to_time = time.time()
snapshot_before = MemorySnapshot.capture()
snapshot_before = self._capture_memory_snapshot()
cache_entry.model.to(target_device)
snapshot_after = MemorySnapshot.capture()
snapshot_after = self._capture_memory_snapshot()
end_model_to_time = time.time()
self.logger.debug(
f"Moved model '{key}' from {source_device} to"
@ -286,7 +297,12 @@ class ModelCache(object):
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
if snapshot_before.vram is not None and snapshot_after.vram is not None:
if (
snapshot_before is not None
and snapshot_after is not None
and snapshot_before.vram is not None
and snapshot_after.vram is not None
):
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
# If the estimated model size does not match the change in VRAM, log a warning.
@ -422,12 +438,17 @@ class ModelCache(object):
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
pos = 0
models_cleared = 0
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
model_key = self._cache_stack[pos]
cache_entry = self._cached_models[model_key]
refs = sys.getrefcount(cache_entry.model)
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
# https://docs.python.org/3/library/gc.html#gc.get_referrers
# manualy clear local variable references of just finished function calls
# for some reason python don't want to collect it even by gc.collect() immidiately
if refs > 2:
@ -453,15 +474,16 @@ class ModelCache(object):
f" refs: {refs}"
)
# 2 refs:
# Expected refs:
# 1 from cache_entry
# 1 from getrefcount function
# 1 from onnx runtime object
if not cache_entry.locked and refs <= 3 if "onnx" in model_key else 2:
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
self.logger.debug(
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
current_size -= cache_entry.size
models_cleared += 1
if self.stats:
self.stats.cleared += 1
del self._cache_stack[pos]
@ -471,7 +493,20 @@ class ModelCache(object):
else:
pos += 1
gc.collect()
if models_cleared > 0:
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
# is high even if no garbage gets collected.)
#
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
# - If models had to be cleared, it's a signal that we are close to our memory limit.
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
# collected.
#
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
# immediately when their reference count hits 0.
gc.collect()
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
@ -491,7 +526,6 @@ class ModelCache(object):
vram_in_use = torch.cuda.memory_allocated()
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
gc.collect()
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()

View File

@ -17,7 +17,7 @@ def skip_torch_weight_init():
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
monkey-patches common torch layers to skip the weight initialization step.
"""
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd]
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
saved_functions = [m.reset_parameters for m in torch_modules]
try:

View File

@ -351,6 +351,7 @@ class ModelManager(object):
precision=precision,
sequential_offload=sequential_offload,
logger=logger,
log_memory_usage=self.app_config.log_memory_usage,
)
self._read_models(config)

View File

@ -440,33 +440,19 @@ class IA3Layer(LoRALayerBase):
class LoRAModelRaw: # (torch.nn.Module):
_name: str
layers: Dict[str, LoRALayer]
_device: torch.device
_dtype: torch.dtype
def __init__(
self,
name: str,
layers: Dict[str, LoRALayer],
device: torch.device,
dtype: torch.dtype,
):
self._name = name
self._device = device or torch.cpu
self._dtype = dtype or torch.float32
self.layers = layers
@property
def name(self):
return self._name
@property
def device(self):
return self._device
@property
def dtype(self):
return self._dtype
def to(
self,
device: Optional[torch.device] = None,
@ -475,8 +461,6 @@ class LoRAModelRaw: # (torch.nn.Module):
# TODO: try revert if exception?
for key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
self._device = device
self._dtype = dtype
def calc_size(self) -> int:
model_size = 0
@ -557,8 +541,6 @@ class LoRAModelRaw: # (torch.nn.Module):
file_path = Path(file_path)
model = cls(
device=device,
dtype=dtype,
name=file_path.stem, # TODO:
layers=dict(),
)

View File

@ -274,9 +274,10 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
else:
interp = self.interpolations[self.merge_method.value[0]]
bases = ["sd-1", "sd-2", "sdxl"]
args = dict(
model_names=models,
base_model=tuple(BaseModelType)[self.base_select.value[0]],
base_model=BaseModelType(bases[self.base_select.value[0]]),
alpha=self.alpha.value,
interp=interp,
force=self.force.value,

View File

@ -722,7 +722,9 @@
"noMatchingModels": "No matching Models",
"noModelsAvailable": "No models available",
"selectLoRA": "Select a LoRA",
"selectModel": "Select a Model"
"selectModel": "Select a Model",
"noLoRAsInstalled": "No LoRAs installed",
"noRefinerModelsInstalled": "No SDXL Refiner models installed"
},
"nodes": {
"addNode": "Add Node",

View File

@ -16,15 +16,13 @@ const ParamDynamicPromptsCollapse = () => {
() =>
createSelector(stateSelector, ({ dynamicPrompts }) => {
const count = dynamicPrompts.prompts.length;
if (count === 1) {
return t('dynamicPrompts.promptsWithCount_one', {
count,
});
} else {
if (count > 1) {
return t('dynamicPrompts.promptsWithCount_other', {
count,
});
}
return;
}),
[t]
);

View File

@ -10,6 +10,7 @@ import { loraAdded } from 'features/lora/store/loraSlice';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
const selector = createSelector(
@ -24,7 +25,7 @@ const ParamLoRASelect = () => {
const dispatch = useAppDispatch();
const { loras } = useAppSelector(selector);
const { data: loraModels } = useGetLoRAModelsQuery();
const { t } = useTranslation();
const currentMainModel = useAppSelector(
(state: RootState) => state.generation.model
);
@ -79,7 +80,7 @@ const ParamLoRASelect = () => {
return (
<Flex sx={{ justifyContent: 'center', p: 2 }}>
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
No LoRAs Loaded
{t('models.noLoRAsInstalled')}
</Text>
</Flex>
);

View File

@ -28,9 +28,7 @@ export default function ParamAdvancedCollapse() {
const activeLabel = useMemo(() => {
const activeLabel: string[] = [];
if (shouldUseCpuNoise) {
activeLabel.push(t('parameters.cpuNoise'));
} else {
if (!shouldUseCpuNoise) {
activeLabel.push(t('parameters.gpuNoise'));
}

View File

@ -4,12 +4,13 @@ import { RootState, stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
import { defaultSelectorOptions } from 'app/store/util/defaultMemoizeOptions';
import IAICollapse from 'common/components/IAICollapse';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { useMemo } from 'react';
import { useTranslation } from 'react-i18next';
import ParamHrfHeight from './ParamHrfHeight';
import ParamHrfStrength from './ParamHrfStrength';
import ParamHrfToggle from './ParamHrfToggle';
import ParamHrfWidth from './ParamHrfWidth';
import ParamHrfHeight from './ParamHrfHeight';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
const selector = createSelector(
stateSelector,
@ -22,15 +23,14 @@ const selector = createSelector(
);
export default function ParamHrfCollapse() {
const { t } = useTranslation();
const isHRFFeatureEnabled = useFeatureStatus('hrf').isFeatureEnabled;
const { hrfEnabled } = useAppSelector(selector);
const activeLabel = useMemo(() => {
if (hrfEnabled) {
return 'On';
} else {
return 'Off';
return t('common.on');
}
}, [hrfEnabled]);
}, [t, hrfEnabled]);
if (!isHRFFeatureEnabled) {
return null;

View File

@ -1,4 +1,4 @@
import { Flex } from '@chakra-ui/react';
import { Flex, Text } from '@chakra-ui/react';
import { createSelector } from '@reduxjs/toolkit';
import { stateSelector } from 'app/store/store';
import { useAppSelector } from 'app/store/storeHooks';
@ -14,6 +14,7 @@ import ParamSDXLRefinerStart from './SDXLRefiner/ParamSDXLRefinerStart';
import ParamSDXLRefinerSteps from './SDXLRefiner/ParamSDXLRefinerSteps';
import ParamUseSDXLRefiner from './SDXLRefiner/ParamUseSDXLRefiner';
import { useTranslation } from 'react-i18next';
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
const selector = createSelector(
stateSelector,
@ -31,6 +32,19 @@ const selector = createSelector(
const ParamSDXLRefinerCollapse = () => {
const { activeLabel, shouldUseSliders } = useAppSelector(selector);
const { t } = useTranslation();
const isRefinerAvailable = useIsRefinerAvailable();
if (!isRefinerAvailable) {
return (
<IAICollapse label={t('sdxl.refiner')} activeLabel={activeLabel}>
<Flex sx={{ justifyContent: 'center', p: 2 }}>
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
{t('models.noRefinerModelsInstalled')}
</Text>
</Flex>
</IAICollapse>
);
}
return (
<IAICollapse label={t('sdxl.refiner')} activeLabel={activeLabel}>

View File

@ -0,0 +1,102 @@
# test that if the model's device changes while the lora is applied, the weights can still be restored
# test that LoRA patching works on both CPU and CUDA
import pytest
import torch
from invokeai.backend.model_management.lora import ModelPatcher
from invokeai.backend.model_management.models.lora import LoRALayer, LoRAModelRaw
@pytest.mark.parametrize(
"device",
[
"cpu",
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
],
)
@torch.no_grad()
def test_apply_lora(device):
"""Test the basic behavior of ModelPatcher.apply_lora(...). Check that patching and unpatching produce the correct
result, and that model/LoRA tensors are moved between devices as expected.
"""
linear_in_features = 4
linear_out_features = 8
lora_dim = 2
model = torch.nn.ModuleDict(
{"linear_layer_1": torch.nn.Linear(linear_in_features, linear_out_features, device=device, dtype=torch.float16)}
)
lora_layers = {
"linear_layer_1": LoRALayer(
layer_key="linear_layer_1",
values={
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
},
)
}
lora = LoRAModelRaw("lora_name", lora_layers)
lora_weight = 0.5
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
expected_patched_linear_weight = orig_linear_weight + (lora_dim * lora_weight)
with ModelPatcher.apply_lora(model, [(lora, lora_weight)], prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
# After patching, the patched model should still be on its original device.
assert model["linear_layer_1"].weight.data.device.type == device
torch.testing.assert_close(model["linear_layer_1"].weight.data, expected_patched_linear_weight)
# After unpatching, the original model weights should have been restored on the original device.
assert model["linear_layer_1"].weight.data.device.type == device
torch.testing.assert_close(model["linear_layer_1"].weight.data, orig_linear_weight)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
@torch.no_grad()
def test_apply_lora_change_device():
"""Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching
still behaves correctly.
"""
linear_in_features = 4
linear_out_features = 8
lora_dim = 2
# Initialize the model on the CPU.
model = torch.nn.ModuleDict(
{"linear_layer_1": torch.nn.Linear(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)}
)
lora_layers = {
"linear_layer_1": LoRALayer(
layer_key="linear_layer_1",
values={
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
},
)
}
lora = LoRAModelRaw("lora_name", lora_layers)
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
with ModelPatcher.apply_lora(model, [(lora, 0.5)], prefix=""):
# After patching, all LoRA layer weights should have been moved back to the cpu.
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
# After patching, the patched model should still be on the CPU.
assert model["linear_layer_1"].weight.data.device.type == "cpu"
# Move the model to the GPU.
assert model.to("cuda")
# After unpatching, the original model weights should have been restored on the GPU.
assert model["linear_layer_1"].weight.data.device.type == "cuda"
torch.testing.assert_close(model["linear_layer_1"].weight.data, orig_linear_weight, check_device=False)

View File

@ -13,10 +13,11 @@ def test_memory_snapshot_capture():
snapshots = [
MemorySnapshot(process_ram=1.0, vram=2.0, malloc_info=Struct_mallinfo2()),
MemorySnapshot(process_ram=1.0, vram=2.0, malloc_info=None),
MemorySnapshot(process_ram=1.0, vram=None, malloc_info=Struct_mallinfo2()),
MemorySnapshot(process_ram=1.0, vram=None, malloc_info=None),
MemorySnapshot(process_ram=1, vram=2, malloc_info=Struct_mallinfo2()),
MemorySnapshot(process_ram=1, vram=2, malloc_info=None),
MemorySnapshot(process_ram=1, vram=None, malloc_info=Struct_mallinfo2()),
MemorySnapshot(process_ram=1, vram=None, malloc_info=None),
None,
]
@ -26,10 +27,12 @@ def test_get_pretty_snapshot_diff(snapshot_1, snapshot_2):
"""Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields."""
msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2)
expected_lines = 1
if snapshot_1.vram is not None and snapshot_2.vram is not None:
expected_lines = 0
if snapshot_1 is not None and snapshot_2 is not None:
expected_lines += 1
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:
expected_lines += 5
if snapshot_1.vram is not None and snapshot_2.vram is not None:
expected_lines += 1
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:
expected_lines += 5
assert len(msg.splitlines()) == expected_lines

View File

@ -11,6 +11,7 @@ from invokeai.backend.model_management.model_load_optimizations import _no_op, s
(torch.nn.Conv1d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
(torch.nn.Conv2d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
(torch.nn.Conv3d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
(torch.nn.Embedding, {"num_embeddings": 10, "embedding_dim": 10}),
],
)
def test_skip_torch_weight_init_linear(torch_module, layer_args):
@ -36,12 +37,14 @@ def test_skip_torch_weight_init_linear(torch_module, layer_args):
# Check that reset_parameters is skipped while `skip_torch_weight_init()` is active.
assert reset_params_fn_during == _no_op
assert not torch.allclose(layer_before.weight, layer_during.weight)
assert not torch.allclose(layer_before.bias, layer_during.bias)
if hasattr(layer_before, "bias"):
assert not torch.allclose(layer_before.bias, layer_during.bias)
# Check that the original behavior is restored after `skip_torch_weight_init()` ends.
assert reset_params_fn_before is reset_params_fn_after
assert torch.allclose(layer_before.weight, layer_after.weight)
assert torch.allclose(layer_before.bias, layer_after.bias)
if hasattr(layer_before, "bias"):
assert torch.allclose(layer_before.bias, layer_after.bias)
def test_skip_torch_weight_init_restores_base_class_behavior():