Update calc_model_size_by_data(...) to handle all expected model types, and to log an error if an unexpected model type is received.

This commit is contained in:
Ryan Dick 2024-07-02 21:14:12 -04:00
parent 0fe92cd406
commit 414750a45d
4 changed files with 40 additions and 10 deletions

View File

@ -136,11 +136,11 @@ class IPAdapter(RawModel):
self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking) self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking) self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
def calc_size(self): def calc_size(self) -> int:
# workaround for circular import # HACK(ryand): Fix this issue with circular imports.
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data from invokeai.backend.model_manager.load.model_util import calc_module_size
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights) return calc_module_size(self._image_proj_model) + calc_module_size(self.attn_weights)
def _init_image_proj_model( def _init_image_proj_model(
self, state_dict: dict[str, torch.Tensor] self, state_dict: dict[str, torch.Tensor]

View File

@ -160,7 +160,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
key = self._make_cache_key(key, submodel_type) key = self._make_cache_key(key, submodel_type)
if key in self._cached_models: if key in self._cached_models:
return return
size = calc_model_size_by_data(model) size = calc_model_size_by_data(self.logger, model)
self.make_room(size) self.make_room(size)
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None

View File

@ -2,25 +2,46 @@
"""Various utility functions needed by the loader and caching system.""" """Various utility functions needed by the loader and caching system."""
import json import json
import logging
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import torch import torch
from diffusers import DiffusionPipeline from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager.config import AnyModel from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw
def calc_model_size_by_data(model: AnyModel) -> int: def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
"""Get size of a model in memory in bytes.""" """Get size of a model in memory in bytes."""
# TODO(ryand): We should create a CacheableModel interface for all models, and move the size calculations down to
# the models themselves.
if isinstance(model, DiffusionPipeline): if isinstance(model, DiffusionPipeline):
return _calc_pipeline_by_data(model) return _calc_pipeline_by_data(model)
elif isinstance(model, torch.nn.Module): elif isinstance(model, torch.nn.Module):
return _calc_model_by_data(model) return calc_module_size(model)
elif isinstance(model, IAIOnnxRuntimeModel): elif isinstance(model, IAIOnnxRuntimeModel):
return _calc_onnx_model_by_data(model) return _calc_onnx_model_by_data(model)
elif isinstance(model, SchedulerMixin):
return 0
elif isinstance(model, CLIPTokenizer):
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
return 0
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw)):
return model.calc_size()
else: else:
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
# supported model types.
logger.error(
f"Failed to calculate model size for unexpected model type: {type(model)}. The model will be treated as "
"having size 0."
)
return 0 return 0
@ -30,11 +51,12 @@ def _calc_pipeline_by_data(pipeline: DiffusionPipeline) -> int:
for submodel_key in pipeline.components.keys(): for submodel_key in pipeline.components.keys():
submodel = getattr(pipeline, submodel_key) submodel = getattr(pipeline, submodel_key)
if submodel is not None and isinstance(submodel, torch.nn.Module): if submodel is not None and isinstance(submodel, torch.nn.Module):
res += _calc_model_by_data(submodel) res += calc_module_size(submodel)
return res return res
def _calc_model_by_data(model: torch.nn.Module) -> int: def calc_module_size(model: torch.nn.Module) -> int:
"""Calculate the size (in bytes) of a torch.nn.Module."""
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()]) mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()]) mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
mem: int = mem_params + mem_bufs # in bytes mem: int = mem_params + mem_bufs # in bytes

View File

@ -77,6 +77,14 @@ class TextualInversionModelRaw(RawModel):
if emb is not None: if emb is not None:
emb.to(device=device, dtype=dtype, non_blocking=non_blocking) emb.to(device=device, dtype=dtype, non_blocking=non_blocking)
def calc_size(self) -> int:
"""Get the size of this model in bytes."""
embedding_size = self.embedding.element_size() * self.embedding.nelement()
embedding_2_size = 0
if self.embedding_2 is not None:
embedding_2_size = self.embedding_2.element_size() * self.embedding_2.nelement()
return embedding_size + embedding_2_size
class TextualInversionManager(BaseTextualInversionManager): class TextualInversionManager(BaseTextualInversionManager):
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library.""" """TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""