Update calc_model_size_by_data(...) to handle all expected model types, and to log an error if an unexpected model type is received.

This commit is contained in:
Ryan Dick 2024-07-02 21:14:12 -04:00
parent 0fe92cd406
commit 414750a45d
4 changed files with 40 additions and 10 deletions

View File

@ -136,11 +136,11 @@ class IPAdapter(RawModel):
self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
def calc_size(self):
# workaround for circular import
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
def calc_size(self) -> int:
# HACK(ryand): Fix this issue with circular imports.
from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
return calc_module_size(self._image_proj_model) + calc_module_size(self.attn_weights)
def _init_image_proj_model(
self, state_dict: dict[str, torch.Tensor]

View File

@ -160,7 +160,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
key = self._make_cache_key(key, submodel_type)
if key in self._cached_models:
return
size = calc_model_size_by_data(model)
size = calc_model_size_by_data(self.logger, model)
self.make_room(size)
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None

View File

@ -2,25 +2,46 @@
"""Various utility functions needed by the loader and caching system."""
import json
import logging
from pathlib import Path
from typing import Optional
import torch
from diffusers import DiffusionPipeline
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw
def calc_model_size_by_data(model: AnyModel) -> int:
def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
"""Get size of a model in memory in bytes."""
# TODO(ryand): We should create a CacheableModel interface for all models, and move the size calculations down to
# the models themselves.
if isinstance(model, DiffusionPipeline):
return _calc_pipeline_by_data(model)
elif isinstance(model, torch.nn.Module):
return _calc_model_by_data(model)
return calc_module_size(model)
elif isinstance(model, IAIOnnxRuntimeModel):
return _calc_onnx_model_by_data(model)
elif isinstance(model, SchedulerMixin):
return 0
elif isinstance(model, CLIPTokenizer):
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
return 0
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw)):
return model.calc_size()
else:
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
# supported model types.
logger.error(
f"Failed to calculate model size for unexpected model type: {type(model)}. The model will be treated as "
"having size 0."
)
return 0
@ -30,11 +51,12 @@ def _calc_pipeline_by_data(pipeline: DiffusionPipeline) -> int:
for submodel_key in pipeline.components.keys():
submodel = getattr(pipeline, submodel_key)
if submodel is not None and isinstance(submodel, torch.nn.Module):
res += _calc_model_by_data(submodel)
res += calc_module_size(submodel)
return res
def _calc_model_by_data(model: torch.nn.Module) -> int:
def calc_module_size(model: torch.nn.Module) -> int:
"""Calculate the size (in bytes) of a torch.nn.Module."""
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
mem: int = mem_params + mem_bufs # in bytes

View File

@ -77,6 +77,14 @@ class TextualInversionModelRaw(RawModel):
if emb is not None:
emb.to(device=device, dtype=dtype, non_blocking=non_blocking)
def calc_size(self) -> int:
"""Get the size of this model in bytes."""
embedding_size = self.embedding.element_size() * self.embedding.nelement()
embedding_2_size = 0
if self.embedding_2 is not None:
embedding_2_size = self.embedding_2.element_size() * self.embedding_2.nelement()
return embedding_size + embedding_2_size
class TextualInversionManager(BaseTextualInversionManager):
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""