clean up ckpt handling

- remove legacy ckpt loading code from model_cache
- added placeholders for lora and textual inversion model loading
This commit is contained in:
Lincoln Stein 2023-05-09 22:44:58 -04:00
parent 9cb962cad7
commit 3d85e769ce
2 changed files with 22 additions and 52 deletions

View File

@ -54,8 +54,6 @@ class LoraType(dict):
pass
class TIType(dict):
pass
class CkptType(dict):
pass
class SDModelType(Enum):
diffusers=StableDiffusionGeneratorPipeline # whole pipeline
@ -70,7 +68,6 @@ class SDModelType(Enum):
# distinguish them by class
lora=LoraType
textual_inversion=TIType
ckpt=CkptType
class ModelStatus(Enum):
unknown='unknown'
@ -93,18 +90,12 @@ SIZE_GUESSTIMATE = {
SDModelType.feature_extractor: 0.001,
SDModelType.lora: 0.1,
SDModelType.textual_inversion: 0.001,
SDModelType.ckpt: 4.2,
}
# The list of model classes we know how to fetch, for typechecking
ModelClass = Union[tuple([x.value for x in SDModelType])]
DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel)
# Legacy information needed to load a legacy checkpoint file
class LegacyInfo(BaseModel):
config_file: Path
vae_file: Path = None
class UnsafeModelException(Exception):
"Raised when a legacy model file fails the picklescan test"
pass
@ -160,7 +151,6 @@ class ModelCache(object):
subfolder: Path=None,
submodel: SDModelType=None,
revision: str=None,
legacy_info: LegacyInfo=None,
attach_model_part: Tuple[SDModelType, str] = (None,None),
gpu_load: bool=True,
)->ModelLocker: # ?? what does it return
@ -210,13 +200,12 @@ class ModelCache(object):
The model will be locked into GPU VRAM for the duration of the context.
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
:param model_type: An SDModelType enum indicating the type of the (parent) model
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.vae
:param attach_model_part: load and attach a diffusers model component. Pass a tuple of format (SDModelType,repo_id)
:param revision: model revision
:param model_class: class of model to return
:param gpu_load: load the model into GPU [default True]
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
'''
key = self._model_key( # internal unique identifier for the model
repo_id_or_path,
@ -256,10 +245,9 @@ class ModelCache(object):
gc.collect()
model = self._load_model_from_storage(
repo_id_or_path=repo_id_or_path,
model_class=model_type.value,
model_type=model_type,
subfolder=subfolder,
revision=revision,
legacy_info=legacy_info,
)
if mem_used := self.calc_model_size(model):
@ -459,27 +447,30 @@ class ModelCache(object):
repo_id_or_path: Union[str,Path],
subfolder: Path=None,
revision: str=None,
model_class: ModelClass=StableDiffusionGeneratorPipeline,
legacy_info: LegacyInfo=None,
model_type: SDModelType=SDModelType.diffusers,
)->ModelClass:
'''
Load and return a HuggingFace model.
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
:param revision: model revision
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
:param model_type: type of model to return, defaults to SDModelType.diffusers
'''
# silence transformer and diffuser warnings
with SilenceWarnings():
model = self._load_diffusers_from_storage(
repo_id_or_path,
subfolder,
revision,
model_class,
)
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
model.enable_offload_submodels(self.execution_device)
if model_type==SDModelType.lora:
model = self._load_lora_from_storage(repo_id_or_path)
elif model_type==SDModelType.textual_inversion:
model = self._load_ti_from_storage(repo_id_or_path)
else:
model = self._load_diffusers_from_storage(
repo_id_or_path,
subfolder,
revision,
model_type.value,
)
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
model.enable_offload_submodels(self.execution_device)
return model
def _load_diffusers_from_storage(
@ -519,30 +510,11 @@ class ModelCache(object):
pass
return model
def _load_ckpt_from_storage(self,
ckpt_path: Union[str,Path],
legacy_info:LegacyInfo)->StableDiffusionGeneratorPipeline:
'''
Load a legacy checkpoint, convert it, and return a StableDiffusionGeneratorPipeline.
:param ckpt_path: string or Path pointing to the weights file (.ckpt or .safetensors)
:param legacy_info: LegacyInfo object containing paths to legacy config file and alternate vae if required
'''
if legacy_info is None or legacy_info.config_file is None:
if Path(ckpt_path).suffix == '.safetensors':
return safetensors.torch.load_file(ckpt_path)
else:
return torch.load(ckpt_path)
else:
# deferred loading to avoid circular import errors
from .convert_ckpt_to_diffusers import load_pipeline_from_original_stable_diffusion_ckpt
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=ckpt_path,
original_config_file=legacy_info.config_file,
vae_path=legacy_info.vae_file,
return_generator_pipeline=True,
precision=self.precision,
)
return pipeline
def _load_lora_from_storage(self, lora_path: Path)->SDModelType.lora.value:
assert False,"_load_lora_from_storage() is not yet implemented"
def _load_ti_from_storage(self, lora_path: Path)->SDModelType.textual_inversion.value:
assert False,"_load_ti_from_storage() is not yet implemented"
def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str:
sha = hashlib.sha256()

View File

@ -296,7 +296,6 @@ class ModelManager(object):
)
model_parts = dict([(x.name,x) for x in SDModelType])
legacy = None
if format == 'diffusers':
# intercept stanzas that point to checkpoint weights and replace them
@ -332,7 +331,6 @@ class ModelManager(object):
model_type = model_type,
revision = revision,
subfolder = subfolder,
legacy_info = legacy,
submodel = submodel,
attach_model_part=vae,
)