mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
16 Commits
install/re
...
bugfix/cli
Author | SHA1 | Date | |
---|---|---|---|
ec1e66dcd3 | |||
69543c23d0 | |||
2bbba323c6 | |||
aa02ebf8f5 | |||
fb3d0c4b12 | |||
8488ab0134 | |||
875231ed3d | |||
43b300498f | |||
5b420653f9 | |||
3d32ce2b58 | |||
e391f3c9a8 | |||
6e7a3f0546 | |||
4a683cc669 | |||
3781e56e57 | |||
267e709ba2 | |||
8ff49109a8 |
@ -45,6 +45,7 @@ InvokeAI:
|
||||
ram: 13.5
|
||||
vram: 0.25
|
||||
lazy_offload: true
|
||||
log_memory_usage: false
|
||||
Device:
|
||||
device: auto
|
||||
precision: auto
|
||||
@ -261,6 +262,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
ram : float = Field(default=7.5, gt=0, description="Maximum memory amount used by model cache for rapid switching (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||
vram : float = Field(default=0.25, ge=0, description="Amount of VRAM reserved for model storage (floating point number, GB)", json_schema_extra=Categories.ModelCache, )
|
||||
lazy_offload : bool = Field(default=True, description="Keep models in VRAM until their space is needed", json_schema_extra=Categories.ModelCache, )
|
||||
log_memory_usage : bool = Field(default=False, description="If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to only enable this feature if you are actively inspecting the model cache's behaviour.", json_schema_extra=Categories.ModelCache)
|
||||
|
||||
# DEVICE
|
||||
device : Literal["auto", "cpu", "cuda", "cuda:1", "mps"] = Field(default="auto", description="Generation device", json_schema_extra=Categories.Device)
|
||||
|
@ -64,7 +64,7 @@ class MemorySnapshot:
|
||||
return cls(process_ram, vram, malloc_info)
|
||||
|
||||
|
||||
def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnapshot) -> str:
|
||||
def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: Optional[MemorySnapshot]) -> str:
|
||||
"""Get a pretty string describing the difference between two `MemorySnapshot`s."""
|
||||
|
||||
def get_msg_line(prefix: str, val1: int, val2: int):
|
||||
@ -73,6 +73,9 @@ def get_pretty_snapshot_diff(snapshot_1: MemorySnapshot, snapshot_2: MemorySnaps
|
||||
|
||||
msg = ""
|
||||
|
||||
if snapshot_1 is None or snapshot_2 is None:
|
||||
return msg
|
||||
|
||||
msg += get_msg_line("Process RAM", snapshot_1.process_ram, snapshot_2.process_ram)
|
||||
|
||||
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:
|
||||
|
@ -117,6 +117,7 @@ class ModelCache(object):
|
||||
lazy_offloading: bool = True,
|
||||
sha_chunksize: int = 16777216,
|
||||
logger: types.ModuleType = logger,
|
||||
log_memory_usage: bool = False,
|
||||
):
|
||||
"""
|
||||
:param max_cache_size: Maximum size of the RAM cache [6.0 GB]
|
||||
@ -126,6 +127,10 @@ class ModelCache(object):
|
||||
:param lazy_offloading: Keep model in VRAM until another model needs to be loaded
|
||||
:param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially
|
||||
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
|
||||
:param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache
|
||||
operation, and the result will be logged (at debug level). There is a time cost to capturing the memory
|
||||
snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's
|
||||
behaviour.
|
||||
"""
|
||||
self.model_infos: Dict[str, ModelBase] = dict()
|
||||
# allow lazy offloading only when vram cache enabled
|
||||
@ -137,6 +142,7 @@ class ModelCache(object):
|
||||
self.storage_device: torch.device = storage_device
|
||||
self.sha_chunksize = sha_chunksize
|
||||
self.logger = logger
|
||||
self._log_memory_usage = log_memory_usage
|
||||
|
||||
# used for stats collection
|
||||
self.stats = None
|
||||
@ -144,6 +150,11 @@ class ModelCache(object):
|
||||
self._cached_models = dict()
|
||||
self._cache_stack = list()
|
||||
|
||||
def _capture_memory_snapshot(self) -> Optional[MemorySnapshot]:
|
||||
if self._log_memory_usage:
|
||||
return MemorySnapshot.capture()
|
||||
return None
|
||||
|
||||
def get_key(
|
||||
self,
|
||||
model_path: str,
|
||||
@ -223,10 +234,10 @@ class ModelCache(object):
|
||||
|
||||
# Load the model from disk and capture a memory snapshot before/after.
|
||||
start_load_time = time.time()
|
||||
snapshot_before = MemorySnapshot.capture()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
with skip_torch_weight_init():
|
||||
model = model_info.get_model(child_type=submodel, torch_dtype=self.precision)
|
||||
snapshot_after = MemorySnapshot.capture()
|
||||
snapshot_after = self._capture_memory_snapshot()
|
||||
end_load_time = time.time()
|
||||
|
||||
self_reported_model_size_after_load = model_info.get_size(submodel)
|
||||
@ -275,9 +286,9 @@ class ModelCache(object):
|
||||
return
|
||||
|
||||
start_model_to_time = time.time()
|
||||
snapshot_before = MemorySnapshot.capture()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
cache_entry.model.to(target_device)
|
||||
snapshot_after = MemorySnapshot.capture()
|
||||
snapshot_after = self._capture_memory_snapshot()
|
||||
end_model_to_time = time.time()
|
||||
self.logger.debug(
|
||||
f"Moved model '{key}' from {source_device} to"
|
||||
@ -286,7 +297,12 @@ class ModelCache(object):
|
||||
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
|
||||
)
|
||||
|
||||
if snapshot_before.vram is not None and snapshot_after.vram is not None:
|
||||
if (
|
||||
snapshot_before is not None
|
||||
and snapshot_after is not None
|
||||
and snapshot_before.vram is not None
|
||||
and snapshot_after.vram is not None
|
||||
):
|
||||
vram_change = abs(snapshot_before.vram - snapshot_after.vram)
|
||||
|
||||
# If the estimated model size does not match the change in VRAM, log a warning.
|
||||
@ -422,12 +438,17 @@ class ModelCache(object):
|
||||
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
|
||||
|
||||
pos = 0
|
||||
models_cleared = 0
|
||||
while current_size + bytes_needed > maximum_size and pos < len(self._cache_stack):
|
||||
model_key = self._cache_stack[pos]
|
||||
cache_entry = self._cached_models[model_key]
|
||||
|
||||
refs = sys.getrefcount(cache_entry.model)
|
||||
|
||||
# HACK: This is a workaround for a memory-management issue that we haven't tracked down yet. We are directly
|
||||
# going against the advice in the Python docs by using `gc.get_referrers(...)` in this way:
|
||||
# https://docs.python.org/3/library/gc.html#gc.get_referrers
|
||||
|
||||
# manualy clear local variable references of just finished function calls
|
||||
# for some reason python don't want to collect it even by gc.collect() immidiately
|
||||
if refs > 2:
|
||||
@ -453,15 +474,16 @@ class ModelCache(object):
|
||||
f" refs: {refs}"
|
||||
)
|
||||
|
||||
# 2 refs:
|
||||
# Expected refs:
|
||||
# 1 from cache_entry
|
||||
# 1 from getrefcount function
|
||||
# 1 from onnx runtime object
|
||||
if not cache_entry.locked and refs <= 3 if "onnx" in model_key else 2:
|
||||
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
|
||||
self.logger.debug(
|
||||
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
models_cleared += 1
|
||||
if self.stats:
|
||||
self.stats.cleared += 1
|
||||
del self._cache_stack[pos]
|
||||
@ -471,7 +493,20 @@ class ModelCache(object):
|
||||
else:
|
||||
pos += 1
|
||||
|
||||
gc.collect()
|
||||
if models_cleared > 0:
|
||||
# There would likely be some 'garbage' to be collected regardless of whether a model was cleared or not, but
|
||||
# there is a significant time cost to calling `gc.collect()`, so we want to use it sparingly. (The time cost
|
||||
# is high even if no garbage gets collected.)
|
||||
#
|
||||
# Calling gc.collect(...) when a model is cleared seems like a good middle-ground:
|
||||
# - If models had to be cleared, it's a signal that we are close to our memory limit.
|
||||
# - If models were cleared, there's a good chance that there's a significant amount of garbage to be
|
||||
# collected.
|
||||
#
|
||||
# Keep in mind that gc is only responsible for handling reference cycles. Most objects should be cleaned up
|
||||
# immediately when their reference count hits 0.
|
||||
gc.collect()
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
@ -491,7 +526,6 @@ class ModelCache(object):
|
||||
vram_in_use = torch.cuda.memory_allocated()
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
@ -17,7 +17,7 @@ def skip_torch_weight_init():
|
||||
completely unnecessary if the intent is to load checkpoint weights from disk for the layer. This context manager
|
||||
monkey-patches common torch layers to skip the weight initialization step.
|
||||
"""
|
||||
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd]
|
||||
torch_modules = [torch.nn.Linear, torch.nn.modules.conv._ConvNd, torch.nn.Embedding]
|
||||
saved_functions = [m.reset_parameters for m in torch_modules]
|
||||
|
||||
try:
|
||||
|
@ -351,6 +351,7 @@ class ModelManager(object):
|
||||
precision=precision,
|
||||
sequential_offload=sequential_offload,
|
||||
logger=logger,
|
||||
log_memory_usage=self.app_config.log_memory_usage,
|
||||
)
|
||||
|
||||
self._read_models(config)
|
||||
|
@ -90,6 +90,7 @@ def _parse_args() -> Namespace:
|
||||
# ------------------------- GUI HERE -------------------------
|
||||
class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
interpolations = ["weighted_sum", "sigmoid", "inv_sigmoid"]
|
||||
bases = ["sd-1", "sd-2", "sdxl"]
|
||||
|
||||
def __init__(self, parentApp, name):
|
||||
self.parentApp = parentApp
|
||||
@ -276,7 +277,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
|
||||
args = dict(
|
||||
model_names=models,
|
||||
base_model=tuple(BaseModelType)[self.base_select.value[0]],
|
||||
base_model=BaseModelType(self.bases[self.base_select.value[0]]),
|
||||
alpha=self.alpha.value,
|
||||
interp=interp,
|
||||
force=self.force.value,
|
||||
@ -319,8 +320,7 @@ class mergeModelsForm(npyscreen.FormMultiPageAction):
|
||||
return sorted(model_names)
|
||||
|
||||
def _populate_models(self, value=None):
|
||||
bases = ["sd-1", "sd-2", "sdxl"]
|
||||
base_model = BaseModelType(bases[value[0]])
|
||||
base_model = BaseModelType(self.bases[value[0]])
|
||||
self.model_names = self.get_model_names(base_model)
|
||||
|
||||
models_plus_none = self.model_names.copy()
|
||||
|
@ -722,7 +722,9 @@
|
||||
"noMatchingModels": "No matching Models",
|
||||
"noModelsAvailable": "No models available",
|
||||
"selectLoRA": "Select a LoRA",
|
||||
"selectModel": "Select a Model"
|
||||
"selectModel": "Select a Model",
|
||||
"noLoRAsInstalled": "No LoRAs installed",
|
||||
"noRefinerModelsInstalled": "No SDXL Refiner models installed"
|
||||
},
|
||||
"nodes": {
|
||||
"addNode": "Add Node",
|
||||
|
@ -10,6 +10,7 @@ import { loraAdded } from 'features/lora/store/loraSlice';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useGetLoRAModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const selector = createSelector(
|
||||
@ -24,7 +25,7 @@ const ParamLoRASelect = () => {
|
||||
const dispatch = useAppDispatch();
|
||||
const { loras } = useAppSelector(selector);
|
||||
const { data: loraModels } = useGetLoRAModelsQuery();
|
||||
|
||||
const { t } = useTranslation();
|
||||
const currentMainModel = useAppSelector(
|
||||
(state: RootState) => state.generation.model
|
||||
);
|
||||
@ -79,7 +80,7 @@ const ParamLoRASelect = () => {
|
||||
return (
|
||||
<Flex sx={{ justifyContent: 'center', p: 2 }}>
|
||||
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
|
||||
No LoRAs Loaded
|
||||
{t('models.noLoRAsInstalled')}
|
||||
</Text>
|
||||
</Flex>
|
||||
);
|
||||
|
@ -1,4 +1,4 @@
|
||||
import { Flex } from '@chakra-ui/react';
|
||||
import { Flex, Text } from '@chakra-ui/react';
|
||||
import { createSelector } from '@reduxjs/toolkit';
|
||||
import { stateSelector } from 'app/store/store';
|
||||
import { useAppSelector } from 'app/store/storeHooks';
|
||||
@ -14,6 +14,7 @@ import ParamSDXLRefinerStart from './SDXLRefiner/ParamSDXLRefinerStart';
|
||||
import ParamSDXLRefinerSteps from './SDXLRefiner/ParamSDXLRefinerSteps';
|
||||
import ParamUseSDXLRefiner from './SDXLRefiner/ParamUseSDXLRefiner';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useIsRefinerAvailable } from 'services/api/hooks/useIsRefinerAvailable';
|
||||
|
||||
const selector = createSelector(
|
||||
stateSelector,
|
||||
@ -31,6 +32,19 @@ const selector = createSelector(
|
||||
const ParamSDXLRefinerCollapse = () => {
|
||||
const { activeLabel, shouldUseSliders } = useAppSelector(selector);
|
||||
const { t } = useTranslation();
|
||||
const isRefinerAvailable = useIsRefinerAvailable();
|
||||
|
||||
if (!isRefinerAvailable) {
|
||||
return (
|
||||
<IAICollapse label={t('sdxl.refiner')} activeLabel={activeLabel}>
|
||||
<Flex sx={{ justifyContent: 'center', p: 2 }}>
|
||||
<Text sx={{ fontSize: 'sm', color: 'base.500', _dark: 'base.700' }}>
|
||||
{t('models.noRefinerModelsInstalled')}
|
||||
</Text>
|
||||
</Flex>
|
||||
</IAICollapse>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<IAICollapse label={t('sdxl.refiner')} activeLabel={activeLabel}>
|
||||
|
@ -13,10 +13,11 @@ def test_memory_snapshot_capture():
|
||||
|
||||
|
||||
snapshots = [
|
||||
MemorySnapshot(process_ram=1.0, vram=2.0, malloc_info=Struct_mallinfo2()),
|
||||
MemorySnapshot(process_ram=1.0, vram=2.0, malloc_info=None),
|
||||
MemorySnapshot(process_ram=1.0, vram=None, malloc_info=Struct_mallinfo2()),
|
||||
MemorySnapshot(process_ram=1.0, vram=None, malloc_info=None),
|
||||
MemorySnapshot(process_ram=1, vram=2, malloc_info=Struct_mallinfo2()),
|
||||
MemorySnapshot(process_ram=1, vram=2, malloc_info=None),
|
||||
MemorySnapshot(process_ram=1, vram=None, malloc_info=Struct_mallinfo2()),
|
||||
MemorySnapshot(process_ram=1, vram=None, malloc_info=None),
|
||||
None,
|
||||
]
|
||||
|
||||
|
||||
@ -26,10 +27,12 @@ def test_get_pretty_snapshot_diff(snapshot_1, snapshot_2):
|
||||
"""Test that get_pretty_snapshot_diff() works with various combinations of missing MemorySnapshot fields."""
|
||||
msg = get_pretty_snapshot_diff(snapshot_1, snapshot_2)
|
||||
|
||||
expected_lines = 1
|
||||
if snapshot_1.vram is not None and snapshot_2.vram is not None:
|
||||
expected_lines = 0
|
||||
if snapshot_1 is not None and snapshot_2 is not None:
|
||||
expected_lines += 1
|
||||
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:
|
||||
expected_lines += 5
|
||||
if snapshot_1.vram is not None and snapshot_2.vram is not None:
|
||||
expected_lines += 1
|
||||
if snapshot_1.malloc_info is not None and snapshot_2.malloc_info is not None:
|
||||
expected_lines += 5
|
||||
|
||||
assert len(msg.splitlines()) == expected_lines
|
||||
|
@ -11,6 +11,7 @@ from invokeai.backend.model_management.model_load_optimizations import _no_op, s
|
||||
(torch.nn.Conv1d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
||||
(torch.nn.Conv2d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
||||
(torch.nn.Conv3d, {"in_channels": 10, "out_channels": 20, "kernel_size": 3}),
|
||||
(torch.nn.Embedding, {"num_embeddings": 10, "embedding_dim": 10}),
|
||||
],
|
||||
)
|
||||
def test_skip_torch_weight_init_linear(torch_module, layer_args):
|
||||
@ -36,12 +37,14 @@ def test_skip_torch_weight_init_linear(torch_module, layer_args):
|
||||
# Check that reset_parameters is skipped while `skip_torch_weight_init()` is active.
|
||||
assert reset_params_fn_during == _no_op
|
||||
assert not torch.allclose(layer_before.weight, layer_during.weight)
|
||||
assert not torch.allclose(layer_before.bias, layer_during.bias)
|
||||
if hasattr(layer_before, "bias"):
|
||||
assert not torch.allclose(layer_before.bias, layer_during.bias)
|
||||
|
||||
# Check that the original behavior is restored after `skip_torch_weight_init()` ends.
|
||||
assert reset_params_fn_before is reset_params_fn_after
|
||||
assert torch.allclose(layer_before.weight, layer_after.weight)
|
||||
assert torch.allclose(layer_before.bias, layer_after.bias)
|
||||
if hasattr(layer_before, "bias"):
|
||||
assert torch.allclose(layer_before.bias, layer_after.bias)
|
||||
|
||||
|
||||
def test_skip_torch_weight_init_restores_base_class_behavior():
|
||||
|
Reference in New Issue
Block a user