mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
f31e4205aa
No reason to not have the whole thing in there.
474 lines
17 KiB
Python
474 lines
17 KiB
Python
import threading
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Callable, Optional
|
|
|
|
from PIL.Image import Image
|
|
from torch import Tensor
|
|
|
|
from invokeai.app.invocations.constants import IMAGE_MODES
|
|
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
|
|
from invokeai.app.services.boards.boards_common import BoardDTO
|
|
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
|
from invokeai.app.services.images.images_common import ImageDTO
|
|
from invokeai.app.services.invocation_services import InvocationServices
|
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
|
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
|
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
|
from invokeai.backend.model_manager.metadata.metadata_base import AnyModelRepoMetadata
|
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
|
|
|
if TYPE_CHECKING:
|
|
from invokeai.app.invocations.baseinvocation import BaseInvocation
|
|
from invokeai.app.services.session_queue.session_queue_common import SessionQueueItem
|
|
|
|
"""
|
|
The InvocationContext provides access to various services and data about the current invocation.
|
|
|
|
We do not provide the invocation services directly, as their methods are both dangerous and
|
|
inconvenient to use.
|
|
|
|
For example:
|
|
- The `images` service allows nodes to delete or unsafely modify existing images.
|
|
- The `configuration` service allows nodes to change the app's config at runtime.
|
|
- The `events` service allows nodes to emit arbitrary events.
|
|
|
|
Wrapping these services provides a simpler and safer interface for nodes to use.
|
|
|
|
When a node executes, a fresh `InvocationContext` is built for it, ensuring nodes cannot interfere
|
|
with each other.
|
|
|
|
Many of the wrappers have the same signature as the methods they wrap. This allows us to write
|
|
user-facing docstrings and not need to go and update the internal services to match.
|
|
|
|
Note: The docstrings are in weird places, but that's where they must be to get IDEs to see them.
|
|
"""
|
|
|
|
|
|
@dataclass
|
|
class InvocationContextData:
|
|
queue_item: "SessionQueueItem"
|
|
"""The queue item that is being executed."""
|
|
invocation: "BaseInvocation"
|
|
"""The invocation that is being executed."""
|
|
source_invocation_id: str
|
|
"""The ID of the invocation from which the currently executing invocation was prepared."""
|
|
|
|
|
|
class InvocationContextInterface:
|
|
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None:
|
|
self._services = services
|
|
self._context_data = context_data
|
|
|
|
|
|
class BoardsInterface(InvocationContextInterface):
|
|
def create(self, board_name: str) -> BoardDTO:
|
|
"""
|
|
Creates a board.
|
|
|
|
:param board_name: The name of the board to create.
|
|
"""
|
|
return self._services.boards.create(board_name)
|
|
|
|
def get_dto(self, board_id: str) -> BoardDTO:
|
|
"""
|
|
Gets a board DTO.
|
|
|
|
:param board_id: The ID of the board to get.
|
|
"""
|
|
return self._services.boards.get_dto(board_id)
|
|
|
|
def get_all(self) -> list[BoardDTO]:
|
|
"""
|
|
Gets all boards.
|
|
"""
|
|
return self._services.boards.get_all()
|
|
|
|
def add_image_to_board(self, board_id: str, image_name: str) -> None:
|
|
"""
|
|
Adds an image to a board.
|
|
|
|
:param board_id: The ID of the board to add the image to.
|
|
:param image_name: The name of the image to add to the board.
|
|
"""
|
|
return self._services.board_images.add_image_to_board(board_id, image_name)
|
|
|
|
def get_all_image_names_for_board(self, board_id: str) -> list[str]:
|
|
"""
|
|
Gets all image names for a board.
|
|
|
|
:param board_id: The ID of the board to get the image names for.
|
|
"""
|
|
return self._services.board_images.get_all_board_image_names_for_board(board_id)
|
|
|
|
|
|
class LoggerInterface(InvocationContextInterface):
|
|
def debug(self, message: str) -> None:
|
|
"""
|
|
Logs a debug message.
|
|
|
|
:param message: The message to log.
|
|
"""
|
|
self._services.logger.debug(message)
|
|
|
|
def info(self, message: str) -> None:
|
|
"""
|
|
Logs an info message.
|
|
|
|
:param message: The message to log.
|
|
"""
|
|
self._services.logger.info(message)
|
|
|
|
def warning(self, message: str) -> None:
|
|
"""
|
|
Logs a warning message.
|
|
|
|
:param message: The message to log.
|
|
"""
|
|
self._services.logger.warning(message)
|
|
|
|
def error(self, message: str) -> None:
|
|
"""
|
|
Logs an error message.
|
|
|
|
:param message: The message to log.
|
|
"""
|
|
self._services.logger.error(message)
|
|
|
|
|
|
class ImagesInterface(InvocationContextInterface):
|
|
def save(
|
|
self,
|
|
image: Image,
|
|
board_id: Optional[str] = None,
|
|
image_category: ImageCategory = ImageCategory.GENERAL,
|
|
metadata: Optional[MetadataField] = None,
|
|
) -> ImageDTO:
|
|
"""
|
|
Saves an image, returning its DTO.
|
|
|
|
If the current queue item has a workflow or metadata, it is automatically saved with the image.
|
|
|
|
:param image: The image to save, as a PIL image.
|
|
:param board_id: The board ID to add the image to, if it should be added. It the invocation \
|
|
inherits from `WithBoard`, that board will be used automatically. **Use this only if \
|
|
you want to override or provide a board manually!**
|
|
:param image_category: The category of the image. Only the GENERAL category is added \
|
|
to the gallery.
|
|
:param metadata: The metadata to save with the image, if it should have any. If the \
|
|
invocation inherits from `WithMetadata`, that metadata will be used automatically. \
|
|
**Use this only if you want to override or provide metadata manually!**
|
|
"""
|
|
|
|
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
|
|
metadata_ = None
|
|
if metadata:
|
|
metadata_ = metadata
|
|
elif isinstance(self._context_data.invocation, WithMetadata):
|
|
metadata_ = self._context_data.invocation.metadata
|
|
|
|
# If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None.
|
|
board_id_ = None
|
|
if board_id:
|
|
board_id_ = board_id
|
|
elif isinstance(self._context_data.invocation, WithBoard) and self._context_data.invocation.board:
|
|
board_id_ = self._context_data.invocation.board.board_id
|
|
|
|
return self._services.images.create(
|
|
image=image,
|
|
is_intermediate=self._context_data.invocation.is_intermediate,
|
|
image_category=image_category,
|
|
board_id=board_id_,
|
|
metadata=metadata_,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
workflow=self._context_data.queue_item.workflow,
|
|
session_id=self._context_data.queue_item.session_id,
|
|
node_id=self._context_data.invocation.id,
|
|
)
|
|
|
|
def get_pil(self, image_name: str, mode: IMAGE_MODES | None = None) -> Image:
|
|
"""
|
|
Gets an image as a PIL Image object.
|
|
|
|
:param image_name: The name of the image to get.
|
|
:param mode: The color mode to convert the image to. If None, the original mode is used.
|
|
"""
|
|
image = self._services.images.get_pil_image(image_name)
|
|
if mode and mode != image.mode:
|
|
try:
|
|
image = image.convert(mode)
|
|
except ValueError:
|
|
self._services.logger.warning(
|
|
f"Could not convert image from {image.mode} to {mode}. Using original mode instead."
|
|
)
|
|
return image
|
|
|
|
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
|
"""
|
|
Gets an image's metadata, if it has any.
|
|
|
|
:param image_name: The name of the image to get the metadata for.
|
|
"""
|
|
return self._services.images.get_metadata(image_name)
|
|
|
|
def get_dto(self, image_name: str) -> ImageDTO:
|
|
"""
|
|
Gets an image as an ImageDTO object.
|
|
|
|
:param image_name: The name of the image to get.
|
|
"""
|
|
return self._services.images.get_dto(image_name)
|
|
|
|
|
|
class TensorsInterface(InvocationContextInterface):
|
|
def save(self, tensor: Tensor) -> str:
|
|
"""
|
|
Saves a tensor, returning its name.
|
|
|
|
:param tensor: The tensor to save.
|
|
"""
|
|
|
|
name = self._services.tensors.save(obj=tensor)
|
|
return name
|
|
|
|
def load(self, name: str) -> Tensor:
|
|
"""
|
|
Loads a tensor by name.
|
|
|
|
:param name: The name of the tensor to load.
|
|
"""
|
|
return self._services.tensors.load(name)
|
|
|
|
|
|
class ConditioningInterface(InvocationContextInterface):
|
|
def save(self, conditioning_data: ConditioningFieldData) -> str:
|
|
"""
|
|
Saves a conditioning data object, returning its name.
|
|
|
|
:param conditioning_context_data: The conditioning data to save.
|
|
"""
|
|
|
|
name = self._services.conditioning.save(obj=conditioning_data)
|
|
return name
|
|
|
|
def load(self, name: str) -> ConditioningFieldData:
|
|
"""
|
|
Loads conditioning data by name.
|
|
|
|
:param name: The name of the conditioning data to load.
|
|
"""
|
|
|
|
return self._services.conditioning.load(name)
|
|
|
|
|
|
class ModelsInterface(InvocationContextInterface):
|
|
def exists(self, key: str) -> bool:
|
|
"""
|
|
Checks if a model exists.
|
|
|
|
:param key: The key of the model.
|
|
"""
|
|
return self._services.model_manager.store.exists(key)
|
|
|
|
def load(self, key: str, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
|
"""
|
|
Loads a model.
|
|
|
|
:param key: The key of the model.
|
|
:param submodel_type: The submodel of the model to get.
|
|
:returns: An object representing the loaded model.
|
|
"""
|
|
|
|
# The model manager emits events as it loads the model. It needs the context data to build
|
|
# the event payloads.
|
|
|
|
return self._services.model_manager.load_model_by_key(
|
|
key=key, submodel_type=submodel_type, context_data=self._context_data
|
|
)
|
|
|
|
def load_by_attrs(
|
|
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
|
|
) -> LoadedModel:
|
|
"""
|
|
Loads a model by its attributes.
|
|
|
|
:param model_name: Name of to be fetched.
|
|
:param base_model: Base model
|
|
:param model_type: Type of the model
|
|
:param submodel: For main (pipeline models), the submodel to fetch
|
|
"""
|
|
return self._services.model_manager.load_model_by_attr(
|
|
model_name=model_name,
|
|
base_model=base_model,
|
|
model_type=model_type,
|
|
submodel=submodel,
|
|
context_data=self._context_data,
|
|
)
|
|
|
|
def get_config(self, key: str) -> AnyModelConfig:
|
|
"""
|
|
Gets a model's info, an dict-like object.
|
|
|
|
:param key: The key of the model.
|
|
"""
|
|
return self._services.model_manager.store.get_model(key=key)
|
|
|
|
def get_metadata(self, key: str) -> Optional[AnyModelRepoMetadata]:
|
|
"""
|
|
Gets a model's metadata, if it has any.
|
|
|
|
:param key: The key of the model.
|
|
"""
|
|
return self._services.model_manager.store.get_metadata(key=key)
|
|
|
|
def search_by_path(self, path: Path) -> list[AnyModelConfig]:
|
|
"""
|
|
Searches for models by path.
|
|
|
|
:param path: The path to search for.
|
|
"""
|
|
return self._services.model_manager.store.search_by_path(path)
|
|
|
|
def search_by_attrs(
|
|
self,
|
|
model_name: Optional[str] = None,
|
|
base_model: Optional[BaseModelType] = None,
|
|
model_type: Optional[ModelType] = None,
|
|
model_format: Optional[ModelFormat] = None,
|
|
) -> list[AnyModelConfig]:
|
|
"""
|
|
Searches for models by attributes.
|
|
|
|
:param model_name: Name of to be fetched.
|
|
:param base_model: Base model
|
|
:param model_type: Type of the model
|
|
:param submodel: For main (pipeline models), the submodel to fetch
|
|
"""
|
|
|
|
return self._services.model_manager.store.search_by_attr(
|
|
model_name=model_name,
|
|
base_model=base_model,
|
|
model_type=model_type,
|
|
model_format=model_format,
|
|
)
|
|
|
|
|
|
class ConfigInterface(InvocationContextInterface):
|
|
def get(self) -> InvokeAIAppConfig:
|
|
"""Gets the app's config."""
|
|
|
|
return self._services.configuration.get_config()
|
|
|
|
|
|
class UtilInterface(InvocationContextInterface):
|
|
def __init__(
|
|
self, services: InvocationServices, context_data: InvocationContextData, is_canceled: Callable[[], bool]
|
|
) -> None:
|
|
super().__init__(services, context_data)
|
|
self._is_canceled = is_canceled
|
|
|
|
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
|
|
"""
|
|
The step callback emits a progress event with the current step, the total number of
|
|
steps, a preview image, and some other internal metadata.
|
|
|
|
This should be called after each denoising step.
|
|
|
|
:param intermediate_state: The intermediate state of the diffusion pipeline.
|
|
:param base_model: The base model for the current denoising step.
|
|
"""
|
|
|
|
stable_diffusion_step_callback(
|
|
context_data=self._context_data,
|
|
intermediate_state=intermediate_state,
|
|
base_model=base_model,
|
|
events=self._services.events,
|
|
is_canceled=self._is_canceled,
|
|
)
|
|
|
|
|
|
class InvocationContext:
|
|
"""
|
|
The `InvocationContext` provides access to various services and data for the current invocation.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
images: ImagesInterface,
|
|
tensors: TensorsInterface,
|
|
conditioning: ConditioningInterface,
|
|
models: ModelsInterface,
|
|
logger: LoggerInterface,
|
|
config: ConfigInterface,
|
|
util: UtilInterface,
|
|
boards: BoardsInterface,
|
|
context_data: InvocationContextData,
|
|
services: InvocationServices,
|
|
is_canceled: Callable[[], bool],
|
|
) -> None:
|
|
self.images = images
|
|
"""Provides methods to save, get and update images and their metadata."""
|
|
self.tensors = tensors
|
|
"""Provides methods to save and get tensors, including image, noise, masks, and masked images."""
|
|
self.conditioning = conditioning
|
|
"""Provides methods to save and get conditioning data."""
|
|
self.models = models
|
|
"""Provides methods to check if a model exists, get a model, and get a model's info."""
|
|
self.logger = logger
|
|
"""Provides access to the app logger."""
|
|
self.config = config
|
|
"""Provides access to the app's config."""
|
|
self.util = util
|
|
"""Provides utility methods."""
|
|
self.boards = boards
|
|
"""Provides methods to interact with boards."""
|
|
self.is_canceled = is_canceled
|
|
"""Checks if the current invocation has been canceled."""
|
|
self._data = context_data
|
|
"""Provides data about the current queue item and invocation. This is an internal API and may change without warning."""
|
|
self._services = services
|
|
"""Provides access to the full application services. This is an internal API and may change without warning."""
|
|
|
|
|
|
def build_invocation_context(
|
|
services: InvocationServices,
|
|
context_data: InvocationContextData,
|
|
cancel_event: threading.Event,
|
|
) -> InvocationContext:
|
|
"""
|
|
Builds the invocation context for a specific invocation execution.
|
|
|
|
:param invocation_services: The invocation services to wrap.
|
|
:param invocation_context_data: The invocation context data.
|
|
"""
|
|
|
|
def is_canceled() -> bool:
|
|
return cancel_event.is_set()
|
|
|
|
logger = LoggerInterface(services=services, context_data=context_data)
|
|
images = ImagesInterface(services=services, context_data=context_data)
|
|
tensors = TensorsInterface(services=services, context_data=context_data)
|
|
models = ModelsInterface(services=services, context_data=context_data)
|
|
config = ConfigInterface(services=services, context_data=context_data)
|
|
util = UtilInterface(services=services, context_data=context_data, is_canceled=is_canceled)
|
|
conditioning = ConditioningInterface(services=services, context_data=context_data)
|
|
boards = BoardsInterface(services=services, context_data=context_data)
|
|
|
|
ctx = InvocationContext(
|
|
images=images,
|
|
logger=logger,
|
|
config=config,
|
|
tensors=tensors,
|
|
models=models,
|
|
context_data=context_data,
|
|
util=util,
|
|
conditioning=conditioning,
|
|
services=services,
|
|
boards=boards,
|
|
is_canceled=is_canceled,
|
|
)
|
|
|
|
return ctx
|