mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
clean up ckpt handling
- remove legacy ckpt loading code from model_cache - added placeholders for lora and textual inversion model loading
This commit is contained in:
parent
9cb962cad7
commit
3d85e769ce
@ -54,8 +54,6 @@ class LoraType(dict):
|
||||
pass
|
||||
class TIType(dict):
|
||||
pass
|
||||
class CkptType(dict):
|
||||
pass
|
||||
|
||||
class SDModelType(Enum):
|
||||
diffusers=StableDiffusionGeneratorPipeline # whole pipeline
|
||||
@ -70,7 +68,6 @@ class SDModelType(Enum):
|
||||
# distinguish them by class
|
||||
lora=LoraType
|
||||
textual_inversion=TIType
|
||||
ckpt=CkptType
|
||||
|
||||
class ModelStatus(Enum):
|
||||
unknown='unknown'
|
||||
@ -93,18 +90,12 @@ SIZE_GUESSTIMATE = {
|
||||
SDModelType.feature_extractor: 0.001,
|
||||
SDModelType.lora: 0.1,
|
||||
SDModelType.textual_inversion: 0.001,
|
||||
SDModelType.ckpt: 4.2,
|
||||
}
|
||||
|
||||
# The list of model classes we know how to fetch, for typechecking
|
||||
ModelClass = Union[tuple([x.value for x in SDModelType])]
|
||||
DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel)
|
||||
|
||||
# Legacy information needed to load a legacy checkpoint file
|
||||
class LegacyInfo(BaseModel):
|
||||
config_file: Path
|
||||
vae_file: Path = None
|
||||
|
||||
class UnsafeModelException(Exception):
|
||||
"Raised when a legacy model file fails the picklescan test"
|
||||
pass
|
||||
@ -160,7 +151,6 @@ class ModelCache(object):
|
||||
subfolder: Path=None,
|
||||
submodel: SDModelType=None,
|
||||
revision: str=None,
|
||||
legacy_info: LegacyInfo=None,
|
||||
attach_model_part: Tuple[SDModelType, str] = (None,None),
|
||||
gpu_load: bool=True,
|
||||
)->ModelLocker: # ?? what does it return
|
||||
@ -210,13 +200,12 @@ class ModelCache(object):
|
||||
|
||||
The model will be locked into GPU VRAM for the duration of the context.
|
||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||
:param model_type: An SDModelType enum indicating the type of the (parent) model
|
||||
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
||||
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.vae
|
||||
:param attach_model_part: load and attach a diffusers model component. Pass a tuple of format (SDModelType,repo_id)
|
||||
:param revision: model revision
|
||||
:param model_class: class of model to return
|
||||
:param gpu_load: load the model into GPU [default True]
|
||||
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
|
||||
'''
|
||||
key = self._model_key( # internal unique identifier for the model
|
||||
repo_id_or_path,
|
||||
@ -256,10 +245,9 @@ class ModelCache(object):
|
||||
gc.collect()
|
||||
model = self._load_model_from_storage(
|
||||
repo_id_or_path=repo_id_or_path,
|
||||
model_class=model_type.value,
|
||||
model_type=model_type,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
legacy_info=legacy_info,
|
||||
)
|
||||
|
||||
if mem_used := self.calc_model_size(model):
|
||||
@ -459,27 +447,30 @@ class ModelCache(object):
|
||||
repo_id_or_path: Union[str,Path],
|
||||
subfolder: Path=None,
|
||||
revision: str=None,
|
||||
model_class: ModelClass=StableDiffusionGeneratorPipeline,
|
||||
legacy_info: LegacyInfo=None,
|
||||
model_type: SDModelType=SDModelType.diffusers,
|
||||
)->ModelClass:
|
||||
'''
|
||||
Load and return a HuggingFace model.
|
||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
||||
:param revision: model revision
|
||||
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
|
||||
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
|
||||
:param model_type: type of model to return, defaults to SDModelType.diffusers
|
||||
'''
|
||||
# silence transformer and diffuser warnings
|
||||
with SilenceWarnings():
|
||||
model = self._load_diffusers_from_storage(
|
||||
repo_id_or_path,
|
||||
subfolder,
|
||||
revision,
|
||||
model_class,
|
||||
)
|
||||
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
|
||||
model.enable_offload_submodels(self.execution_device)
|
||||
if model_type==SDModelType.lora:
|
||||
model = self._load_lora_from_storage(repo_id_or_path)
|
||||
elif model_type==SDModelType.textual_inversion:
|
||||
model = self._load_ti_from_storage(repo_id_or_path)
|
||||
else:
|
||||
model = self._load_diffusers_from_storage(
|
||||
repo_id_or_path,
|
||||
subfolder,
|
||||
revision,
|
||||
model_type.value,
|
||||
)
|
||||
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
|
||||
model.enable_offload_submodels(self.execution_device)
|
||||
return model
|
||||
|
||||
def _load_diffusers_from_storage(
|
||||
@ -519,30 +510,11 @@ class ModelCache(object):
|
||||
pass
|
||||
return model
|
||||
|
||||
def _load_ckpt_from_storage(self,
|
||||
ckpt_path: Union[str,Path],
|
||||
legacy_info:LegacyInfo)->StableDiffusionGeneratorPipeline:
|
||||
'''
|
||||
Load a legacy checkpoint, convert it, and return a StableDiffusionGeneratorPipeline.
|
||||
:param ckpt_path: string or Path pointing to the weights file (.ckpt or .safetensors)
|
||||
:param legacy_info: LegacyInfo object containing paths to legacy config file and alternate vae if required
|
||||
'''
|
||||
if legacy_info is None or legacy_info.config_file is None:
|
||||
if Path(ckpt_path).suffix == '.safetensors':
|
||||
return safetensors.torch.load_file(ckpt_path)
|
||||
else:
|
||||
return torch.load(ckpt_path)
|
||||
else:
|
||||
# deferred loading to avoid circular import errors
|
||||
from .convert_ckpt_to_diffusers import load_pipeline_from_original_stable_diffusion_ckpt
|
||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path=ckpt_path,
|
||||
original_config_file=legacy_info.config_file,
|
||||
vae_path=legacy_info.vae_file,
|
||||
return_generator_pipeline=True,
|
||||
precision=self.precision,
|
||||
)
|
||||
return pipeline
|
||||
def _load_lora_from_storage(self, lora_path: Path)->SDModelType.lora.value:
|
||||
assert False,"_load_lora_from_storage() is not yet implemented"
|
||||
|
||||
def _load_ti_from_storage(self, lora_path: Path)->SDModelType.textual_inversion.value:
|
||||
assert False,"_load_ti_from_storage() is not yet implemented"
|
||||
|
||||
def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str:
|
||||
sha = hashlib.sha256()
|
||||
|
@ -296,7 +296,6 @@ class ModelManager(object):
|
||||
)
|
||||
|
||||
model_parts = dict([(x.name,x) for x in SDModelType])
|
||||
legacy = None
|
||||
|
||||
if format == 'diffusers':
|
||||
# intercept stanzas that point to checkpoint weights and replace them
|
||||
@ -332,7 +331,6 @@ class ModelManager(object):
|
||||
model_type = model_type,
|
||||
revision = revision,
|
||||
subfolder = subfolder,
|
||||
legacy_info = legacy,
|
||||
submodel = submodel,
|
||||
attach_model_part=vae,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user