mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Multiple refinements on loaders:
- Cache stat collection enabled. - Implemented ONNX loading. - Add ability to specify the repo version variant in installer CLI. - If caller asks for a repo version that doesn't exist, will fall back to empty version rather than raising an error.
This commit is contained in:
committed by
psychedelicious
parent
0d3addc69b
commit
5745ce9c7d
@ -52,7 +52,7 @@ class ModelLoader(ModelLoaderBase):
|
||||
self._logger = logger
|
||||
self._ram_cache = ram_cache
|
||||
self._convert_cache = convert_cache
|
||||
self._torch_dtype = torch_dtype(choose_torch_device())
|
||||
self._torch_dtype = torch_dtype(choose_torch_device(), app_config)
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
@ -102,8 +102,10 @@ class ModelLoader(ModelLoaderBase):
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
) -> ModelLockerBase:
|
||||
# TO DO: This is not thread safe!
|
||||
if self._ram_cache.exists(config.key, submodel_type):
|
||||
try:
|
||||
return self._ram_cache.get(config.key, submodel_type)
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
model_variant = getattr(config, "repo_variant", None)
|
||||
self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
|
||||
@ -119,7 +121,11 @@ class ModelLoader(ModelLoaderBase):
|
||||
size=calc_model_size_by_data(loaded_model),
|
||||
)
|
||||
|
||||
return self._ram_cache.get(config.key, submodel_type)
|
||||
return self._ram_cache.get(
|
||||
key=config.key,
|
||||
submodel_type=submodel_type,
|
||||
stats_name=":".join([config.base, config.type, config.name, (submodel_type or "")]),
|
||||
)
|
||||
|
||||
def get_size_fs(
|
||||
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
|
||||
@ -146,13 +152,21 @@ class ModelLoader(ModelLoaderBase):
|
||||
# TO DO: Add exception handling
|
||||
def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
|
||||
if submodel_type:
|
||||
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
||||
module, class_name = config[submodel_type.value]
|
||||
return self._hf_definition_to_type(module=module, class_name=class_name)
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="model_index.json")
|
||||
module, class_name = config[submodel_type.value]
|
||||
return self._hf_definition_to_type(module=module, class_name=class_name)
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException(
|
||||
f'The "{submodel_type}" submodel is not available for this model.'
|
||||
) from e
|
||||
else:
|
||||
config = self._load_diffusers_config(model_path, config_name="config.json")
|
||||
class_name = config["_class_name"]
|
||||
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="config.json")
|
||||
class_name = config["_class_name"]
|
||||
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||
|
||||
# This needs to be implemented in subclasses that handle checkpoints
|
||||
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
|
||||
|
Reference in New Issue
Block a user