From 3e1af5173713a43e01cca72c68258109ec99ff4a Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 4 Apr 2024 13:15:19 -0400 Subject: [PATCH] Add skeleton of a PeftModel class. --- invokeai/backend/peft/peft_model.py | 56 +++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 invokeai/backend/peft/peft_model.py diff --git a/invokeai/backend/peft/peft_model.py b/invokeai/backend/peft/peft_model.py new file mode 100644 index 0000000000..c21890f831 --- /dev/null +++ b/invokeai/backend/peft/peft_model.py @@ -0,0 +1,56 @@ +from pathlib import Path +from typing import Optional, Union + +import torch +from safetensors.torch import load_file + +from invokeai.backend.model_manager.config import BaseModelType + + +class PeftModel: + """A class for loading and managing parameter-efficient fine-tuning models.""" + + def __init__( + self, + name: str, + state_dict: dict[str, torch.Tensor], + ): + self._name = name + self._state_dict = state_dict + + @property + def name(self) -> str: + return self._name + + def calc_size(self) -> int: + model_size = 0 + for tensor in self._state_dict.values(): + model_size += tensor.nelement() * tensor.element_size() + return model_size + + @classmethod + def from_checkpoint( + cls, + file_path: Union[str, Path], + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + base_model: Optional[BaseModelType] = None, + ): + device = device or torch.device("cpu") + dtype = dtype or torch.float32 + + file_path = Path(file_path) + + # TODO(ryand): Implement a helper function for this. This logic is duplicated repeatedly. + if file_path.suffix == ".safetensors": + state_dict = load_file(file_path, device="cpu") + else: + state_dict = torch.load(file_path, map_location="cpu") + + # TODO(ryand): + # - Detect state_dict format + # - Convert state_dict to diffusers format if necessary + + # if base_model == BaseModelType.StableDiffusionXL: + # state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict) + return cls(name=file_path.stem, state_dict=state_dict)