tidy(app): remove unnecessary changes in invocation_context

- Any mypy issues are a misconfiguration of mypy
- Use simple conditionals instead of ternaries
- Consistent & standards-compliant docstring formatting
- Use `dict` instead of `typing.Dict`
This commit is contained in:
psychedelicious 2024-06-03 08:35:23 +10:00
parent 2276f327e5
commit 132bbf330a

View File

@ -1,10 +1,10 @@
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
from typing import TYPE_CHECKING, Callable, Optional, Union
import torch
from PIL.Image import Image
from pydantic.networks import AnyHttpUrl
from torch import Tensor
from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
@ -268,7 +268,7 @@ class ImagesInterface(InvocationContextInterface):
class TensorsInterface(InvocationContextInterface):
def save(self, tensor: torch.Tensor) -> str:
def save(self, tensor: Tensor) -> str:
"""Saves a tensor, returning its name.
Args:
@ -281,7 +281,7 @@ class TensorsInterface(InvocationContextInterface):
name = self._services.tensors.save(obj=tensor)
return name
def load(self, name: str) -> torch.Tensor:
def load(self, name: str) -> Tensor:
"""Loads a tensor by name.
Args:
@ -333,13 +333,9 @@ class ModelsInterface(InvocationContextInterface):
True if the model exists, False if not.
"""
if isinstance(identifier, str):
# For some reason, Mypy is not getting the type annotations for many of
# the model manager service calls and raises a "returning Any in typed
# context" error. Hence the extra typing hints here and below.
result: bool = self._services.model_manager.store.exists(identifier)
return self._services.model_manager.store.exists(identifier)
else:
result = self._services.model_manager.store.exists(identifier.key)
return result
return self._services.model_manager.store.exists(identifier.key)
def load(
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
@ -353,6 +349,7 @@ class ModelsInterface(InvocationContextInterface):
Returns:
An object representing the loaded model.
"""
# The model manager emits events as it loads the model. It needs the context data to build
# the event payloads.
@ -379,6 +376,7 @@ class ModelsInterface(InvocationContextInterface):
Returns:
An object representing the loaded model.
"""
configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
if len(configs) == 0:
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
@ -398,10 +396,9 @@ class ModelsInterface(InvocationContextInterface):
The model's config.
"""
if isinstance(identifier, str):
result: AnyModelConfig = self._services.model_manager.store.get_model(identifier)
return self._services.model_manager.store.get_model(identifier)
else:
result = self._services.model_manager.store.get_model(identifier.key)
return result
return self._services.model_manager.store.get_model(identifier.key)
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
"""Search for models by path.
@ -412,8 +409,7 @@ class ModelsInterface(InvocationContextInterface):
Returns:
A list of models that match the path.
"""
result: list[AnyModelConfig] = self._services.model_manager.store.search_by_path(path)
return result
return self._services.model_manager.store.search_by_path(path)
def search_by_attrs(
self,
@ -433,13 +429,13 @@ class ModelsInterface(InvocationContextInterface):
Returns:
A list of models that match the attributes.
"""
result: list[AnyModelConfig] = self._services.model_manager.store.search_by_attr(
return self._services.model_manager.store.search_by_attr(
model_name=name,
base_model=base,
model_type=type,
model_format=format,
)
return result
def download_and_cache_model(
self,
@ -453,24 +449,18 @@ class ModelsInterface(InvocationContextInterface):
installed, the cached path will be returned. Otherwise it will be downloaded.
Args:
source: A model path, URL or repo_id.
Result:
Path to the downloaded model
source: A model path, URL or repo_id.
May Raise:
HTTPError
TimeoutError
Returns:
Path to the downloaded model
"""
installer = self._services.model_manager.install
path: Path = installer.download_and_cache_model(
source=source,
)
return path
return self._services.model_manager.install.download_and_cache_model(source=source)
def load_and_cache_model(
self,
source: Path | str | AnyHttpUrl,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
loader: Optional[Callable[[Path], dict[str, Tensor]]] = None,
) -> LoadedModel:
"""
Download, cache, and load the model file located at the indicated URL.
@ -485,24 +475,22 @@ class ModelsInterface(InvocationContextInterface):
Be aware that the LoadedModel object will have a `config` attribute of None.
Args:
source: A model Path, URL, or repoid.
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
source: A model Path, URL, or repoid.
loader: A Callable that expects a Path and returns a dict[str|int, Any]
Returns:
A LoadedModel object.
A LoadedModel object.
"""
result: LoadedModel = (
self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
if isinstance(source, Path)
else self._services.model_manager.load_model_from_url(source=source, loader=loader)
)
return result
if isinstance(source, Path):
return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
else:
return self._services.model_manager.load_model_from_url(source=source, loader=loader)
class ConfigInterface(InvocationContextInterface):
def get(self) -> InvokeAIAppConfig:
"""
Gets the app's config.
"""Gets the app's config.
Returns:
The app's config.