mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update calc_model_size_by_data(...) to handle all expected model types, and to log an error if an unexpected model type is received.
This commit is contained in:
@ -160,7 +160,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
return
|
||||
size = calc_model_size_by_data(model)
|
||||
size = calc_model_size_by_data(self.logger, model)
|
||||
self.make_room(size)
|
||||
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
||||
|
@ -2,25 +2,46 @@
|
||||
"""Various utility functions needed by the loader and caching system."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
|
||||
def calc_model_size_by_data(model: AnyModel) -> int:
|
||||
def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
"""Get size of a model in memory in bytes."""
|
||||
# TODO(ryand): We should create a CacheableModel interface for all models, and move the size calculations down to
|
||||
# the models themselves.
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
return _calc_pipeline_by_data(model)
|
||||
elif isinstance(model, torch.nn.Module):
|
||||
return _calc_model_by_data(model)
|
||||
return calc_module_size(model)
|
||||
elif isinstance(model, IAIOnnxRuntimeModel):
|
||||
return _calc_onnx_model_by_data(model)
|
||||
elif isinstance(model, SchedulerMixin):
|
||||
return 0
|
||||
elif isinstance(model, CLIPTokenizer):
|
||||
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
|
||||
return 0
|
||||
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw)):
|
||||
return model.calc_size()
|
||||
else:
|
||||
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
||||
# supported model types.
|
||||
logger.error(
|
||||
f"Failed to calculate model size for unexpected model type: {type(model)}. The model will be treated as "
|
||||
"having size 0."
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
@ -30,11 +51,12 @@ def _calc_pipeline_by_data(pipeline: DiffusionPipeline) -> int:
|
||||
for submodel_key in pipeline.components.keys():
|
||||
submodel = getattr(pipeline, submodel_key)
|
||||
if submodel is not None and isinstance(submodel, torch.nn.Module):
|
||||
res += _calc_model_by_data(submodel)
|
||||
res += calc_module_size(submodel)
|
||||
return res
|
||||
|
||||
|
||||
def _calc_model_by_data(model: torch.nn.Module) -> int:
|
||||
def calc_module_size(model: torch.nn.Module) -> int:
|
||||
"""Calculate the size (in bytes) of a torch.nn.Module."""
|
||||
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
|
||||
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
|
||||
mem: int = mem_params + mem_bufs # in bytes
|
||||
|
Reference in New Issue
Block a user