2024-02-17 14:41:04 +00:00
import threading
2024-01-13 07:02:58 +00:00
from dataclasses import dataclass
2024-02-15 09:43:41 +00:00
from pathlib import Path
2024-02-18 00:54:16 +00:00
from typing import TYPE_CHECKING , Optional
2024-01-13 07:02:58 +00:00
from PIL . Image import Image
from torch import Tensor
2024-02-18 21:56:46 +00:00
from invokeai . app . invocations . constants import IMAGE_MODES
2024-02-19 04:11:36 +00:00
from invokeai . app . invocations . fields import MetadataField , WithBoard , WithMetadata
2024-02-05 06:40:49 +00:00
from invokeai . app . services . boards . boards_common import BoardDTO
2024-01-13 07:02:58 +00:00
from invokeai . app . services . config . config_default import InvokeAIAppConfig
2024-02-07 05:33:55 +00:00
from invokeai . app . services . image_records . image_records_common import ImageCategory , ResourceOrigin
2024-01-13 07:02:58 +00:00
from invokeai . app . services . images . images_common import ImageDTO
from invokeai . app . services . invocation_services import InvocationServices
from invokeai . app . util . step_callback import stable_diffusion_step_callback
2024-02-15 09:43:41 +00:00
from invokeai . backend . model_manager . config import AnyModelConfig , BaseModelType , ModelFormat , ModelType , SubModelType
from invokeai . backend . model_manager . load . load_base import LoadedModel
from invokeai . backend . model_manager . metadata . metadata_base import AnyModelRepoMetadata
2024-01-13 07:02:58 +00:00
from invokeai . backend . stable_diffusion . diffusers_pipeline import PipelineIntermediateState
2024-01-14 23:41:25 +00:00
from invokeai . backend . stable_diffusion . diffusion . conditioning_data import ConditioningFieldData
2024-01-13 07:02:58 +00:00
if TYPE_CHECKING :
from invokeai . app . invocations . baseinvocation import BaseInvocation
2024-02-18 00:51:50 +00:00
from invokeai . app . services . session_queue . session_queue_common import SessionQueueItem
2024-01-13 07:02:58 +00:00
"""
The InvocationContext provides access to various services and data about the current invocation .
We do not provide the invocation services directly , as their methods are both dangerous and
inconvenient to use .
For example :
- The ` images ` service allows nodes to delete or unsafely modify existing images .
- The ` configuration ` service allows nodes to change the app ' s config at runtime.
- The ` events ` service allows nodes to emit arbitrary events .
Wrapping these services provides a simpler and safer interface for nodes to use .
When a node executes , a fresh ` InvocationContext ` is built for it , ensuring nodes cannot interfere
with each other .
2024-01-13 13:05:15 +00:00
Many of the wrappers have the same signature as the methods they wrap . This allows us to write
user - facing docstrings and not need to go and update the internal services to match .
2024-01-13 07:02:58 +00:00
Note : The docstrings are in weird places , but that ' s where they must be to get IDEs to see them.
"""
2024-01-16 09:02:38 +00:00
@dataclass
2024-01-13 07:02:58 +00:00
class InvocationContextData :
2024-02-18 00:51:50 +00:00
queue_item : " SessionQueueItem "
""" The queue item that is being executed. """
2024-01-13 07:02:58 +00:00
invocation : " BaseInvocation "
2024-01-13 13:05:15 +00:00
""" The invocation that is being executed. """
2024-02-18 00:51:50 +00:00
source_invocation_id : str
""" The ID of the invocation from which the currently executing invocation was prepared. """
2024-01-13 07:02:58 +00:00
2024-02-07 03:24:05 +00:00
class InvocationContextInterface :
2024-02-18 00:56:54 +00:00
def __init__ ( self , services : InvocationServices , data : InvocationContextData ) - > None :
2024-02-07 03:24:05 +00:00
self . _services = services
2024-02-18 00:56:54 +00:00
self . _data = data
2024-02-07 03:24:05 +00:00
class BoardsInterface ( InvocationContextInterface ) :
def create ( self , board_name : str ) - > BoardDTO :
"""
Creates a board .
: param board_name : The name of the board to create .
"""
return self . _services . boards . create ( board_name )
def get_dto ( self , board_id : str ) - > BoardDTO :
"""
Gets a board DTO .
: param board_id : The ID of the board to get .
"""
return self . _services . boards . get_dto ( board_id )
def get_all ( self ) - > list [ BoardDTO ] :
"""
Gets all boards .
"""
return self . _services . boards . get_all ( )
def add_image_to_board ( self , board_id : str , image_name : str ) - > None :
"""
Adds an image to a board .
: param board_id : The ID of the board to add the image to .
: param image_name : The name of the image to add to the board .
"""
return self . _services . board_images . add_image_to_board ( board_id , image_name )
def get_all_image_names_for_board ( self , board_id : str ) - > list [ str ] :
"""
Gets all image names for a board .
: param board_id : The ID of the board to get the image names for .
"""
return self . _services . board_images . get_all_board_image_names_for_board ( board_id )
class LoggerInterface ( InvocationContextInterface ) :
def debug ( self , message : str ) - > None :
"""
Logs a debug message .
: param message : The message to log .
"""
self . _services . logger . debug ( message )
def info ( self , message : str ) - > None :
"""
Logs an info message .
: param message : The message to log .
"""
self . _services . logger . info ( message )
def warning ( self , message : str ) - > None :
"""
Logs a warning message .
: param message : The message to log .
"""
self . _services . logger . warning ( message )
def error ( self , message : str ) - > None :
"""
Logs an error message .
: param message : The message to log .
"""
self . _services . logger . error ( message )
class ImagesInterface ( InvocationContextInterface ) :
def save (
2024-01-13 07:02:58 +00:00
self ,
2024-02-07 03:24:05 +00:00
image : Image ,
board_id : Optional [ str ] = None ,
image_category : ImageCategory = ImageCategory . GENERAL ,
metadata : Optional [ MetadataField ] = None ,
) - > ImageDTO :
"""
Saves an image , returning its DTO .
If the current queue item has a workflow or metadata , it is automatically saved with the image .
: param image : The image to save , as a PIL image .
2024-02-07 05:33:55 +00:00
: param board_id : The board ID to add the image to , if it should be added . It the invocation \
inherits from ` WithBoard ` , that board will be used automatically . * * Use this only if \
you want to override or provide a board manually ! * *
2024-02-07 03:24:05 +00:00
: param image_category : The category of the image . Only the GENERAL category is added \
to the gallery .
: param metadata : The metadata to save with the image , if it should have any . If the \
invocation inherits from ` WithMetadata ` , that metadata will be used automatically . \
* * Use this only if you want to override or provide metadata manually ! * *
"""
2024-02-08 00:05:33 +00:00
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
metadata_ = None
if metadata :
metadata_ = metadata
2024-02-18 00:56:54 +00:00
elif isinstance ( self . _data . invocation , WithMetadata ) :
metadata_ = self . _data . invocation . metadata
2024-02-08 00:05:33 +00:00
# If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None.
board_id_ = None
if board_id :
board_id_ = board_id
2024-02-18 00:56:54 +00:00
elif isinstance ( self . _data . invocation , WithBoard ) and self . _data . invocation . board :
board_id_ = self . _data . invocation . board . board_id
2024-02-07 05:33:55 +00:00
2024-02-07 03:24:05 +00:00
return self . _services . images . create (
image = image ,
2024-02-18 00:56:54 +00:00
is_intermediate = self . _data . invocation . is_intermediate ,
2024-02-07 03:24:05 +00:00
image_category = image_category ,
2024-02-07 05:33:55 +00:00
board_id = board_id_ ,
2024-02-07 03:24:05 +00:00
metadata = metadata_ ,
image_origin = ResourceOrigin . INTERNAL ,
2024-02-18 00:56:54 +00:00
workflow = self . _data . queue_item . workflow ,
session_id = self . _data . queue_item . session_id ,
node_id = self . _data . invocation . id ,
2024-02-07 03:24:05 +00:00
)
2024-02-18 21:56:46 +00:00
def get_pil ( self , image_name : str , mode : IMAGE_MODES | None = None ) - > Image :
2024-02-07 03:24:05 +00:00
"""
Gets an image as a PIL Image object .
: param image_name : The name of the image to get .
2024-02-18 21:56:46 +00:00
: param mode : The color mode to convert the image to . If None , the original mode is used .
2024-02-18 20:42:58 +00:00
"""
image = self . _services . images . get_pil_image ( image_name )
2024-02-18 21:56:46 +00:00
if mode and mode != image . mode :
2024-02-18 20:42:58 +00:00
try :
2024-02-18 21:56:46 +00:00
image = image . convert ( mode )
2024-02-18 20:42:58 +00:00
except ValueError :
2024-02-19 04:11:36 +00:00
self . _services . logger . warning (
f " Could not convert image from { image . mode } to { mode } . Using original mode instead. "
)
2024-02-18 20:42:58 +00:00
return image
2024-02-07 03:24:05 +00:00
def get_metadata ( self , image_name : str ) - > Optional [ MetadataField ] :
"""
Gets an image ' s metadata, if it has any.
: param image_name : The name of the image to get the metadata for .
"""
return self . _services . images . get_metadata ( image_name )
def get_dto ( self , image_name : str ) - > ImageDTO :
"""
Gets an image as an ImageDTO object .
: param image_name : The name of the image to get .
"""
return self . _services . images . get_dto ( image_name )
2024-01-13 13:05:15 +00:00
2024-02-07 06:41:23 +00:00
class TensorsInterface ( InvocationContextInterface ) :
2024-02-07 03:24:05 +00:00
def save ( self , tensor : Tensor ) - > str :
"""
2024-02-07 06:41:23 +00:00
Saves a tensor , returning its name .
2024-01-13 07:02:58 +00:00
2024-02-07 06:41:23 +00:00
: param tensor : The tensor to save .
2024-02-07 03:24:05 +00:00
"""
2024-01-13 07:02:58 +00:00
2024-02-07 13:23:47 +00:00
name = self . _services . tensors . save ( obj = tensor )
return name
2024-02-07 03:24:05 +00:00
2024-02-07 12:30:46 +00:00
def load ( self , name : str ) - > Tensor :
2024-02-07 03:24:05 +00:00
"""
2024-02-07 12:30:46 +00:00
Loads a tensor by name .
2024-01-13 13:05:15 +00:00
2024-02-07 12:30:46 +00:00
: param name : The name of the tensor to load .
2024-02-07 03:24:05 +00:00
"""
2024-02-07 12:30:46 +00:00
return self . _services . tensors . load ( name )
2024-01-13 07:02:58 +00:00
2024-02-07 03:24:05 +00:00
class ConditioningInterface ( InvocationContextInterface ) :
def save ( self , conditioning_data : ConditioningFieldData ) - > str :
"""
Saves a conditioning data object , returning its name .
2024-01-13 07:02:58 +00:00
2024-02-18 01:03:43 +00:00
: param conditioning_data : The conditioning data to save .
2024-02-07 03:24:05 +00:00
"""
2024-02-07 13:23:47 +00:00
name = self . _services . conditioning . save ( obj = conditioning_data )
return name
2024-01-13 07:02:58 +00:00
2024-02-07 12:30:46 +00:00
def load ( self , name : str ) - > ConditioningFieldData :
2024-02-07 03:24:05 +00:00
"""
2024-02-07 12:30:46 +00:00
Loads conditioning data by name .
2024-02-07 03:24:05 +00:00
2024-02-07 12:30:46 +00:00
: param name : The name of the conditioning data to load .
2024-02-07 03:24:05 +00:00
"""
2024-01-13 13:05:15 +00:00
2024-02-07 12:30:46 +00:00
return self . _services . conditioning . load ( name )
2024-02-07 03:24:05 +00:00
class ModelsInterface ( InvocationContextInterface ) :
2024-02-15 09:43:41 +00:00
def exists ( self , key : str ) - > bool :
2024-02-07 03:24:05 +00:00
"""
Checks if a model exists .
2024-02-15 09:43:41 +00:00
: param key : The key of the model .
2024-02-07 03:24:05 +00:00
"""
2024-02-15 09:43:41 +00:00
return self . _services . model_manager . store . exists ( key )
2024-02-07 03:24:05 +00:00
2024-02-15 09:43:41 +00:00
def load ( self , key : str , submodel_type : Optional [ SubModelType ] = None ) - > LoadedModel :
2024-02-07 03:24:05 +00:00
"""
2024-02-10 22:27:57 +00:00
Loads a model .
2024-02-07 03:24:05 +00:00
2024-02-15 09:43:41 +00:00
: param key : The key of the model .
: param submodel_type : The submodel of the model to get .
2024-02-10 22:27:57 +00:00
: returns : An object representing the loaded model .
2024-02-07 03:24:05 +00:00
"""
2024-02-07 23:57:01 +00:00
# The model manager emits events as it loads the model. It needs the context data to build
# the event payloads.
2024-02-07 03:24:05 +00:00
2024-02-18 06:27:42 +00:00
return self . _services . model_manager . load_model_by_key (
2024-02-18 00:56:54 +00:00
key = key , submodel_type = submodel_type , context_data = self . _data
2024-02-07 03:24:05 +00:00
)
2024-02-15 09:43:41 +00:00
def load_by_attrs (
2024-02-29 10:34:25 +00:00
self , name : str , base : BaseModelType , type : ModelType , submodel_type : Optional [ SubModelType ] = None
2024-02-15 09:43:41 +00:00
) - > LoadedModel :
"""
Loads a model by its attributes .
: param model_name : Name of to be fetched .
: param base_model : Base model
: param model_type : Type of the model
: param submodel : For main ( pipeline models ) , the submodel to fetch
"""
2024-02-18 06:27:42 +00:00
return self . _services . model_manager . load_model_by_attr (
2024-02-29 10:34:25 +00:00
model_name = name ,
base_model = base ,
model_type = type ,
submodel = submodel_type ,
2024-02-18 00:56:54 +00:00
context_data = self . _data ,
2024-02-15 09:43:41 +00:00
)
def get_config ( self , key : str ) - > AnyModelConfig :
2024-02-07 03:24:05 +00:00
"""
Gets a model ' s info, an dict-like object.
2024-02-15 09:43:41 +00:00
: param key : The key of the model .
"""
return self . _services . model_manager . store . get_model ( key = key )
def get_metadata ( self , key : str ) - > Optional [ AnyModelRepoMetadata ] :
"""
Gets a model ' s metadata, if it has any.
: param key : The key of the model .
2024-02-07 03:24:05 +00:00
"""
2024-02-15 09:43:41 +00:00
return self . _services . model_manager . store . get_metadata ( key = key )
def search_by_path ( self , path : Path ) - > list [ AnyModelConfig ] :
"""
Searches for models by path .
: param path : The path to search for .
"""
return self . _services . model_manager . store . search_by_path ( path )
def search_by_attrs (
self ,
2024-02-29 10:34:25 +00:00
name : Optional [ str ] = None ,
base : Optional [ BaseModelType ] = None ,
type : Optional [ ModelType ] = None ,
format : Optional [ ModelFormat ] = None ,
2024-02-15 09:43:41 +00:00
) - > list [ AnyModelConfig ] :
"""
Searches for models by attributes .
: param model_name : Name of to be fetched .
: param base_model : Base model
: param model_type : Type of the model
: param submodel : For main ( pipeline models ) , the submodel to fetch
"""
return self . _services . model_manager . store . search_by_attr (
2024-02-29 10:34:25 +00:00
model_name = name ,
base_model = base ,
model_type = type ,
model_format = format ,
2024-02-15 09:43:41 +00:00
)
2024-02-07 03:24:05 +00:00
class ConfigInterface ( InvocationContextInterface ) :
def get ( self ) - > InvokeAIAppConfig :
2024-02-07 04:58:46 +00:00
""" Gets the app ' s config. """
2024-02-07 03:24:05 +00:00
2024-02-07 04:58:46 +00:00
return self . _services . configuration . get_config ( )
2024-02-07 03:24:05 +00:00
class UtilInterface ( InvocationContextInterface ) :
2024-02-17 14:41:04 +00:00
def __init__ (
2024-02-18 00:56:54 +00:00
self , services : InvocationServices , data : InvocationContextData , cancel_event : threading . Event
2024-02-17 14:41:04 +00:00
) - > None :
2024-02-18 00:56:54 +00:00
super ( ) . __init__ ( services , data )
2024-02-18 00:54:16 +00:00
self . _cancel_event = cancel_event
def is_canceled ( self ) - > bool :
""" Checks if the current invocation has been canceled. """
return self . _cancel_event . is_set ( )
2024-02-17 14:41:04 +00:00
2024-02-07 03:24:05 +00:00
def sd_step_callback ( self , intermediate_state : PipelineIntermediateState , base_model : BaseModelType ) - > None :
"""
The step callback emits a progress event with the current step , the total number of
steps , a preview image , and some other internal metadata .
This should be called after each denoising step .
: param intermediate_state : The intermediate state of the diffusion pipeline .
: param base_model : The base model for the current denoising step .
"""
2024-01-13 13:05:15 +00:00
2024-02-07 03:24:05 +00:00
stable_diffusion_step_callback (
2024-02-18 00:56:54 +00:00
context_data = self . _data ,
2024-02-07 03:24:05 +00:00
intermediate_state = intermediate_state ,
base_model = base_model ,
events = self . _services . events ,
2024-02-18 00:54:16 +00:00
is_canceled = self . is_canceled ,
2024-02-07 03:24:05 +00:00
)
2024-01-13 07:02:58 +00:00
class InvocationContext :
"""
2024-01-13 13:05:15 +00:00
The ` InvocationContext ` provides access to various services and data for the current invocation .
2024-01-13 07:02:58 +00:00
"""
def __init__ (
self ,
images : ImagesInterface ,
2024-02-07 06:41:23 +00:00
tensors : TensorsInterface ,
2024-01-13 13:05:15 +00:00
conditioning : ConditioningInterface ,
2024-01-13 07:02:58 +00:00
models : ModelsInterface ,
logger : LoggerInterface ,
2024-01-13 13:05:15 +00:00
config : ConfigInterface ,
2024-01-13 07:02:58 +00:00
util : UtilInterface ,
2024-02-05 06:40:49 +00:00
boards : BoardsInterface ,
2024-02-18 00:56:54 +00:00
data : InvocationContextData ,
2024-01-14 09:16:51 +00:00
services : InvocationServices ,
2024-01-13 07:02:58 +00:00
) - > None :
self . images = images
2024-02-18 01:03:43 +00:00
""" Methods to save, get and update images and their metadata. """
2024-02-07 06:41:23 +00:00
self . tensors = tensors
2024-02-18 01:03:43 +00:00
""" Methods to save and get tensors, including image, noise, masks, and masked images. """
2024-01-13 07:02:58 +00:00
self . conditioning = conditioning
2024-02-18 01:03:43 +00:00
""" Methods to save and get conditioning data. """
2024-01-13 07:02:58 +00:00
self . models = models
2024-02-18 01:03:43 +00:00
""" Methods to check if a model exists, get a model, and get a model ' s info. """
2024-01-13 13:05:15 +00:00
self . logger = logger
2024-02-18 01:03:43 +00:00
""" The app logger. """
2024-01-13 07:02:58 +00:00
self . config = config
2024-02-18 01:03:43 +00:00
""" The app config. """
2024-01-13 07:02:58 +00:00
self . util = util
2024-02-18 01:03:43 +00:00
""" Utility methods, including a method to check if an invocation was canceled and step callbacks. """
2024-02-05 06:40:49 +00:00
self . boards = boards
2024-02-18 01:03:43 +00:00
""" Methods to interact with boards. """
2024-02-18 00:56:54 +00:00
self . _data = data
2024-02-18 01:03:43 +00:00
""" An internal API providing access to data about the current queue item and invocation. You probably shouldn ' t use this. It may change without warning. """
2024-02-07 03:36:42 +00:00
self . _services = services
2024-02-18 01:03:43 +00:00
""" An internal API providing access to all application services. You probably shouldn ' t use this. It may change without warning. """
2024-01-14 09:16:51 +00:00
2024-01-13 07:02:58 +00:00
def build_invocation_context (
services : InvocationServices ,
2024-02-18 00:56:54 +00:00
data : InvocationContextData ,
2024-02-17 14:41:04 +00:00
cancel_event : threading . Event ,
2024-01-13 07:02:58 +00:00
) - > InvocationContext :
"""
2024-01-13 13:05:15 +00:00
Builds the invocation context for a specific invocation execution .
2024-01-13 07:02:58 +00:00
2024-02-18 01:03:43 +00:00
: param services : The invocation services to wrap .
: param data : The invocation context data .
2024-01-13 07:02:58 +00:00
"""
2024-02-18 00:56:54 +00:00
logger = LoggerInterface ( services = services , data = data )
images = ImagesInterface ( services = services , data = data )
tensors = TensorsInterface ( services = services , data = data )
models = ModelsInterface ( services = services , data = data )
config = ConfigInterface ( services = services , data = data )
util = UtilInterface ( services = services , data = data , cancel_event = cancel_event )
conditioning = ConditioningInterface ( services = services , data = data )
boards = BoardsInterface ( services = services , data = data )
2024-01-13 07:02:58 +00:00
ctx = InvocationContext (
images = images ,
logger = logger ,
config = config ,
2024-02-07 06:41:23 +00:00
tensors = tensors ,
2024-01-13 07:02:58 +00:00
models = models ,
2024-02-18 00:56:54 +00:00
data = data ,
2024-01-13 07:02:58 +00:00
util = util ,
conditioning = conditioning ,
2024-01-14 09:16:51 +00:00
services = services ,
2024-02-05 06:40:49 +00:00
boards = boards ,
2024-01-13 07:02:58 +00:00
)
return ctx