Update calc_model_size_by_data(...) to handle all expected model types, and to log an error if an unexpected model type is received.

This commit is contained in:
Ryan Dick
2024-07-02 21:14:12 -04:00
parent 0fe92cd406
commit 414750a45d
4 changed files with 40 additions and 10 deletions

View File

@ -77,6 +77,14 @@ class TextualInversionModelRaw(RawModel):
if emb is not None:
emb.to(device=device, dtype=dtype, non_blocking=non_blocking)
def calc_size(self) -> int:
"""Get the size of this model in bytes."""
embedding_size = self.embedding.element_size() * self.embedding.nelement()
embedding_2_size = 0
if self.embedding_2 is not None:
embedding_2_size = self.embedding_2.element_size() * self.embedding_2.nelement()
return embedding_size + embedding_2_size
class TextualInversionManager(BaseTextualInversionManager):
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""