doc(model_manager): docstrings

This commit is contained in:
Kevin Turner 2023-07-31 09:08:46 -07:00
parent e3519052ae
commit bacdf985f1
2 changed files with 30 additions and 10 deletions

View File

@ -385,6 +385,11 @@ class ModelManager(object):
def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool: def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool:
""" """
Given a model name, returns True if it is a valid identifier. Given a model name, returns True if it is a valid identifier.
:param model_name: symbolic name of the model in models.yaml
:param model_type: ModelType enum indicating the type of model to return
:param base_model: BaseModelType enum indicating the base model used by this model
:param rescan: if True, scan_models_directory
""" """
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
exists = model_key in self.models exists = model_key in self.models
@ -470,7 +475,7 @@ class ModelManager(object):
else: else:
self.models.pop(model_key, None) self.models.pop(model_key, None)
raise ModelNotFoundException(f"Model not found - {model_key}") raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}')
# TODO: path # TODO: path
# TODO: is it accurate to use path as id # TODO: is it accurate to use path as id
@ -508,7 +513,13 @@ class ModelManager(object):
_cache=self.cache, _cache=self.cache,
) )
def _get_model_path(self, model_config: ModelConfigBase, submodel_type: SubModelType = None) -> (Path, bool): def _get_model_path(
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
) -> (Path, bool):
"""Extract a model's filesystem path from its config.
:return: The fully qualified Path of the module (or submodule).
"""
model_path = model_config.path model_path = model_config.path
is_submodel_override = False is_submodel_override = False
@ -523,6 +534,7 @@ class ModelManager(object):
return model_path, is_submodel_override return model_path, is_submodel_override
def _get_model_config(self, base_model, model_name, model_type) -> ModelConfigBase: def _get_model_config(self, base_model, model_name, model_type) -> ModelConfigBase:
"""Get a model's config object."""
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
try: try:
model_config = self.models[model_key] model_config = self.models[model_key]
@ -531,12 +543,18 @@ class ModelManager(object):
return model_config return model_config
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]: def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
"""Get the concrete implementation class for a specific model type."""
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
return model_class return model_class
def _instantiate( def _instantiate(
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel_type: SubModelType = None self,
model_name: str,
base_model: BaseModelType,
model_type: ModelType,
submodel_type: Optional[SubModelType] = None,
) -> ModelBase: ) -> ModelBase:
"""Make a new instance of this model, without loading it."""
model_config = self._get_model_config(base_model, model_name, model_type) model_config = self._get_model_config(base_model, model_name, model_type)
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
# FIXME: do non-overriden submodels get the right class? # FIXME: do non-overriden submodels get the right class?

View File

@ -1,9 +1,14 @@
import os import os
import torch
import safetensors
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Optional, Union, Literal from typing import Optional
import safetensors
import torch
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
from invokeai.app.services.config import InvokeAIAppConfig
from .base import ( from .base import (
ModelBase, ModelBase,
ModelConfigBase, ModelConfigBase,
@ -18,9 +23,6 @@ from .base import (
InvalidModelException, InvalidModelException,
ModelNotFoundException, ModelNotFoundException,
) )
from invokeai.app.services.config import InvokeAIAppConfig
from diffusers.utils import is_safetensors_available
from omegaconf import OmegaConf
class VaeModelFormat(str, Enum): class VaeModelFormat(str, Enum):
@ -80,7 +82,7 @@ class VaeModel(ModelBase):
@classmethod @classmethod
def detect_format(cls, path: str): def detect_format(cls, path: str):
if not os.path.exists(path): if not os.path.exists(path):
raise ModelNotFoundException() raise ModelNotFoundException(f"Does not exist as local file: {path}")
if os.path.isdir(path): if os.path.isdir(path):
if os.path.exists(os.path.join(path, "config.json")): if os.path.exists(os.path.join(path, "config.json")):