mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(app): remove unnecessary changes in invocation_context
- Any mypy issues are a misconfiguration of mypy - Use simple conditionals instead of ternaries - Consistent & standards-compliant docstring formatting - Use `dict` instead of `typing.Dict`
This commit is contained in:
parent
2276f327e5
commit
132bbf330a
@ -1,10 +1,10 @@
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
|
||||
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
from PIL.Image import Image
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from torch import Tensor
|
||||
|
||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
|
||||
@ -268,7 +268,7 @@ class ImagesInterface(InvocationContextInterface):
|
||||
|
||||
|
||||
class TensorsInterface(InvocationContextInterface):
|
||||
def save(self, tensor: torch.Tensor) -> str:
|
||||
def save(self, tensor: Tensor) -> str:
|
||||
"""Saves a tensor, returning its name.
|
||||
|
||||
Args:
|
||||
@ -281,7 +281,7 @@ class TensorsInterface(InvocationContextInterface):
|
||||
name = self._services.tensors.save(obj=tensor)
|
||||
return name
|
||||
|
||||
def load(self, name: str) -> torch.Tensor:
|
||||
def load(self, name: str) -> Tensor:
|
||||
"""Loads a tensor by name.
|
||||
|
||||
Args:
|
||||
@ -333,13 +333,9 @@ class ModelsInterface(InvocationContextInterface):
|
||||
True if the model exists, False if not.
|
||||
"""
|
||||
if isinstance(identifier, str):
|
||||
# For some reason, Mypy is not getting the type annotations for many of
|
||||
# the model manager service calls and raises a "returning Any in typed
|
||||
# context" error. Hence the extra typing hints here and below.
|
||||
result: bool = self._services.model_manager.store.exists(identifier)
|
||||
return self._services.model_manager.store.exists(identifier)
|
||||
else:
|
||||
result = self._services.model_manager.store.exists(identifier.key)
|
||||
return result
|
||||
return self._services.model_manager.store.exists(identifier.key)
|
||||
|
||||
def load(
|
||||
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
|
||||
@ -353,6 +349,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
An object representing the loaded model.
|
||||
"""
|
||||
|
||||
# The model manager emits events as it loads the model. It needs the context data to build
|
||||
# the event payloads.
|
||||
|
||||
@ -379,6 +376,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
An object representing the loaded model.
|
||||
"""
|
||||
|
||||
configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
|
||||
if len(configs) == 0:
|
||||
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
|
||||
@ -398,10 +396,9 @@ class ModelsInterface(InvocationContextInterface):
|
||||
The model's config.
|
||||
"""
|
||||
if isinstance(identifier, str):
|
||||
result: AnyModelConfig = self._services.model_manager.store.get_model(identifier)
|
||||
return self._services.model_manager.store.get_model(identifier)
|
||||
else:
|
||||
result = self._services.model_manager.store.get_model(identifier.key)
|
||||
return result
|
||||
return self._services.model_manager.store.get_model(identifier.key)
|
||||
|
||||
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
||||
"""Search for models by path.
|
||||
@ -412,8 +409,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
A list of models that match the path.
|
||||
"""
|
||||
result: list[AnyModelConfig] = self._services.model_manager.store.search_by_path(path)
|
||||
return result
|
||||
return self._services.model_manager.store.search_by_path(path)
|
||||
|
||||
def search_by_attrs(
|
||||
self,
|
||||
@ -433,13 +429,13 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Returns:
|
||||
A list of models that match the attributes.
|
||||
"""
|
||||
result: list[AnyModelConfig] = self._services.model_manager.store.search_by_attr(
|
||||
|
||||
return self._services.model_manager.store.search_by_attr(
|
||||
model_name=name,
|
||||
base_model=base,
|
||||
model_type=type,
|
||||
model_format=format,
|
||||
)
|
||||
return result
|
||||
|
||||
def download_and_cache_model(
|
||||
self,
|
||||
@ -453,24 +449,18 @@ class ModelsInterface(InvocationContextInterface):
|
||||
installed, the cached path will be returned. Otherwise it will be downloaded.
|
||||
|
||||
Args:
|
||||
source: A model path, URL or repo_id.
|
||||
Result:
|
||||
Path to the downloaded model
|
||||
source: A model path, URL or repo_id.
|
||||
|
||||
May Raise:
|
||||
HTTPError
|
||||
TimeoutError
|
||||
Returns:
|
||||
Path to the downloaded model
|
||||
"""
|
||||
installer = self._services.model_manager.install
|
||||
path: Path = installer.download_and_cache_model(
|
||||
source=source,
|
||||
)
|
||||
return path
|
||||
|
||||
return self._services.model_manager.install.download_and_cache_model(source=source)
|
||||
|
||||
def load_and_cache_model(
|
||||
self,
|
||||
source: Path | str | AnyHttpUrl,
|
||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
||||
loader: Optional[Callable[[Path], dict[str, Tensor]]] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
Download, cache, and load the model file located at the indicated URL.
|
||||
@ -485,24 +475,22 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
||||
|
||||
Args:
|
||||
source: A model Path, URL, or repoid.
|
||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||
source: A model Path, URL, or repoid.
|
||||
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
A LoadedModel object.
|
||||
"""
|
||||
result: LoadedModel = (
|
||||
self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
|
||||
if isinstance(source, Path)
|
||||
else self._services.model_manager.load_model_from_url(source=source, loader=loader)
|
||||
)
|
||||
return result
|
||||
|
||||
if isinstance(source, Path):
|
||||
return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
|
||||
else:
|
||||
return self._services.model_manager.load_model_from_url(source=source, loader=loader)
|
||||
|
||||
|
||||
class ConfigInterface(InvocationContextInterface):
|
||||
def get(self) -> InvokeAIAppConfig:
|
||||
"""
|
||||
Gets the app's config.
|
||||
"""Gets the app's config.
|
||||
|
||||
Returns:
|
||||
The app's config.
|
||||
|
Loading…
Reference in New Issue
Block a user