diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index b8f44f82ec..64f95699c9 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -54,8 +54,6 @@ class LoraType(dict): pass class TIType(dict): pass -class CkptType(dict): - pass class SDModelType(Enum): diffusers=StableDiffusionGeneratorPipeline # whole pipeline @@ -70,7 +68,6 @@ class SDModelType(Enum): # distinguish them by class lora=LoraType textual_inversion=TIType - ckpt=CkptType class ModelStatus(Enum): unknown='unknown' @@ -93,18 +90,12 @@ SIZE_GUESSTIMATE = { SDModelType.feature_extractor: 0.001, SDModelType.lora: 0.1, SDModelType.textual_inversion: 0.001, - SDModelType.ckpt: 4.2, } # The list of model classes we know how to fetch, for typechecking ModelClass = Union[tuple([x.value for x in SDModelType])] DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel) -# Legacy information needed to load a legacy checkpoint file -class LegacyInfo(BaseModel): - config_file: Path - vae_file: Path = None - class UnsafeModelException(Exception): "Raised when a legacy model file fails the picklescan test" pass @@ -160,7 +151,6 @@ class ModelCache(object): subfolder: Path=None, submodel: SDModelType=None, revision: str=None, - legacy_info: LegacyInfo=None, attach_model_part: Tuple[SDModelType, str] = (None,None), gpu_load: bool=True, )->ModelLocker: # ?? what does it return @@ -210,13 +200,12 @@ class ModelCache(object): The model will be locked into GPU VRAM for the duration of the context. :param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model + :param model_type: An SDModelType enum indicating the type of the (parent) model :param subfolder: name of a subfolder in which the model can be found, e.g. "vae" :param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.vae :param attach_model_part: load and attach a diffusers model component. Pass a tuple of format (SDModelType,repo_id) :param revision: model revision - :param model_class: class of model to return :param gpu_load: load the model into GPU [default True] - :param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt ''' key = self._model_key( # internal unique identifier for the model repo_id_or_path, @@ -256,10 +245,9 @@ class ModelCache(object): gc.collect() model = self._load_model_from_storage( repo_id_or_path=repo_id_or_path, - model_class=model_type.value, + model_type=model_type, subfolder=subfolder, revision=revision, - legacy_info=legacy_info, ) if mem_used := self.calc_model_size(model): @@ -459,27 +447,30 @@ class ModelCache(object): repo_id_or_path: Union[str,Path], subfolder: Path=None, revision: str=None, - model_class: ModelClass=StableDiffusionGeneratorPipeline, - legacy_info: LegacyInfo=None, + model_type: SDModelType=SDModelType.diffusers, )->ModelClass: ''' Load and return a HuggingFace model. :param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model :param subfolder: name of a subfolder in which the model can be found, e.g. "vae" :param revision: model revision - :param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline - :param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt + :param model_type: type of model to return, defaults to SDModelType.diffusers ''' # silence transformer and diffuser warnings with SilenceWarnings(): - model = self._load_diffusers_from_storage( - repo_id_or_path, - subfolder, - revision, - model_class, - ) - if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline): - model.enable_offload_submodels(self.execution_device) + if model_type==SDModelType.lora: + model = self._load_lora_from_storage(repo_id_or_path) + elif model_type==SDModelType.textual_inversion: + model = self._load_ti_from_storage(repo_id_or_path) + else: + model = self._load_diffusers_from_storage( + repo_id_or_path, + subfolder, + revision, + model_type.value, + ) + if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline): + model.enable_offload_submodels(self.execution_device) return model def _load_diffusers_from_storage( @@ -519,30 +510,11 @@ class ModelCache(object): pass return model - def _load_ckpt_from_storage(self, - ckpt_path: Union[str,Path], - legacy_info:LegacyInfo)->StableDiffusionGeneratorPipeline: - ''' - Load a legacy checkpoint, convert it, and return a StableDiffusionGeneratorPipeline. - :param ckpt_path: string or Path pointing to the weights file (.ckpt or .safetensors) - :param legacy_info: LegacyInfo object containing paths to legacy config file and alternate vae if required - ''' - if legacy_info is None or legacy_info.config_file is None: - if Path(ckpt_path).suffix == '.safetensors': - return safetensors.torch.load_file(ckpt_path) - else: - return torch.load(ckpt_path) - else: - # deferred loading to avoid circular import errors - from .convert_ckpt_to_diffusers import load_pipeline_from_original_stable_diffusion_ckpt - pipeline = load_pipeline_from_original_stable_diffusion_ckpt( - checkpoint_path=ckpt_path, - original_config_file=legacy_info.config_file, - vae_path=legacy_info.vae_file, - return_generator_pipeline=True, - precision=self.precision, - ) - return pipeline + def _load_lora_from_storage(self, lora_path: Path)->SDModelType.lora.value: + assert False,"_load_lora_from_storage() is not yet implemented" + + def _load_ti_from_storage(self, lora_path: Path)->SDModelType.textual_inversion.value: + assert False,"_load_ti_from_storage() is not yet implemented" def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str: sha = hashlib.sha256() diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 4c0b1b3ad9..0905d5bf40 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -296,7 +296,6 @@ class ModelManager(object): ) model_parts = dict([(x.name,x) for x in SDModelType]) - legacy = None if format == 'diffusers': # intercept stanzas that point to checkpoint weights and replace them @@ -332,7 +331,6 @@ class ModelManager(object): model_type = model_type, revision = revision, subfolder = subfolder, - legacy_info = legacy, submodel = submodel, attach_model_part=vae, )