tidy(app): remove unnecessary changes in invocation_context

- Any mypy issues are a misconfiguration of mypy
- Use simple conditionals instead of ternaries
- Consistent & standards-compliant docstring formatting
- Use `dict` instead of `typing.Dict`
This commit is contained in:
psychedelicious 2024-06-03 08:35:23 +10:00
parent 2276f327e5
commit 132bbf330a

View File

@ -1,10 +1,10 @@
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union from typing import TYPE_CHECKING, Callable, Optional, Union
import torch
from PIL.Image import Image from PIL.Image import Image
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from torch import Tensor
from invokeai.app.invocations.constants import IMAGE_MODES from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
@ -268,7 +268,7 @@ class ImagesInterface(InvocationContextInterface):
class TensorsInterface(InvocationContextInterface): class TensorsInterface(InvocationContextInterface):
def save(self, tensor: torch.Tensor) -> str: def save(self, tensor: Tensor) -> str:
"""Saves a tensor, returning its name. """Saves a tensor, returning its name.
Args: Args:
@ -281,7 +281,7 @@ class TensorsInterface(InvocationContextInterface):
name = self._services.tensors.save(obj=tensor) name = self._services.tensors.save(obj=tensor)
return name return name
def load(self, name: str) -> torch.Tensor: def load(self, name: str) -> Tensor:
"""Loads a tensor by name. """Loads a tensor by name.
Args: Args:
@ -333,13 +333,9 @@ class ModelsInterface(InvocationContextInterface):
True if the model exists, False if not. True if the model exists, False if not.
""" """
if isinstance(identifier, str): if isinstance(identifier, str):
# For some reason, Mypy is not getting the type annotations for many of return self._services.model_manager.store.exists(identifier)
# the model manager service calls and raises a "returning Any in typed
# context" error. Hence the extra typing hints here and below.
result: bool = self._services.model_manager.store.exists(identifier)
else: else:
result = self._services.model_manager.store.exists(identifier.key) return self._services.model_manager.store.exists(identifier.key)
return result
def load( def load(
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
@ -353,6 +349,7 @@ class ModelsInterface(InvocationContextInterface):
Returns: Returns:
An object representing the loaded model. An object representing the loaded model.
""" """
# The model manager emits events as it loads the model. It needs the context data to build # The model manager emits events as it loads the model. It needs the context data to build
# the event payloads. # the event payloads.
@ -379,6 +376,7 @@ class ModelsInterface(InvocationContextInterface):
Returns: Returns:
An object representing the loaded model. An object representing the loaded model.
""" """
configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type) configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
if len(configs) == 0: if len(configs) == 0:
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}") raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
@ -398,10 +396,9 @@ class ModelsInterface(InvocationContextInterface):
The model's config. The model's config.
""" """
if isinstance(identifier, str): if isinstance(identifier, str):
result: AnyModelConfig = self._services.model_manager.store.get_model(identifier) return self._services.model_manager.store.get_model(identifier)
else: else:
result = self._services.model_manager.store.get_model(identifier.key) return self._services.model_manager.store.get_model(identifier.key)
return result
def search_by_path(self, path: Path) -> list[AnyModelConfig]: def search_by_path(self, path: Path) -> list[AnyModelConfig]:
"""Search for models by path. """Search for models by path.
@ -412,8 +409,7 @@ class ModelsInterface(InvocationContextInterface):
Returns: Returns:
A list of models that match the path. A list of models that match the path.
""" """
result: list[AnyModelConfig] = self._services.model_manager.store.search_by_path(path) return self._services.model_manager.store.search_by_path(path)
return result
def search_by_attrs( def search_by_attrs(
self, self,
@ -433,13 +429,13 @@ class ModelsInterface(InvocationContextInterface):
Returns: Returns:
A list of models that match the attributes. A list of models that match the attributes.
""" """
result: list[AnyModelConfig] = self._services.model_manager.store.search_by_attr(
return self._services.model_manager.store.search_by_attr(
model_name=name, model_name=name,
base_model=base, base_model=base,
model_type=type, model_type=type,
model_format=format, model_format=format,
) )
return result
def download_and_cache_model( def download_and_cache_model(
self, self,
@ -453,24 +449,18 @@ class ModelsInterface(InvocationContextInterface):
installed, the cached path will be returned. Otherwise it will be downloaded. installed, the cached path will be returned. Otherwise it will be downloaded.
Args: Args:
source: A model path, URL or repo_id. source: A model path, URL or repo_id.
Result:
Path to the downloaded model
May Raise: Returns:
HTTPError Path to the downloaded model
TimeoutError
""" """
installer = self._services.model_manager.install
path: Path = installer.download_and_cache_model( return self._services.model_manager.install.download_and_cache_model(source=source)
source=source,
)
return path
def load_and_cache_model( def load_and_cache_model(
self, self,
source: Path | str | AnyHttpUrl, source: Path | str | AnyHttpUrl,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, loader: Optional[Callable[[Path], dict[str, Tensor]]] = None,
) -> LoadedModel: ) -> LoadedModel:
""" """
Download, cache, and load the model file located at the indicated URL. Download, cache, and load the model file located at the indicated URL.
@ -485,24 +475,22 @@ class ModelsInterface(InvocationContextInterface):
Be aware that the LoadedModel object will have a `config` attribute of None. Be aware that the LoadedModel object will have a `config` attribute of None.
Args: Args:
source: A model Path, URL, or repoid. source: A model Path, URL, or repoid.
loader: A Callable that expects a Path and returns a Dict[str|int, Any] loader: A Callable that expects a Path and returns a dict[str|int, Any]
Returns: Returns:
A LoadedModel object. A LoadedModel object.
""" """
result: LoadedModel = (
self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader) if isinstance(source, Path):
if isinstance(source, Path) return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
else self._services.model_manager.load_model_from_url(source=source, loader=loader) else:
) return self._services.model_manager.load_model_from_url(source=source, loader=loader)
return result
class ConfigInterface(InvocationContextInterface): class ConfigInterface(InvocationContextInterface):
def get(self) -> InvokeAIAppConfig: def get(self) -> InvokeAIAppConfig:
""" """Gets the app's config.
Gets the app's config.
Returns: Returns:
The app's config. The app's config.