mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
c48fd9c083
Refine concept of "parameter" nodes to "primitives": - integer - float - string - boolean - image - latents - conditioning - color Each primitive has: - A field definition, if it is not already python primitive value. The field is how this primitive value is passed between nodes. Collections are lists of the field in node definitions. ex: `ImageField` & `list[ImageField]` - A single output class. ex: `ImageOutput` - A collection output class. ex: `ImageCollectionOutput` - A node, which functions to load or pass on the primitive value. ex: `ImageInvocation` (in this case, `ImageInvocation` replaces `LoadImage`) Plus a number of related changes: - Reorganize these into `primitives.py` - Update all nodes and logic to use primitives - Consolidate "prompt" outputs into "string" & "mask" into "image" (there's no reason for these to be different, the function identically) - Update default graphs & tests - Regen frontend types & minor frontend tidy related to changes
513 lines
20 KiB
Python
513 lines
20 KiB
Python
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
|
|
|
import inspect
|
|
import re
|
|
from contextlib import ExitStack
|
|
from typing import List, Literal, Optional, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
|
from diffusers.image_processor import VaeImageProcessor
|
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
|
from pydantic import BaseModel, Field, validator
|
|
from tqdm import tqdm
|
|
|
|
from invokeai.app.invocations.metadata import CoreMetadata
|
|
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
|
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
|
|
|
from ...backend.model_management import ONNXModelPatcher
|
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
|
from ...backend.util import choose_torch_device
|
|
from ..models.image import ImageCategory, ResourceOrigin
|
|
from .baseinvocation import (
|
|
BaseInvocation,
|
|
BaseInvocationOutput,
|
|
FieldDescriptions,
|
|
InputField,
|
|
Input,
|
|
InvocationContext,
|
|
OutputField,
|
|
UIComponent,
|
|
UITypeHint,
|
|
tags,
|
|
title,
|
|
)
|
|
from .controlnet_image_processors import ControlField
|
|
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
|
|
from .model import ClipField, ModelInfo, UNetField, VaeField
|
|
|
|
ORT_TO_NP_TYPE = {
|
|
"tensor(bool)": np.bool_,
|
|
"tensor(int8)": np.int8,
|
|
"tensor(uint8)": np.uint8,
|
|
"tensor(int16)": np.int16,
|
|
"tensor(uint16)": np.uint16,
|
|
"tensor(int32)": np.int32,
|
|
"tensor(uint32)": np.uint32,
|
|
"tensor(int64)": np.int64,
|
|
"tensor(uint64)": np.uint64,
|
|
"tensor(float16)": np.float16,
|
|
"tensor(float)": np.float32,
|
|
"tensor(double)": np.float64,
|
|
}
|
|
|
|
PRECISION_VALUES = Literal[tuple(list(ORT_TO_NP_TYPE.keys()))]
|
|
|
|
|
|
@title("ONNX Prompt (Raw)")
|
|
@tags("onnx", "prompt")
|
|
class ONNXPromptInvocation(BaseInvocation):
|
|
type: Literal["prompt_onnx"] = "prompt_onnx"
|
|
|
|
prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea)
|
|
clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection)
|
|
|
|
def invoke(self, context: InvocationContext) -> ConditioningOutput:
|
|
tokenizer_info = context.services.model_manager.get_model(
|
|
**self.clip.tokenizer.dict(),
|
|
)
|
|
text_encoder_info = context.services.model_manager.get_model(
|
|
**self.clip.text_encoder.dict(),
|
|
)
|
|
with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder, ExitStack() as stack:
|
|
loras = [
|
|
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
|
for lora in self.clip.loras
|
|
]
|
|
|
|
ti_list = []
|
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
|
name = trigger[1:-1]
|
|
try:
|
|
ti_list.append(
|
|
(
|
|
name,
|
|
context.services.model_manager.get_model(
|
|
model_name=name,
|
|
base_model=self.clip.text_encoder.base_model,
|
|
model_type=ModelType.TextualInversion,
|
|
).context.model,
|
|
)
|
|
)
|
|
except Exception:
|
|
# print(e)
|
|
# import traceback
|
|
# print(traceback.format_exc())
|
|
print(f'Warn: trigger: "{trigger}" not found')
|
|
if loras or ti_list:
|
|
text_encoder.release_session()
|
|
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras), ONNXModelPatcher.apply_ti(
|
|
orig_tokenizer, text_encoder, ti_list
|
|
) as (tokenizer, ti_manager):
|
|
text_encoder.create_session()
|
|
|
|
# copy from
|
|
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L153
|
|
text_inputs = tokenizer(
|
|
self.prompt,
|
|
padding="max_length",
|
|
max_length=tokenizer.model_max_length,
|
|
truncation=True,
|
|
return_tensors="np",
|
|
)
|
|
text_input_ids = text_inputs.input_ids
|
|
"""
|
|
untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
|
|
|
|
if not np.array_equal(text_input_ids, untruncated_ids):
|
|
removed_text = self.tokenizer.batch_decode(
|
|
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
|
)
|
|
logger.warning(
|
|
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
|
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
|
)
|
|
"""
|
|
|
|
prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
|
|
|
|
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
|
|
|
# TODO: hacky but works ;D maybe rename latents somehow?
|
|
context.services.latents.save(conditioning_name, (prompt_embeds, None))
|
|
|
|
return ConditioningOutput(
|
|
conditioning=ConditioningField(
|
|
conditioning_name=conditioning_name,
|
|
),
|
|
)
|
|
|
|
|
|
# Text to image
|
|
@title("ONNX Text to Latents")
|
|
@tags("latents", "inference", "txt2img", "onnx")
|
|
class ONNXTextToLatentsInvocation(BaseInvocation):
|
|
"""Generates latents from conditionings."""
|
|
|
|
type: Literal["t2l_onnx"] = "t2l_onnx"
|
|
|
|
# Inputs
|
|
positive_conditioning: ConditioningField = InputField(
|
|
description=FieldDescriptions.positive_cond,
|
|
input=Input.Connection,
|
|
)
|
|
negative_conditioning: ConditioningField = InputField(
|
|
description=FieldDescriptions.negative_cond,
|
|
input=Input.Connection,
|
|
)
|
|
noise: LatentsField = InputField(
|
|
description=FieldDescriptions.noise,
|
|
input=Input.Connection,
|
|
)
|
|
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
|
cfg_scale: Union[float, List[float]] = InputField(
|
|
default=7.5,
|
|
ge=1,
|
|
description=FieldDescriptions.cfg_scale,
|
|
ui_type_hint=UITypeHint.Float,
|
|
)
|
|
scheduler: SAMPLER_NAME_VALUES = InputField(
|
|
default="euler", description=FieldDescriptions.scheduler, input=Input.Direct
|
|
)
|
|
precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision)
|
|
unet: UNetField = InputField(
|
|
description=FieldDescriptions.unet,
|
|
input=Input.Connection,
|
|
)
|
|
control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
|
default=None,
|
|
description=FieldDescriptions.control,
|
|
ui_type_hint=UITypeHint.Control,
|
|
)
|
|
# seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", )
|
|
# seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
|
|
|
@validator("cfg_scale")
|
|
def ge_one(cls, v):
|
|
"""validate that all cfg_scale values are >= 1"""
|
|
if isinstance(v, list):
|
|
for i in v:
|
|
if i < 1:
|
|
raise ValueError("cfg_scale must be greater than 1")
|
|
else:
|
|
if v < 1:
|
|
raise ValueError("cfg_scale must be greater than 1")
|
|
return v
|
|
|
|
# based on
|
|
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375
|
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
|
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
|
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
|
|
source_node_id = graph_execution_state.prepared_source_mapping[self.id]
|
|
if isinstance(c, torch.Tensor):
|
|
c = c.cpu().numpy()
|
|
if isinstance(uc, torch.Tensor):
|
|
uc = uc.cpu().numpy()
|
|
device = torch.device(choose_torch_device())
|
|
prompt_embeds = np.concatenate([uc, c])
|
|
|
|
latents = context.services.latents.get(self.noise.latents_name)
|
|
if isinstance(latents, torch.Tensor):
|
|
latents = latents.cpu().numpy()
|
|
|
|
# TODO: better execution device handling
|
|
latents = latents.astype(ORT_TO_NP_TYPE[self.precision])
|
|
|
|
# get the initial random noise unless the user supplied it
|
|
do_classifier_free_guidance = True
|
|
# latents_dtype = prompt_embeds.dtype
|
|
# latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
|
# if latents.shape != latents_shape:
|
|
# raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
|
|
|
scheduler = get_scheduler(
|
|
context=context,
|
|
scheduler_info=self.unet.scheduler,
|
|
scheduler_name=self.scheduler,
|
|
seed=0, # TODO: refactor this node
|
|
)
|
|
|
|
def torch2numpy(latent: torch.Tensor):
|
|
return latent.cpu().numpy()
|
|
|
|
def numpy2torch(latent, device):
|
|
return torch.from_numpy(latent).to(device)
|
|
|
|
def dispatch_progress(
|
|
self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState
|
|
) -> None:
|
|
stable_diffusion_step_callback(
|
|
context=context,
|
|
intermediate_state=intermediate_state,
|
|
node=self.dict(),
|
|
source_node_id=source_node_id,
|
|
)
|
|
|
|
scheduler.set_timesteps(self.steps)
|
|
latents = latents * np.float64(scheduler.init_noise_sigma)
|
|
|
|
extra_step_kwargs = dict()
|
|
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
|
extra_step_kwargs.update(
|
|
eta=0.0,
|
|
)
|
|
|
|
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
|
|
|
with unet_info as unet, ExitStack() as stack:
|
|
# loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
|
loras = [
|
|
(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight)
|
|
for lora in self.unet.loras
|
|
]
|
|
|
|
if loras:
|
|
unet.release_session()
|
|
with ONNXModelPatcher.apply_lora_unet(unet, loras):
|
|
# TODO:
|
|
_, _, h, w = latents.shape
|
|
unet.create_session(h, w)
|
|
|
|
timestep_dtype = next(
|
|
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)"
|
|
)
|
|
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
|
for i in tqdm(range(len(scheduler.timesteps))):
|
|
t = scheduler.timesteps[i]
|
|
# expand the latents if we are doing classifier free guidance
|
|
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
|
latent_model_input = scheduler.scale_model_input(numpy2torch(latent_model_input, device), t)
|
|
latent_model_input = latent_model_input.cpu().numpy()
|
|
|
|
# predict the noise residual
|
|
timestep = np.array([t], dtype=timestep_dtype)
|
|
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
|
|
noise_pred = noise_pred[0]
|
|
|
|
# perform guidance
|
|
if do_classifier_free_guidance:
|
|
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
|
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
|
|
|
# compute the previous noisy sample x_t -> x_t-1
|
|
scheduler_output = scheduler.step(
|
|
numpy2torch(noise_pred, device), t, numpy2torch(latents, device), **extra_step_kwargs
|
|
)
|
|
latents = torch2numpy(scheduler_output.prev_sample)
|
|
|
|
state = PipelineIntermediateState(
|
|
run_id="test", step=i, timestep=timestep, latents=scheduler_output.prev_sample
|
|
)
|
|
dispatch_progress(self, context=context, source_node_id=source_node_id, intermediate_state=state)
|
|
|
|
# call the callback, if provided
|
|
# if callback is not None and i % callback_steps == 0:
|
|
# callback(i, t, latents)
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
|
context.services.latents.save(name, latents)
|
|
return build_latents_output(latents_name=name, latents=torch.from_numpy(latents))
|
|
|
|
|
|
# Latent to image
|
|
@title("ONNX Latents to Image")
|
|
@tags("latents", "image", "vae", "onnx")
|
|
class ONNXLatentsToImageInvocation(BaseInvocation):
|
|
"""Generates an image from latents."""
|
|
|
|
type: Literal["l2i_onnx"] = "l2i_onnx"
|
|
|
|
# Inputs
|
|
latents: LatentsField = InputField(
|
|
description=FieldDescriptions.denoised_latents,
|
|
input=Input.Connection,
|
|
)
|
|
vae: VaeField = InputField(
|
|
description=FieldDescriptions.vae,
|
|
input=Input.Connection,
|
|
)
|
|
metadata: Optional[CoreMetadata] = InputField(
|
|
default=None,
|
|
description=FieldDescriptions.core_metadata,
|
|
ui_hidden=True,
|
|
)
|
|
# tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
|
|
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
|
latents = context.services.latents.get(self.latents.latents_name)
|
|
|
|
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
|
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
|
|
|
vae_info = context.services.model_manager.get_model(
|
|
**self.vae.vae.dict(),
|
|
)
|
|
|
|
# clear memory as vae decode can request a lot
|
|
torch.cuda.empty_cache()
|
|
|
|
with vae_info as vae:
|
|
vae.create_session()
|
|
|
|
# copied from
|
|
# https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L427
|
|
latents = 1 / 0.18215 * latents
|
|
# image = self.vae_decoder(latent_sample=latents)[0]
|
|
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
|
image = np.concatenate([vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])])
|
|
|
|
image = np.clip(image / 2 + 0.5, 0, 1)
|
|
image = image.transpose((0, 2, 3, 1))
|
|
image = VaeImageProcessor.numpy_to_pil(image)[0]
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
image_dto = context.services.images.create(
|
|
image=image,
|
|
image_origin=ResourceOrigin.INTERNAL,
|
|
image_category=ImageCategory.GENERAL,
|
|
node_id=self.id,
|
|
session_id=context.graph_execution_state_id,
|
|
is_intermediate=self.is_intermediate,
|
|
metadata=self.metadata.dict() if self.metadata else None,
|
|
)
|
|
|
|
return ImageOutput(
|
|
image=ImageField(image_name=image_dto.image_name),
|
|
width=image_dto.width,
|
|
height=image_dto.height,
|
|
)
|
|
|
|
|
|
class ONNXModelLoaderOutput(BaseInvocationOutput):
|
|
"""Model loader output"""
|
|
|
|
# fmt: off
|
|
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
|
|
|
|
unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet")
|
|
clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP")
|
|
vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder")
|
|
vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder")
|
|
# fmt: on
|
|
|
|
|
|
class OnnxModelField(BaseModel):
|
|
"""Onnx model field"""
|
|
|
|
model_name: str = Field(description="Name of the model")
|
|
base_model: BaseModelType = Field(description="Base model")
|
|
model_type: ModelType = Field(description="Model Type")
|
|
|
|
|
|
@title("ONNX Model Loader")
|
|
@tags("onnx", "model")
|
|
class OnnxModelLoaderInvocation(BaseInvocation):
|
|
"""Loads a main model, outputting its submodels."""
|
|
|
|
type: Literal["onnx_model_loader"] = "onnx_model_loader"
|
|
|
|
# Inputs
|
|
model: OnnxModelField = InputField(
|
|
description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type_hint=UITypeHint.ONNXModel
|
|
)
|
|
|
|
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
|
base_model = self.model.base_model
|
|
model_name = self.model.model_name
|
|
model_type = ModelType.ONNX
|
|
|
|
# TODO: not found exceptions
|
|
if not context.services.model_manager.model_exists(
|
|
model_name=model_name,
|
|
base_model=base_model,
|
|
model_type=model_type,
|
|
):
|
|
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
|
|
|
"""
|
|
if not context.services.model_manager.model_exists(
|
|
model_name=self.model_name,
|
|
model_type=SDModelType.Diffusers,
|
|
submodel=SDModelType.Tokenizer,
|
|
):
|
|
raise Exception(
|
|
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
|
)
|
|
|
|
if not context.services.model_manager.model_exists(
|
|
model_name=self.model_name,
|
|
model_type=SDModelType.Diffusers,
|
|
submodel=SDModelType.TextEncoder,
|
|
):
|
|
raise Exception(
|
|
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
|
)
|
|
|
|
if not context.services.model_manager.model_exists(
|
|
model_name=self.model_name,
|
|
model_type=SDModelType.Diffusers,
|
|
submodel=SDModelType.UNet,
|
|
):
|
|
raise Exception(
|
|
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
|
)
|
|
"""
|
|
|
|
return ONNXModelLoaderOutput(
|
|
unet=UNetField(
|
|
unet=ModelInfo(
|
|
model_name=model_name,
|
|
base_model=base_model,
|
|
model_type=model_type,
|
|
submodel=SubModelType.UNet,
|
|
),
|
|
scheduler=ModelInfo(
|
|
model_name=model_name,
|
|
base_model=base_model,
|
|
model_type=model_type,
|
|
submodel=SubModelType.Scheduler,
|
|
),
|
|
loras=[],
|
|
),
|
|
clip=ClipField(
|
|
tokenizer=ModelInfo(
|
|
model_name=model_name,
|
|
base_model=base_model,
|
|
model_type=model_type,
|
|
submodel=SubModelType.Tokenizer,
|
|
),
|
|
text_encoder=ModelInfo(
|
|
model_name=model_name,
|
|
base_model=base_model,
|
|
model_type=model_type,
|
|
submodel=SubModelType.TextEncoder,
|
|
),
|
|
loras=[],
|
|
skipped_layers=0,
|
|
),
|
|
vae_decoder=VaeField(
|
|
vae=ModelInfo(
|
|
model_name=model_name,
|
|
base_model=base_model,
|
|
model_type=model_type,
|
|
submodel=SubModelType.VaeDecoder,
|
|
),
|
|
),
|
|
vae_encoder=VaeField(
|
|
vae=ModelInfo(
|
|
model_name=model_name,
|
|
base_model=base_model,
|
|
model_type=model_type,
|
|
submodel=SubModelType.VaeEncoder,
|
|
),
|
|
),
|
|
)
|