Multiple refinements on loaders:

- Cache stat collection enabled.
- Implemented ONNX loading.
- Add ability to specify the repo version variant in installer CLI.
- If caller asks for a repo version that doesn't exist, will fall back
  to empty version rather than raising an error.
This commit is contained in:
Lincoln Stein
2024-02-05 21:55:11 -05:00
committed by psychedelicious
parent 0d3addc69b
commit 5745ce9c7d
18 changed files with 215 additions and 49 deletions

View File

@ -52,7 +52,7 @@ class ModelLoader(ModelLoaderBase):
self._logger = logger
self._ram_cache = ram_cache
self._convert_cache = convert_cache
self._torch_dtype = torch_dtype(choose_torch_device())
self._torch_dtype = torch_dtype(choose_torch_device(), app_config)
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
@ -102,8 +102,10 @@ class ModelLoader(ModelLoaderBase):
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
) -> ModelLockerBase:
# TO DO: This is not thread safe!
if self._ram_cache.exists(config.key, submodel_type):
try:
return self._ram_cache.get(config.key, submodel_type)
except IndexError:
pass
model_variant = getattr(config, "repo_variant", None)
self._ram_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
@ -119,7 +121,11 @@ class ModelLoader(ModelLoaderBase):
size=calc_model_size_by_data(loaded_model),
)
return self._ram_cache.get(config.key, submodel_type)
return self._ram_cache.get(
key=config.key,
submodel_type=submodel_type,
stats_name=":".join([config.base, config.type, config.name, (submodel_type or "")]),
)
def get_size_fs(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
@ -146,13 +152,21 @@ class ModelLoader(ModelLoaderBase):
# TO DO: Add exception handling
def _get_hf_load_class(self, model_path: Path, submodel_type: Optional[SubModelType] = None) -> ModelMixin:
if submodel_type:
config = self._load_diffusers_config(model_path, config_name="model_index.json")
module, class_name = config[submodel_type.value]
return self._hf_definition_to_type(module=module, class_name=class_name)
try:
config = self._load_diffusers_config(model_path, config_name="model_index.json")
module, class_name = config[submodel_type.value]
return self._hf_definition_to_type(module=module, class_name=class_name)
except KeyError as e:
raise InvalidModelConfigException(
f'The "{submodel_type}" submodel is not available for this model.'
) from e
else:
config = self._load_diffusers_config(model_path, config_name="config.json")
class_name = config["_class_name"]
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
try:
config = self._load_diffusers_config(model_path, config_name="config.json")
class_name = config["_class_name"]
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
# This needs to be implemented in subclasses that handle checkpoints
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path: