mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix some models treated as having size 0 in the model cache (#6571)
## Summary This PR fixes a regression that caused the following models to be treated as having size 0 in the model cache: `(TextualInversionModelRaw, IPAdapter, LoRAModelRaw)`. Changes: - Call the correct model size calculation for all supported model types. - Log an error message if an unexpected model type is loaded, to prevent similar regressions in the future. ## QA Instructions I tested the following features and verified that no models fell back to using a size of 0 unexpectedly: - Test-to-image - Textual Inversion - LoRA - IP-Adapter - ControlNet (All tested with both SD1.5 and SDXL.) I compared the model cache switching behavior before and after this change with a large number of LoRAs (10). Since LoRAs are small compared to the main models, the changes in behaviour are minimal. Nonetheless, it makes sense to get this in for correctness. And it might make a difference for some usage patterns with limited RAM. ## Merge Plan No special instructions. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_
This commit is contained in:
commit
7bbd793064
@ -136,11 +136,11 @@ class IPAdapter(RawModel):
|
||||
self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
|
||||
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
|
||||
|
||||
def calc_size(self):
|
||||
# workaround for circular import
|
||||
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
|
||||
def calc_size(self) -> int:
|
||||
# HACK(ryand): Fix this issue with circular imports.
|
||||
from invokeai.backend.model_manager.load.model_util import calc_module_size
|
||||
|
||||
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
|
||||
return calc_module_size(self._image_proj_model) + calc_module_size(self.attn_weights)
|
||||
|
||||
def _init_image_proj_model(
|
||||
self, state_dict: dict[str, torch.Tensor]
|
||||
|
@ -160,7 +160,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
key = self._make_cache_key(key, submodel_type)
|
||||
if key in self._cached_models:
|
||||
return
|
||||
size = calc_model_size_by_data(model)
|
||||
size = calc_model_size_by_data(self.logger, model)
|
||||
self.make_room(size)
|
||||
|
||||
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None
|
||||
|
@ -2,25 +2,46 @@
|
||||
"""Various utility functions needed by the loader and caching system."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin
|
||||
from transformers import CLIPTokenizer
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.lora import LoRAModelRaw
|
||||
from invokeai.backend.model_manager.config import AnyModel
|
||||
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
|
||||
from invokeai.backend.textual_inversion import TextualInversionModelRaw
|
||||
|
||||
|
||||
def calc_model_size_by_data(model: AnyModel) -> int:
|
||||
def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
|
||||
"""Get size of a model in memory in bytes."""
|
||||
# TODO(ryand): We should create a CacheableModel interface for all models, and move the size calculations down to
|
||||
# the models themselves.
|
||||
if isinstance(model, DiffusionPipeline):
|
||||
return _calc_pipeline_by_data(model)
|
||||
elif isinstance(model, torch.nn.Module):
|
||||
return _calc_model_by_data(model)
|
||||
return calc_module_size(model)
|
||||
elif isinstance(model, IAIOnnxRuntimeModel):
|
||||
return _calc_onnx_model_by_data(model)
|
||||
elif isinstance(model, SchedulerMixin):
|
||||
return 0
|
||||
elif isinstance(model, CLIPTokenizer):
|
||||
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
|
||||
return 0
|
||||
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw)):
|
||||
return model.calc_size()
|
||||
else:
|
||||
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
|
||||
# supported model types.
|
||||
logger.error(
|
||||
f"Failed to calculate model size for unexpected model type: {type(model)}. The model will be treated as "
|
||||
"having size 0."
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
@ -30,11 +51,12 @@ def _calc_pipeline_by_data(pipeline: DiffusionPipeline) -> int:
|
||||
for submodel_key in pipeline.components.keys():
|
||||
submodel = getattr(pipeline, submodel_key)
|
||||
if submodel is not None and isinstance(submodel, torch.nn.Module):
|
||||
res += _calc_model_by_data(submodel)
|
||||
res += calc_module_size(submodel)
|
||||
return res
|
||||
|
||||
|
||||
def _calc_model_by_data(model: torch.nn.Module) -> int:
|
||||
def calc_module_size(model: torch.nn.Module) -> int:
|
||||
"""Calculate the size (in bytes) of a torch.nn.Module."""
|
||||
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
|
||||
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
|
||||
mem: int = mem_params + mem_bufs # in bytes
|
||||
|
@ -77,6 +77,14 @@ class TextualInversionModelRaw(RawModel):
|
||||
if emb is not None:
|
||||
emb.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
"""Get the size of this model in bytes."""
|
||||
embedding_size = self.embedding.element_size() * self.embedding.nelement()
|
||||
embedding_2_size = 0
|
||||
if self.embedding_2 is not None:
|
||||
embedding_2_size = self.embedding_2.element_size() * self.embedding_2.nelement()
|
||||
return embedding_size + embedding_2_size
|
||||
|
||||
|
||||
class TextualInversionManager(BaseTextualInversionManager):
|
||||
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
|
||||
|
Loading…
Reference in New Issue
Block a user