Add keep_ram_copy option to CachedModelOnlyFullLoad.

This commit is contained in:
Ryan Dick
2025-01-16 15:08:23 +00:00
parent 04087c38ce
commit c76d08d1fd
4 changed files with 47 additions and 18 deletions

View File

@ -9,12 +9,17 @@ class CachedModelOnlyFullLoad:
MPS memory, etc.
"""
def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int):
def __init__(
self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int, keep_ram_copy: bool = False
):
"""Initialize a CachedModelOnlyFullLoad.
Args:
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU.
compute_device (torch.device): The compute device to move the model to.
total_bytes (int): The total size (in bytes) of all the weights in the model.
keep_ram_copy (bool): Whether to keep a read-only copy of the model's state dict in RAM. Keeping a RAM copy
increases RAM usage, but speeds up model offload from VRAM and LoRA patching (assuming there is
sufficient RAM).
"""
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases.
self._model = model
@ -23,7 +28,7 @@ class CachedModelOnlyFullLoad:
# A CPU read-only copy of the model's state dict.
self._cpu_state_dict: dict[str, torch.Tensor] | None = None
if isinstance(model, torch.nn.Module):
if isinstance(model, torch.nn.Module) and keep_ram_copy:
self._cpu_state_dict = model.state_dict()
self._total_bytes = total_bytes

View File

@ -3,7 +3,11 @@ import torch
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
CachedModelOnlyFullLoad,
)
from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda
from tests.backend.model_manager.load.model_cache.cached_model.utils import (
DummyModule,
parameterize_keep_ram_copy,
parameterize_mps_and_cuda,
)
class NonTorchModel:
@ -17,16 +21,22 @@ class NonTorchModel:
@parameterize_mps_and_cuda
def test_cached_model_total_bytes(device: str):
@parameterize_keep_ram_copy
def test_cached_model_total_bytes(device: str, keep_ram_copy: bool):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
cached_model = CachedModelOnlyFullLoad(
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
)
assert cached_model.total_bytes() == 100
@parameterize_mps_and_cuda
def test_cached_model_is_in_vram(device: str):
@parameterize_keep_ram_copy
def test_cached_model_is_in_vram(device: str, keep_ram_copy: bool):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
cached_model = CachedModelOnlyFullLoad(
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
)
assert not cached_model.is_in_vram()
assert cached_model.cur_vram_bytes() == 0
@ -40,9 +50,12 @@ def test_cached_model_is_in_vram(device: str):
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_unload(device: str):
@parameterize_keep_ram_copy
def test_cached_model_full_load_and_unload(device: str, keep_ram_copy: bool):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
cached_model = CachedModelOnlyFullLoad(
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
)
assert cached_model.full_load_to_vram() == 100
assert cached_model.is_in_vram()
assert all(p.device.type == device for p in cached_model.model.parameters())
@ -55,7 +68,9 @@ def test_cached_model_full_load_and_unload(device: str):
@parameterize_mps_and_cuda
def test_cached_model_get_cpu_state_dict(device: str):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
cached_model = CachedModelOnlyFullLoad(
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=True
)
assert not cached_model.is_in_vram()
# The CPU state dict can be accessed and has the expected properties.
@ -76,9 +91,12 @@ def test_cached_model_get_cpu_state_dict(device: str):
@parameterize_mps_and_cuda
def test_cached_model_full_load_and_inference(device: str):
@parameterize_keep_ram_copy
def test_cached_model_full_load_and_inference(device: str, keep_ram_copy: bool):
model = DummyModule()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
cached_model = CachedModelOnlyFullLoad(
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
)
assert not cached_model.is_in_vram()
# Run inference on the CPU.
@ -99,9 +117,12 @@ def test_cached_model_full_load_and_inference(device: str):
@parameterize_mps_and_cuda
def test_non_torch_model(device: str):
@parameterize_keep_ram_copy
def test_non_torch_model(device: str, keep_ram_copy: bool):
model = NonTorchModel()
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
cached_model = CachedModelOnlyFullLoad(
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
)
assert not cached_model.is_in_vram()
# The model does not have a CPU state dict.

View File

@ -10,7 +10,11 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch
apply_custom_layers_to_model,
)
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda
from tests.backend.model_manager.load.model_cache.cached_model.utils import (
DummyModule,
parameterize_keep_ram_copy,
parameterize_mps_and_cuda,
)
@pytest.fixture
@ -20,9 +24,6 @@ def model():
return model
parameterize_keep_ram_copy = pytest.mark.parametrize("keep_ram_copy", [True, False])
@parameterize_mps_and_cuda
@parameterize_keep_ram_copy
def test_cached_model_total_bytes(device: str, model: DummyModule, keep_ram_copy: bool):

View File

@ -29,3 +29,5 @@ parameterize_mps_and_cuda = pytest.mark.parametrize(
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
],
)
parameterize_keep_ram_copy = pytest.mark.parametrize("keep_ram_copy", [True, False])