2023-06-20 23:12:21 +00:00
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
from contextlib import ExitStack
from typing import List , Literal , Optional , Union
import re
import inspect
from pydantic import BaseModel , Field , validator
import torch
import numpy as np
from diffusers import ControlNetModel , DPMSolverMultistepScheduler
from diffusers . image_processor import VaeImageProcessor
from diffusers . schedulers import SchedulerMixin as Scheduler
from . . models . image import ImageCategory , ImageField , ResourceOrigin
2023-06-21 01:24:25 +00:00
from . . . backend . model_management import ONNXModelPatcher
2023-07-18 16:35:07 +00:00
from . . . backend . util import choose_torch_device
2023-06-20 23:12:21 +00:00
from . baseinvocation import ( BaseInvocation , BaseInvocationOutput ,
InvocationConfig , InvocationContext )
from . compel import ConditioningField
from . controlnet_image_processors import ControlField
from . image import ImageOutput
from . model import ModelInfo , UNetField , VaeField
2023-07-18 18:27:54 +00:00
from invokeai . app . invocations . metadata import CoreMetadata
2023-06-20 23:12:21 +00:00
from invokeai . backend import BaseModelType , ModelType , SubModelType
2023-07-18 16:35:07 +00:00
from invokeai . app . util . step_callback import stable_diffusion_step_callback
from . . . backend . stable_diffusion import PipelineIntermediateState
2023-06-20 23:12:21 +00:00
2023-06-21 01:24:25 +00:00
from tqdm import tqdm
2023-06-20 23:12:21 +00:00
from . model import ClipField
from . latent import LatentsField , LatentsOutput , build_latents_output , get_scheduler , SAMPLER_NAME_VALUES
from . compel import CompelOutput
ORT_TO_NP_TYPE = {
" tensor(bool) " : np . bool_ ,
" tensor(int8) " : np . int8 ,
" tensor(uint8) " : np . uint8 ,
" tensor(int16) " : np . int16 ,
" tensor(uint16) " : np . uint16 ,
" tensor(int32) " : np . int32 ,
" tensor(uint32) " : np . uint32 ,
" tensor(int64) " : np . int64 ,
" tensor(uint64) " : np . uint64 ,
" tensor(float16) " : np . float16 ,
" tensor(float) " : np . float32 ,
" tensor(double) " : np . float64 ,
}
2023-07-20 17:15:45 +00:00
PRECISION_VALUES = Literal [
tuple ( list ( ORT_TO_NP_TYPE . keys ( ) ) )
]
2023-06-20 23:12:21 +00:00
class ONNXPromptInvocation ( BaseInvocation ) :
type : Literal [ " prompt_onnx " ] = " prompt_onnx "
prompt : str = Field ( default = " " , description = " Prompt " )
clip : ClipField = Field ( None , description = " Clip to use " )
def invoke ( self , context : InvocationContext ) - > CompelOutput :
tokenizer_info = context . services . model_manager . get_model (
* * self . clip . tokenizer . dict ( ) ,
)
text_encoder_info = context . services . model_manager . get_model (
* * self . clip . text_encoder . dict ( ) ,
)
with tokenizer_info as orig_tokenizer , \
text_encoder_info as text_encoder , \
ExitStack ( ) as stack :
2023-06-22 17:03:17 +00:00
#loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
loras = [ ( context . services . model_manager . get_model ( * * lora . dict ( exclude = { " weight " } ) ) . context . model , lora . weight ) for lora in self . clip . loras ]
2023-06-20 23:12:21 +00:00
ti_list = [ ]
for trigger in re . findall ( r " <[a-zA-Z0-9., _-]+> " , self . prompt ) :
name = trigger [ 1 : - 1 ]
try :
ti_list . append (
2023-06-22 17:03:17 +00:00
#stack.enter_context(
# context.services.model_manager.get_model(
# model_name=name,
# base_model=self.clip.text_encoder.base_model,
# model_type=ModelType.TextualInversion,
# )
#)
context . services . model_manager . get_model (
model_name = name ,
base_model = self . clip . text_encoder . base_model ,
model_type = ModelType . TextualInversion ,
) . context . model
2023-06-20 23:12:21 +00:00
)
except Exception :
#print(e)
#import traceback
#print(traceback.format_exc())
print ( f " Warn: trigger: \" { trigger } \" not found " )
with ONNXModelPatcher . apply_lora_text_encoder ( text_encoder , loras ) , \
ONNXModelPatcher . apply_ti ( orig_tokenizer , text_encoder , ti_list ) as ( tokenizer , ti_manager ) :
text_encoder . create_session ( )
2023-06-21 01:24:25 +00:00
# copy from
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L153
2023-06-20 23:12:21 +00:00
text_inputs = tokenizer (
self . prompt ,
padding = " max_length " ,
max_length = tokenizer . model_max_length ,
truncation = True ,
return_tensors = " np " ,
)
text_input_ids = text_inputs . input_ids
"""
untruncated_ids = tokenizer ( prompt , padding = " max_length " , return_tensors = " np " ) . input_ids
if not np . array_equal ( text_input_ids , untruncated_ids ) :
removed_text = self . tokenizer . batch_decode (
untruncated_ids [ : , self . tokenizer . model_max_length - 1 : - 1 ]
)
logger . warning (
" The following part of your input was truncated because CLIP can only handle sequences up to "
f " { self . tokenizer . model_max_length } tokens: { removed_text } "
)
"""
prompt_embeds = text_encoder ( input_ids = text_input_ids . astype ( np . int32 ) ) [ 0 ]
text_encoder . release_session ( )
conditioning_name = f " { context . graph_execution_state_id } _ { self . id } _conditioning "
# TODO: hacky but works ;D maybe rename latents somehow?
context . services . latents . save ( conditioning_name , ( prompt_embeds , None ) )
return CompelOutput (
conditioning = ConditioningField (
conditioning_name = conditioning_name ,
) ,
)
# Text to image
class ONNXTextToLatentsInvocation ( BaseInvocation ) :
""" Generates latents from conditionings. """
type : Literal [ " t2l_onnx " ] = " t2l_onnx "
# Inputs
# fmt: off
positive_conditioning : Optional [ ConditioningField ] = Field ( description = " Positive conditioning for generation " )
negative_conditioning : Optional [ ConditioningField ] = Field ( description = " Negative conditioning for generation " )
noise : Optional [ LatentsField ] = Field ( description = " The noise to use " )
steps : int = Field ( default = 10 , gt = 0 , description = " The number of steps to use to generate the image " )
cfg_scale : Union [ float , List [ float ] ] = Field ( default = 7.5 , ge = 1 , description = " The Classifier-Free Guidance, higher values may result in a result closer to the prompt " , )
scheduler : SAMPLER_NAME_VALUES = Field ( default = " euler " , description = " The scheduler to use " )
2023-07-20 17:15:45 +00:00
precision : PRECISION_VALUES = Field ( default = " tensor(float16) " , description = " The precision to use when generating latents " )
2023-06-20 23:12:21 +00:00
unet : UNetField = Field ( default = None , description = " UNet submodel " )
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
@validator ( " cfg_scale " )
def ge_one ( cls , v ) :
""" validate that all cfg_scale values are >= 1 """
if isinstance ( v , list ) :
for i in v :
if i < 1 :
raise ValueError ( ' cfg_scale must be greater than 1 ' )
else :
if v < 1 :
raise ValueError ( ' cfg_scale must be greater than 1 ' )
return v
# Schema customisation
class Config ( InvocationConfig ) :
schema_extra = {
" ui " : {
" tags " : [ " latents " ] ,
" type_hints " : {
" model " : " model " ,
# "cfg_scale": "float",
" cfg_scale " : " number "
}
} ,
}
2023-06-21 01:24:25 +00:00
# based on
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
2023-06-20 23:12:21 +00:00
def invoke ( self , context : InvocationContext ) - > LatentsOutput :
c , _ = context . services . latents . get ( self . positive_conditioning . conditioning_name )
uc , _ = context . services . latents . get ( self . negative_conditioning . conditioning_name )
2023-07-18 16:35:07 +00:00
graph_execution_state = context . services . graph_execution_manager . get (
context . graph_execution_state_id )
source_node_id = graph_execution_state . prepared_source_mapping [ self . id ]
2023-06-20 23:12:21 +00:00
if isinstance ( c , torch . Tensor ) :
c = c . cpu ( ) . numpy ( )
if isinstance ( uc , torch . Tensor ) :
uc = uc . cpu ( ) . numpy ( )
2023-07-18 16:35:07 +00:00
device = torch . device ( choose_torch_device ( ) )
2023-06-20 23:12:21 +00:00
prompt_embeds = np . concatenate ( [ uc , c ] )
latents = context . services . latents . get ( self . noise . latents_name )
if isinstance ( latents , torch . Tensor ) :
latents = latents . cpu ( ) . numpy ( )
# TODO: better execution device handling
2023-07-20 17:15:45 +00:00
latents = latents . astype ( ORT_TO_NP_TYPE [ self . precision ] )
2023-06-20 23:12:21 +00:00
# get the initial random noise unless the user supplied it
do_classifier_free_guidance = True
#latents_dtype = prompt_embeds.dtype
#latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
#if latents.shape != latents_shape:
# raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
scheduler = get_scheduler (
context = context ,
scheduler_info = self . unet . scheduler ,
scheduler_name = self . scheduler ,
)
2023-07-18 16:35:07 +00:00
def torch2numpy ( latent : torch . Tensor ) :
return latent . cpu ( ) . numpy ( )
def numpy2torch ( latent , device ) :
return torch . from_numpy ( latent ) . to ( device )
def dispatch_progress (
self , context : InvocationContext , source_node_id : str ,
intermediate_state : PipelineIntermediateState ) - > None :
stable_diffusion_step_callback (
context = context ,
intermediate_state = intermediate_state ,
node = self . dict ( ) ,
source_node_id = source_node_id ,
)
2023-06-20 23:12:21 +00:00
scheduler . set_timesteps ( self . steps )
latents = latents * np . float64 ( scheduler . init_noise_sigma )
extra_step_kwargs = dict ( )
if " eta " in set ( inspect . signature ( scheduler . step ) . parameters . keys ( ) ) :
extra_step_kwargs . update (
eta = 0.0 ,
)
unet_info = context . services . model_manager . get_model ( * * self . unet . unet . dict ( ) )
with unet_info as unet , \
ExitStack ( ) as stack :
2023-06-22 17:03:17 +00:00
#loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
loras = [ ( context . services . model_manager . get_model ( * * lora . dict ( exclude = { " weight " } ) ) . context . model , lora . weight ) for lora in self . unet . loras ]
2023-06-20 23:12:21 +00:00
with ONNXModelPatcher . apply_lora_unet ( unet , loras ) :
# TODO:
unet . create_session ( )
timestep_dtype = next (
2023-07-17 20:27:33 +00:00
( input . type for input in unet . session . get_inputs ( ) if input . name == " timestep " ) , " tensor(float16) "
2023-06-20 23:12:21 +00:00
)
timestep_dtype = ORT_TO_NP_TYPE [ timestep_dtype ]
2023-07-17 20:27:33 +00:00
import time
times = [ ]
2023-06-20 23:12:21 +00:00
for i in tqdm ( range ( len ( scheduler . timesteps ) ) ) :
t = scheduler . timesteps [ i ]
# expand the latents if we are doing classifier free guidance
latent_model_input = np . concatenate ( [ latents ] * 2 ) if do_classifier_free_guidance else latents
2023-07-18 16:35:07 +00:00
latent_model_input = scheduler . scale_model_input ( numpy2torch ( latent_model_input , device ) , t )
2023-06-20 23:12:21 +00:00
latent_model_input = latent_model_input . cpu ( ) . numpy ( )
# predict the noise residual
timestep = np . array ( [ t ] , dtype = timestep_dtype )
2023-07-17 20:27:33 +00:00
start_time = time . time ( )
2023-06-20 23:12:21 +00:00
noise_pred = unet ( sample = latent_model_input , timestep = timestep , encoder_hidden_states = prompt_embeds )
2023-07-17 20:27:33 +00:00
times . append ( time . time ( ) - start_time )
2023-06-20 23:12:21 +00:00
noise_pred = noise_pred [ 0 ]
# perform guidance
if do_classifier_free_guidance :
noise_pred_uncond , noise_pred_text = np . split ( noise_pred , 2 )
noise_pred = noise_pred_uncond + self . cfg_scale * ( noise_pred_text - noise_pred_uncond )
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = scheduler . step (
2023-07-18 16:35:07 +00:00
numpy2torch ( noise_pred , device ) , t , numpy2torch ( latents , device ) , * * extra_step_kwargs
)
latents = torch2numpy ( scheduler_output . prev_sample )
state = PipelineIntermediateState (
run_id = " test " ,
step = i ,
timestep = timestep ,
latents = scheduler_output . prev_sample
)
dispatch_progress (
self ,
context = context ,
source_node_id = source_node_id ,
intermediate_state = state
2023-06-20 23:12:21 +00:00
)
# call the callback, if provided
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
2023-07-17 20:27:33 +00:00
print ( times )
2023-06-20 23:12:21 +00:00
unet . release_session ( )
torch . cuda . empty_cache ( )
name = f ' { context . graph_execution_state_id } __ { self . id } '
context . services . latents . save ( name , latents )
2023-07-17 20:27:33 +00:00
return build_latents_output ( latents_name = name , latents = torch . from_numpy ( latents ) )
2023-06-20 23:12:21 +00:00
# Latent to image
class ONNXLatentsToImageInvocation ( BaseInvocation ) :
""" Generates an image from latents. """
type : Literal [ " l2i_onnx " ] = " l2i_onnx "
# Inputs
latents : Optional [ LatentsField ] = Field ( description = " The latents to generate an image from " )
vae : VaeField = Field ( default = None , description = " Vae submodel " )
2023-07-18 18:27:54 +00:00
metadata : Optional [ CoreMetadata ] = Field ( default = None , description = " Optional core metadata to be written to the image " )
2023-06-20 23:12:21 +00:00
#tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
# Schema customisation
class Config ( InvocationConfig ) :
schema_extra = {
" ui " : {
" tags " : [ " latents " , " image " ] ,
} ,
}
def invoke ( self , context : InvocationContext ) - > ImageOutput :
latents = context . services . latents . get ( self . latents . latents_name )
if self . vae . vae . submodel != SubModelType . VaeDecoder :
raise Exception ( f " Expected vae_decoder, found: { self . vae . vae . model_type } " )
vae_info = context . services . model_manager . get_model (
* * self . vae . vae . dict ( ) ,
)
# clear memory as vae decode can request a lot
torch . cuda . empty_cache ( )
with vae_info as vae :
vae . create_session ( )
2023-06-21 01:24:25 +00:00
# copied from
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L427
2023-06-20 23:12:21 +00:00
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np . concatenate (
[ vae ( latent_sample = latents [ i : i + 1 ] ) [ 0 ] for i in range ( latents . shape [ 0 ] ) ]
)
image = np . clip ( image / 2 + 0.5 , 0 , 1 )
image = image . transpose ( ( 0 , 2 , 3 , 1 ) )
image = VaeImageProcessor . numpy_to_pil ( image ) [ 0 ]
vae . release_session ( )
torch . cuda . empty_cache ( )
image_dto = context . services . images . create (
image = image ,
image_origin = ResourceOrigin . INTERNAL ,
image_category = ImageCategory . GENERAL ,
node_id = self . id ,
session_id = context . graph_execution_state_id ,
2023-07-18 18:27:54 +00:00
is_intermediate = self . is_intermediate ,
metadata = self . metadata . dict ( ) if self . metadata else None ,
2023-06-20 23:12:21 +00:00
)
return ImageOutput (
image = ImageField ( image_name = image_dto . image_name ) ,
width = image_dto . width ,
height = image_dto . height ,
)
class ONNXModelLoaderOutput ( BaseInvocationOutput ) :
""" Model loader output """
#fmt: off
type : Literal [ " model_loader_output_onnx " ] = " model_loader_output_onnx "
unet : UNetField = Field ( default = None , description = " UNet submodel " )
clip : ClipField = Field ( default = None , description = " Tokenizer and text_encoder submodels " )
vae_decoder : VaeField = Field ( default = None , description = " Vae submodel " )
vae_encoder : VaeField = Field ( default = None , description = " Vae submodel " )
#fmt: on
class ONNXSD1ModelLoaderInvocation ( BaseInvocation ) :
""" Loading submodels of selected model. """
type : Literal [ " sd1_model_loader_onnx " ] = " sd1_model_loader_onnx "
model_name : str = Field ( default = " " , description = " Model to load " )
# TODO: precision?
# Schema customisation
class Config ( InvocationConfig ) :
schema_extra = {
" ui " : {
" tags " : [ " model " , " loader " ] ,
" type_hints " : {
" model_name " : " model " # TODO: rename to model_name?
}
} ,
}
def invoke ( self , context : InvocationContext ) - > ONNXModelLoaderOutput :
model_name = " stable-diffusion-v1-5 "
base_model = BaseModelType . StableDiffusion1
# TODO: not found exceptions
if not context . services . model_manager . model_exists (
model_name = model_name ,
base_model = BaseModelType . StableDiffusion1 ,
model_type = ModelType . ONNX ,
) :
raise Exception ( f " Unkown model name: { model_name } ! " )
return ONNXModelLoaderOutput (
unet = UNetField (
unet = ModelInfo (
model_name = model_name ,
base_model = base_model ,
model_type = ModelType . ONNX ,
submodel = SubModelType . UNet ,
) ,
scheduler = ModelInfo (
model_name = model_name ,
base_model = base_model ,
model_type = ModelType . ONNX ,
submodel = SubModelType . Scheduler ,
) ,
loras = [ ] ,
) ,
clip = ClipField (
tokenizer = ModelInfo (
model_name = model_name ,
base_model = base_model ,
model_type = ModelType . ONNX ,
submodel = SubModelType . Tokenizer ,
) ,
text_encoder = ModelInfo (
model_name = model_name ,
base_model = base_model ,
model_type = ModelType . ONNX ,
submodel = SubModelType . TextEncoder ,
) ,
loras = [ ] ,
) ,
vae_decoder = VaeField (
vae = ModelInfo (
model_name = model_name ,
base_model = base_model ,
model_type = ModelType . ONNX ,
submodel = SubModelType . VaeDecoder ,
) ,
) ,
vae_encoder = VaeField (
vae = ModelInfo (
model_name = model_name ,
base_model = base_model ,
model_type = ModelType . ONNX ,
submodel = SubModelType . VaeEncoder ,
) ,
)
2023-07-14 18:24:15 +00:00
)
class OnnxModelField ( BaseModel ) :
""" Onnx model field """
model_name : str = Field ( description = " Name of the model " )
base_model : BaseModelType = Field ( description = " Base model " )
2023-07-19 02:40:27 +00:00
model_type : ModelType = Field ( description = " Model Type " )
2023-07-14 18:24:15 +00:00
class OnnxModelLoaderInvocation ( BaseInvocation ) :
""" Loads a main model, outputting its submodels. """
type : Literal [ " onnx_model_loader " ] = " onnx_model_loader "
model : OnnxModelField = Field ( description = " The model to load " )
# Schema customisation
class Config ( InvocationConfig ) :
schema_extra = {
" ui " : {
" title " : " Onnx Model Loader " ,
" tags " : [ " model " , " loader " ] ,
" type_hints " : { " model " : " model " } ,
} ,
}
def invoke ( self , context : InvocationContext ) - > ONNXModelLoaderOutput :
base_model = self . model . base_model
model_name = self . model . model_name
model_type = ModelType . ONNX
# TODO: not found exceptions
if not context . services . model_manager . model_exists (
model_name = model_name ,
base_model = base_model ,
model_type = model_type ,
) :
raise Exception ( f " Unknown { base_model } { model_type } model: { model_name } " )
"""
if not context . services . model_manager . model_exists (
model_name = self . model_name ,
model_type = SDModelType . Diffusers ,
submodel = SDModelType . Tokenizer ,
) :
raise Exception (
f " Failed to find tokenizer submodel in { self . model_name } ! Check if model corrupted "
)
if not context . services . model_manager . model_exists (
model_name = self . model_name ,
model_type = SDModelType . Diffusers ,
submodel = SDModelType . TextEncoder ,
) :
raise Exception (
f " Failed to find text_encoder submodel in { self . model_name } ! Check if model corrupted "
)
if not context . services . model_manager . model_exists (
model_name = self . model_name ,
model_type = SDModelType . Diffusers ,
submodel = SDModelType . UNet ,
) :
raise Exception (
f " Failed to find unet submodel from { self . model_name } ! Check if model corrupted "
)
"""
return ONNXModelLoaderOutput (
unet = UNetField (
unet = ModelInfo (
model_name = model_name ,
base_model = base_model ,
model_type = model_type ,
submodel = SubModelType . UNet ,
) ,
scheduler = ModelInfo (
model_name = model_name ,
base_model = base_model ,
model_type = model_type ,
submodel = SubModelType . Scheduler ,
) ,
loras = [ ] ,
) ,
clip = ClipField (
tokenizer = ModelInfo (
model_name = model_name ,
base_model = base_model ,
model_type = model_type ,
submodel = SubModelType . Tokenizer ,
) ,
text_encoder = ModelInfo (
model_name = model_name ,
base_model = base_model ,
model_type = model_type ,
submodel = SubModelType . TextEncoder ,
) ,
loras = [ ] ,
2023-07-16 03:56:48 +00:00
skipped_layers = 0 ,
2023-07-14 18:24:15 +00:00
) ,
vae_decoder = VaeField (
vae = ModelInfo (
model_name = model_name ,
base_model = base_model ,
model_type = model_type ,
submodel = SubModelType . VaeDecoder ,
) ,
) ,
vae_encoder = VaeField (
vae = ModelInfo (
model_name = model_name ,
base_model = base_model ,
model_type = model_type ,
submodel = SubModelType . VaeEncoder ,
) ,
)
2023-06-20 23:12:21 +00:00
)