2024-01-13 07:02:58 +00:00
from dataclasses import dataclass
2024-02-15 09:43:41 +00:00
from pathlib import Path
2024-06-02 22:35:23 +00:00
from typing import TYPE_CHECKING , Callable , Optional , Union
2024-01-13 07:02:58 +00:00
from PIL . Image import Image
2024-04-04 03:26:48 +00:00
from pydantic . networks import AnyHttpUrl
2024-06-02 22:35:23 +00:00
from torch import Tensor
2024-01-13 07:02:58 +00:00
2024-02-18 21:56:46 +00:00
from invokeai . app . invocations . constants import IMAGE_MODES
2024-02-19 04:11:36 +00:00
from invokeai . app . invocations . fields import MetadataField , WithBoard , WithMetadata
2024-02-05 06:40:49 +00:00
from invokeai . app . services . boards . boards_common import BoardDTO
2024-01-13 07:02:58 +00:00
from invokeai . app . services . config . config_default import InvokeAIAppConfig
2024-02-07 05:33:55 +00:00
from invokeai . app . services . image_records . image_records_common import ImageCategory , ResourceOrigin
2024-01-13 07:02:58 +00:00
from invokeai . app . services . images . images_common import ImageDTO
from invokeai . app . services . invocation_services import InvocationServices
2024-08-16 21:04:48 +00:00
from invokeai . app . services . model_records import ModelRecordChanges
2024-03-06 08:37:15 +00:00
from invokeai . app . services . model_records . model_records_base import UnknownModelException
2024-01-13 07:02:58 +00:00
from invokeai . app . util . step_callback import stable_diffusion_step_callback
2024-06-06 04:31:41 +00:00
from invokeai . backend . model_manager . config import (
AnyModel ,
AnyModelConfig ,
BaseModelType ,
ModelFormat ,
ModelType ,
SubModelType ,
)
2024-06-04 00:31:05 +00:00
from invokeai . backend . model_manager . load . load_base import LoadedModel , LoadedModelWithoutConfig
2024-01-13 07:02:58 +00:00
from invokeai . backend . stable_diffusion . diffusers_pipeline import PipelineIntermediateState
2024-01-14 23:41:25 +00:00
from invokeai . backend . stable_diffusion . diffusion . conditioning_data import ConditioningFieldData
2024-01-13 07:02:58 +00:00
if TYPE_CHECKING :
from invokeai . app . invocations . baseinvocation import BaseInvocation
2024-03-09 08:43:24 +00:00
from invokeai . app . invocations . model import ModelIdentifierField
2024-02-18 00:51:50 +00:00
from invokeai . app . services . session_queue . session_queue_common import SessionQueueItem
2024-01-13 07:02:58 +00:00
"""
The InvocationContext provides access to various services and data about the current invocation .
We do not provide the invocation services directly , as their methods are both dangerous and
inconvenient to use .
For example :
- The ` images ` service allows nodes to delete or unsafely modify existing images .
- The ` configuration ` service allows nodes to change the app ' s config at runtime.
- The ` events ` service allows nodes to emit arbitrary events .
Wrapping these services provides a simpler and safer interface for nodes to use .
When a node executes , a fresh ` InvocationContext ` is built for it , ensuring nodes cannot interfere
with each other .
2024-01-13 13:05:15 +00:00
Many of the wrappers have the same signature as the methods they wrap . This allows us to write
user - facing docstrings and not need to go and update the internal services to match .
2024-01-13 07:02:58 +00:00
Note : The docstrings are in weird places , but that ' s where they must be to get IDEs to see them.
"""
2024-01-16 09:02:38 +00:00
@dataclass
2024-01-13 07:02:58 +00:00
class InvocationContextData :
2024-02-18 00:51:50 +00:00
queue_item : " SessionQueueItem "
""" The queue item that is being executed. """
2024-01-13 07:02:58 +00:00
invocation : " BaseInvocation "
2024-01-13 13:05:15 +00:00
""" The invocation that is being executed. """
2024-02-18 00:51:50 +00:00
source_invocation_id : str
""" The ID of the invocation from which the currently executing invocation was prepared. """
2024-01-13 07:02:58 +00:00
2024-02-07 03:24:05 +00:00
class InvocationContextInterface :
2024-02-18 00:56:54 +00:00
def __init__ ( self , services : InvocationServices , data : InvocationContextData ) - > None :
2024-02-07 03:24:05 +00:00
self . _services = services
2024-02-18 00:56:54 +00:00
self . _data = data
2024-02-07 03:24:05 +00:00
class BoardsInterface ( InvocationContextInterface ) :
def create ( self , board_name : str ) - > BoardDTO :
2024-02-29 12:03:17 +00:00
""" Creates a board.
Args :
board_name : The name of the board to create .
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Returns :
The created board DTO .
2024-02-07 03:24:05 +00:00
"""
return self . _services . boards . create ( board_name )
def get_dto ( self , board_id : str ) - > BoardDTO :
2024-02-29 12:03:17 +00:00
""" Gets a board DTO.
Args :
board_id : The ID of the board to get .
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Returns :
The board DTO .
2024-02-07 03:24:05 +00:00
"""
return self . _services . boards . get_dto ( board_id )
def get_all ( self ) - > list [ BoardDTO ] :
2024-02-29 12:03:17 +00:00
""" Gets all boards.
Returns :
A list of all boards .
2024-02-07 03:24:05 +00:00
"""
return self . _services . boards . get_all ( )
def add_image_to_board ( self , board_id : str , image_name : str ) - > None :
2024-02-29 12:03:17 +00:00
""" Adds an image to a board.
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Args :
board_id : The ID of the board to add the image to .
image_name : The name of the image to add to the board .
2024-02-07 03:24:05 +00:00
"""
return self . _services . board_images . add_image_to_board ( board_id , image_name )
def get_all_image_names_for_board ( self , board_id : str ) - > list [ str ] :
2024-02-29 12:03:17 +00:00
""" Gets all image names for a board.
Args :
board_id : The ID of the board to get the image names for .
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Returns :
A list of all image names for the board .
2024-02-07 03:24:05 +00:00
"""
return self . _services . board_images . get_all_board_image_names_for_board ( board_id )
class LoggerInterface ( InvocationContextInterface ) :
def debug ( self , message : str ) - > None :
2024-02-29 12:03:17 +00:00
""" Logs a debug message.
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Args :
message : The message to log .
2024-02-07 03:24:05 +00:00
"""
self . _services . logger . debug ( message )
def info ( self , message : str ) - > None :
2024-02-29 12:03:17 +00:00
""" Logs an info message.
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Args :
message : The message to log .
2024-02-07 03:24:05 +00:00
"""
self . _services . logger . info ( message )
def warning ( self , message : str ) - > None :
2024-02-29 12:03:17 +00:00
""" Logs a warning message.
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Args :
message : The message to log .
2024-02-07 03:24:05 +00:00
"""
self . _services . logger . warning ( message )
def error ( self , message : str ) - > None :
2024-02-29 12:03:17 +00:00
""" Logs an error message.
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Args :
message : The message to log .
2024-02-07 03:24:05 +00:00
"""
self . _services . logger . error ( message )
class ImagesInterface ( InvocationContextInterface ) :
def save (
2024-01-13 07:02:58 +00:00
self ,
2024-02-07 03:24:05 +00:00
image : Image ,
board_id : Optional [ str ] = None ,
image_category : ImageCategory = ImageCategory . GENERAL ,
metadata : Optional [ MetadataField ] = None ,
) - > ImageDTO :
2024-02-29 12:03:17 +00:00
""" Saves an image, returning its DTO.
2024-02-07 03:24:05 +00:00
If the current queue item has a workflow or metadata , it is automatically saved with the image .
2024-02-29 12:03:17 +00:00
Args :
image : The image to save , as a PIL image .
board_id : The board ID to add the image to , if it should be added . It the invocation \
2024-02-07 05:33:55 +00:00
inherits from ` WithBoard ` , that board will be used automatically . * * Use this only if \
you want to override or provide a board manually ! * *
2024-02-29 12:03:17 +00:00
image_category : The category of the image . Only the GENERAL category is added \
2024-02-07 03:24:05 +00:00
to the gallery .
2024-02-29 12:03:17 +00:00
metadata : The metadata to save with the image , if it should have any . If the \
2024-02-07 03:24:05 +00:00
invocation inherits from ` WithMetadata ` , that metadata will be used automatically . \
* * Use this only if you want to override or provide metadata manually ! * *
2024-02-29 12:03:17 +00:00
Returns :
The saved image DTO .
2024-02-07 03:24:05 +00:00
"""
2024-02-08 00:05:33 +00:00
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
metadata_ = None
if metadata :
2024-05-17 09:25:04 +00:00
metadata_ = metadata . model_dump_json ( )
elif isinstance ( self . _data . invocation , WithMetadata ) and self . _data . invocation . metadata :
metadata_ = self . _data . invocation . metadata . model_dump_json ( )
2024-02-08 00:05:33 +00:00
# If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None.
board_id_ = None
if board_id :
board_id_ = board_id
2024-02-18 00:56:54 +00:00
elif isinstance ( self . _data . invocation , WithBoard ) and self . _data . invocation . board :
board_id_ = self . _data . invocation . board . board_id
2024-02-07 05:33:55 +00:00
2024-05-17 09:25:04 +00:00
workflow_ = None
if self . _data . queue_item . workflow :
workflow_ = self . _data . queue_item . workflow . model_dump_json ( )
graph_ = None
if self . _data . queue_item . session . graph :
graph_ = self . _data . queue_item . session . graph . model_dump_json ( )
2024-02-07 03:24:05 +00:00
return self . _services . images . create (
image = image ,
2024-02-18 00:56:54 +00:00
is_intermediate = self . _data . invocation . is_intermediate ,
2024-02-07 03:24:05 +00:00
image_category = image_category ,
2024-02-07 05:33:55 +00:00
board_id = board_id_ ,
2024-02-07 03:24:05 +00:00
metadata = metadata_ ,
image_origin = ResourceOrigin . INTERNAL ,
2024-05-17 09:25:04 +00:00
workflow = workflow_ ,
graph = graph_ ,
2024-02-18 00:56:54 +00:00
session_id = self . _data . queue_item . session_id ,
node_id = self . _data . invocation . id ,
2024-02-07 03:24:05 +00:00
)
2024-02-18 21:56:46 +00:00
def get_pil ( self , image_name : str , mode : IMAGE_MODES | None = None ) - > Image :
2024-02-29 12:03:17 +00:00
""" Gets an image as a PIL Image object.
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Args :
image_name : The name of the image to get .
mode : The color mode to convert the image to . If None , the original mode is used .
Returns :
The image as a PIL Image object .
2024-02-18 20:42:58 +00:00
"""
image = self . _services . images . get_pil_image ( image_name )
2024-02-18 21:56:46 +00:00
if mode and mode != image . mode :
2024-02-18 20:42:58 +00:00
try :
2024-02-18 21:56:46 +00:00
image = image . convert ( mode )
2024-02-18 20:42:58 +00:00
except ValueError :
2024-02-19 04:11:36 +00:00
self . _services . logger . warning (
f " Could not convert image from { image . mode } to { mode } . Using original mode instead. "
)
2024-02-18 20:42:58 +00:00
return image
2024-02-07 03:24:05 +00:00
def get_metadata ( self , image_name : str ) - > Optional [ MetadataField ] :
2024-02-29 12:03:17 +00:00
""" Gets an image ' s metadata, if it has any.
Args :
image_name : The name of the image to get the metadata for .
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Returns :
The image ' s metadata, if it has any.
2024-02-07 03:24:05 +00:00
"""
return self . _services . images . get_metadata ( image_name )
def get_dto ( self , image_name : str ) - > ImageDTO :
2024-02-29 12:03:17 +00:00
""" Gets an image as an ImageDTO object.
Args :
image_name : The name of the image to get .
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Returns :
The image as an ImageDTO object .
2024-02-07 03:24:05 +00:00
"""
return self . _services . images . get_dto ( image_name )
2024-04-08 07:01:07 +00:00
def get_path ( self , image_name : str , thumbnail : bool = False ) - > Path :
""" Gets the internal path to an image or thumbnail.
Args :
image_name : The name of the image to get the path of .
thumbnail : Get the path of the thumbnail instead of the full image
Returns :
The local path of the image or thumbnail .
"""
return self . _services . images . get_path ( image_name , thumbnail )
2024-01-13 13:05:15 +00:00
2024-02-07 06:41:23 +00:00
class TensorsInterface ( InvocationContextInterface ) :
2024-06-02 22:35:23 +00:00
def save ( self , tensor : Tensor ) - > str :
2024-02-29 12:03:17 +00:00
""" Saves a tensor, returning its name.
2024-01-13 07:02:58 +00:00
2024-02-29 12:03:17 +00:00
Args :
tensor : The tensor to save .
Returns :
The name of the saved tensor .
2024-02-07 03:24:05 +00:00
"""
2024-01-13 07:02:58 +00:00
2024-02-07 13:23:47 +00:00
name = self . _services . tensors . save ( obj = tensor )
return name
2024-02-07 03:24:05 +00:00
2024-06-02 22:35:23 +00:00
def load ( self , name : str ) - > Tensor :
2024-02-29 12:03:17 +00:00
""" Loads a tensor by name.
Args :
name : The name of the tensor to load .
2024-01-13 13:05:15 +00:00
2024-02-29 12:03:17 +00:00
Returns :
The loaded tensor .
2024-02-07 03:24:05 +00:00
"""
2024-02-07 12:30:46 +00:00
return self . _services . tensors . load ( name )
2024-01-13 07:02:58 +00:00
2024-02-07 03:24:05 +00:00
class ConditioningInterface ( InvocationContextInterface ) :
def save ( self , conditioning_data : ConditioningFieldData ) - > str :
2024-02-29 12:03:17 +00:00
""" Saves a conditioning data object, returning its name.
Args :
conditioning_data : The conditioning data to save .
2024-01-13 07:02:58 +00:00
2024-02-29 12:03:17 +00:00
Returns :
The name of the saved conditioning data .
2024-02-07 03:24:05 +00:00
"""
2024-02-07 13:23:47 +00:00
name = self . _services . conditioning . save ( obj = conditioning_data )
return name
2024-01-13 07:02:58 +00:00
2024-02-07 12:30:46 +00:00
def load ( self , name : str ) - > ConditioningFieldData :
2024-02-29 12:03:17 +00:00
""" Loads conditioning data by name.
Args :
name : The name of the conditioning data to load .
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Returns :
The loaded conditioning data .
2024-02-07 03:24:05 +00:00
"""
2024-01-13 13:05:15 +00:00
2024-02-07 12:30:46 +00:00
return self . _services . conditioning . load ( name )
2024-02-07 03:24:05 +00:00
class ModelsInterface ( InvocationContextInterface ) :
2024-04-28 15:33:23 +00:00
""" Common API for loading, downloading and managing models. """
2024-03-09 08:43:24 +00:00
def exists ( self , identifier : Union [ str , " ModelIdentifierField " ] ) - > bool :
2024-04-28 15:33:23 +00:00
""" Check if a model exists.
2024-02-29 12:03:17 +00:00
Args :
2024-03-06 08:37:15 +00:00
identifier : The key or ModelField representing the model .
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Returns :
True if the model exists , False if not .
2024-02-07 03:24:05 +00:00
"""
2024-03-06 08:37:15 +00:00
if isinstance ( identifier , str ) :
2024-06-02 22:35:23 +00:00
return self . _services . model_manager . store . exists ( identifier )
2024-04-28 15:33:23 +00:00
else :
2024-06-02 22:35:23 +00:00
return self . _services . model_manager . store . exists ( identifier . key )
2024-03-06 08:37:15 +00:00
2024-03-09 08:43:24 +00:00
def load (
self , identifier : Union [ str , " ModelIdentifierField " ] , submodel_type : Optional [ SubModelType ] = None
) - > LoadedModel :
2024-04-28 15:33:23 +00:00
""" Load a model.
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Args :
2024-03-06 08:37:15 +00:00
identifier : The key or ModelField representing the model .
2024-02-29 12:03:17 +00:00
submodel_type : The submodel of the model to get .
Returns :
An object representing the loaded model .
2024-02-07 03:24:05 +00:00
"""
2024-06-02 22:35:23 +00:00
2024-02-07 23:57:01 +00:00
# The model manager emits events as it loads the model. It needs the context data to build
# the event payloads.
2024-02-07 03:24:05 +00:00
2024-03-06 08:37:15 +00:00
if isinstance ( identifier , str ) :
model = self . _services . model_manager . store . get_model ( identifier )
2024-03-14 08:04:19 +00:00
return self . _services . model_manager . load . load_model ( model , submodel_type )
2024-03-06 08:37:15 +00:00
else :
_submodel_type = submodel_type or identifier . submodel_type
model = self . _services . model_manager . store . get_model ( identifier . key )
2024-03-14 08:04:19 +00:00
return self . _services . model_manager . load . load_model ( model , _submodel_type )
2024-02-07 03:24:05 +00:00
2024-02-15 09:43:41 +00:00
def load_by_attrs (
2024-02-29 10:34:25 +00:00
self , name : str , base : BaseModelType , type : ModelType , submodel_type : Optional [ SubModelType ] = None
2024-02-15 09:43:41 +00:00
) - > LoadedModel :
2024-04-28 15:33:23 +00:00
""" Load a model by its attributes.
2024-02-29 12:03:17 +00:00
Args :
name : Name of the model .
base : The models ' base type, e.g. `BaseModelType.StableDiffusion1`, `BaseModelType.StableDiffusionXL`, etc.
type : Type of the model , e . g . ` ModelType . Main ` , ` ModelType . Vae ` , etc .
submodel_type : The type of submodel to load , e . g . ` SubModelType . UNet ` , ` SubModelType . TextEncoder ` , etc . Only main
models have submodels .
2024-02-15 09:43:41 +00:00
2024-02-29 12:03:17 +00:00
Returns :
An object representing the loaded model .
2024-02-15 09:43:41 +00:00
"""
2024-06-02 22:35:23 +00:00
2024-03-06 08:37:15 +00:00
configs = self . _services . model_manager . store . search_by_attr ( model_name = name , base_model = base , model_type = type )
if len ( configs ) == 0 :
raise UnknownModelException ( f " No model found with name { name } , base { base } , and type { type } " )
2024-02-29 12:03:17 +00:00
2024-03-06 08:37:15 +00:00
if len ( configs ) > 1 :
raise ValueError ( f " More than one model found with name { name } , base { base } , and type { type } " )
2024-02-07 03:24:05 +00:00
2024-03-14 08:04:19 +00:00
return self . _services . model_manager . load . load_model ( configs [ 0 ] , submodel_type )
2024-02-15 09:43:41 +00:00
2024-03-09 08:43:24 +00:00
def get_config ( self , identifier : Union [ str , " ModelIdentifierField " ] ) - > AnyModelConfig :
2024-04-28 15:33:23 +00:00
""" Get a model ' s config.
2024-02-15 09:43:41 +00:00
2024-02-29 12:03:17 +00:00
Args :
2024-03-06 08:37:15 +00:00
identifier : The key or ModelField representing the model .
2024-02-29 12:03:17 +00:00
Returns :
2024-03-06 08:37:15 +00:00
The model ' s config.
2024-02-07 03:24:05 +00:00
"""
2024-03-06 08:37:15 +00:00
if isinstance ( identifier , str ) :
2024-06-02 22:35:23 +00:00
return self . _services . model_manager . store . get_model ( identifier )
2024-04-28 15:33:23 +00:00
else :
2024-06-02 22:35:23 +00:00
return self . _services . model_manager . store . get_model ( identifier . key )
2024-02-15 09:43:41 +00:00
def search_by_path ( self , path : Path ) - > list [ AnyModelConfig ] :
2024-04-28 15:33:23 +00:00
""" Search for models by path.
2024-02-29 12:03:17 +00:00
Args :
path : The path to search for .
2024-02-15 09:43:41 +00:00
2024-02-29 12:03:17 +00:00
Returns :
A list of models that match the path .
2024-02-15 09:43:41 +00:00
"""
2024-06-02 22:35:23 +00:00
return self . _services . model_manager . store . search_by_path ( path )
2024-02-15 09:43:41 +00:00
def search_by_attrs (
self ,
2024-02-29 10:34:25 +00:00
name : Optional [ str ] = None ,
base : Optional [ BaseModelType ] = None ,
type : Optional [ ModelType ] = None ,
format : Optional [ ModelFormat ] = None ,
2024-02-15 09:43:41 +00:00
) - > list [ AnyModelConfig ] :
2024-04-28 15:33:23 +00:00
""" Search for models by attributes.
2024-02-29 12:03:17 +00:00
Args :
name : The name to search for ( exact match ) .
base : The base to search for , e . g . ` BaseModelType . StableDiffusion1 ` , ` BaseModelType . StableDiffusionXL ` , etc .
type : Type type of model to search for , e . g . ` ModelType . Main ` , ` ModelType . Vae ` , etc .
format : The format of model to search for , e . g . ` ModelFormat . Checkpoint ` , ` ModelFormat . Diffusers ` , etc .
2024-02-15 09:43:41 +00:00
2024-02-29 12:03:17 +00:00
Returns :
A list of models that match the attributes .
2024-02-15 09:43:41 +00:00
"""
2024-06-02 22:35:23 +00:00
return self . _services . model_manager . store . search_by_attr (
2024-02-29 10:34:25 +00:00
model_name = name ,
base_model = base ,
model_type = type ,
model_format = format ,
2024-02-15 09:43:41 +00:00
)
2024-02-07 03:24:05 +00:00
2024-05-18 02:29:19 +00:00
def download_and_cache_model (
2024-04-04 03:26:48 +00:00
self ,
2024-04-28 17:41:06 +00:00
source : str | AnyHttpUrl ,
2024-04-04 03:26:48 +00:00
) - > Path :
2024-04-12 04:55:21 +00:00
"""
Download the model file located at source to the models cache and return its Path .
2024-04-04 03:26:48 +00:00
This can be used to single - file install models and other resources of arbitrary types
which should not get registered with the database . If the model is already
installed , the cached path will be returned . Otherwise it will be downloaded .
Args :
2024-06-06 04:31:41 +00:00
source : A URL that points to the model , or a huggingface repo_id .
2024-04-04 03:26:48 +00:00
2024-06-02 22:35:23 +00:00
Returns :
Path to the downloaded model
2024-04-04 03:26:48 +00:00
"""
2024-06-02 22:35:23 +00:00
return self . _services . model_manager . install . download_and_cache_model ( source = source )
2024-04-04 03:26:48 +00:00
2024-08-16 21:04:48 +00:00
def import_local_model (
2024-08-19 14:14:58 +00:00
self ,
model_path : Path ,
config : Optional [ ModelRecordChanges ] = None ,
inplace : Optional [ bool ] = False ,
2024-08-16 21:04:48 +00:00
) :
"""
2024-08-21 15:18:07 +00:00
Import the model file located at the given local file path and return its ModelInstallJob .
This can be used to single - file models or directories .
Args :
model_path : A pathlib . Path object pointing to a model file or directory
config : Optional ModelRecordChanges to define manual probe overrides
inplace : Optional boolean to declare whether or not to install the model in the models dir
Returns :
ModelInstallJob object defining the install job to be used in tracking the job
2024-08-16 21:04:48 +00:00
"""
if not model_path . exists ( ) :
raise Exception ( " Models provided to import_local_model must already exist on disk " )
2024-08-21 15:18:07 +00:00
return self . _services . model_manager . install . heuristic_import ( str ( model_path ) , config = config , inplace = inplace )
2024-08-16 21:04:48 +00:00
2024-06-06 04:31:41 +00:00
def load_local_model (
2024-04-12 04:55:21 +00:00
self ,
2024-06-06 04:31:41 +00:00
model_path : Path ,
loader : Optional [ Callable [ [ Path ] , AnyModel ] ] = None ,
2024-06-04 00:31:05 +00:00
) - > LoadedModelWithoutConfig :
2024-04-12 04:55:21 +00:00
"""
2024-06-06 04:31:41 +00:00
Load the model file located at the indicated path
If a loader callable is provided , it will be invoked to load the model . Otherwise ,
` safetensors . torch . load_file ( ) ` or ` torch . load ( ) ` will be called to load the model .
Be aware that the LoadedModelWithoutConfig object has no ` config ` attribute
Args :
path : A model Path
loader : A Callable that expects a Path and returns a dict [ str | int , Any ]
Returns :
A LoadedModelWithoutConfig object .
"""
return self . _services . model_manager . load . load_model_from_path ( model_path = model_path , loader = loader )
def load_remote_model (
self ,
source : str | AnyHttpUrl ,
loader : Optional [ Callable [ [ Path ] , AnyModel ] ] = None ,
) - > LoadedModelWithoutConfig :
"""
Download , cache , and load the model file located at the indicated URL or repo_id .
2024-04-12 04:55:21 +00:00
2024-06-02 22:40:29 +00:00
If the model is already downloaded , it will be loaded from the cache .
If the a loader callable is provided , it will be invoked to load the model . Otherwise ,
` safetensors . torch . load_file ( ) ` or ` torch . load ( ) ` will be called to load the model .
2024-04-12 04:55:21 +00:00
2024-06-04 00:31:05 +00:00
Be aware that the LoadedModelWithoutConfig object has no ` config ` attribute
2024-04-12 04:55:21 +00:00
Args :
2024-06-06 04:31:41 +00:00
source : A URL or huggingface repoid .
2024-06-02 22:35:23 +00:00
loader : A Callable that expects a Path and returns a dict [ str | int , Any ]
2024-04-12 04:55:21 +00:00
Returns :
2024-06-04 00:31:05 +00:00
A LoadedModelWithoutConfig object .
2024-04-12 04:55:21 +00:00
"""
2024-06-06 04:31:41 +00:00
model_path = self . _services . model_manager . install . download_and_cache_model ( source = str ( source ) )
return self . _services . model_manager . load . load_model_from_path ( model_path = model_path , loader = loader )
2024-06-02 22:51:21 +00:00
2024-04-12 04:55:21 +00:00
2024-02-07 03:24:05 +00:00
class ConfigInterface ( InvocationContextInterface ) :
def get ( self ) - > InvokeAIAppConfig :
2024-06-02 22:35:23 +00:00
""" Gets the app ' s config.
2024-02-29 12:03:17 +00:00
Returns :
The app ' s config.
"""
2024-02-07 03:24:05 +00:00
2024-03-11 12:01:48 +00:00
return self . _services . configuration
2024-02-07 03:24:05 +00:00
class UtilInterface ( InvocationContextInterface ) :
2024-02-17 14:41:04 +00:00
def __init__ (
2024-05-26 01:40:38 +00:00
self , services : InvocationServices , data : InvocationContextData , is_canceled : Callable [ [ ] , bool ]
2024-02-17 14:41:04 +00:00
) - > None :
2024-02-18 00:56:54 +00:00
super ( ) . __init__ ( services , data )
2024-05-26 01:40:38 +00:00
self . _is_canceled = is_canceled
2024-02-18 00:54:16 +00:00
def is_canceled ( self ) - > bool :
2024-02-29 12:03:17 +00:00
""" Checks if the current session has been canceled.
Returns :
True if the current session has been canceled , False if not .
"""
2024-05-26 01:40:38 +00:00
return self . _is_canceled ( )
2024-02-17 14:41:04 +00:00
2024-02-07 03:24:05 +00:00
def sd_step_callback ( self , intermediate_state : PipelineIntermediateState , base_model : BaseModelType ) - > None :
"""
The step callback emits a progress event with the current step , the total number of
steps , a preview image , and some other internal metadata .
This should be called after each denoising step .
2024-02-29 12:03:17 +00:00
Args :
intermediate_state : The intermediate state of the diffusion pipeline .
base_model : The base model for the current denoising step .
2024-02-07 03:24:05 +00:00
"""
2024-01-13 13:05:15 +00:00
2024-02-07 03:24:05 +00:00
stable_diffusion_step_callback (
2024-02-18 00:56:54 +00:00
context_data = self . _data ,
2024-02-07 03:24:05 +00:00
intermediate_state = intermediate_state ,
base_model = base_model ,
events = self . _services . events ,
2024-02-18 00:54:16 +00:00
is_canceled = self . is_canceled ,
2024-02-07 03:24:05 +00:00
)
2024-01-13 07:02:58 +00:00
class InvocationContext :
2024-02-29 12:03:17 +00:00
""" Provides access to various services and data for the current invocation.
Attributes :
images ( ImagesInterface ) : Methods to save , get and update images and their metadata .
tensors ( TensorsInterface ) : Methods to save and get tensors , including image , noise , masks , and masked images .
conditioning ( ConditioningInterface ) : Methods to save and get conditioning data .
models ( ModelsInterface ) : Methods to check if a model exists , get a model , and get a model ' s info.
logger ( LoggerInterface ) : The app logger .
config ( ConfigInterface ) : The app config .
util ( UtilInterface ) : Utility methods , including a method to check if an invocation was canceled and step callbacks .
boards ( BoardsInterface ) : Methods to interact with boards .
2024-01-13 07:02:58 +00:00
"""
def __init__ (
self ,
images : ImagesInterface ,
2024-02-07 06:41:23 +00:00
tensors : TensorsInterface ,
2024-01-13 13:05:15 +00:00
conditioning : ConditioningInterface ,
2024-01-13 07:02:58 +00:00
models : ModelsInterface ,
logger : LoggerInterface ,
2024-01-13 13:05:15 +00:00
config : ConfigInterface ,
2024-01-13 07:02:58 +00:00
util : UtilInterface ,
2024-02-05 06:40:49 +00:00
boards : BoardsInterface ,
2024-02-18 00:56:54 +00:00
data : InvocationContextData ,
2024-01-14 09:16:51 +00:00
services : InvocationServices ,
2024-01-13 07:02:58 +00:00
) - > None :
self . images = images
2024-02-18 01:03:43 +00:00
""" Methods to save, get and update images and their metadata. """
2024-02-07 06:41:23 +00:00
self . tensors = tensors
2024-02-18 01:03:43 +00:00
""" Methods to save and get tensors, including image, noise, masks, and masked images. """
2024-01-13 07:02:58 +00:00
self . conditioning = conditioning
2024-02-18 01:03:43 +00:00
""" Methods to save and get conditioning data. """
2024-01-13 07:02:58 +00:00
self . models = models
2024-02-18 01:03:43 +00:00
""" Methods to check if a model exists, get a model, and get a model ' s info. """
2024-01-13 13:05:15 +00:00
self . logger = logger
2024-02-18 01:03:43 +00:00
""" The app logger. """
2024-01-13 07:02:58 +00:00
self . config = config
2024-02-18 01:03:43 +00:00
""" The app config. """
2024-01-13 07:02:58 +00:00
self . util = util
2024-02-18 01:03:43 +00:00
""" Utility methods, including a method to check if an invocation was canceled and step callbacks. """
2024-02-05 06:40:49 +00:00
self . boards = boards
2024-02-18 01:03:43 +00:00
""" Methods to interact with boards. """
2024-02-18 00:56:54 +00:00
self . _data = data
2024-02-18 01:03:43 +00:00
""" An internal API providing access to data about the current queue item and invocation. You probably shouldn ' t use this. It may change without warning. """
2024-02-07 03:36:42 +00:00
self . _services = services
2024-02-18 01:03:43 +00:00
""" An internal API providing access to all application services. You probably shouldn ' t use this. It may change without warning. """
2024-01-14 09:16:51 +00:00
2024-01-13 07:02:58 +00:00
def build_invocation_context (
services : InvocationServices ,
2024-02-18 00:56:54 +00:00
data : InvocationContextData ,
2024-05-26 01:40:38 +00:00
is_canceled : Callable [ [ ] , bool ] ,
2024-01-13 07:02:58 +00:00
) - > InvocationContext :
2024-02-29 12:03:17 +00:00
""" Builds the invocation context for a specific invocation execution.
Args :
services : The invocation services to wrap .
data : The invocation context data .
2024-01-13 07:02:58 +00:00
2024-02-29 12:03:17 +00:00
Returns :
The invocation context .
2024-01-13 07:02:58 +00:00
"""
2024-02-18 00:56:54 +00:00
logger = LoggerInterface ( services = services , data = data )
images = ImagesInterface ( services = services , data = data )
tensors = TensorsInterface ( services = services , data = data )
models = ModelsInterface ( services = services , data = data )
config = ConfigInterface ( services = services , data = data )
2024-05-26 01:40:38 +00:00
util = UtilInterface ( services = services , data = data , is_canceled = is_canceled )
2024-02-18 00:56:54 +00:00
conditioning = ConditioningInterface ( services = services , data = data )
boards = BoardsInterface ( services = services , data = data )
2024-01-13 07:02:58 +00:00
ctx = InvocationContext (
images = images ,
logger = logger ,
config = config ,
2024-02-07 06:41:23 +00:00
tensors = tensors ,
2024-01-13 07:02:58 +00:00
models = models ,
2024-02-18 00:56:54 +00:00
data = data ,
2024-01-13 07:02:58 +00:00
util = util ,
conditioning = conditioning ,
2024-01-14 09:16:51 +00:00
services = services ,
2024-02-05 06:40:49 +00:00
boards = boards ,
2024-01-13 07:02:58 +00:00
)
return ctx