2024-02-17 14:41:04 +00:00
import threading
2024-01-13 07:02:58 +00:00
from dataclasses import dataclass
2024-02-15 09:43:41 +00:00
from pathlib import Path
2024-04-12 04:55:21 +00:00
from typing import TYPE_CHECKING , Any , Callable , Dict , Optional , Union
2024-01-13 07:02:58 +00:00
2024-04-12 04:55:21 +00:00
from picklescan . scanner import scan_file_path
2024-01-13 07:02:58 +00:00
from PIL . Image import Image
2024-04-04 03:26:48 +00:00
from pydantic . networks import AnyHttpUrl
2024-04-12 04:55:21 +00:00
from safetensors . torch import load_file as safetensors_load_file
2024-01-13 07:02:58 +00:00
from torch import Tensor
2024-04-12 04:55:21 +00:00
from torch import load as torch_load
2024-01-13 07:02:58 +00:00
2024-02-18 21:56:46 +00:00
from invokeai . app . invocations . constants import IMAGE_MODES
2024-02-19 04:11:36 +00:00
from invokeai . app . invocations . fields import MetadataField , WithBoard , WithMetadata
2024-02-05 06:40:49 +00:00
from invokeai . app . services . boards . boards_common import BoardDTO
2024-01-13 07:02:58 +00:00
from invokeai . app . services . config . config_default import InvokeAIAppConfig
2024-02-07 05:33:55 +00:00
from invokeai . app . services . image_records . image_records_common import ImageCategory , ResourceOrigin
2024-01-13 07:02:58 +00:00
from invokeai . app . services . images . images_common import ImageDTO
from invokeai . app . services . invocation_services import InvocationServices
2024-03-06 08:37:15 +00:00
from invokeai . app . services . model_records . model_records_base import UnknownModelException
2024-01-13 07:02:58 +00:00
from invokeai . app . util . step_callback import stable_diffusion_step_callback
2024-02-15 09:43:41 +00:00
from invokeai . backend . model_manager . config import AnyModelConfig , BaseModelType , ModelFormat , ModelType , SubModelType
from invokeai . backend . model_manager . load . load_base import LoadedModel
2024-01-13 07:02:58 +00:00
from invokeai . backend . stable_diffusion . diffusers_pipeline import PipelineIntermediateState
2024-01-14 23:41:25 +00:00
from invokeai . backend . stable_diffusion . diffusion . conditioning_data import ConditioningFieldData
2024-01-13 07:02:58 +00:00
if TYPE_CHECKING :
from invokeai . app . invocations . baseinvocation import BaseInvocation
2024-03-09 08:43:24 +00:00
from invokeai . app . invocations . model import ModelIdentifierField
2024-02-18 00:51:50 +00:00
from invokeai . app . services . session_queue . session_queue_common import SessionQueueItem
2024-01-13 07:02:58 +00:00
"""
The InvocationContext provides access to various services and data about the current invocation .
We do not provide the invocation services directly , as their methods are both dangerous and
inconvenient to use .
For example :
- The ` images ` service allows nodes to delete or unsafely modify existing images .
- The ` configuration ` service allows nodes to change the app ' s config at runtime.
- The ` events ` service allows nodes to emit arbitrary events .
Wrapping these services provides a simpler and safer interface for nodes to use .
When a node executes , a fresh ` InvocationContext ` is built for it , ensuring nodes cannot interfere
with each other .
2024-01-13 13:05:15 +00:00
Many of the wrappers have the same signature as the methods they wrap . This allows us to write
user - facing docstrings and not need to go and update the internal services to match .
2024-01-13 07:02:58 +00:00
Note : The docstrings are in weird places , but that ' s where they must be to get IDEs to see them.
"""
2024-01-16 09:02:38 +00:00
@dataclass
2024-01-13 07:02:58 +00:00
class InvocationContextData :
2024-02-18 00:51:50 +00:00
queue_item : " SessionQueueItem "
""" The queue item that is being executed. """
2024-01-13 07:02:58 +00:00
invocation : " BaseInvocation "
2024-01-13 13:05:15 +00:00
""" The invocation that is being executed. """
2024-02-18 00:51:50 +00:00
source_invocation_id : str
""" The ID of the invocation from which the currently executing invocation was prepared. """
2024-01-13 07:02:58 +00:00
2024-02-07 03:24:05 +00:00
class InvocationContextInterface :
2024-02-18 00:56:54 +00:00
def __init__ ( self , services : InvocationServices , data : InvocationContextData ) - > None :
2024-02-07 03:24:05 +00:00
self . _services = services
2024-02-18 00:56:54 +00:00
self . _data = data
2024-02-07 03:24:05 +00:00
class BoardsInterface ( InvocationContextInterface ) :
def create ( self , board_name : str ) - > BoardDTO :
2024-02-29 12:03:17 +00:00
""" Creates a board.
Args :
board_name : The name of the board to create .
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Returns :
The created board DTO .
2024-02-07 03:24:05 +00:00
"""
return self . _services . boards . create ( board_name )
def get_dto ( self , board_id : str ) - > BoardDTO :
2024-02-29 12:03:17 +00:00
""" Gets a board DTO.
Args :
board_id : The ID of the board to get .
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Returns :
The board DTO .
2024-02-07 03:24:05 +00:00
"""
return self . _services . boards . get_dto ( board_id )
def get_all ( self ) - > list [ BoardDTO ] :
2024-02-29 12:03:17 +00:00
""" Gets all boards.
Returns :
A list of all boards .
2024-02-07 03:24:05 +00:00
"""
return self . _services . boards . get_all ( )
def add_image_to_board ( self , board_id : str , image_name : str ) - > None :
2024-02-29 12:03:17 +00:00
""" Adds an image to a board.
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Args :
board_id : The ID of the board to add the image to .
image_name : The name of the image to add to the board .
2024-02-07 03:24:05 +00:00
"""
return self . _services . board_images . add_image_to_board ( board_id , image_name )
def get_all_image_names_for_board ( self , board_id : str ) - > list [ str ] :
2024-02-29 12:03:17 +00:00
""" Gets all image names for a board.
Args :
board_id : The ID of the board to get the image names for .
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Returns :
A list of all image names for the board .
2024-02-07 03:24:05 +00:00
"""
return self . _services . board_images . get_all_board_image_names_for_board ( board_id )
class LoggerInterface ( InvocationContextInterface ) :
def debug ( self , message : str ) - > None :
2024-02-29 12:03:17 +00:00
""" Logs a debug message.
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Args :
message : The message to log .
2024-02-07 03:24:05 +00:00
"""
self . _services . logger . debug ( message )
def info ( self , message : str ) - > None :
2024-02-29 12:03:17 +00:00
""" Logs an info message.
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Args :
message : The message to log .
2024-02-07 03:24:05 +00:00
"""
self . _services . logger . info ( message )
def warning ( self , message : str ) - > None :
2024-02-29 12:03:17 +00:00
""" Logs a warning message.
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Args :
message : The message to log .
2024-02-07 03:24:05 +00:00
"""
self . _services . logger . warning ( message )
def error ( self , message : str ) - > None :
2024-02-29 12:03:17 +00:00
""" Logs an error message.
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Args :
message : The message to log .
2024-02-07 03:24:05 +00:00
"""
self . _services . logger . error ( message )
class ImagesInterface ( InvocationContextInterface ) :
def save (
2024-01-13 07:02:58 +00:00
self ,
2024-02-07 03:24:05 +00:00
image : Image ,
board_id : Optional [ str ] = None ,
image_category : ImageCategory = ImageCategory . GENERAL ,
metadata : Optional [ MetadataField ] = None ,
) - > ImageDTO :
2024-02-29 12:03:17 +00:00
""" Saves an image, returning its DTO.
2024-02-07 03:24:05 +00:00
If the current queue item has a workflow or metadata , it is automatically saved with the image .
2024-02-29 12:03:17 +00:00
Args :
image : The image to save , as a PIL image .
board_id : The board ID to add the image to , if it should be added . It the invocation \
2024-02-07 05:33:55 +00:00
inherits from ` WithBoard ` , that board will be used automatically . * * Use this only if \
you want to override or provide a board manually ! * *
2024-02-29 12:03:17 +00:00
image_category : The category of the image . Only the GENERAL category is added \
2024-02-07 03:24:05 +00:00
to the gallery .
2024-02-29 12:03:17 +00:00
metadata : The metadata to save with the image , if it should have any . If the \
2024-02-07 03:24:05 +00:00
invocation inherits from ` WithMetadata ` , that metadata will be used automatically . \
* * Use this only if you want to override or provide metadata manually ! * *
2024-02-29 12:03:17 +00:00
Returns :
The saved image DTO .
2024-02-07 03:24:05 +00:00
"""
2024-02-08 00:05:33 +00:00
# If `metadata` is provided directly, use that. Else, use the metadata provided by `WithMetadata`, falling back to None.
metadata_ = None
if metadata :
metadata_ = metadata
2024-02-18 00:56:54 +00:00
elif isinstance ( self . _data . invocation , WithMetadata ) :
metadata_ = self . _data . invocation . metadata
2024-02-08 00:05:33 +00:00
# If `board_id` is provided directly, use that. Else, use the board provided by `WithBoard`, falling back to None.
board_id_ = None
if board_id :
board_id_ = board_id
2024-02-18 00:56:54 +00:00
elif isinstance ( self . _data . invocation , WithBoard ) and self . _data . invocation . board :
board_id_ = self . _data . invocation . board . board_id
2024-02-07 05:33:55 +00:00
2024-02-07 03:24:05 +00:00
return self . _services . images . create (
image = image ,
2024-02-18 00:56:54 +00:00
is_intermediate = self . _data . invocation . is_intermediate ,
2024-02-07 03:24:05 +00:00
image_category = image_category ,
2024-02-07 05:33:55 +00:00
board_id = board_id_ ,
2024-02-07 03:24:05 +00:00
metadata = metadata_ ,
image_origin = ResourceOrigin . INTERNAL ,
2024-02-18 00:56:54 +00:00
workflow = self . _data . queue_item . workflow ,
session_id = self . _data . queue_item . session_id ,
node_id = self . _data . invocation . id ,
2024-02-07 03:24:05 +00:00
)
2024-02-18 21:56:46 +00:00
def get_pil ( self , image_name : str , mode : IMAGE_MODES | None = None ) - > Image :
2024-02-29 12:03:17 +00:00
""" Gets an image as a PIL Image object.
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Args :
image_name : The name of the image to get .
mode : The color mode to convert the image to . If None , the original mode is used .
Returns :
The image as a PIL Image object .
2024-02-18 20:42:58 +00:00
"""
image = self . _services . images . get_pil_image ( image_name )
2024-02-18 21:56:46 +00:00
if mode and mode != image . mode :
2024-02-18 20:42:58 +00:00
try :
2024-02-18 21:56:46 +00:00
image = image . convert ( mode )
2024-02-18 20:42:58 +00:00
except ValueError :
2024-02-19 04:11:36 +00:00
self . _services . logger . warning (
f " Could not convert image from { image . mode } to { mode } . Using original mode instead. "
)
2024-02-18 20:42:58 +00:00
return image
2024-02-07 03:24:05 +00:00
def get_metadata ( self , image_name : str ) - > Optional [ MetadataField ] :
2024-02-29 12:03:17 +00:00
""" Gets an image ' s metadata, if it has any.
Args :
image_name : The name of the image to get the metadata for .
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Returns :
The image ' s metadata, if it has any.
2024-02-07 03:24:05 +00:00
"""
return self . _services . images . get_metadata ( image_name )
def get_dto ( self , image_name : str ) - > ImageDTO :
2024-02-29 12:03:17 +00:00
""" Gets an image as an ImageDTO object.
Args :
image_name : The name of the image to get .
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Returns :
The image as an ImageDTO object .
2024-02-07 03:24:05 +00:00
"""
return self . _services . images . get_dto ( image_name )
2024-01-13 13:05:15 +00:00
2024-02-07 06:41:23 +00:00
class TensorsInterface ( InvocationContextInterface ) :
2024-02-07 03:24:05 +00:00
def save ( self , tensor : Tensor ) - > str :
2024-02-29 12:03:17 +00:00
""" Saves a tensor, returning its name.
2024-01-13 07:02:58 +00:00
2024-02-29 12:03:17 +00:00
Args :
tensor : The tensor to save .
Returns :
The name of the saved tensor .
2024-02-07 03:24:05 +00:00
"""
2024-01-13 07:02:58 +00:00
2024-02-07 13:23:47 +00:00
name = self . _services . tensors . save ( obj = tensor )
return name
2024-02-07 03:24:05 +00:00
2024-02-07 12:30:46 +00:00
def load ( self , name : str ) - > Tensor :
2024-02-29 12:03:17 +00:00
""" Loads a tensor by name.
Args :
name : The name of the tensor to load .
2024-01-13 13:05:15 +00:00
2024-02-29 12:03:17 +00:00
Returns :
The loaded tensor .
2024-02-07 03:24:05 +00:00
"""
2024-02-07 12:30:46 +00:00
return self . _services . tensors . load ( name )
2024-01-13 07:02:58 +00:00
2024-02-07 03:24:05 +00:00
class ConditioningInterface ( InvocationContextInterface ) :
def save ( self , conditioning_data : ConditioningFieldData ) - > str :
2024-02-29 12:03:17 +00:00
""" Saves a conditioning data object, returning its name.
Args :
conditioning_data : The conditioning data to save .
2024-01-13 07:02:58 +00:00
2024-02-29 12:03:17 +00:00
Returns :
The name of the saved conditioning data .
2024-02-07 03:24:05 +00:00
"""
2024-02-07 13:23:47 +00:00
name = self . _services . conditioning . save ( obj = conditioning_data )
return name
2024-01-13 07:02:58 +00:00
2024-02-07 12:30:46 +00:00
def load ( self , name : str ) - > ConditioningFieldData :
2024-02-29 12:03:17 +00:00
""" Loads conditioning data by name.
Args :
name : The name of the conditioning data to load .
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Returns :
The loaded conditioning data .
2024-02-07 03:24:05 +00:00
"""
2024-01-13 13:05:15 +00:00
2024-02-07 12:30:46 +00:00
return self . _services . conditioning . load ( name )
2024-02-07 03:24:05 +00:00
class ModelsInterface ( InvocationContextInterface ) :
2024-03-09 08:43:24 +00:00
def exists ( self , identifier : Union [ str , " ModelIdentifierField " ] ) - > bool :
2024-02-29 12:03:17 +00:00
""" Checks if a model exists.
Args :
2024-03-06 08:37:15 +00:00
identifier : The key or ModelField representing the model .
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Returns :
True if the model exists , False if not .
2024-02-07 03:24:05 +00:00
"""
2024-03-06 08:37:15 +00:00
if isinstance ( identifier , str ) :
return self . _services . model_manager . store . exists ( identifier )
2024-02-07 03:24:05 +00:00
2024-03-06 08:37:15 +00:00
return self . _services . model_manager . store . exists ( identifier . key )
2024-03-09 08:43:24 +00:00
def load (
self , identifier : Union [ str , " ModelIdentifierField " ] , submodel_type : Optional [ SubModelType ] = None
) - > LoadedModel :
2024-02-29 12:03:17 +00:00
""" Loads a model.
2024-02-07 03:24:05 +00:00
2024-02-29 12:03:17 +00:00
Args :
2024-03-06 08:37:15 +00:00
identifier : The key or ModelField representing the model .
2024-02-29 12:03:17 +00:00
submodel_type : The submodel of the model to get .
Returns :
An object representing the loaded model .
2024-02-07 03:24:05 +00:00
"""
2024-02-07 23:57:01 +00:00
# The model manager emits events as it loads the model. It needs the context data to build
# the event payloads.
2024-02-07 03:24:05 +00:00
2024-03-06 08:37:15 +00:00
if isinstance ( identifier , str ) :
model = self . _services . model_manager . store . get_model ( identifier )
return self . _services . model_manager . load . load_model ( model , submodel_type , self . _data )
else :
_submodel_type = submodel_type or identifier . submodel_type
model = self . _services . model_manager . store . get_model ( identifier . key )
return self . _services . model_manager . load . load_model ( model , _submodel_type , self . _data )
2024-02-07 03:24:05 +00:00
2024-02-15 09:43:41 +00:00
def load_by_attrs (
2024-02-29 10:34:25 +00:00
self , name : str , base : BaseModelType , type : ModelType , submodel_type : Optional [ SubModelType ] = None
2024-02-15 09:43:41 +00:00
) - > LoadedModel :
2024-02-29 12:03:17 +00:00
""" Loads a model by its attributes.
Args :
name : Name of the model .
base : The models ' base type, e.g. `BaseModelType.StableDiffusion1`, `BaseModelType.StableDiffusionXL`, etc.
type : Type of the model , e . g . ` ModelType . Main ` , ` ModelType . Vae ` , etc .
submodel_type : The type of submodel to load , e . g . ` SubModelType . UNet ` , ` SubModelType . TextEncoder ` , etc . Only main
models have submodels .
2024-02-15 09:43:41 +00:00
2024-02-29 12:03:17 +00:00
Returns :
An object representing the loaded model .
2024-02-15 09:43:41 +00:00
"""
2024-03-06 08:37:15 +00:00
configs = self . _services . model_manager . store . search_by_attr ( model_name = name , base_model = base , model_type = type )
if len ( configs ) == 0 :
raise UnknownModelException ( f " No model found with name { name } , base { base } , and type { type } " )
2024-02-29 12:03:17 +00:00
2024-03-06 08:37:15 +00:00
if len ( configs ) > 1 :
raise ValueError ( f " More than one model found with name { name } , base { base } , and type { type } " )
2024-02-07 03:24:05 +00:00
2024-03-06 08:37:15 +00:00
return self . _services . model_manager . load . load_model ( configs [ 0 ] , submodel_type , self . _data )
2024-02-15 09:43:41 +00:00
2024-03-09 08:43:24 +00:00
def get_config ( self , identifier : Union [ str , " ModelIdentifierField " ] ) - > AnyModelConfig :
2024-03-06 08:37:15 +00:00
""" Gets a model ' s config.
2024-02-15 09:43:41 +00:00
2024-02-29 12:03:17 +00:00
Args :
2024-03-06 08:37:15 +00:00
identifier : The key or ModelField representing the model .
2024-02-29 12:03:17 +00:00
Returns :
2024-03-06 08:37:15 +00:00
The model ' s config.
2024-02-07 03:24:05 +00:00
"""
2024-03-06 08:37:15 +00:00
if isinstance ( identifier , str ) :
return self . _services . model_manager . store . get_model ( identifier )
return self . _services . model_manager . store . get_model ( identifier . key )
2024-02-15 09:43:41 +00:00
def search_by_path ( self , path : Path ) - > list [ AnyModelConfig ] :
2024-02-29 12:03:17 +00:00
""" Searches for models by path.
Args :
path : The path to search for .
2024-02-15 09:43:41 +00:00
2024-02-29 12:03:17 +00:00
Returns :
A list of models that match the path .
2024-02-15 09:43:41 +00:00
"""
return self . _services . model_manager . store . search_by_path ( path )
def search_by_attrs (
self ,
2024-02-29 10:34:25 +00:00
name : Optional [ str ] = None ,
base : Optional [ BaseModelType ] = None ,
type : Optional [ ModelType ] = None ,
format : Optional [ ModelFormat ] = None ,
2024-02-15 09:43:41 +00:00
) - > list [ AnyModelConfig ] :
2024-02-29 12:03:17 +00:00
""" Searches for models by attributes.
Args :
name : The name to search for ( exact match ) .
base : The base to search for , e . g . ` BaseModelType . StableDiffusion1 ` , ` BaseModelType . StableDiffusionXL ` , etc .
type : Type type of model to search for , e . g . ` ModelType . Main ` , ` ModelType . Vae ` , etc .
format : The format of model to search for , e . g . ` ModelFormat . Checkpoint ` , ` ModelFormat . Diffusers ` , etc .
2024-02-15 09:43:41 +00:00
2024-02-29 12:03:17 +00:00
Returns :
A list of models that match the attributes .
2024-02-15 09:43:41 +00:00
"""
return self . _services . model_manager . store . search_by_attr (
2024-02-29 10:34:25 +00:00
model_name = name ,
base_model = base ,
model_type = type ,
model_format = format ,
2024-02-15 09:43:41 +00:00
)
2024-02-07 03:24:05 +00:00
2024-04-04 03:26:48 +00:00
def install_model (
self ,
source : str ,
config : Optional [ Dict [ str , Any ] ] = None ,
access_token : Optional [ str ] = None ,
inplace : Optional [ bool ] = False ,
timeout : Optional [ int ] = 0 ,
) - > str :
""" Install and register a model in the database.
Args :
source : String source ; see below
config : Optional dict . Any fields in this dict
will override corresponding autoassigned probe fields in the
model ' s config record.
access_token : Optional access token for remote sources .
inplace : If true , installs a local model in place rather than copying
it into the models directory
timeout : How long to wait on install ( in seconds ) . A value of 0 ( default )
blocks indefinitely
The source can be :
1. A local file path in posix ( ) format ( ` / foo / bar ` or ` C : \foo \bar ` )
2. An http or https URL ( ` https : / / foo . bar / foo ` )
3. A HuggingFace repo_id ( ` foo / bar ` , ` foo / bar : fp16 ` , ` foo / bar : fp16 : vae ` )
We extend the HuggingFace repo_id syntax to include the variant and the
subfolder or path . The following are acceptable alternatives :
stabilityai / stable - diffusion - v4
stabilityai / stable - diffusion - v4 : fp16
stabilityai / stable - diffusion - v4 : fp16 : vae
stabilityai / stable - diffusion - v4 : : / checkpoints / sd4 . safetensors
stabilityai / stable - diffusion - v4 : onnx : vae
Because a local file path can look like a huggingface repo_id , the logic
first checks whether the path exists on disk , and if not , it is treated as
a parseable huggingface repo .
Returns :
Key to the newly installed model .
May Raise :
ValueError - - bad source
UnknownModelException - - remote model not found
InvalidModelException - - what was retrieved from remote is not a model
TimeoutError - - model could not be installed within timeout
Exception - - another error condition
"""
installer = self . _services . model_manager . install
job = installer . heuristic_import (
source = source ,
config = config ,
access_token = access_token ,
inplace = inplace ,
)
installer . wait_for_job ( job , timeout )
if job . errored :
raise Exception ( job . error )
key : str = job . config_out . key
return key
2024-04-12 04:55:21 +00:00
def download_and_cache_ckpt (
2024-04-04 03:26:48 +00:00
self ,
source : Union [ str , AnyHttpUrl ] ,
access_token : Optional [ str ] = None ,
timeout : Optional [ int ] = 0 ,
) - > Path :
2024-04-12 04:55:21 +00:00
"""
Download the model file located at source to the models cache and return its Path .
2024-04-04 03:26:48 +00:00
This can be used to single - file install models and other resources of arbitrary types
which should not get registered with the database . If the model is already
installed , the cached path will be returned . Otherwise it will be downloaded .
Args :
source : A URL or a string that can be converted in one . Repo_ids
do not work here .
access_token : Optional access token for restricted resources .
timeout : Wait up to the indicated number of seconds before timing
out long downloads .
Result :
Path of the downloaded model
May Raise :
HTTPError
TimeoutError
"""
installer = self . _services . model_manager . install
path : Path = installer . download_and_cache (
source = source ,
access_token = access_token ,
timeout = timeout ,
)
return path
2024-04-12 04:55:21 +00:00
def load_ckpt_from_url (
self ,
source : Union [ str , AnyHttpUrl ] ,
access_token : Optional [ str ] = None ,
timeout : Optional [ int ] = 0 ,
loader : Optional [ Callable [ [ Path ] , Dict [ str | int , Any ] ] ] = None ,
) - > LoadedModel :
"""
Load and cache the model file located at the indicated URL .
This will check the model download cache for the model designated
by the provided URL and download it if needed using download_and_cache_model ( ) .
It will then load the model into the RAM cache . If the optional loader
argument is provided , the loader will be invoked to load the model into
memory . Otherwise the method will call safetensors . torch . load_file ( ) or
torch . load ( ) as appropriate to the file suffix .
Be aware that the LoadedModel object will have a ` config ` attribute of None .
Args :
source : A URL or a string that can be converted in one . Repo_ids
do not work here .
access_token : Optional access token for restricted resources .
timeout : Wait up to the indicated number of seconds before timing
out long downloads .
loader : A Callable that expects a Path and returns a Dict [ str | int , Any ]
Returns :
A LoadedModel object .
"""
ram_cache = self . _services . model_manager . load . ram_cache
try :
return LoadedModel ( _locker = ram_cache . get ( key = str ( source ) ) )
except IndexError :
pass
def torch_load_file ( checkpoint : Path ) - > Dict [ str | int , Any ] :
scan_result = scan_file_path ( checkpoint )
if scan_result . infected_files != 0 :
raise Exception ( " The model at {checkpoint} is potentially infected by malware. Aborting load. " )
return torch_load ( path , map_location = " cpu " )
path = self . download_and_cache_ckpt ( source , access_token , timeout )
if loader is None :
loader = (
torch_load_file
if path . suffix . endswith ( ( " .ckpt " , " .pt " , " .pth " , " .bin " ) )
else lambda path : safetensors_load_file ( path , device = " cpu " )
)
raw_model = loader ( path )
ram_cache . put ( key = str ( source ) , model = raw_model )
return LoadedModel ( _locker = ram_cache . get ( key = str ( source ) ) )
2024-02-07 03:24:05 +00:00
class ConfigInterface ( InvocationContextInterface ) :
def get ( self ) - > InvokeAIAppConfig :
2024-04-12 04:55:21 +00:00
"""
Gets the app ' s config.
2024-02-29 12:03:17 +00:00
Returns :
The app ' s config.
"""
2024-02-07 03:24:05 +00:00
2024-03-11 12:01:48 +00:00
return self . _services . configuration
2024-02-07 03:24:05 +00:00
class UtilInterface ( InvocationContextInterface ) :
2024-02-17 14:41:04 +00:00
def __init__ (
2024-02-18 00:56:54 +00:00
self , services : InvocationServices , data : InvocationContextData , cancel_event : threading . Event
2024-02-17 14:41:04 +00:00
) - > None :
2024-02-18 00:56:54 +00:00
super ( ) . __init__ ( services , data )
2024-02-18 00:54:16 +00:00
self . _cancel_event = cancel_event
def is_canceled ( self ) - > bool :
2024-02-29 12:03:17 +00:00
""" Checks if the current session has been canceled.
Returns :
True if the current session has been canceled , False if not .
"""
2024-02-18 00:54:16 +00:00
return self . _cancel_event . is_set ( )
2024-02-17 14:41:04 +00:00
2024-02-07 03:24:05 +00:00
def sd_step_callback ( self , intermediate_state : PipelineIntermediateState , base_model : BaseModelType ) - > None :
"""
The step callback emits a progress event with the current step , the total number of
steps , a preview image , and some other internal metadata .
This should be called after each denoising step .
2024-02-29 12:03:17 +00:00
Args :
intermediate_state : The intermediate state of the diffusion pipeline .
base_model : The base model for the current denoising step .
2024-02-07 03:24:05 +00:00
"""
2024-01-13 13:05:15 +00:00
2024-02-07 03:24:05 +00:00
stable_diffusion_step_callback (
2024-02-18 00:56:54 +00:00
context_data = self . _data ,
2024-02-07 03:24:05 +00:00
intermediate_state = intermediate_state ,
base_model = base_model ,
events = self . _services . events ,
2024-02-18 00:54:16 +00:00
is_canceled = self . is_canceled ,
2024-02-07 03:24:05 +00:00
)
2024-01-13 07:02:58 +00:00
class InvocationContext :
2024-02-29 12:03:17 +00:00
""" Provides access to various services and data for the current invocation.
Attributes :
images ( ImagesInterface ) : Methods to save , get and update images and their metadata .
tensors ( TensorsInterface ) : Methods to save and get tensors , including image , noise , masks , and masked images .
conditioning ( ConditioningInterface ) : Methods to save and get conditioning data .
models ( ModelsInterface ) : Methods to check if a model exists , get a model , and get a model ' s info.
logger ( LoggerInterface ) : The app logger .
config ( ConfigInterface ) : The app config .
util ( UtilInterface ) : Utility methods , including a method to check if an invocation was canceled and step callbacks .
boards ( BoardsInterface ) : Methods to interact with boards .
2024-01-13 07:02:58 +00:00
"""
def __init__ (
self ,
images : ImagesInterface ,
2024-02-07 06:41:23 +00:00
tensors : TensorsInterface ,
2024-01-13 13:05:15 +00:00
conditioning : ConditioningInterface ,
2024-01-13 07:02:58 +00:00
models : ModelsInterface ,
logger : LoggerInterface ,
2024-01-13 13:05:15 +00:00
config : ConfigInterface ,
2024-01-13 07:02:58 +00:00
util : UtilInterface ,
2024-02-05 06:40:49 +00:00
boards : BoardsInterface ,
2024-02-18 00:56:54 +00:00
data : InvocationContextData ,
2024-01-14 09:16:51 +00:00
services : InvocationServices ,
2024-01-13 07:02:58 +00:00
) - > None :
self . images = images
2024-02-18 01:03:43 +00:00
""" Methods to save, get and update images and their metadata. """
2024-02-07 06:41:23 +00:00
self . tensors = tensors
2024-02-18 01:03:43 +00:00
""" Methods to save and get tensors, including image, noise, masks, and masked images. """
2024-01-13 07:02:58 +00:00
self . conditioning = conditioning
2024-02-18 01:03:43 +00:00
""" Methods to save and get conditioning data. """
2024-01-13 07:02:58 +00:00
self . models = models
2024-02-18 01:03:43 +00:00
""" Methods to check if a model exists, get a model, and get a model ' s info. """
2024-01-13 13:05:15 +00:00
self . logger = logger
2024-02-18 01:03:43 +00:00
""" The app logger. """
2024-01-13 07:02:58 +00:00
self . config = config
2024-02-18 01:03:43 +00:00
""" The app config. """
2024-01-13 07:02:58 +00:00
self . util = util
2024-02-18 01:03:43 +00:00
""" Utility methods, including a method to check if an invocation was canceled and step callbacks. """
2024-02-05 06:40:49 +00:00
self . boards = boards
2024-02-18 01:03:43 +00:00
""" Methods to interact with boards. """
2024-02-18 00:56:54 +00:00
self . _data = data
2024-02-18 01:03:43 +00:00
""" An internal API providing access to data about the current queue item and invocation. You probably shouldn ' t use this. It may change without warning. """
2024-02-07 03:36:42 +00:00
self . _services = services
2024-02-18 01:03:43 +00:00
""" An internal API providing access to all application services. You probably shouldn ' t use this. It may change without warning. """
2024-01-14 09:16:51 +00:00
2024-01-13 07:02:58 +00:00
def build_invocation_context (
services : InvocationServices ,
2024-02-18 00:56:54 +00:00
data : InvocationContextData ,
2024-02-17 14:41:04 +00:00
cancel_event : threading . Event ,
2024-01-13 07:02:58 +00:00
) - > InvocationContext :
2024-02-29 12:03:17 +00:00
""" Builds the invocation context for a specific invocation execution.
Args :
services : The invocation services to wrap .
data : The invocation context data .
2024-01-13 07:02:58 +00:00
2024-02-29 12:03:17 +00:00
Returns :
The invocation context .
2024-01-13 07:02:58 +00:00
"""
2024-02-18 00:56:54 +00:00
logger = LoggerInterface ( services = services , data = data )
images = ImagesInterface ( services = services , data = data )
tensors = TensorsInterface ( services = services , data = data )
models = ModelsInterface ( services = services , data = data )
config = ConfigInterface ( services = services , data = data )
util = UtilInterface ( services = services , data = data , cancel_event = cancel_event )
conditioning = ConditioningInterface ( services = services , data = data )
boards = BoardsInterface ( services = services , data = data )
2024-01-13 07:02:58 +00:00
ctx = InvocationContext (
images = images ,
logger = logger ,
config = config ,
2024-02-07 06:41:23 +00:00
tensors = tensors ,
2024-01-13 07:02:58 +00:00
models = models ,
2024-02-18 00:56:54 +00:00
data = data ,
2024-01-13 07:02:58 +00:00
util = util ,
conditioning = conditioning ,
2024-01-14 09:16:51 +00:00
services = services ,
2024-02-05 06:40:49 +00:00
boards = boards ,
2024-01-13 07:02:58 +00:00
)
return ctx