mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): do not hide services
in invocation context interfaces
This commit is contained in:
parent
0ff466ebc4
commit
0fde0d1ff7
@ -64,379 +64,338 @@ class InvocationContextData:
|
|||||||
"""The workflow associated with this queue item, if any."""
|
"""The workflow associated with this queue item, if any."""
|
||||||
|
|
||||||
|
|
||||||
class BoardsInterface:
|
class InvocationContextInterface:
|
||||||
def __init__(self, services: InvocationServices) -> None:
|
|
||||||
def create(board_name: str) -> BoardDTO:
|
|
||||||
"""
|
|
||||||
Creates a board.
|
|
||||||
|
|
||||||
:param board_name: The name of the board to create.
|
|
||||||
"""
|
|
||||||
return services.boards.create(board_name)
|
|
||||||
|
|
||||||
def get_dto(board_id: str) -> BoardDTO:
|
|
||||||
"""
|
|
||||||
Gets a board DTO.
|
|
||||||
|
|
||||||
:param board_id: The ID of the board to get.
|
|
||||||
"""
|
|
||||||
return services.boards.get_dto(board_id)
|
|
||||||
|
|
||||||
def get_all() -> list[BoardDTO]:
|
|
||||||
"""
|
|
||||||
Gets all boards.
|
|
||||||
"""
|
|
||||||
return services.boards.get_all()
|
|
||||||
|
|
||||||
def add_image_to_board(board_id: str, image_name: str) -> None:
|
|
||||||
"""
|
|
||||||
Adds an image to a board.
|
|
||||||
|
|
||||||
:param board_id: The ID of the board to add the image to.
|
|
||||||
:param image_name: The name of the image to add to the board.
|
|
||||||
"""
|
|
||||||
services.board_images.add_image_to_board(board_id, image_name)
|
|
||||||
|
|
||||||
def get_all_image_names_for_board(board_id: str) -> list[str]:
|
|
||||||
"""
|
|
||||||
Gets all image names for a board.
|
|
||||||
|
|
||||||
:param board_id: The ID of the board to get the image names for.
|
|
||||||
"""
|
|
||||||
return services.board_images.get_all_board_image_names_for_board(board_id)
|
|
||||||
|
|
||||||
self.create = create
|
|
||||||
self.get_dto = get_dto
|
|
||||||
self.get_all = get_all
|
|
||||||
self.add_image_to_board = add_image_to_board
|
|
||||||
self.get_all_image_names_for_board = get_all_image_names_for_board
|
|
||||||
|
|
||||||
|
|
||||||
class LoggerInterface:
|
|
||||||
def __init__(self, services: InvocationServices) -> None:
|
|
||||||
def debug(message: str) -> None:
|
|
||||||
"""
|
|
||||||
Logs a debug message.
|
|
||||||
|
|
||||||
:param message: The message to log.
|
|
||||||
"""
|
|
||||||
services.logger.debug(message)
|
|
||||||
|
|
||||||
def info(message: str) -> None:
|
|
||||||
"""
|
|
||||||
Logs an info message.
|
|
||||||
|
|
||||||
:param message: The message to log.
|
|
||||||
"""
|
|
||||||
services.logger.info(message)
|
|
||||||
|
|
||||||
def warning(message: str) -> None:
|
|
||||||
"""
|
|
||||||
Logs a warning message.
|
|
||||||
|
|
||||||
:param message: The message to log.
|
|
||||||
"""
|
|
||||||
services.logger.warning(message)
|
|
||||||
|
|
||||||
def error(message: str) -> None:
|
|
||||||
"""
|
|
||||||
Logs an error message.
|
|
||||||
|
|
||||||
:param message: The message to log.
|
|
||||||
"""
|
|
||||||
services.logger.error(message)
|
|
||||||
|
|
||||||
self.debug = debug
|
|
||||||
self.info = info
|
|
||||||
self.warning = warning
|
|
||||||
self.error = error
|
|
||||||
|
|
||||||
|
|
||||||
class ImagesInterface:
|
|
||||||
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None:
|
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None:
|
||||||
def save(
|
self._services = services
|
||||||
image: Image,
|
self._context_data = context_data
|
||||||
board_id: Optional[str] = None,
|
|
||||||
image_category: ImageCategory = ImageCategory.GENERAL,
|
|
||||||
metadata: Optional[MetadataField] = None,
|
|
||||||
) -> ImageDTO:
|
|
||||||
"""
|
|
||||||
Saves an image, returning its DTO.
|
|
||||||
|
|
||||||
If the current queue item has a workflow or metadata, it is automatically saved with the image.
|
|
||||||
|
|
||||||
:param image: The image to save, as a PIL image.
|
|
||||||
:param board_id: The board ID to add the image to, if it should be added.
|
|
||||||
:param image_category: The category of the image. Only the GENERAL category is added \
|
|
||||||
to the gallery.
|
|
||||||
:param metadata: The metadata to save with the image, if it should have any. If the \
|
|
||||||
invocation inherits from `WithMetadata`, that metadata will be used automatically. \
|
|
||||||
**Use this only if you want to override or provide metadata manually!**
|
|
||||||
"""
|
|
||||||
|
|
||||||
# If the invocation inherits metadata, use that. Else, use the metadata passed in.
|
|
||||||
metadata_ = (
|
|
||||||
context_data.invocation.metadata if isinstance(context_data.invocation, WithMetadata) else metadata
|
|
||||||
)
|
|
||||||
|
|
||||||
return services.images.create(
|
|
||||||
image=image,
|
|
||||||
is_intermediate=context_data.invocation.is_intermediate,
|
|
||||||
image_category=image_category,
|
|
||||||
board_id=board_id,
|
|
||||||
metadata=metadata_,
|
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
|
||||||
workflow=context_data.workflow,
|
|
||||||
session_id=context_data.session_id,
|
|
||||||
node_id=context_data.invocation.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_pil(image_name: str) -> Image:
|
|
||||||
"""
|
|
||||||
Gets an image as a PIL Image object.
|
|
||||||
|
|
||||||
:param image_name: The name of the image to get.
|
|
||||||
"""
|
|
||||||
return services.images.get_pil_image(image_name)
|
|
||||||
|
|
||||||
def get_metadata(image_name: str) -> Optional[MetadataField]:
|
|
||||||
"""
|
|
||||||
Gets an image's metadata, if it has any.
|
|
||||||
|
|
||||||
:param image_name: The name of the image to get the metadata for.
|
|
||||||
"""
|
|
||||||
return services.images.get_metadata(image_name)
|
|
||||||
|
|
||||||
def get_dto(image_name: str) -> ImageDTO:
|
|
||||||
"""
|
|
||||||
Gets an image as an ImageDTO object.
|
|
||||||
|
|
||||||
:param image_name: The name of the image to get.
|
|
||||||
"""
|
|
||||||
return services.images.get_dto(image_name)
|
|
||||||
|
|
||||||
def update(
|
|
||||||
image_name: str,
|
|
||||||
board_id: Optional[str] = None,
|
|
||||||
is_intermediate: Optional[bool] = False,
|
|
||||||
) -> ImageDTO:
|
|
||||||
"""
|
|
||||||
Updates an image, returning its updated DTO.
|
|
||||||
|
|
||||||
It is not suggested to update images saved by earlier nodes, as this can cause confusion for users.
|
|
||||||
|
|
||||||
If you use this method, you *must* return the image as an :class:`ImageOutput` for the gallery to
|
|
||||||
get the updated image.
|
|
||||||
|
|
||||||
:param image_name: The name of the image to update.
|
|
||||||
:param board_id: The board ID to add the image to, if it should be added.
|
|
||||||
:param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery.
|
|
||||||
"""
|
|
||||||
if is_intermediate is not None:
|
|
||||||
services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate))
|
|
||||||
if board_id is None:
|
|
||||||
services.board_images.remove_image_from_board(image_name)
|
|
||||||
else:
|
|
||||||
services.board_images.add_image_to_board(image_name, board_id)
|
|
||||||
return services.images.get_dto(image_name)
|
|
||||||
|
|
||||||
self.save = save
|
|
||||||
self.get_pil = get_pil
|
|
||||||
self.get_metadata = get_metadata
|
|
||||||
self.get_dto = get_dto
|
|
||||||
self.update = update
|
|
||||||
|
|
||||||
|
|
||||||
class LatentsInterface:
|
class BoardsInterface(InvocationContextInterface):
|
||||||
def __init__(
|
def create(self, board_name: str) -> BoardDTO:
|
||||||
|
"""
|
||||||
|
Creates a board.
|
||||||
|
|
||||||
|
:param board_name: The name of the board to create.
|
||||||
|
"""
|
||||||
|
return self._services.boards.create(board_name)
|
||||||
|
|
||||||
|
def get_dto(self, board_id: str) -> BoardDTO:
|
||||||
|
"""
|
||||||
|
Gets a board DTO.
|
||||||
|
|
||||||
|
:param board_id: The ID of the board to get.
|
||||||
|
"""
|
||||||
|
return self._services.boards.get_dto(board_id)
|
||||||
|
|
||||||
|
def get_all(self) -> list[BoardDTO]:
|
||||||
|
"""
|
||||||
|
Gets all boards.
|
||||||
|
"""
|
||||||
|
return self._services.boards.get_all()
|
||||||
|
|
||||||
|
def add_image_to_board(self, board_id: str, image_name: str) -> None:
|
||||||
|
"""
|
||||||
|
Adds an image to a board.
|
||||||
|
|
||||||
|
:param board_id: The ID of the board to add the image to.
|
||||||
|
:param image_name: The name of the image to add to the board.
|
||||||
|
"""
|
||||||
|
return self._services.board_images.add_image_to_board(board_id, image_name)
|
||||||
|
|
||||||
|
def get_all_image_names_for_board(self, board_id: str) -> list[str]:
|
||||||
|
"""
|
||||||
|
Gets all image names for a board.
|
||||||
|
|
||||||
|
:param board_id: The ID of the board to get the image names for.
|
||||||
|
"""
|
||||||
|
return self._services.board_images.get_all_board_image_names_for_board(board_id)
|
||||||
|
|
||||||
|
|
||||||
|
class LoggerInterface(InvocationContextInterface):
|
||||||
|
def debug(self, message: str) -> None:
|
||||||
|
"""
|
||||||
|
Logs a debug message.
|
||||||
|
|
||||||
|
:param message: The message to log.
|
||||||
|
"""
|
||||||
|
self._services.logger.debug(message)
|
||||||
|
|
||||||
|
def info(self, message: str) -> None:
|
||||||
|
"""
|
||||||
|
Logs an info message.
|
||||||
|
|
||||||
|
:param message: The message to log.
|
||||||
|
"""
|
||||||
|
self._services.logger.info(message)
|
||||||
|
|
||||||
|
def warning(self, message: str) -> None:
|
||||||
|
"""
|
||||||
|
Logs a warning message.
|
||||||
|
|
||||||
|
:param message: The message to log.
|
||||||
|
"""
|
||||||
|
self._services.logger.warning(message)
|
||||||
|
|
||||||
|
def error(self, message: str) -> None:
|
||||||
|
"""
|
||||||
|
Logs an error message.
|
||||||
|
|
||||||
|
:param message: The message to log.
|
||||||
|
"""
|
||||||
|
self._services.logger.error(message)
|
||||||
|
|
||||||
|
|
||||||
|
class ImagesInterface(InvocationContextInterface):
|
||||||
|
def save(
|
||||||
self,
|
self,
|
||||||
services: InvocationServices,
|
image: Image,
|
||||||
context_data: InvocationContextData,
|
board_id: Optional[str] = None,
|
||||||
) -> None:
|
image_category: ImageCategory = ImageCategory.GENERAL,
|
||||||
def save(tensor: Tensor) -> str:
|
metadata: Optional[MetadataField] = None,
|
||||||
"""
|
) -> ImageDTO:
|
||||||
Saves a latents tensor, returning its name.
|
"""
|
||||||
|
Saves an image, returning its DTO.
|
||||||
|
|
||||||
:param tensor: The latents tensor to save.
|
If the current queue item has a workflow or metadata, it is automatically saved with the image.
|
||||||
"""
|
|
||||||
|
|
||||||
# Previously, we added a suffix indicating the type of Tensor we were saving, e.g.
|
:param image: The image to save, as a PIL image.
|
||||||
# "mask", "noise", "masked_latents", etc.
|
:param board_id: The board ID to add the image to, if it should be added.
|
||||||
#
|
:param image_category: The category of the image. Only the GENERAL category is added \
|
||||||
# Retaining that capability in this wrapper would require either many different methods
|
to the gallery.
|
||||||
# to save latents, or extra args for this method. Instead of complicating the API, we
|
:param metadata: The metadata to save with the image, if it should have any. If the \
|
||||||
# will use the same naming scheme for all latents.
|
invocation inherits from `WithMetadata`, that metadata will be used automatically. \
|
||||||
#
|
**Use this only if you want to override or provide metadata manually!**
|
||||||
# This has a very minor impact as we don't use them after a session completes.
|
"""
|
||||||
|
|
||||||
# Previously, invocations chose the name for their latents. This is a bit risky, so we
|
# If the invocation inherits metadata, use that. Else, use the metadata passed in.
|
||||||
# will generate a name for them instead. We use a uuid to ensure the name is unique.
|
metadata_ = (
|
||||||
#
|
self._context_data.invocation.metadata
|
||||||
# Because the name of the latents file will includes the session and invocation IDs,
|
if isinstance(self._context_data.invocation, WithMetadata)
|
||||||
# we don't need to worry about collisions. A truncated UUIDv4 is fine.
|
else metadata
|
||||||
|
)
|
||||||
|
|
||||||
name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}"
|
return self._services.images.create(
|
||||||
services.latents.save(
|
image=image,
|
||||||
name=name,
|
is_intermediate=self._context_data.invocation.is_intermediate,
|
||||||
data=tensor,
|
image_category=image_category,
|
||||||
)
|
board_id=board_id,
|
||||||
return name
|
metadata=metadata_,
|
||||||
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
workflow=self._context_data.workflow,
|
||||||
|
session_id=self._context_data.session_id,
|
||||||
|
node_id=self._context_data.invocation.id,
|
||||||
|
)
|
||||||
|
|
||||||
def get(latents_name: str) -> Tensor:
|
def get_pil(self, image_name: str) -> Image:
|
||||||
"""
|
"""
|
||||||
Gets a latents tensor by name.
|
Gets an image as a PIL Image object.
|
||||||
|
|
||||||
:param latents_name: The name of the latents tensor to get.
|
:param image_name: The name of the image to get.
|
||||||
"""
|
"""
|
||||||
return services.latents.get(latents_name)
|
return self._services.images.get_pil_image(image_name)
|
||||||
|
|
||||||
self.save = save
|
def get_metadata(self, image_name: str) -> Optional[MetadataField]:
|
||||||
self.get = get
|
"""
|
||||||
|
Gets an image's metadata, if it has any.
|
||||||
|
|
||||||
|
:param image_name: The name of the image to get the metadata for.
|
||||||
|
"""
|
||||||
|
return self._services.images.get_metadata(image_name)
|
||||||
|
|
||||||
class ConditioningInterface:
|
def get_dto(self, image_name: str) -> ImageDTO:
|
||||||
def __init__(
|
"""
|
||||||
|
Gets an image as an ImageDTO object.
|
||||||
|
|
||||||
|
:param image_name: The name of the image to get.
|
||||||
|
"""
|
||||||
|
return self._services.images.get_dto(image_name)
|
||||||
|
|
||||||
|
def update(
|
||||||
self,
|
self,
|
||||||
services: InvocationServices,
|
image_name: str,
|
||||||
context_data: InvocationContextData,
|
board_id: Optional[str] = None,
|
||||||
) -> None:
|
is_intermediate: Optional[bool] = False,
|
||||||
# TODO(psyche): We are (ab)using the latents storage service as a general pickle storage
|
) -> ImageDTO:
|
||||||
# service, but it is typed to work with Tensors only. We have to fudge the types here.
|
"""
|
||||||
|
Updates an image, returning its updated DTO.
|
||||||
|
|
||||||
def save(conditioning_data: ConditioningFieldData) -> str:
|
It is not suggested to update images saved by earlier nodes, as this can cause confusion for users.
|
||||||
"""
|
|
||||||
Saves a conditioning data object, returning its name.
|
|
||||||
|
|
||||||
:param conditioning_data: The conditioning data to save.
|
If you use this method, you *must* return the image as an :class:`ImageOutput` for the gallery to
|
||||||
"""
|
get the updated image.
|
||||||
|
|
||||||
# Conditioning data is *not* a Tensor, so we will suffix it to indicate this.
|
:param image_name: The name of the image to update.
|
||||||
#
|
:param board_id: The board ID to add the image to, if it should be added.
|
||||||
# See comment for `LatentsInterface.save` for more info about this method (it's very
|
:param is_intermediate: Whether the image is an intermediate. Intermediate images aren't added to the gallery.
|
||||||
# similar).
|
"""
|
||||||
|
if is_intermediate is not None:
|
||||||
name = f"{context_data.session_id}__{context_data.invocation.id}__{uuid_string()[:7]}__conditioning"
|
self._services.images.update(image_name, ImageRecordChanges(is_intermediate=is_intermediate))
|
||||||
services.latents.save(
|
if board_id is None:
|
||||||
name=name,
|
self._services.board_images.remove_image_from_board(image_name)
|
||||||
data=conditioning_data, # type: ignore [arg-type]
|
else:
|
||||||
)
|
self._services.board_images.add_image_to_board(image_name, board_id)
|
||||||
return name
|
return self._services.images.get_dto(image_name)
|
||||||
|
|
||||||
def get(conditioning_name: str) -> ConditioningFieldData:
|
|
||||||
"""
|
|
||||||
Gets conditioning data by name.
|
|
||||||
|
|
||||||
:param conditioning_name: The name of the conditioning data to get.
|
|
||||||
"""
|
|
||||||
|
|
||||||
return services.latents.get(conditioning_name) # type: ignore [return-value]
|
|
||||||
|
|
||||||
self.save = save
|
|
||||||
self.get = get
|
|
||||||
|
|
||||||
|
|
||||||
class ModelsInterface:
|
class LatentsInterface(InvocationContextInterface):
|
||||||
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None:
|
def save(self, tensor: Tensor) -> str:
|
||||||
def exists(model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool:
|
"""
|
||||||
"""
|
Saves a latents tensor, returning its name.
|
||||||
Checks if a model exists.
|
|
||||||
|
|
||||||
:param model_name: The name of the model to check.
|
:param tensor: The latents tensor to save.
|
||||||
:param base_model: The base model of the model to check.
|
"""
|
||||||
:param model_type: The type of the model to check.
|
|
||||||
"""
|
|
||||||
return services.model_manager.model_exists(model_name, base_model, model_type)
|
|
||||||
|
|
||||||
def load(
|
# Previously, we added a suffix indicating the type of Tensor we were saving, e.g.
|
||||||
model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
|
# "mask", "noise", "masked_latents", etc.
|
||||||
) -> ModelInfo:
|
#
|
||||||
"""
|
# Retaining that capability in this wrapper would require either many different methods
|
||||||
Loads a model, returning its `ModelInfo` object.
|
# to save latents, or extra args for this method. Instead of complicating the API, we
|
||||||
|
# will use the same naming scheme for all latents.
|
||||||
|
#
|
||||||
|
# This has a very minor impact as we don't use them after a session completes.
|
||||||
|
|
||||||
:param model_name: The name of the model to get.
|
# Previously, invocations chose the name for their latents. This is a bit risky, so we
|
||||||
:param base_model: The base model of the model to get.
|
# will generate a name for them instead. We use a uuid to ensure the name is unique.
|
||||||
:param model_type: The type of the model to get.
|
#
|
||||||
:param submodel: The submodel of the model to get.
|
# Because the name of the latents file will includes the session and invocation IDs,
|
||||||
"""
|
# we don't need to worry about collisions. A truncated UUIDv4 is fine.
|
||||||
|
|
||||||
# During this call, the model manager emits events with model loading status. The model
|
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}"
|
||||||
# manager itself has access to the events services, but does not have access to the
|
self._services.latents.save(
|
||||||
# required metadata for the events.
|
name=name,
|
||||||
#
|
data=tensor,
|
||||||
# For example, it needs access to the node's ID so that the events can be associated
|
)
|
||||||
# with the execution of a specific node.
|
return name
|
||||||
#
|
|
||||||
# While this is available within the node, it's tedious to need to pass it in on every
|
|
||||||
# call. We can avoid that by wrapping the method here.
|
|
||||||
|
|
||||||
return services.model_manager.get_model(
|
def get(self, latents_name: str) -> Tensor:
|
||||||
model_name, base_model, model_type, submodel, context_data=context_data
|
"""
|
||||||
)
|
Gets a latents tensor by name.
|
||||||
|
|
||||||
def get_info(model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
:param latents_name: The name of the latents tensor to get.
|
||||||
"""
|
"""
|
||||||
Gets a model's info, an dict-like object.
|
return self._services.latents.get(latents_name)
|
||||||
|
|
||||||
:param model_name: The name of the model to get.
|
|
||||||
:param base_model: The base model of the model to get.
|
|
||||||
:param model_type: The type of the model to get.
|
|
||||||
"""
|
|
||||||
return services.model_manager.model_info(model_name, base_model, model_type)
|
|
||||||
|
|
||||||
self.exists = exists
|
|
||||||
self.load = load
|
|
||||||
self.get_info = get_info
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigInterface:
|
class ConditioningInterface(InvocationContextInterface):
|
||||||
def __init__(self, services: InvocationServices) -> None:
|
# TODO(psyche): We are (ab)using the latents storage service as a general pickle storage
|
||||||
def get() -> InvokeAIAppConfig:
|
# service, but it is typed to work with Tensors only. We have to fudge the types here.
|
||||||
"""
|
def save(self, conditioning_data: ConditioningFieldData) -> str:
|
||||||
Gets the app's config. The config is read-only; attempts to mutate it will raise an error.
|
"""
|
||||||
"""
|
Saves a conditioning data object, returning its name.
|
||||||
|
|
||||||
# The config can be changed at runtime.
|
:param conditioning_context_data: The conditioning data to save.
|
||||||
#
|
"""
|
||||||
# We don't want nodes doing this, so we make a frozen copy.
|
|
||||||
|
|
||||||
config = services.configuration.get_config()
|
# Conditioning data is *not* a Tensor, so we will suffix it to indicate this.
|
||||||
# TODO(psyche): If config cannot be changed at runtime, should we cache this?
|
#
|
||||||
frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)})
|
# See comment for `LatentsInterface.save` for more info about this method (it's very
|
||||||
return frozen_config
|
# similar).
|
||||||
|
|
||||||
self.get = get
|
name = f"{self._context_data.session_id}__{self._context_data.invocation.id}__{uuid_string()[:7]}__conditioning"
|
||||||
|
self._services.latents.save(
|
||||||
|
name=name,
|
||||||
|
data=conditioning_data, # type: ignore [arg-type]
|
||||||
|
)
|
||||||
|
return name
|
||||||
|
|
||||||
|
def get(self, conditioning_name: str) -> ConditioningFieldData:
|
||||||
|
"""
|
||||||
|
Gets conditioning data by name.
|
||||||
|
|
||||||
|
:param conditioning_name: The name of the conditioning data to get.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return self._services.latents.get(conditioning_name) # type: ignore [return-value]
|
||||||
|
|
||||||
|
|
||||||
class UtilInterface:
|
class ModelsInterface(InvocationContextInterface):
|
||||||
def __init__(self, services: InvocationServices, context_data: InvocationContextData) -> None:
|
def exists(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> bool:
|
||||||
def sd_step_callback(
|
"""
|
||||||
intermediate_state: PipelineIntermediateState,
|
Checks if a model exists.
|
||||||
base_model: BaseModelType,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
The step callback emits a progress event with the current step, the total number of
|
|
||||||
steps, a preview image, and some other internal metadata.
|
|
||||||
|
|
||||||
This should be called after each denoising step.
|
:param model_name: The name of the model to check.
|
||||||
|
:param base_model: The base model of the model to check.
|
||||||
|
:param model_type: The type of the model to check.
|
||||||
|
"""
|
||||||
|
return self._services.model_manager.model_exists(model_name, base_model, model_type)
|
||||||
|
|
||||||
:param intermediate_state: The intermediate state of the diffusion pipeline.
|
def load(
|
||||||
:param base_model: The base model for the current denoising step.
|
self, model_name: str, base_model: BaseModelType, model_type: ModelType, submodel: Optional[SubModelType] = None
|
||||||
"""
|
) -> ModelInfo:
|
||||||
|
"""
|
||||||
|
Loads a model, returning its `ModelInfo` object.
|
||||||
|
|
||||||
# The step callback needs access to the events and the invocation queue services, but this
|
:param model_name: The name of the model to get.
|
||||||
# represents a dangerous level of access.
|
:param base_model: The base model of the model to get.
|
||||||
#
|
:param model_type: The type of the model to get.
|
||||||
# We wrap the step callback so that nodes do not have direct access to these services.
|
:param submodel: The submodel of the model to get.
|
||||||
|
"""
|
||||||
|
|
||||||
stable_diffusion_step_callback(
|
# During this call, the model manager emits events with model loading status. The model
|
||||||
context_data=context_data,
|
# manager itself has access to the events services, but does not have access to the
|
||||||
intermediate_state=intermediate_state,
|
# required metadata for the events.
|
||||||
base_model=base_model,
|
#
|
||||||
invocation_queue=services.queue,
|
# For example, it needs access to the node's ID so that the events can be associated
|
||||||
events=services.events,
|
# with the execution of a specific node.
|
||||||
)
|
#
|
||||||
|
# While this is available within the node, it's tedious to need to pass it in on every
|
||||||
|
# call. We can avoid that by wrapping the method here.
|
||||||
|
|
||||||
self.sd_step_callback = sd_step_callback
|
return self._services.model_manager.get_model(
|
||||||
|
model_name, base_model, model_type, submodel, context_data=self._context_data
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_info(self, model_name: str, base_model: BaseModelType, model_type: ModelType) -> dict:
|
||||||
|
"""
|
||||||
|
Gets a model's info, an dict-like object.
|
||||||
|
|
||||||
|
:param model_name: The name of the model to get.
|
||||||
|
:param base_model: The base model of the model to get.
|
||||||
|
:param model_type: The type of the model to get.
|
||||||
|
"""
|
||||||
|
return self._services.model_manager.model_info(model_name, base_model, model_type)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigInterface(InvocationContextInterface):
|
||||||
|
def get(self) -> InvokeAIAppConfig:
|
||||||
|
"""
|
||||||
|
Gets the app's config. The config is read-only; attempts to mutate it will raise an error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The config can be changed at runtime.
|
||||||
|
#
|
||||||
|
# We don't want nodes doing this, so we make a frozen copy.
|
||||||
|
|
||||||
|
config = self._services.configuration.get_config()
|
||||||
|
# TODO(psyche): If config cannot be changed at runtime, should we cache this?
|
||||||
|
frozen_config = config.model_copy(update={"model_config": ConfigDict(frozen=True)})
|
||||||
|
return frozen_config
|
||||||
|
|
||||||
|
|
||||||
|
class UtilInterface(InvocationContextInterface):
|
||||||
|
def sd_step_callback(self, intermediate_state: PipelineIntermediateState, base_model: BaseModelType) -> None:
|
||||||
|
"""
|
||||||
|
The step callback emits a progress event with the current step, the total number of
|
||||||
|
steps, a preview image, and some other internal metadata.
|
||||||
|
|
||||||
|
This should be called after each denoising step.
|
||||||
|
|
||||||
|
:param intermediate_state: The intermediate state of the diffusion pipeline.
|
||||||
|
:param base_model: The base model for the current denoising step.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# The step callback needs access to the events and the invocation queue services, but this
|
||||||
|
# represents a dangerous level of access.
|
||||||
|
#
|
||||||
|
# We wrap the step callback so that nodes do not have direct access to these services.
|
||||||
|
|
||||||
|
stable_diffusion_step_callback(
|
||||||
|
context_data=self._context_data,
|
||||||
|
intermediate_state=intermediate_state,
|
||||||
|
base_model=base_model,
|
||||||
|
invocation_queue=self._services.queue,
|
||||||
|
events=self._services.events,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
deprecation_version = "3.7.0"
|
deprecation_version = "3.7.0"
|
||||||
@ -600,14 +559,14 @@ def build_invocation_context(
|
|||||||
:param invocation_context_data: The invocation context data.
|
:param invocation_context_data: The invocation context data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logger = LoggerInterface(services=services)
|
logger = LoggerInterface(services=services, context_data=context_data)
|
||||||
images = ImagesInterface(services=services, context_data=context_data)
|
images = ImagesInterface(services=services, context_data=context_data)
|
||||||
latents = LatentsInterface(services=services, context_data=context_data)
|
latents = LatentsInterface(services=services, context_data=context_data)
|
||||||
models = ModelsInterface(services=services, context_data=context_data)
|
models = ModelsInterface(services=services, context_data=context_data)
|
||||||
config = ConfigInterface(services=services)
|
config = ConfigInterface(services=services, context_data=context_data)
|
||||||
util = UtilInterface(services=services, context_data=context_data)
|
util = UtilInterface(services=services, context_data=context_data)
|
||||||
conditioning = ConditioningInterface(services=services, context_data=context_data)
|
conditioning = ConditioningInterface(services=services, context_data=context_data)
|
||||||
boards = BoardsInterface(services=services)
|
boards = BoardsInterface(services=services, context_data=context_data)
|
||||||
|
|
||||||
ctx = InvocationContext(
|
ctx = InvocationContext(
|
||||||
images=images,
|
images=images,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user