clean up ckpt handling

- remove legacy ckpt loading code from model_cache
- added placeholders for lora and textual inversion model loading
This commit is contained in:
Lincoln Stein 2023-05-09 22:44:58 -04:00
parent 9cb962cad7
commit 3d85e769ce
2 changed files with 22 additions and 52 deletions

View File

@ -54,8 +54,6 @@ class LoraType(dict):
pass pass
class TIType(dict): class TIType(dict):
pass pass
class CkptType(dict):
pass
class SDModelType(Enum): class SDModelType(Enum):
diffusers=StableDiffusionGeneratorPipeline # whole pipeline diffusers=StableDiffusionGeneratorPipeline # whole pipeline
@ -70,7 +68,6 @@ class SDModelType(Enum):
# distinguish them by class # distinguish them by class
lora=LoraType lora=LoraType
textual_inversion=TIType textual_inversion=TIType
ckpt=CkptType
class ModelStatus(Enum): class ModelStatus(Enum):
unknown='unknown' unknown='unknown'
@ -93,18 +90,12 @@ SIZE_GUESSTIMATE = {
SDModelType.feature_extractor: 0.001, SDModelType.feature_extractor: 0.001,
SDModelType.lora: 0.1, SDModelType.lora: 0.1,
SDModelType.textual_inversion: 0.001, SDModelType.textual_inversion: 0.001,
SDModelType.ckpt: 4.2,
} }
# The list of model classes we know how to fetch, for typechecking # The list of model classes we know how to fetch, for typechecking
ModelClass = Union[tuple([x.value for x in SDModelType])] ModelClass = Union[tuple([x.value for x in SDModelType])]
DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel) DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel)
# Legacy information needed to load a legacy checkpoint file
class LegacyInfo(BaseModel):
config_file: Path
vae_file: Path = None
class UnsafeModelException(Exception): class UnsafeModelException(Exception):
"Raised when a legacy model file fails the picklescan test" "Raised when a legacy model file fails the picklescan test"
pass pass
@ -160,7 +151,6 @@ class ModelCache(object):
subfolder: Path=None, subfolder: Path=None,
submodel: SDModelType=None, submodel: SDModelType=None,
revision: str=None, revision: str=None,
legacy_info: LegacyInfo=None,
attach_model_part: Tuple[SDModelType, str] = (None,None), attach_model_part: Tuple[SDModelType, str] = (None,None),
gpu_load: bool=True, gpu_load: bool=True,
)->ModelLocker: # ?? what does it return )->ModelLocker: # ?? what does it return
@ -210,13 +200,12 @@ class ModelCache(object):
The model will be locked into GPU VRAM for the duration of the context. The model will be locked into GPU VRAM for the duration of the context.
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model :param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
:param model_type: An SDModelType enum indicating the type of the (parent) model
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae" :param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.vae :param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.vae
:param attach_model_part: load and attach a diffusers model component. Pass a tuple of format (SDModelType,repo_id) :param attach_model_part: load and attach a diffusers model component. Pass a tuple of format (SDModelType,repo_id)
:param revision: model revision :param revision: model revision
:param model_class: class of model to return
:param gpu_load: load the model into GPU [default True] :param gpu_load: load the model into GPU [default True]
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
''' '''
key = self._model_key( # internal unique identifier for the model key = self._model_key( # internal unique identifier for the model
repo_id_or_path, repo_id_or_path,
@ -256,10 +245,9 @@ class ModelCache(object):
gc.collect() gc.collect()
model = self._load_model_from_storage( model = self._load_model_from_storage(
repo_id_or_path=repo_id_or_path, repo_id_or_path=repo_id_or_path,
model_class=model_type.value, model_type=model_type,
subfolder=subfolder, subfolder=subfolder,
revision=revision, revision=revision,
legacy_info=legacy_info,
) )
if mem_used := self.calc_model_size(model): if mem_used := self.calc_model_size(model):
@ -459,27 +447,30 @@ class ModelCache(object):
repo_id_or_path: Union[str,Path], repo_id_or_path: Union[str,Path],
subfolder: Path=None, subfolder: Path=None,
revision: str=None, revision: str=None,
model_class: ModelClass=StableDiffusionGeneratorPipeline, model_type: SDModelType=SDModelType.diffusers,
legacy_info: LegacyInfo=None,
)->ModelClass: )->ModelClass:
''' '''
Load and return a HuggingFace model. Load and return a HuggingFace model.
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model :param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae" :param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
:param revision: model revision :param revision: model revision
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline :param model_type: type of model to return, defaults to SDModelType.diffusers
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
''' '''
# silence transformer and diffuser warnings # silence transformer and diffuser warnings
with SilenceWarnings(): with SilenceWarnings():
model = self._load_diffusers_from_storage( if model_type==SDModelType.lora:
repo_id_or_path, model = self._load_lora_from_storage(repo_id_or_path)
subfolder, elif model_type==SDModelType.textual_inversion:
revision, model = self._load_ti_from_storage(repo_id_or_path)
model_class, else:
) model = self._load_diffusers_from_storage(
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline): repo_id_or_path,
model.enable_offload_submodels(self.execution_device) subfolder,
revision,
model_type.value,
)
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
model.enable_offload_submodels(self.execution_device)
return model return model
def _load_diffusers_from_storage( def _load_diffusers_from_storage(
@ -519,30 +510,11 @@ class ModelCache(object):
pass pass
return model return model
def _load_ckpt_from_storage(self, def _load_lora_from_storage(self, lora_path: Path)->SDModelType.lora.value:
ckpt_path: Union[str,Path], assert False,"_load_lora_from_storage() is not yet implemented"
legacy_info:LegacyInfo)->StableDiffusionGeneratorPipeline:
''' def _load_ti_from_storage(self, lora_path: Path)->SDModelType.textual_inversion.value:
Load a legacy checkpoint, convert it, and return a StableDiffusionGeneratorPipeline. assert False,"_load_ti_from_storage() is not yet implemented"
:param ckpt_path: string or Path pointing to the weights file (.ckpt or .safetensors)
:param legacy_info: LegacyInfo object containing paths to legacy config file and alternate vae if required
'''
if legacy_info is None or legacy_info.config_file is None:
if Path(ckpt_path).suffix == '.safetensors':
return safetensors.torch.load_file(ckpt_path)
else:
return torch.load(ckpt_path)
else:
# deferred loading to avoid circular import errors
from .convert_ckpt_to_diffusers import load_pipeline_from_original_stable_diffusion_ckpt
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=ckpt_path,
original_config_file=legacy_info.config_file,
vae_path=legacy_info.vae_file,
return_generator_pipeline=True,
precision=self.precision,
)
return pipeline
def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str: def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str:
sha = hashlib.sha256() sha = hashlib.sha256()

View File

@ -296,7 +296,6 @@ class ModelManager(object):
) )
model_parts = dict([(x.name,x) for x in SDModelType]) model_parts = dict([(x.name,x) for x in SDModelType])
legacy = None
if format == 'diffusers': if format == 'diffusers':
# intercept stanzas that point to checkpoint weights and replace them # intercept stanzas that point to checkpoint weights and replace them
@ -332,7 +331,6 @@ class ModelManager(object):
model_type = model_type, model_type = model_type,
revision = revision, revision = revision,
subfolder = subfolder, subfolder = subfolder,
legacy_info = legacy,
submodel = submodel, submodel = submodel,
attach_model_part=vae, attach_model_part=vae,
) )