mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
Add CachedModelWithPartialLoad to manage partially-loaded models using the new autocast modules.
This commit is contained in:
@ -0,0 +1,183 @@
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch_module_autocast import (
|
||||
AUTOCAST_MODULE_TYPE_MAPPING,
|
||||
apply_custom_layers_to_model,
|
||||
remove_custom_layers_from_model,
|
||||
)
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
|
||||
def set_nested_attr(obj: object, attr: str, value: object):
|
||||
"""A helper function that extends setattr() to support nested attributes.
|
||||
|
||||
Example:
|
||||
set_nested_attr(model, "module.encoder.conv1.weight", new_conv1_weight)
|
||||
"""
|
||||
attrs = attr.split(".")
|
||||
for attr in attrs[:-1]:
|
||||
obj = getattr(obj, attr)
|
||||
setattr(obj, attrs[-1], value)
|
||||
|
||||
|
||||
class CachedModelWithPartialLoad:
|
||||
"""A wrapper around a PyTorch model to handle partial loads and unloads between the CPU and the compute device.
|
||||
|
||||
Note: "VRAM" is used throughout this class to refer to the memory on the compute device. It could be CUDA memory,
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module, compute_device: torch.device):
|
||||
self._model = model
|
||||
self._compute_device = compute_device
|
||||
|
||||
# A CPU read-only copy of the model's state dict.
|
||||
self._cpu_state_dict: dict[str, torch.Tensor] = model.state_dict()
|
||||
|
||||
# TODO(ryand): Handle the case where the model sizes changes after initial load (e.g. due to dtype casting).
|
||||
# Consider how we should handle this for both self._total_bytes and self._cur_vram_bytes.
|
||||
self._total_bytes = sum(calc_tensor_size(p) for p in self._cpu_state_dict.values())
|
||||
self._cur_vram_bytes: int | None = None
|
||||
|
||||
self._modules_that_support_autocast = self._find_modules_that_support_autocast()
|
||||
self._keys_in_modules_that_do_not_support_autocast = self._find_keys_in_modules_that_do_not_support_autocast()
|
||||
|
||||
def _find_modules_that_support_autocast(self) -> dict[str, torch.nn.Module]:
|
||||
"""Find all modules that support autocasting."""
|
||||
return {n: m for n, m in self._model.named_modules() if type(m) in AUTOCAST_MODULE_TYPE_MAPPING}
|
||||
|
||||
def _find_keys_in_modules_that_do_not_support_autocast(self) -> set[str]:
|
||||
keys_in_modules_that_do_not_support_autocast = set()
|
||||
for key in self._cpu_state_dict.keys():
|
||||
for module_name in self._modules_that_support_autocast.keys():
|
||||
if key.startswith(module_name):
|
||||
break
|
||||
else:
|
||||
keys_in_modules_that_do_not_support_autocast.add(key)
|
||||
return keys_in_modules_that_do_not_support_autocast
|
||||
|
||||
@property
|
||||
def model(self) -> torch.nn.Module:
|
||||
return self._model
|
||||
|
||||
def get_cpu_state_dict(self) -> dict[str, torch.Tensor] | None:
|
||||
"""Get a read-only copy of the model's state dict in RAM."""
|
||||
# TODO(ryand): Document this better.
|
||||
return self._cpu_state_dict
|
||||
|
||||
def total_bytes(self) -> int:
|
||||
"""Get the total size (in bytes) of all the weights in the model."""
|
||||
return self._total_bytes
|
||||
|
||||
def cur_vram_bytes(self) -> int:
|
||||
"""Get the size (in bytes) of the weights that are currently in VRAM."""
|
||||
if self._cur_vram_bytes is None:
|
||||
cur_state_dict = self._model.state_dict()
|
||||
self._cur_vram_bytes = sum(
|
||||
calc_tensor_size(p) for p in cur_state_dict.values() if p.device.type == self._compute_device.type
|
||||
)
|
||||
return self._cur_vram_bytes
|
||||
|
||||
def full_load_to_vram(self) -> int:
|
||||
"""Load all weights into VRAM."""
|
||||
return self.partial_load_to_vram(self.total_bytes())
|
||||
|
||||
def full_unload_from_vram(self) -> int:
|
||||
"""Unload all weights from VRAM."""
|
||||
return self.partial_unload_from_vram(self.total_bytes())
|
||||
|
||||
@torch.no_grad()
|
||||
def partial_load_to_vram(self, vram_bytes_to_load: int) -> int:
|
||||
"""Load more weights into VRAM without exceeding vram_bytes_to_load.
|
||||
|
||||
Returns:
|
||||
The number of bytes loaded into VRAM.
|
||||
"""
|
||||
# TODO(ryand): Handle the case where an exception is thrown while loading or unloading weights. At the very
|
||||
# least, we should reset self._cur_vram_bytes to None.
|
||||
|
||||
vram_bytes_loaded = 0
|
||||
|
||||
cur_state_dict = self._model.state_dict()
|
||||
|
||||
# First, process the keys *must* be loaded into VRAM.
|
||||
for key in self._keys_in_modules_that_do_not_support_autocast:
|
||||
param = cur_state_dict[key]
|
||||
if param.device.type == self._compute_device.type:
|
||||
continue
|
||||
|
||||
param_size = calc_tensor_size(param)
|
||||
cur_state_dict[key] = param.to(self._compute_device, copy=True)
|
||||
vram_bytes_loaded += param_size
|
||||
|
||||
if vram_bytes_loaded > vram_bytes_to_load:
|
||||
logger = InvokeAILogger.get_logger()
|
||||
logger.warning(
|
||||
f"Loaded {vram_bytes_loaded / 2**20} MB into VRAM, but only {vram_bytes_to_load / 2**20} MB were "
|
||||
"requested. This is the minimum set of weights in VRAM required to run the model."
|
||||
)
|
||||
|
||||
# Next, process the keys that can optionally be loaded into VRAM.
|
||||
fully_loaded = True
|
||||
for key, param in cur_state_dict.items():
|
||||
if param.device.type == self._compute_device.type:
|
||||
continue
|
||||
|
||||
param_size = calc_tensor_size(param)
|
||||
if vram_bytes_loaded + param_size > vram_bytes_to_load:
|
||||
# TODO(ryand): Should we just break here? If we couldn't fit this parameter into VRAM, is it really
|
||||
# worth continuing to search for a smaller parameter that would fit?
|
||||
fully_loaded = False
|
||||
continue
|
||||
|
||||
cur_state_dict[key] = param.to(self._compute_device, copy=True)
|
||||
vram_bytes_loaded += param_size
|
||||
|
||||
if vram_bytes_loaded > 0:
|
||||
# We load the entire state dict, not just the parameters that changed, in case there are modules that
|
||||
# override _load_from_state_dict() and do some funky stuff that requires the entire state dict.
|
||||
# Alternatively, in the future, grouping parameters by module could probably solve this problem.
|
||||
self._model.load_state_dict(cur_state_dict, assign=True)
|
||||
|
||||
if self._cur_vram_bytes is not None:
|
||||
self._cur_vram_bytes += vram_bytes_loaded
|
||||
|
||||
if fully_loaded:
|
||||
remove_custom_layers_from_model(self._model)
|
||||
# TODO(ryand): Warn if the self.cur_vram_bytes() and self.total_bytes() are out of sync.
|
||||
else:
|
||||
apply_custom_layers_to_model(self._model)
|
||||
|
||||
# TODO(ryand): Handle non-persistent buffers.
|
||||
return vram_bytes_loaded
|
||||
|
||||
@torch.no_grad()
|
||||
def partial_unload_from_vram(self, vram_bytes_to_free: int) -> int:
|
||||
"""Unload weights from VRAM until vram_bytes_to_free bytes are freed. Or the entire model is unloaded.
|
||||
|
||||
Returns:
|
||||
The number of bytes unloaded from VRAM.
|
||||
"""
|
||||
vram_bytes_freed = 0
|
||||
|
||||
offload_device = "cpu"
|
||||
cur_state_dict = self._model.state_dict()
|
||||
for key, param in cur_state_dict.items():
|
||||
if vram_bytes_freed >= vram_bytes_to_free:
|
||||
break
|
||||
|
||||
if param.device.type == offload_device:
|
||||
continue
|
||||
|
||||
cur_state_dict[key] = self._cpu_state_dict[key]
|
||||
vram_bytes_freed += calc_tensor_size(param)
|
||||
|
||||
if vram_bytes_freed > 0:
|
||||
self._model.load_state_dict(cur_state_dict, assign=True)
|
||||
|
||||
if self._cur_vram_bytes is not None:
|
||||
self._cur_vram_bytes -= vram_bytes_freed
|
||||
|
||||
apply_custom_layers_to_model(self._model)
|
||||
return vram_bytes_freed
|
@ -0,0 +1,295 @@
|
||||
import itertools
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_with_partial_load import (
|
||||
CachedModelWithPartialLoad,
|
||||
)
|
||||
from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.autocast_modules import CustomLinear
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
|
||||
|
||||
class DummyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear1 = torch.nn.Linear(10, 10)
|
||||
self.linear2 = torch.nn.Linear(10, 10)
|
||||
self.register_buffer("buffer1", torch.ones(10, 10))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.linear1(x)
|
||||
x = self.linear2(x)
|
||||
return x
|
||||
|
||||
|
||||
parameterize_mps_and_cuda = pytest.mark.parametrize(
|
||||
("device"),
|
||||
[
|
||||
pytest.param(
|
||||
"mps", marks=pytest.mark.skipif(not torch.backends.mps.is_available(), reason="MPS is not available.")
|
||||
),
|
||||
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_total_bytes(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
linear_numel = 10 * 10 + 10
|
||||
buffer_numel = 10 * 10
|
||||
assert cached_model.total_bytes() == (2 * linear_numel + buffer_numel) * 4
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_cur_vram_bytes(device: str):
|
||||
model = DummyModule()
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Full load the model into VRAM.
|
||||
cached_model.full_load_to_vram()
|
||||
assert cached_model.cur_vram_bytes() > 0
|
||||
assert cached_model.cur_vram_bytes() == cached_model.total_bytes()
|
||||
assert all(p.device.type == device for p in model.parameters())
|
||||
assert all(p.device.type == device for p in model.buffers())
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_load(device: str):
|
||||
model = DummyModule()
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Partially load the model into VRAM.
|
||||
target_vram_bytes = int(model_total_bytes * 0.6)
|
||||
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
|
||||
|
||||
# Check that the model is partially loaded into VRAM.
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes < model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
assert loaded_bytes == sum(
|
||||
calc_tensor_size(p) for p in itertools.chain(model.parameters(), model.buffers()) if p.device.type == device
|
||||
)
|
||||
|
||||
# Check that the model's modules have been patched with CustomLinear layers.
|
||||
assert type(model.linear1) is CustomLinear
|
||||
assert type(model.linear2) is CustomLinear
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_unload(device: str):
|
||||
model = DummyModule()
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Full load the model into VRAM.
|
||||
cached_model.full_load_to_vram()
|
||||
assert cached_model.cur_vram_bytes() == model_total_bytes
|
||||
|
||||
# Partially unload the model from VRAM.
|
||||
bytes_to_free = int(model_total_bytes * 0.4)
|
||||
freed_bytes = cached_model.partial_unload_from_vram(bytes_to_free)
|
||||
|
||||
# Check that the model is partially unloaded from VRAM.
|
||||
assert freed_bytes >= bytes_to_free
|
||||
assert freed_bytes < model_total_bytes
|
||||
assert freed_bytes == model_total_bytes - cached_model.cur_vram_bytes()
|
||||
assert freed_bytes == sum(
|
||||
calc_tensor_size(p) for p in itertools.chain(model.parameters(), model.buffers()) if p.device.type == "cpu"
|
||||
)
|
||||
|
||||
# Check that the model's modules are still patched with CustomLinear layers.
|
||||
assert type(model.linear1) is CustomLinear
|
||||
assert type(model.linear2) is CustomLinear
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_unload(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Full load the model into VRAM.
|
||||
loaded_bytes = cached_model.full_load_to_vram()
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes == model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers()))
|
||||
assert type(model.linear1) is torch.nn.Linear
|
||||
assert type(model.linear2) is torch.nn.Linear
|
||||
|
||||
# Full unload the model from VRAM.
|
||||
unloaded_bytes = cached_model.full_unload_from_vram()
|
||||
|
||||
# Check that the model is fully unloaded from VRAM.
|
||||
assert unloaded_bytes > 0
|
||||
assert unloaded_bytes == model_total_bytes
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
assert all(p.device.type == "cpu" for p in itertools.chain(model.parameters(), model.buffers()))
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_from_partial(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Partially load the model into VRAM.
|
||||
target_vram_bytes = int(model_total_bytes * 0.6)
|
||||
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes < model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
assert type(model.linear1) is CustomLinear
|
||||
assert type(model.linear2) is CustomLinear
|
||||
|
||||
# Full load the rest of the model into VRAM.
|
||||
loaded_bytes_2 = cached_model.full_load_to_vram()
|
||||
assert loaded_bytes_2 > 0
|
||||
assert loaded_bytes_2 < model_total_bytes
|
||||
assert loaded_bytes + loaded_bytes_2 == cached_model.cur_vram_bytes()
|
||||
assert loaded_bytes + loaded_bytes_2 == model_total_bytes
|
||||
assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers()))
|
||||
assert type(model.linear1) is torch.nn.Linear
|
||||
assert type(model.linear2) is torch.nn.Linear
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_unload_from_partial(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Partially load the model into VRAM.
|
||||
target_vram_bytes = int(model_total_bytes * 0.6)
|
||||
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes < model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
|
||||
# Full unload the model from VRAM.
|
||||
unloaded_bytes = cached_model.full_unload_from_vram()
|
||||
assert unloaded_bytes > 0
|
||||
assert unloaded_bytes == loaded_bytes
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
assert all(p.device.type == "cpu" for p in itertools.chain(model.parameters(), model.buffers()))
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_get_cpu_state_dict(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
|
||||
# Model starts in CPU memory.
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# The CPU state dict can be accessed and has the expected properties.
|
||||
cpu_state_dict = cached_model.get_cpu_state_dict()
|
||||
assert cpu_state_dict is not None
|
||||
assert len(cpu_state_dict) == len(model.state_dict())
|
||||
assert all(p.device.type == "cpu" for p in cpu_state_dict.values())
|
||||
|
||||
# Full load the model into VRAM.
|
||||
cached_model.full_load_to_vram()
|
||||
assert cached_model.cur_vram_bytes() == cached_model.total_bytes()
|
||||
|
||||
# The CPU state dict is still available, and still on the CPU.
|
||||
cpu_state_dict = cached_model.get_cpu_state_dict()
|
||||
assert cpu_state_dict is not None
|
||||
assert len(cpu_state_dict) == len(model.state_dict())
|
||||
assert all(p.device.type == "cpu" for p in cpu_state_dict.values())
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_inference(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
# Model starts in CPU memory.
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Run inference on the CPU.
|
||||
x = model(torch.randn(1, 10))
|
||||
output1 = model(x)
|
||||
assert output1.device.type == "cpu"
|
||||
|
||||
# Full load the model into VRAM.
|
||||
loaded_bytes = cached_model.full_load_to_vram()
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes == model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
assert all(p.device.type == device for p in itertools.chain(model.parameters(), model.buffers()))
|
||||
|
||||
# Run inference on the GPU.
|
||||
output2 = model(x.to(device))
|
||||
assert output2.device.type == device
|
||||
|
||||
# Full unload the model from VRAM.
|
||||
unloaded_bytes = cached_model.full_unload_from_vram()
|
||||
assert unloaded_bytes > 0
|
||||
assert unloaded_bytes == model_total_bytes
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
assert all(p.device.type == "cpu" for p in itertools.chain(model.parameters(), model.buffers()))
|
||||
|
||||
# Run inference on the CPU again.
|
||||
output3 = model(x)
|
||||
assert output3.device.type == "cpu"
|
||||
|
||||
# The outputs should be the same for all three runs.
|
||||
assert torch.allclose(output1, output2.to("cpu"))
|
||||
assert torch.allclose(output1, output3)
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_partial_load_and_inference(device: str):
|
||||
model = DummyModule()
|
||||
# Model starts in CPU memory.
|
||||
cached_model = CachedModelWithPartialLoad(model=model, compute_device=torch.device(device))
|
||||
model_total_bytes = cached_model.total_bytes()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
# Run inference on the CPU.
|
||||
x = model(torch.randn(1, 10))
|
||||
output1 = model(x)
|
||||
assert output1.device.type == "cpu"
|
||||
|
||||
# Partially load the model into VRAM.
|
||||
target_vram_bytes = int(model_total_bytes * 0.6)
|
||||
loaded_bytes = cached_model.partial_load_to_vram(target_vram_bytes)
|
||||
|
||||
# Check that the model is partially loaded into VRAM.
|
||||
assert loaded_bytes > 0
|
||||
assert loaded_bytes < model_total_bytes
|
||||
assert loaded_bytes == cached_model.cur_vram_bytes()
|
||||
assert loaded_bytes == sum(
|
||||
calc_tensor_size(p) for p in itertools.chain(model.parameters(), model.buffers()) if p.device.type == device
|
||||
)
|
||||
|
||||
# Check that the model's modules have been patched with CustomLinear layers.
|
||||
assert type(model.linear1) is CustomLinear
|
||||
assert type(model.linear2) is CustomLinear
|
||||
|
||||
# Run inference on the GPU.
|
||||
output2 = model(x.to(device))
|
||||
assert output2.device.type == device
|
||||
|
||||
# The output should be the same as the output from the CPU.
|
||||
assert torch.allclose(output1, output2.to("cpu"))
|
Reference in New Issue
Block a user