mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
clean up ckpt handling
- remove legacy ckpt loading code from model_cache - added placeholders for lora and textual inversion model loading
This commit is contained in:
parent
9cb962cad7
commit
3d85e769ce
@ -54,8 +54,6 @@ class LoraType(dict):
|
|||||||
pass
|
pass
|
||||||
class TIType(dict):
|
class TIType(dict):
|
||||||
pass
|
pass
|
||||||
class CkptType(dict):
|
|
||||||
pass
|
|
||||||
|
|
||||||
class SDModelType(Enum):
|
class SDModelType(Enum):
|
||||||
diffusers=StableDiffusionGeneratorPipeline # whole pipeline
|
diffusers=StableDiffusionGeneratorPipeline # whole pipeline
|
||||||
@ -70,7 +68,6 @@ class SDModelType(Enum):
|
|||||||
# distinguish them by class
|
# distinguish them by class
|
||||||
lora=LoraType
|
lora=LoraType
|
||||||
textual_inversion=TIType
|
textual_inversion=TIType
|
||||||
ckpt=CkptType
|
|
||||||
|
|
||||||
class ModelStatus(Enum):
|
class ModelStatus(Enum):
|
||||||
unknown='unknown'
|
unknown='unknown'
|
||||||
@ -93,18 +90,12 @@ SIZE_GUESSTIMATE = {
|
|||||||
SDModelType.feature_extractor: 0.001,
|
SDModelType.feature_extractor: 0.001,
|
||||||
SDModelType.lora: 0.1,
|
SDModelType.lora: 0.1,
|
||||||
SDModelType.textual_inversion: 0.001,
|
SDModelType.textual_inversion: 0.001,
|
||||||
SDModelType.ckpt: 4.2,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
# The list of model classes we know how to fetch, for typechecking
|
# The list of model classes we know how to fetch, for typechecking
|
||||||
ModelClass = Union[tuple([x.value for x in SDModelType])]
|
ModelClass = Union[tuple([x.value for x in SDModelType])]
|
||||||
DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel)
|
DiffusionClasses = (StableDiffusionGeneratorPipeline, AutoencoderKL, SchedulerMixin, UNet2DConditionModel)
|
||||||
|
|
||||||
# Legacy information needed to load a legacy checkpoint file
|
|
||||||
class LegacyInfo(BaseModel):
|
|
||||||
config_file: Path
|
|
||||||
vae_file: Path = None
|
|
||||||
|
|
||||||
class UnsafeModelException(Exception):
|
class UnsafeModelException(Exception):
|
||||||
"Raised when a legacy model file fails the picklescan test"
|
"Raised when a legacy model file fails the picklescan test"
|
||||||
pass
|
pass
|
||||||
@ -160,7 +151,6 @@ class ModelCache(object):
|
|||||||
subfolder: Path=None,
|
subfolder: Path=None,
|
||||||
submodel: SDModelType=None,
|
submodel: SDModelType=None,
|
||||||
revision: str=None,
|
revision: str=None,
|
||||||
legacy_info: LegacyInfo=None,
|
|
||||||
attach_model_part: Tuple[SDModelType, str] = (None,None),
|
attach_model_part: Tuple[SDModelType, str] = (None,None),
|
||||||
gpu_load: bool=True,
|
gpu_load: bool=True,
|
||||||
)->ModelLocker: # ?? what does it return
|
)->ModelLocker: # ?? what does it return
|
||||||
@ -210,13 +200,12 @@ class ModelCache(object):
|
|||||||
|
|
||||||
The model will be locked into GPU VRAM for the duration of the context.
|
The model will be locked into GPU VRAM for the duration of the context.
|
||||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||||
|
:param model_type: An SDModelType enum indicating the type of the (parent) model
|
||||||
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
||||||
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.vae
|
:param submodel: an SDModelType enum indicating the model part to return, e.g. SDModelType.vae
|
||||||
:param attach_model_part: load and attach a diffusers model component. Pass a tuple of format (SDModelType,repo_id)
|
:param attach_model_part: load and attach a diffusers model component. Pass a tuple of format (SDModelType,repo_id)
|
||||||
:param revision: model revision
|
:param revision: model revision
|
||||||
:param model_class: class of model to return
|
|
||||||
:param gpu_load: load the model into GPU [default True]
|
:param gpu_load: load the model into GPU [default True]
|
||||||
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
|
|
||||||
'''
|
'''
|
||||||
key = self._model_key( # internal unique identifier for the model
|
key = self._model_key( # internal unique identifier for the model
|
||||||
repo_id_or_path,
|
repo_id_or_path,
|
||||||
@ -256,10 +245,9 @@ class ModelCache(object):
|
|||||||
gc.collect()
|
gc.collect()
|
||||||
model = self._load_model_from_storage(
|
model = self._load_model_from_storage(
|
||||||
repo_id_or_path=repo_id_or_path,
|
repo_id_or_path=repo_id_or_path,
|
||||||
model_class=model_type.value,
|
model_type=model_type,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
legacy_info=legacy_info,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if mem_used := self.calc_model_size(model):
|
if mem_used := self.calc_model_size(model):
|
||||||
@ -459,24 +447,27 @@ class ModelCache(object):
|
|||||||
repo_id_or_path: Union[str,Path],
|
repo_id_or_path: Union[str,Path],
|
||||||
subfolder: Path=None,
|
subfolder: Path=None,
|
||||||
revision: str=None,
|
revision: str=None,
|
||||||
model_class: ModelClass=StableDiffusionGeneratorPipeline,
|
model_type: SDModelType=SDModelType.diffusers,
|
||||||
legacy_info: LegacyInfo=None,
|
|
||||||
)->ModelClass:
|
)->ModelClass:
|
||||||
'''
|
'''
|
||||||
Load and return a HuggingFace model.
|
Load and return a HuggingFace model.
|
||||||
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
:param repo_id_or_path: either the HuggingFace repo_id or a Path to a local model
|
||||||
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
:param subfolder: name of a subfolder in which the model can be found, e.g. "vae"
|
||||||
:param revision: model revision
|
:param revision: model revision
|
||||||
:param model_class: class of model to return, defaults to StableDiffusionGeneratorPIpeline
|
:param model_type: type of model to return, defaults to SDModelType.diffusers
|
||||||
:param legacy_info: a LegacyInfo object containing additional info needed to load a legacy ckpt
|
|
||||||
'''
|
'''
|
||||||
# silence transformer and diffuser warnings
|
# silence transformer and diffuser warnings
|
||||||
with SilenceWarnings():
|
with SilenceWarnings():
|
||||||
|
if model_type==SDModelType.lora:
|
||||||
|
model = self._load_lora_from_storage(repo_id_or_path)
|
||||||
|
elif model_type==SDModelType.textual_inversion:
|
||||||
|
model = self._load_ti_from_storage(repo_id_or_path)
|
||||||
|
else:
|
||||||
model = self._load_diffusers_from_storage(
|
model = self._load_diffusers_from_storage(
|
||||||
repo_id_or_path,
|
repo_id_or_path,
|
||||||
subfolder,
|
subfolder,
|
||||||
revision,
|
revision,
|
||||||
model_class,
|
model_type.value,
|
||||||
)
|
)
|
||||||
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
|
if self.sequential_offload and isinstance(model,StableDiffusionGeneratorPipeline):
|
||||||
model.enable_offload_submodels(self.execution_device)
|
model.enable_offload_submodels(self.execution_device)
|
||||||
@ -519,30 +510,11 @@ class ModelCache(object):
|
|||||||
pass
|
pass
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def _load_ckpt_from_storage(self,
|
def _load_lora_from_storage(self, lora_path: Path)->SDModelType.lora.value:
|
||||||
ckpt_path: Union[str,Path],
|
assert False,"_load_lora_from_storage() is not yet implemented"
|
||||||
legacy_info:LegacyInfo)->StableDiffusionGeneratorPipeline:
|
|
||||||
'''
|
def _load_ti_from_storage(self, lora_path: Path)->SDModelType.textual_inversion.value:
|
||||||
Load a legacy checkpoint, convert it, and return a StableDiffusionGeneratorPipeline.
|
assert False,"_load_ti_from_storage() is not yet implemented"
|
||||||
:param ckpt_path: string or Path pointing to the weights file (.ckpt or .safetensors)
|
|
||||||
:param legacy_info: LegacyInfo object containing paths to legacy config file and alternate vae if required
|
|
||||||
'''
|
|
||||||
if legacy_info is None or legacy_info.config_file is None:
|
|
||||||
if Path(ckpt_path).suffix == '.safetensors':
|
|
||||||
return safetensors.torch.load_file(ckpt_path)
|
|
||||||
else:
|
|
||||||
return torch.load(ckpt_path)
|
|
||||||
else:
|
|
||||||
# deferred loading to avoid circular import errors
|
|
||||||
from .convert_ckpt_to_diffusers import load_pipeline_from_original_stable_diffusion_ckpt
|
|
||||||
pipeline = load_pipeline_from_original_stable_diffusion_ckpt(
|
|
||||||
checkpoint_path=ckpt_path,
|
|
||||||
original_config_file=legacy_info.config_file,
|
|
||||||
vae_path=legacy_info.vae_file,
|
|
||||||
return_generator_pipeline=True,
|
|
||||||
precision=self.precision,
|
|
||||||
)
|
|
||||||
return pipeline
|
|
||||||
|
|
||||||
def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str:
|
def _legacy_model_hash(self, checkpoint_path: Union[str,Path])->str:
|
||||||
sha = hashlib.sha256()
|
sha = hashlib.sha256()
|
||||||
|
@ -296,7 +296,6 @@ class ModelManager(object):
|
|||||||
)
|
)
|
||||||
|
|
||||||
model_parts = dict([(x.name,x) for x in SDModelType])
|
model_parts = dict([(x.name,x) for x in SDModelType])
|
||||||
legacy = None
|
|
||||||
|
|
||||||
if format == 'diffusers':
|
if format == 'diffusers':
|
||||||
# intercept stanzas that point to checkpoint weights and replace them
|
# intercept stanzas that point to checkpoint weights and replace them
|
||||||
@ -332,7 +331,6 @@ class ModelManager(object):
|
|||||||
model_type = model_type,
|
model_type = model_type,
|
||||||
revision = revision,
|
revision = revision,
|
||||||
subfolder = subfolder,
|
subfolder = subfolder,
|
||||||
legacy_info = legacy,
|
|
||||||
submodel = submodel,
|
submodel = submodel,
|
||||||
attach_model_part=vae,
|
attach_model_part=vae,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user