mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
Add keep_ram_copy option to CachedModelOnlyFullLoad.
This commit is contained in:
@ -9,12 +9,17 @@ class CachedModelOnlyFullLoad:
|
||||
MPS memory, etc.
|
||||
"""
|
||||
|
||||
def __init__(self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int):
|
||||
def __init__(
|
||||
self, model: torch.nn.Module | Any, compute_device: torch.device, total_bytes: int, keep_ram_copy: bool = False
|
||||
):
|
||||
"""Initialize a CachedModelOnlyFullLoad.
|
||||
Args:
|
||||
model (torch.nn.Module | Any): The model to wrap. Should be on the CPU.
|
||||
compute_device (torch.device): The compute device to move the model to.
|
||||
total_bytes (int): The total size (in bytes) of all the weights in the model.
|
||||
keep_ram_copy (bool): Whether to keep a read-only copy of the model's state dict in RAM. Keeping a RAM copy
|
||||
increases RAM usage, but speeds up model offload from VRAM and LoRA patching (assuming there is
|
||||
sufficient RAM).
|
||||
"""
|
||||
# model is often a torch.nn.Module, but could be any model type. Throughout this class, we handle both cases.
|
||||
self._model = model
|
||||
@ -23,7 +28,7 @@ class CachedModelOnlyFullLoad:
|
||||
|
||||
# A CPU read-only copy of the model's state dict.
|
||||
self._cpu_state_dict: dict[str, torch.Tensor] | None = None
|
||||
if isinstance(model, torch.nn.Module):
|
||||
if isinstance(model, torch.nn.Module) and keep_ram_copy:
|
||||
self._cpu_state_dict = model.state_dict()
|
||||
|
||||
self._total_bytes = total_bytes
|
||||
|
@ -3,7 +3,11 @@ import torch
|
||||
from invokeai.backend.model_manager.load.model_cache.cached_model.cached_model_only_full_load import (
|
||||
CachedModelOnlyFullLoad,
|
||||
)
|
||||
from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda
|
||||
from tests.backend.model_manager.load.model_cache.cached_model.utils import (
|
||||
DummyModule,
|
||||
parameterize_keep_ram_copy,
|
||||
parameterize_mps_and_cuda,
|
||||
)
|
||||
|
||||
|
||||
class NonTorchModel:
|
||||
@ -17,16 +21,22 @@ class NonTorchModel:
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_total_bytes(device: str):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_total_bytes(device: str, keep_ram_copy: bool):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
cached_model = CachedModelOnlyFullLoad(
|
||||
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
assert cached_model.total_bytes() == 100
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_is_in_vram(device: str):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_is_in_vram(device: str, keep_ram_copy: bool):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
cached_model = CachedModelOnlyFullLoad(
|
||||
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
assert not cached_model.is_in_vram()
|
||||
assert cached_model.cur_vram_bytes() == 0
|
||||
|
||||
@ -40,9 +50,12 @@ def test_cached_model_is_in_vram(device: str):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_unload(device: str):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_full_load_and_unload(device: str, keep_ram_copy: bool):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
cached_model = CachedModelOnlyFullLoad(
|
||||
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
assert cached_model.full_load_to_vram() == 100
|
||||
assert cached_model.is_in_vram()
|
||||
assert all(p.device.type == device for p in cached_model.model.parameters())
|
||||
@ -55,7 +68,9 @@ def test_cached_model_full_load_and_unload(device: str):
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_get_cpu_state_dict(device: str):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
cached_model = CachedModelOnlyFullLoad(
|
||||
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=True
|
||||
)
|
||||
assert not cached_model.is_in_vram()
|
||||
|
||||
# The CPU state dict can be accessed and has the expected properties.
|
||||
@ -76,9 +91,12 @@ def test_cached_model_get_cpu_state_dict(device: str):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_cached_model_full_load_and_inference(device: str):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_full_load_and_inference(device: str, keep_ram_copy: bool):
|
||||
model = DummyModule()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
cached_model = CachedModelOnlyFullLoad(
|
||||
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
assert not cached_model.is_in_vram()
|
||||
|
||||
# Run inference on the CPU.
|
||||
@ -99,9 +117,12 @@ def test_cached_model_full_load_and_inference(device: str):
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
def test_non_torch_model(device: str):
|
||||
@parameterize_keep_ram_copy
|
||||
def test_non_torch_model(device: str, keep_ram_copy: bool):
|
||||
model = NonTorchModel()
|
||||
cached_model = CachedModelOnlyFullLoad(model=model, compute_device=torch.device(device), total_bytes=100)
|
||||
cached_model = CachedModelOnlyFullLoad(
|
||||
model=model, compute_device=torch.device(device), total_bytes=100, keep_ram_copy=keep_ram_copy
|
||||
)
|
||||
assert not cached_model.is_in_vram()
|
||||
|
||||
# The model does not have a CPU state dict.
|
||||
|
@ -10,7 +10,11 @@ from invokeai.backend.model_manager.load.model_cache.torch_module_autocast.torch
|
||||
apply_custom_layers_to_model,
|
||||
)
|
||||
from invokeai.backend.util.calc_tensor_size import calc_tensor_size
|
||||
from tests.backend.model_manager.load.model_cache.cached_model.utils import DummyModule, parameterize_mps_and_cuda
|
||||
from tests.backend.model_manager.load.model_cache.cached_model.utils import (
|
||||
DummyModule,
|
||||
parameterize_keep_ram_copy,
|
||||
parameterize_mps_and_cuda,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -20,9 +24,6 @@ def model():
|
||||
return model
|
||||
|
||||
|
||||
parameterize_keep_ram_copy = pytest.mark.parametrize("keep_ram_copy", [True, False])
|
||||
|
||||
|
||||
@parameterize_mps_and_cuda
|
||||
@parameterize_keep_ram_copy
|
||||
def test_cached_model_total_bytes(device: str, model: DummyModule, keep_ram_copy: bool):
|
||||
|
@ -29,3 +29,5 @@ parameterize_mps_and_cuda = pytest.mark.parametrize(
|
||||
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available.")),
|
||||
],
|
||||
)
|
||||
|
||||
parameterize_keep_ram_copy = pytest.mark.parametrize("keep_ram_copy", [True, False])
|
||||
|
Reference in New Issue
Block a user