diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index c932e66989..08ca207118 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,10 +1,10 @@ from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Callable, Dict, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Union -import torch from PIL.Image import Image from pydantic.networks import AnyHttpUrl +from torch import Tensor from invokeai.app.invocations.constants import IMAGE_MODES from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata @@ -268,7 +268,7 @@ class ImagesInterface(InvocationContextInterface): class TensorsInterface(InvocationContextInterface): - def save(self, tensor: torch.Tensor) -> str: + def save(self, tensor: Tensor) -> str: """Saves a tensor, returning its name. Args: @@ -281,7 +281,7 @@ class TensorsInterface(InvocationContextInterface): name = self._services.tensors.save(obj=tensor) return name - def load(self, name: str) -> torch.Tensor: + def load(self, name: str) -> Tensor: """Loads a tensor by name. Args: @@ -333,13 +333,9 @@ class ModelsInterface(InvocationContextInterface): True if the model exists, False if not. """ if isinstance(identifier, str): - # For some reason, Mypy is not getting the type annotations for many of - # the model manager service calls and raises a "returning Any in typed - # context" error. Hence the extra typing hints here and below. - result: bool = self._services.model_manager.store.exists(identifier) + return self._services.model_manager.store.exists(identifier) else: - result = self._services.model_manager.store.exists(identifier.key) - return result + return self._services.model_manager.store.exists(identifier.key) def load( self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None @@ -353,6 +349,7 @@ class ModelsInterface(InvocationContextInterface): Returns: An object representing the loaded model. """ + # The model manager emits events as it loads the model. It needs the context data to build # the event payloads. @@ -379,6 +376,7 @@ class ModelsInterface(InvocationContextInterface): Returns: An object representing the loaded model. """ + configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type) if len(configs) == 0: raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}") @@ -398,10 +396,9 @@ class ModelsInterface(InvocationContextInterface): The model's config. """ if isinstance(identifier, str): - result: AnyModelConfig = self._services.model_manager.store.get_model(identifier) + return self._services.model_manager.store.get_model(identifier) else: - result = self._services.model_manager.store.get_model(identifier.key) - return result + return self._services.model_manager.store.get_model(identifier.key) def search_by_path(self, path: Path) -> list[AnyModelConfig]: """Search for models by path. @@ -412,8 +409,7 @@ class ModelsInterface(InvocationContextInterface): Returns: A list of models that match the path. """ - result: list[AnyModelConfig] = self._services.model_manager.store.search_by_path(path) - return result + return self._services.model_manager.store.search_by_path(path) def search_by_attrs( self, @@ -433,13 +429,13 @@ class ModelsInterface(InvocationContextInterface): Returns: A list of models that match the attributes. """ - result: list[AnyModelConfig] = self._services.model_manager.store.search_by_attr( + + return self._services.model_manager.store.search_by_attr( model_name=name, base_model=base, model_type=type, model_format=format, ) - return result def download_and_cache_model( self, @@ -453,24 +449,18 @@ class ModelsInterface(InvocationContextInterface): installed, the cached path will be returned. Otherwise it will be downloaded. Args: - source: A model path, URL or repo_id. - Result: - Path to the downloaded model + source: A model path, URL or repo_id. - May Raise: - HTTPError - TimeoutError + Returns: + Path to the downloaded model """ - installer = self._services.model_manager.install - path: Path = installer.download_and_cache_model( - source=source, - ) - return path + + return self._services.model_manager.install.download_and_cache_model(source=source) def load_and_cache_model( self, source: Path | str | AnyHttpUrl, - loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, + loader: Optional[Callable[[Path], dict[str, Tensor]]] = None, ) -> LoadedModel: """ Download, cache, and load the model file located at the indicated URL. @@ -485,24 +475,22 @@ class ModelsInterface(InvocationContextInterface): Be aware that the LoadedModel object will have a `config` attribute of None. Args: - source: A model Path, URL, or repoid. - loader: A Callable that expects a Path and returns a Dict[str|int, Any] + source: A model Path, URL, or repoid. + loader: A Callable that expects a Path and returns a dict[str|int, Any] Returns: - A LoadedModel object. + A LoadedModel object. """ - result: LoadedModel = ( - self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader) - if isinstance(source, Path) - else self._services.model_manager.load_model_from_url(source=source, loader=loader) - ) - return result + + if isinstance(source, Path): + return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader) + else: + return self._services.model_manager.load_model_from_url(source=source, loader=loader) class ConfigInterface(InvocationContextInterface): def get(self) -> InvokeAIAppConfig: - """ - Gets the app's config. + """Gets the app's config. Returns: The app's config.