mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
doc(model_manager): docstrings
This commit is contained in:
parent
e3519052ae
commit
bacdf985f1
@ -385,6 +385,11 @@ class ModelManager(object):
|
|||||||
def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool:
|
def model_exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType, *, rescan=False) -> bool:
|
||||||
"""
|
"""
|
||||||
Given a model name, returns True if it is a valid identifier.
|
Given a model name, returns True if it is a valid identifier.
|
||||||
|
|
||||||
|
:param model_name: symbolic name of the model in models.yaml
|
||||||
|
:param model_type: ModelType enum indicating the type of model to return
|
||||||
|
:param base_model: BaseModelType enum indicating the base model used by this model
|
||||||
|
:param rescan: if True, scan_models_directory
|
||||||
"""
|
"""
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
exists = model_key in self.models
|
exists = model_key in self.models
|
||||||
@ -470,7 +475,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
else:
|
else:
|
||||||
self.models.pop(model_key, None)
|
self.models.pop(model_key, None)
|
||||||
raise ModelNotFoundException(f"Model not found - {model_key}")
|
raise ModelNotFoundException(f'Files for model "{model_key}" not found at {model_path}')
|
||||||
|
|
||||||
# TODO: path
|
# TODO: path
|
||||||
# TODO: is it accurate to use path as id
|
# TODO: is it accurate to use path as id
|
||||||
@ -508,7 +513,13 @@ class ModelManager(object):
|
|||||||
_cache=self.cache,
|
_cache=self.cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_model_path(self, model_config: ModelConfigBase, submodel_type: SubModelType = None) -> (Path, bool):
|
def _get_model_path(
|
||||||
|
self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None
|
||||||
|
) -> (Path, bool):
|
||||||
|
"""Extract a model's filesystem path from its config.
|
||||||
|
|
||||||
|
:return: The fully qualified Path of the module (or submodule).
|
||||||
|
"""
|
||||||
model_path = model_config.path
|
model_path = model_config.path
|
||||||
is_submodel_override = False
|
is_submodel_override = False
|
||||||
|
|
||||||
@ -523,6 +534,7 @@ class ModelManager(object):
|
|||||||
return model_path, is_submodel_override
|
return model_path, is_submodel_override
|
||||||
|
|
||||||
def _get_model_config(self, base_model, model_name, model_type) -> ModelConfigBase:
|
def _get_model_config(self, base_model, model_name, model_type) -> ModelConfigBase:
|
||||||
|
"""Get a model's config object."""
|
||||||
model_key = self.create_key(model_name, base_model, model_type)
|
model_key = self.create_key(model_name, base_model, model_type)
|
||||||
try:
|
try:
|
||||||
model_config = self.models[model_key]
|
model_config = self.models[model_key]
|
||||||
@ -531,12 +543,18 @@ class ModelManager(object):
|
|||||||
return model_config
|
return model_config
|
||||||
|
|
||||||
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
|
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
|
||||||
|
"""Get the concrete implementation class for a specific model type."""
|
||||||
model_class = MODEL_CLASSES[base_model][model_type]
|
model_class = MODEL_CLASSES[base_model][model_type]
|
||||||
return model_class
|
return model_class
|
||||||
|
|
||||||
def _instantiate(
|
def _instantiate(
|
||||||
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel_type: SubModelType = None
|
self,
|
||||||
|
model_name: str,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
model_type: ModelType,
|
||||||
|
submodel_type: Optional[SubModelType] = None,
|
||||||
) -> ModelBase:
|
) -> ModelBase:
|
||||||
|
"""Make a new instance of this model, without loading it."""
|
||||||
model_config = self._get_model_config(base_model, model_name, model_type)
|
model_config = self._get_model_config(base_model, model_name, model_type)
|
||||||
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
|
||||||
# FIXME: do non-overriden submodels get the right class?
|
# FIXME: do non-overriden submodels get the right class?
|
||||||
|
@ -1,9 +1,14 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
import safetensors
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Union, Literal
|
from typing import Optional
|
||||||
|
|
||||||
|
import safetensors
|
||||||
|
import torch
|
||||||
|
from diffusers.utils import is_safetensors_available
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from .base import (
|
from .base import (
|
||||||
ModelBase,
|
ModelBase,
|
||||||
ModelConfigBase,
|
ModelConfigBase,
|
||||||
@ -18,9 +23,6 @@ from .base import (
|
|||||||
InvalidModelException,
|
InvalidModelException,
|
||||||
ModelNotFoundException,
|
ModelNotFoundException,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
|
||||||
from diffusers.utils import is_safetensors_available
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
|
|
||||||
class VaeModelFormat(str, Enum):
|
class VaeModelFormat(str, Enum):
|
||||||
@ -80,7 +82,7 @@ class VaeModel(ModelBase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def detect_format(cls, path: str):
|
def detect_format(cls, path: str):
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
raise ModelNotFoundException()
|
raise ModelNotFoundException(f"Does not exist as local file: {path}")
|
||||||
|
|
||||||
if os.path.isdir(path):
|
if os.path.isdir(path):
|
||||||
if os.path.exists(os.path.join(path, "config.json")):
|
if os.path.exists(os.path.join(path, "config.json")):
|
||||||
|
Loading…
Reference in New Issue
Block a user