mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add memory usage calculations for controlnet, scheduler, tokenizer and upscaler
This commit is contained in:
parent
a40fa8e83b
commit
ba8f06c285
@ -293,7 +293,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
new_dict[k] = v.to(target_device, copy=True)
|
new_dict[k] = v.to(target_device, copy=True)
|
||||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||||
try:
|
try:
|
||||||
cache_entry.model.to(target_device, non_blocking=TorchDevice.get_non_blocking(target_device))
|
cache_entry.model.to(target_device)
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
if "got an unexpected keyword argument 'non_blocking'" in str(e):
|
if "got an unexpected keyword argument 'non_blocking'" in str(e):
|
||||||
cache_entry.model.to(target_device)
|
cache_entry.model.to(target_device)
|
||||||
|
@ -3,8 +3,9 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||||
@ -30,12 +31,14 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
|||||||
elif isinstance(model, IAIOnnxRuntimeModel):
|
elif isinstance(model, IAIOnnxRuntimeModel):
|
||||||
return _calc_onnx_model_by_data(model)
|
return _calc_onnx_model_by_data(model)
|
||||||
elif isinstance(model, SchedulerMixin):
|
elif isinstance(model, SchedulerMixin):
|
||||||
return 0
|
assert hasattr(model, "config") # size is dominated by config
|
||||||
|
return sys.getsizeof(model.config)
|
||||||
elif isinstance(model, CLIPTokenizer):
|
elif isinstance(model, CLIPTokenizer):
|
||||||
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
|
return sys.getsizeof(model.get_vocab()) # size is dominated by the vocab dict
|
||||||
return 0
|
|
||||||
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)):
|
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw, SpandrelImageToImageModel)):
|
||||||
return model.calc_size()
|
return model.calc_size()
|
||||||
|
elif isinstance(model, dict):
|
||||||
|
return _calc_size_from_dict(model, logger)
|
||||||
else:
|
else:
|
||||||
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
||||||
# supported model types.
|
# supported model types.
|
||||||
@ -70,6 +73,19 @@ def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int:
|
|||||||
return mem
|
return mem
|
||||||
|
|
||||||
|
|
||||||
|
def _calc_size_from_dict(model: dict[str, Any] | torch.Tensor | torch.nn.Module, logger: logging.Logger) -> int:
|
||||||
|
total = sys.getsizeof(model) # get python overhead for object
|
||||||
|
if isinstance(model, dict):
|
||||||
|
total += sum(_calc_size_from_dict(model[x], logger) for x in model.keys())
|
||||||
|
elif isinstance(model, torch.Tensor):
|
||||||
|
total += model.element_size() * model.nelement()
|
||||||
|
elif isinstance(model, torch.nn.Module):
|
||||||
|
total += calc_module_size(model)
|
||||||
|
else:
|
||||||
|
logger.warning(f"Failed to calculate model size for unexpected model type: {type(model)}.")
|
||||||
|
return total
|
||||||
|
|
||||||
|
|
||||||
def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int:
|
def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int:
|
||||||
"""Estimate the size of a model on disk in bytes."""
|
"""Estimate the size of a model on disk in bytes."""
|
||||||
if model_path.is_file():
|
if model_path.is_file():
|
||||||
|
@ -88,6 +88,7 @@ def test_download_diffusers_subfolder(mock_context: InvocationContext) -> None:
|
|||||||
model_path / "diffusion_pytorch_model.safetensors"
|
model_path / "diffusion_pytorch_model.safetensors"
|
||||||
).exists()
|
).exists()
|
||||||
|
|
||||||
|
|
||||||
def test_download_diffusers_preserve_subfolders(mock_context: InvocationContext) -> None:
|
def test_download_diffusers_preserve_subfolders(mock_context: InvocationContext) -> None:
|
||||||
model_path = mock_context.models.download_and_cache_model(
|
model_path = mock_context.models.download_and_cache_model(
|
||||||
"stabilityai/sdxl-turbo::/vae",
|
"stabilityai/sdxl-turbo::/vae",
|
||||||
@ -98,4 +99,3 @@ def test_download_diffusers_preserve_subfolders(mock_context: InvocationContext)
|
|||||||
assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or (
|
assert (model_path / "diffusion_pytorch_model.fp16.safetensors").exists() or (
|
||||||
model_path / "diffusion_pytorch_model.safetensors"
|
model_path / "diffusion_pytorch_model.safetensors"
|
||||||
).exists()
|
).exists()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user