diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py deleted file mode 100644 index e7b4d3d9fc..0000000000 --- a/invokeai/app/invocations/onnx.py +++ /dev/null @@ -1,510 +0,0 @@ -# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779) - -import inspect - -# from contextlib import ExitStack -from typing import List, Literal, Union - -import numpy as np -import torch -from diffusers.image_processor import VaeImageProcessor -from pydantic import BaseModel, ConfigDict, Field, field_validator -from tqdm import tqdm - -from invokeai.app.invocations.fields import ( - FieldDescriptions, - Input, - InputField, - OutputField, - UIComponent, - UIType, - WithMetadata, -) -from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput -from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.app.util.step_callback import stable_diffusion_step_callback -from invokeai.backend import BaseModelType, ModelType, SubModelType - -from ...backend.model_management import ONNXModelPatcher -from ...backend.stable_diffusion import PipelineIntermediateState -from ...backend.util import choose_torch_device -from ..util.ti_utils import extract_ti_triggers_from_prompt -from .baseinvocation import ( - BaseInvocation, - BaseInvocationOutput, - InvocationContext, - invocation, - invocation_output, -) -from .controlnet_image_processors import ControlField -from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, get_scheduler -from .model import ClipField, ModelInfo, UNetField, VaeField - -ORT_TO_NP_TYPE = { - "tensor(bool)": np.bool_, - "tensor(int8)": np.int8, - "tensor(uint8)": np.uint8, - "tensor(int16)": np.int16, - "tensor(uint16)": np.uint16, - "tensor(int32)": np.int32, - "tensor(uint32)": np.uint32, - "tensor(int64)": np.int64, - "tensor(uint64)": np.uint64, - "tensor(float16)": np.float16, - "tensor(float)": np.float32, - "tensor(double)": np.float64, -} - -PRECISION_VALUES = Literal[tuple(ORT_TO_NP_TYPE.keys())] - - -@invocation("prompt_onnx", title="ONNX Prompt (Raw)", tags=["prompt", "onnx"], category="conditioning", version="1.0.0") -class ONNXPromptInvocation(BaseInvocation): - prompt: str = InputField(default="", description=FieldDescriptions.raw_prompt, ui_component=UIComponent.Textarea) - clip: ClipField = InputField(description=FieldDescriptions.clip, input=Input.Connection) - - def invoke(self, context: InvocationContext) -> ConditioningOutput: - tokenizer_info = context.services.model_manager.get_model( - **self.clip.tokenizer.model_dump(), - ) - text_encoder_info = context.services.model_manager.get_model( - **self.clip.text_encoder.model_dump(), - ) - with tokenizer_info as orig_tokenizer, text_encoder_info as text_encoder: # , ExitStack() as stack: - loras = [ - ( - context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model, - lora.weight, - ) - for lora in self.clip.loras - ] - - ti_list = [] - for trigger in extract_ti_triggers_from_prompt(self.prompt): - name = trigger[1:-1] - try: - ti_list.append( - ( - name, - context.services.model_manager.get_model( - model_name=name, - base_model=self.clip.text_encoder.base_model, - model_type=ModelType.TextualInversion, - ).context.model, - ) - ) - except Exception: - # print(e) - # import traceback - # print(traceback.format_exc()) - print(f'Warn: trigger: "{trigger}" not found') - if loras or ti_list: - text_encoder.release_session() - with ( - ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras), - ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager), - ): - text_encoder.create_session() - - # copy from - # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L153 - text_inputs = tokenizer( - self.prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="np", - ) - text_input_ids = text_inputs.input_ids - """ - untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids - - if not np.array_equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - """ - - prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0] - - conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" - - # TODO: hacky but works ;D maybe rename latents somehow? - context.services.latents.save(conditioning_name, (prompt_embeds, None)) - - return ConditioningOutput( - conditioning=ConditioningField( - conditioning_name=conditioning_name, - ), - ) - - -# Text to image -@invocation( - "t2l_onnx", - title="ONNX Text to Latents", - tags=["latents", "inference", "txt2img", "onnx"], - category="latents", - version="1.0.0", -) -class ONNXTextToLatentsInvocation(BaseInvocation): - """Generates latents from conditionings.""" - - positive_conditioning: ConditioningField = InputField( - description=FieldDescriptions.positive_cond, - input=Input.Connection, - ) - negative_conditioning: ConditioningField = InputField( - description=FieldDescriptions.negative_cond, - input=Input.Connection, - ) - noise: LatentsField = InputField( - description=FieldDescriptions.noise, - input=Input.Connection, - ) - steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps) - cfg_scale: Union[float, List[float]] = InputField( - default=7.5, - ge=1, - description=FieldDescriptions.cfg_scale, - ) - scheduler: SAMPLER_NAME_VALUES = InputField( - default="euler", description=FieldDescriptions.scheduler, input=Input.Direct, ui_type=UIType.Scheduler - ) - precision: PRECISION_VALUES = InputField(default="tensor(float16)", description=FieldDescriptions.precision) - unet: UNetField = InputField( - description=FieldDescriptions.unet, - input=Input.Connection, - ) - control: Union[ControlField, list[ControlField]] = InputField( - default=None, - description=FieldDescriptions.control, - ) - # seamless: bool = InputField(default=False, description="Whether or not to generate an image that can tile without seams", ) - # seamless_axes: str = InputField(default="", description="The axes to tile the image on, 'x' and/or 'y'") - - @field_validator("cfg_scale") - def ge_one(cls, v): - """validate that all cfg_scale values are >= 1""" - if isinstance(v, list): - for i in v: - if i < 1: - raise ValueError("cfg_scale must be greater than 1") - else: - if v < 1: - raise ValueError("cfg_scale must be greater than 1") - return v - - # based on - # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L375 - def invoke(self, context: InvocationContext) -> LatentsOutput: - c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name) - uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) - graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) - source_node_id = graph_execution_state.prepared_source_mapping[self.id] - if isinstance(c, torch.Tensor): - c = c.cpu().numpy() - if isinstance(uc, torch.Tensor): - uc = uc.cpu().numpy() - device = torch.device(choose_torch_device()) - prompt_embeds = np.concatenate([uc, c]) - - latents = context.services.latents.get(self.noise.latents_name) - if isinstance(latents, torch.Tensor): - latents = latents.cpu().numpy() - - # TODO: better execution device handling - latents = latents.astype(ORT_TO_NP_TYPE[self.precision]) - - # get the initial random noise unless the user supplied it - do_classifier_free_guidance = True - # latents_dtype = prompt_embeds.dtype - # latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) - # if latents.shape != latents_shape: - # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") - - scheduler = get_scheduler( - context=context, - scheduler_info=self.unet.scheduler, - scheduler_name=self.scheduler, - seed=0, # TODO: refactor this node - ) - - def torch2numpy(latent: torch.Tensor): - return latent.cpu().numpy() - - def numpy2torch(latent, device): - return torch.from_numpy(latent).to(device) - - def dispatch_progress( - self, context: InvocationContext, source_node_id: str, intermediate_state: PipelineIntermediateState - ) -> None: - stable_diffusion_step_callback( - context=context, - intermediate_state=intermediate_state, - node=self.model_dump(), - source_node_id=source_node_id, - ) - - scheduler.set_timesteps(self.steps) - latents = latents * np.float64(scheduler.init_noise_sigma) - - extra_step_kwargs = {} - if "eta" in set(inspect.signature(scheduler.step).parameters.keys()): - extra_step_kwargs.update( - eta=0.0, - ) - - unet_info = context.services.model_manager.get_model(**self.unet.unet.model_dump()) - - with unet_info as unet: # , ExitStack() as stack: - # loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] - loras = [ - ( - context.services.model_manager.get_model(**lora.model_dump(exclude={"weight"})).context.model, - lora.weight, - ) - for lora in self.unet.loras - ] - - if loras: - unet.release_session() - with ONNXModelPatcher.apply_lora_unet(unet, loras): - # TODO: - _, _, h, w = latents.shape - unet.create_session(h, w) - - timestep_dtype = next( - (input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float16)" - ) - timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] - for i in tqdm(range(len(scheduler.timesteps))): - t = scheduler.timesteps[i] - # expand the latents if we are doing classifier free guidance - latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = scheduler.scale_model_input(numpy2torch(latent_model_input, device), t) - latent_model_input = latent_model_input.cpu().numpy() - - # predict the noise residual - timestep = np.array([t], dtype=timestep_dtype) - noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds) - noise_pred = noise_pred[0] - - # perform guidance - if do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) - noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond) - - # compute the previous noisy sample x_t -> x_t-1 - scheduler_output = scheduler.step( - numpy2torch(noise_pred, device), t, numpy2torch(latents, device), **extra_step_kwargs - ) - latents = torch2numpy(scheduler_output.prev_sample) - - state = PipelineIntermediateState( - run_id="test", step=i, timestep=timestep, latents=scheduler_output.prev_sample - ) - dispatch_progress(self, context=context, source_node_id=source_node_id, intermediate_state=state) - - # call the callback, if provided - # if callback is not None and i % callback_steps == 0: - # callback(i, t, latents) - - torch.cuda.empty_cache() - - name = f"{context.graph_execution_state_id}__{self.id}" - context.services.latents.save(name, latents) - # return build_latents_output(latents_name=name, latents=torch.from_numpy(latents)) - - -# Latent to image -@invocation( - "l2i_onnx", - title="ONNX Latents to Image", - tags=["latents", "image", "vae", "onnx"], - category="image", - version="1.2.0", -) -class ONNXLatentsToImageInvocation(BaseInvocation, WithMetadata): - """Generates an image from latents.""" - - latents: LatentsField = InputField( - description=FieldDescriptions.denoised_latents, - input=Input.Connection, - ) - vae: VaeField = InputField( - description=FieldDescriptions.vae, - input=Input.Connection, - ) - # tiled: bool = InputField(default=False, description="Decode latents by overlaping tiles(less memory consumption)") - - def invoke(self, context: InvocationContext) -> ImageOutput: - latents = context.services.latents.get(self.latents.latents_name) - - if self.vae.vae.submodel != SubModelType.VaeDecoder: - raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}") - - vae_info = context.services.model_manager.get_model( - **self.vae.vae.model_dump(), - ) - - # clear memory as vae decode can request a lot - torch.cuda.empty_cache() - - with vae_info as vae: - vae.create_session() - - # copied from - # https://github.com/huggingface/diffusers/blob/3ebbaf7c96801271f9e6c21400033b6aa5ffcf29/src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py#L427 - latents = 1 / 0.18215 * latents - # image = self.vae_decoder(latent_sample=latents)[0] - # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 - image = np.concatenate([vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]) - - image = np.clip(image / 2 + 0.5, 0, 1) - image = image.transpose((0, 2, 3, 1)) - image = VaeImageProcessor.numpy_to_pil(image)[0] - - torch.cuda.empty_cache() - - image_dto = context.services.images.create( - image=image, - image_origin=ResourceOrigin.INTERNAL, - image_category=ImageCategory.GENERAL, - node_id=self.id, - session_id=context.graph_execution_state_id, - is_intermediate=self.is_intermediate, - metadata=self.metadata, - workflow=context.workflow, - ) - - return ImageOutput( - image=ImageField(image_name=image_dto.image_name), - width=image_dto.width, - height=image_dto.height, - ) - - -@invocation_output("model_loader_output_onnx") -class ONNXModelLoaderOutput(BaseInvocationOutput): - """Model loader output""" - - unet: UNetField = OutputField(default=None, description=FieldDescriptions.unet, title="UNet") - clip: ClipField = OutputField(default=None, description=FieldDescriptions.clip, title="CLIP") - vae_decoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Decoder") - vae_encoder: VaeField = OutputField(default=None, description=FieldDescriptions.vae, title="VAE Encoder") - - -class OnnxModelField(BaseModel): - """Onnx model field""" - - model_name: str = Field(description="Name of the model") - base_model: BaseModelType = Field(description="Base model") - model_type: ModelType = Field(description="Model Type") - - model_config = ConfigDict(protected_namespaces=()) - - -@invocation("onnx_model_loader", title="ONNX Main Model", tags=["onnx", "model"], category="model", version="1.0.0") -class OnnxModelLoaderInvocation(BaseInvocation): - """Loads a main model, outputting its submodels.""" - - model: OnnxModelField = InputField( - description=FieldDescriptions.onnx_main_model, input=Input.Direct, ui_type=UIType.ONNXModel - ) - - def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: - base_model = self.model.base_model - model_name = self.model.model_name - model_type = ModelType.ONNX - - # TODO: not found exceptions - if not context.services.model_manager.model_exists( - model_name=model_name, - base_model=base_model, - model_type=model_type, - ): - raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") - - """ - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.Tokenizer, - ): - raise Exception( - f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.TextEncoder, - ): - raise Exception( - f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted" - ) - - if not context.services.model_manager.model_exists( - model_name=self.model_name, - model_type=SDModelType.Diffusers, - submodel=SDModelType.UNet, - ): - raise Exception( - f"Failed to find unet submodel from {self.model_name}! Check if model corrupted" - ) - """ - - return ONNXModelLoaderOutput( - unet=UNetField( - unet=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.UNet, - ), - scheduler=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Scheduler, - ), - loras=[], - ), - clip=ClipField( - tokenizer=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.Tokenizer, - ), - text_encoder=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.TextEncoder, - ), - loras=[], - skipped_layers=0, - ), - vae_decoder=VaeField( - vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.VaeDecoder, - ), - ), - vae_encoder=VaeField( - vae=ModelInfo( - model_name=model_name, - base_model=base_model, - model_type=model_type, - submodel=SubModelType.VaeEncoder, - ), - ), - )