InvokeAI/invokeai/backend/model_manager/load/model_util.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

159 lines
5.7 KiB
Python
Raw Normal View History

2024-02-01 04:37:59 +00:00
# Copyright (c) 2024 The InvokeAI Development Team
"""Various utility functions needed by the loader and caching system."""
import json
import logging
2024-02-01 04:37:59 +00:00
from pathlib import Path
from typing import Optional
2024-02-01 04:37:59 +00:00
import torch
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from transformers import CLIPTokenizer, T5Tokenizer, T5TokenizerFast
2024-02-01 04:37:59 +00:00
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
from invokeai.backend.image_util.grounding_dino.grounding_dino_pipeline import GroundingDinoPipeline
from invokeai.backend.image_util.segment_anything.segment_anything_pipeline import SegmentAnythingPipeline
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.textual_inversion import TextualInversionModelRaw
2024-02-01 04:37:59 +00:00
def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
2024-02-01 04:37:59 +00:00
"""Get size of a model in memory in bytes."""
# TODO(ryand): We should create a CacheableModel interface for all models, and move the size calculations down to
# the models themselves.
2024-02-01 04:37:59 +00:00
if isinstance(model, DiffusionPipeline):
return _calc_pipeline_by_data(model)
elif isinstance(model, torch.nn.Module):
return calc_module_size(model)
2024-02-01 04:37:59 +00:00
elif isinstance(model, IAIOnnxRuntimeModel):
return _calc_onnx_model_by_data(model)
elif isinstance(model, SchedulerMixin):
return 0
elif isinstance(model, CLIPTokenizer):
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
return 0
elif isinstance(
model,
(
TextualInversionModelRaw,
IPAdapter,
LoRAModelRaw,
SpandrelImageToImageModel,
GroundingDinoPipeline,
SegmentAnythingPipeline,
DepthAnythingPipeline,
),
):
return model.calc_size()
elif isinstance(
model,
(
T5TokenizerFast,
T5Tokenizer,
),
):
return len(model)
2024-02-01 04:37:59 +00:00
else:
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
# supported model types.
logger.warning(
f"Failed to calculate model size for unexpected model type: {type(model)}. The model will be treated as "
"having size 0."
)
2024-02-01 04:37:59 +00:00
return 0
def _calc_pipeline_by_data(pipeline: DiffusionPipeline) -> int:
res = 0
assert hasattr(pipeline, "components")
for submodel_key in pipeline.components.keys():
submodel = getattr(pipeline, submodel_key)
if submodel is not None and isinstance(submodel, torch.nn.Module):
res += calc_module_size(submodel)
2024-02-01 04:37:59 +00:00
return res
def calc_module_size(model: torch.nn.Module) -> int:
"""Calculate the size (in bytes) of a torch.nn.Module."""
2024-02-01 04:37:59 +00:00
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
mem: int = mem_params + mem_bufs # in bytes
return mem
def _calc_onnx_model_by_data(model: IAIOnnxRuntimeModel) -> int:
tensor_size = model.tensors.size() * 2 # The session doubles this
mem = tensor_size # in bytes
return mem
def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, variant: Optional[str] = None) -> int:
"""Estimate the size of a model on disk in bytes."""
if model_path.is_file():
return model_path.stat().st_size
2024-02-01 04:37:59 +00:00
if subfolder is not None:
model_path = model_path / subfolder
# this can happen when, for example, the safety checker is not downloaded.
if not model_path.exists():
return 0
all_files = [f for f in model_path.iterdir() if (model_path / f).is_file()]
fp16_files = {f for f in all_files if ".fp16." in f.name or ".fp16-" in f.name}
bit8_files = {f for f in all_files if ".8bit." in f.name or ".8bit-" in f.name}
other_files = set(all_files) - fp16_files - bit8_files
if not variant: # ModelRepoVariant.DEFAULT evaluates to empty string for compatability with HF
2024-02-01 04:37:59 +00:00
files = other_files
elif variant == "fp16":
files = fp16_files
elif variant == "8bit":
files = bit8_files
else:
raise NotImplementedError(f"Unknown variant: {variant}")
# try read from index if exists
index_postfix = ".index.json"
if variant is not None:
index_postfix = f".index.{variant}.json"
for file in files:
if not file.name.endswith(index_postfix):
continue
try:
with open(model_path / file, "r") as f:
index_data = json.loads(f.read())
return int(index_data["metadata"]["total_size"])
except Exception:
pass
# calculate files size if there is no index file
formats = [
(".safetensors",), # safetensors
(".bin",), # torch
(".onnx", ".pb"), # onnx
(".msgpack",), # flax
(".ckpt",), # tf
(".h5",), # tf2
]
for file_format in formats:
model_files = [f for f in files if f.suffix in file_format]
if len(model_files) == 0:
continue
model_size = 0
for model_file in model_files:
file_stats = (model_path / model_file).stat()
model_size += file_stats.st_size
return model_size
return 0 # scheduler/feature_extractor/tokenizer - models without loading to gpu