mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(app): remove unnecessary changes in invocation_context
- Any mypy issues are a misconfiguration of mypy - Use simple conditionals instead of ternaries - Consistent & standards-compliant docstring formatting - Use `dict` instead of `typing.Dict`
This commit is contained in:
parent
2276f327e5
commit
132bbf330a
@ -1,10 +1,10 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||||
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
|
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
|
||||||
@ -268,7 +268,7 @@ class ImagesInterface(InvocationContextInterface):
|
|||||||
|
|
||||||
|
|
||||||
class TensorsInterface(InvocationContextInterface):
|
class TensorsInterface(InvocationContextInterface):
|
||||||
def save(self, tensor: torch.Tensor) -> str:
|
def save(self, tensor: Tensor) -> str:
|
||||||
"""Saves a tensor, returning its name.
|
"""Saves a tensor, returning its name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -281,7 +281,7 @@ class TensorsInterface(InvocationContextInterface):
|
|||||||
name = self._services.tensors.save(obj=tensor)
|
name = self._services.tensors.save(obj=tensor)
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def load(self, name: str) -> torch.Tensor:
|
def load(self, name: str) -> Tensor:
|
||||||
"""Loads a tensor by name.
|
"""Loads a tensor by name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -333,13 +333,9 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
True if the model exists, False if not.
|
True if the model exists, False if not.
|
||||||
"""
|
"""
|
||||||
if isinstance(identifier, str):
|
if isinstance(identifier, str):
|
||||||
# For some reason, Mypy is not getting the type annotations for many of
|
return self._services.model_manager.store.exists(identifier)
|
||||||
# the model manager service calls and raises a "returning Any in typed
|
|
||||||
# context" error. Hence the extra typing hints here and below.
|
|
||||||
result: bool = self._services.model_manager.store.exists(identifier)
|
|
||||||
else:
|
else:
|
||||||
result = self._services.model_manager.store.exists(identifier.key)
|
return self._services.model_manager.store.exists(identifier.key)
|
||||||
return result
|
|
||||||
|
|
||||||
def load(
|
def load(
|
||||||
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
|
self, identifier: Union[str, "ModelIdentifierField"], submodel_type: Optional[SubModelType] = None
|
||||||
@ -353,6 +349,7 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
Returns:
|
Returns:
|
||||||
An object representing the loaded model.
|
An object representing the loaded model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# The model manager emits events as it loads the model. It needs the context data to build
|
# The model manager emits events as it loads the model. It needs the context data to build
|
||||||
# the event payloads.
|
# the event payloads.
|
||||||
|
|
||||||
@ -379,6 +376,7 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
Returns:
|
Returns:
|
||||||
An object representing the loaded model.
|
An object representing the loaded model.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
|
configs = self._services.model_manager.store.search_by_attr(model_name=name, base_model=base, model_type=type)
|
||||||
if len(configs) == 0:
|
if len(configs) == 0:
|
||||||
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
|
raise UnknownModelException(f"No model found with name {name}, base {base}, and type {type}")
|
||||||
@ -398,10 +396,9 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
The model's config.
|
The model's config.
|
||||||
"""
|
"""
|
||||||
if isinstance(identifier, str):
|
if isinstance(identifier, str):
|
||||||
result: AnyModelConfig = self._services.model_manager.store.get_model(identifier)
|
return self._services.model_manager.store.get_model(identifier)
|
||||||
else:
|
else:
|
||||||
result = self._services.model_manager.store.get_model(identifier.key)
|
return self._services.model_manager.store.get_model(identifier.key)
|
||||||
return result
|
|
||||||
|
|
||||||
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
||||||
"""Search for models by path.
|
"""Search for models by path.
|
||||||
@ -412,8 +409,7 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
Returns:
|
Returns:
|
||||||
A list of models that match the path.
|
A list of models that match the path.
|
||||||
"""
|
"""
|
||||||
result: list[AnyModelConfig] = self._services.model_manager.store.search_by_path(path)
|
return self._services.model_manager.store.search_by_path(path)
|
||||||
return result
|
|
||||||
|
|
||||||
def search_by_attrs(
|
def search_by_attrs(
|
||||||
self,
|
self,
|
||||||
@ -433,13 +429,13 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
Returns:
|
Returns:
|
||||||
A list of models that match the attributes.
|
A list of models that match the attributes.
|
||||||
"""
|
"""
|
||||||
result: list[AnyModelConfig] = self._services.model_manager.store.search_by_attr(
|
|
||||||
|
return self._services.model_manager.store.search_by_attr(
|
||||||
model_name=name,
|
model_name=name,
|
||||||
base_model=base,
|
base_model=base,
|
||||||
model_type=type,
|
model_type=type,
|
||||||
model_format=format,
|
model_format=format,
|
||||||
)
|
)
|
||||||
return result
|
|
||||||
|
|
||||||
def download_and_cache_model(
|
def download_and_cache_model(
|
||||||
self,
|
self,
|
||||||
@ -453,24 +449,18 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
installed, the cached path will be returned. Otherwise it will be downloaded.
|
installed, the cached path will be returned. Otherwise it will be downloaded.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source: A model path, URL or repo_id.
|
source: A model path, URL or repo_id.
|
||||||
Result:
|
|
||||||
Path to the downloaded model
|
|
||||||
|
|
||||||
May Raise:
|
Returns:
|
||||||
HTTPError
|
Path to the downloaded model
|
||||||
TimeoutError
|
|
||||||
"""
|
"""
|
||||||
installer = self._services.model_manager.install
|
|
||||||
path: Path = installer.download_and_cache_model(
|
return self._services.model_manager.install.download_and_cache_model(source=source)
|
||||||
source=source,
|
|
||||||
)
|
|
||||||
return path
|
|
||||||
|
|
||||||
def load_and_cache_model(
|
def load_and_cache_model(
|
||||||
self,
|
self,
|
||||||
source: Path | str | AnyHttpUrl,
|
source: Path | str | AnyHttpUrl,
|
||||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
loader: Optional[Callable[[Path], dict[str, Tensor]]] = None,
|
||||||
) -> LoadedModel:
|
) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
Download, cache, and load the model file located at the indicated URL.
|
Download, cache, and load the model file located at the indicated URL.
|
||||||
@ -485,24 +475,22 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
Be aware that the LoadedModel object will have a `config` attribute of None.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source: A model Path, URL, or repoid.
|
source: A model Path, URL, or repoid.
|
||||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A LoadedModel object.
|
A LoadedModel object.
|
||||||
"""
|
"""
|
||||||
result: LoadedModel = (
|
|
||||||
self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
|
if isinstance(source, Path):
|
||||||
if isinstance(source, Path)
|
return self._services.model_manager.load.load_model_from_path(model_path=source, loader=loader)
|
||||||
else self._services.model_manager.load_model_from_url(source=source, loader=loader)
|
else:
|
||||||
)
|
return self._services.model_manager.load_model_from_url(source=source, loader=loader)
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigInterface(InvocationContextInterface):
|
class ConfigInterface(InvocationContextInterface):
|
||||||
def get(self) -> InvokeAIAppConfig:
|
def get(self) -> InvokeAIAppConfig:
|
||||||
"""
|
"""Gets the app's config.
|
||||||
Gets the app's config.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The app's config.
|
The app's config.
|
||||||
|
Loading…
Reference in New Issue
Block a user