From 4d337f6abca92a4273313f6c8f9110479e1caa47 Mon Sep 17 00:00:00 2001 From: Sergey Borisov Date: Wed, 21 Jun 2023 02:12:21 +0300 Subject: [PATCH] ONNX Model/runtime first implementation --- invokeai/app/invocations/latent.py | 4 +- invokeai/app/invocations/model.py | 12 +- invokeai/app/invocations/onnx.py | 441 ++++++++++++++++++ invokeai/backend/model_management/lora.py | 223 ++++++++- .../model_management/models/__init__.py | 4 + .../backend/model_management/models/base.py | 111 ++++- .../models/stable_diffusion_onnx.py | 156 +++++++ 7 files changed, 935 insertions(+), 16 deletions(-) create mode 100644 invokeai/app/invocations/onnx.py create mode 100644 invokeai/backend/model_management/models/stable_diffusion_onnx.py diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 63db3d925c..a8b9131775 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -55,8 +55,8 @@ class LatentsOutput(BaseInvocationOutput): def build_latents_output(latents_name: str, latents: torch.Tensor): return LatentsOutput( latents=LatentsField(latents_name=latents_name), - width=latents.size()[3] * 8, - height=latents.size()[2] * 8, + width=latents.shape[3] * 8, + height=latents.shape[2] * 8, ) class NoiseOutput(BaseInvocationOutput): diff --git a/invokeai/app/invocations/model.py b/invokeai/app/invocations/model.py index 9d77cadf8c..bd9ab67271 100644 --- a/invokeai/app/invocations/model.py +++ b/invokeai/app/invocations/model.py @@ -271,9 +271,13 @@ class LoraLoaderInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> LoraLoaderOutput: + # TODO: ui rewrite + base_model = BaseModelType.StableDiffusion1 + if not context.services.model_manager.model_exists( + base_model=base_model, model_name=self.lora_name, - model_type=SDModelType.Lora, + model_type=ModelType.Lora, ): raise Exception(f"Unkown lora name: {self.lora_name}!") @@ -289,8 +293,9 @@ class LoraLoaderInvocation(BaseInvocation): output.unet = copy.deepcopy(self.unet) output.unet.loras.append( LoraInfo( + base_model=base_model, model_name=self.lora_name, - model_type=SDModelType.Lora, + model_type=ModelType.Lora, submodel=None, weight=self.weight, ) @@ -300,8 +305,9 @@ class LoraLoaderInvocation(BaseInvocation): output.clip = copy.deepcopy(self.clip) output.clip.loras.append( LoraInfo( + base_model=base_model, model_name=self.lora_name, - model_type=SDModelType.Lora, + model_type=ModelType.Lora, submodel=None, weight=self.weight, ) diff --git a/invokeai/app/invocations/onnx.py b/invokeai/app/invocations/onnx.py new file mode 100644 index 0000000000..12a928c849 --- /dev/null +++ b/invokeai/app/invocations/onnx.py @@ -0,0 +1,441 @@ +# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779) + +from contextlib import ExitStack +from typing import List, Literal, Optional, Union + +import re +import inspect + +from pydantic import BaseModel, Field, validator +import torch +import numpy as np +from diffusers import ControlNetModel, DPMSolverMultistepScheduler +from diffusers.image_processor import VaeImageProcessor +from diffusers.schedulers import SchedulerMixin as Scheduler + +from ..models.image import ImageCategory, ImageField, ResourceOrigin +from ...backend.model_management.lora import ONNXModelPatcher +from .baseinvocation import (BaseInvocation, BaseInvocationOutput, + InvocationConfig, InvocationContext) +from .compel import ConditioningField +from .controlnet_image_processors import ControlField +from .image import ImageOutput +from .model import ModelInfo, UNetField, VaeField + +from invokeai.backend import BaseModelType, ModelType, SubModelType + +from .model import ClipField +from .latent import LatentsField, LatentsOutput, build_latents_output, get_scheduler, SAMPLER_NAME_VALUES +from .compel import CompelOutput + + +ORT_TO_NP_TYPE = { + "tensor(bool)": np.bool_, + "tensor(int8)": np.int8, + "tensor(uint8)": np.uint8, + "tensor(int16)": np.int16, + "tensor(uint16)": np.uint16, + "tensor(int32)": np.int32, + "tensor(uint32)": np.uint32, + "tensor(int64)": np.int64, + "tensor(uint64)": np.uint64, + "tensor(float16)": np.float16, + "tensor(float)": np.float32, + "tensor(double)": np.float64, +} + + +class ONNXPromptInvocation(BaseInvocation): + type: Literal["prompt_onnx"] = "prompt_onnx" + + prompt: str = Field(default="", description="Prompt") + clip: ClipField = Field(None, description="Clip to use") + + def invoke(self, context: InvocationContext) -> CompelOutput: + tokenizer_info = context.services.model_manager.get_model( + **self.clip.tokenizer.dict(), + ) + text_encoder_info = context.services.model_manager.get_model( + **self.clip.text_encoder.dict(), + ) + with tokenizer_info as orig_tokenizer,\ + text_encoder_info as text_encoder,\ + ExitStack() as stack: + + loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras] + + ti_list = [] + for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt): + name = trigger[1:-1] + try: + ti_list.append( + stack.enter_context( + context.services.model_manager.get_model( + model_name=name, + base_model=self.clip.text_encoder.base_model, + model_type=ModelType.TextualInversion, + ) + ) + ) + except Exception: + #print(e) + #import traceback + #print(traceback.format_exc()) + print(f"Warn: trigger: \"{trigger}\" not found") + + with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),\ + ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager): + + text_encoder.create_session() + + text_inputs = tokenizer( + self.prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="np", + ) + text_input_ids = text_inputs.input_ids + """ + untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids + + if not np.array_equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + """ + + prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0] + + text_encoder.release_session() + + conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" + + # TODO: hacky but works ;D maybe rename latents somehow? + context.services.latents.save(conditioning_name, (prompt_embeds, None)) + + return CompelOutput( + conditioning=ConditioningField( + conditioning_name=conditioning_name, + ), + ) + +# Text to image +class ONNXTextToLatentsInvocation(BaseInvocation): + """Generates latents from conditionings.""" + + type: Literal["t2l_onnx"] = "t2l_onnx" + + # Inputs + # fmt: off + positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation") + negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation") + noise: Optional[LatentsField] = Field(description="The noise to use") + steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image") + cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", ) + scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" ) + unet: UNetField = Field(default=None, description="UNet submodel") + #control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use") + #seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", ) + #seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'") + # fmt: on + + @validator("cfg_scale") + def ge_one(cls, v): + """validate that all cfg_scale values are >= 1""" + if isinstance(v, list): + for i in v: + if i < 1: + raise ValueError('cfg_scale must be greater than 1') + else: + if v < 1: + raise ValueError('cfg_scale must be greater than 1') + return v + + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["latents"], + "type_hints": { + "model": "model", + # "cfg_scale": "float", + "cfg_scale": "number" + } + }, + } + + def invoke(self, context: InvocationContext) -> LatentsOutput: + c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name) + uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name) + if isinstance(c, torch.Tensor): + c = c.cpu().numpy() + if isinstance(uc, torch.Tensor): + uc = uc.cpu().numpy() + + prompt_embeds = np.concatenate([uc, c]) + + latents = context.services.latents.get(self.noise.latents_name) + if isinstance(latents, torch.Tensor): + latents = latents.cpu().numpy() + + # TODO: better execution device handling + latents = latents.astype(np.float32) + + # get the initial random noise unless the user supplied it + do_classifier_free_guidance = True + #latents_dtype = prompt_embeds.dtype + #latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8) + #if latents.shape != latents_shape: + # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}") + + scheduler = get_scheduler( + context=context, + scheduler_info=self.unet.scheduler, + scheduler_name=self.scheduler, + ) + + scheduler.set_timesteps(self.steps) + latents = latents * np.float64(scheduler.init_noise_sigma) + + extra_step_kwargs = dict() + if "eta" in set(inspect.signature(scheduler.step).parameters.keys()): + extra_step_kwargs.update( + eta=0.0, + ) + + unet_info = context.services.model_manager.get_model(**self.unet.unet.dict()) + + with unet_info as unet,\ + ExitStack() as stack: + + loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras] + + with ONNXModelPatcher.apply_lora_unet(unet, loras): + # TODO: + unet.create_session() + + timestep_dtype = next( + (input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float)" + ) + timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype] + + from tqdm import tqdm + for i in tqdm(range(len(scheduler.timesteps))): + t = scheduler.timesteps[i] + # expand the latents if we are doing classifier free guidance + latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = scheduler.scale_model_input(torch.from_numpy(latent_model_input), t) + latent_model_input = latent_model_input.cpu().numpy() + + # predict the noise residual + timestep = np.array([t], dtype=timestep_dtype) + noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds) + noise_pred = noise_pred[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2) + noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + scheduler_output = scheduler.step( + torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs + ) + latents = scheduler_output.prev_sample.numpy() + + # call the callback, if provided + #if callback is not None and i % callback_steps == 0: + # callback(i, t, latents) + + unet.release_session() + + torch.cuda.empty_cache() + + name = f'{context.graph_execution_state_id}__{self.id}' + context.services.latents.save(name, latents) + return build_latents_output(latents_name=name, latents=latents) + + +@staticmethod +def numpy_to_pil(images): + """ + Convert a numpy image or a batch of images to a PIL image. + """ + if images.ndim == 3: + images = images[None, ...] + images = (images * 255).round().astype("uint8") + if images.shape[-1] == 1: + # special case for grayscale (single channel) images + pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] + else: + pil_images = [Image.fromarray(image) for image in images] + + return pil_images + + +# Latent to image +class ONNXLatentsToImageInvocation(BaseInvocation): + """Generates an image from latents.""" + + type: Literal["l2i_onnx"] = "l2i_onnx" + + # Inputs + latents: Optional[LatentsField] = Field(description="The latents to generate an image from") + vae: VaeField = Field(default=None, description="Vae submodel") + #tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)") + + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["latents", "image"], + }, + } + + def invoke(self, context: InvocationContext) -> ImageOutput: + latents = context.services.latents.get(self.latents.latents_name) + + if self.vae.vae.submodel != SubModelType.VaeDecoder: + raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}") + + vae_info = context.services.model_manager.get_model( + **self.vae.vae.dict(), + ) + + # clear memory as vae decode can request a lot + torch.cuda.empty_cache() + + with vae_info as vae: + + vae.create_session() + + latents = 1 / 0.18215 * latents + # image = self.vae_decoder(latent_sample=latents)[0] + # it seems likes there is a strange result for using half-precision vae decoder if batchsize>1 + image = np.concatenate( + [vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])] + ) + + image = np.clip(image / 2 + 0.5, 0, 1) + image = image.transpose((0, 2, 3, 1)) + + image = VaeImageProcessor.numpy_to_pil(image)[0] + + vae.release_session() + + + + torch.cuda.empty_cache() + + image_dto = context.services.images.create( + image=image, + image_origin=ResourceOrigin.INTERNAL, + image_category=ImageCategory.GENERAL, + node_id=self.id, + session_id=context.graph_execution_state_id, + ) + + return ImageOutput( + image=ImageField(image_name=image_dto.image_name), + width=image_dto.width, + height=image_dto.height, + ) + +class ONNXModelLoaderOutput(BaseInvocationOutput): + """Model loader output""" + + #fmt: off + type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx" + + unet: UNetField = Field(default=None, description="UNet submodel") + clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") + vae_decoder: VaeField = Field(default=None, description="Vae submodel") + vae_encoder: VaeField = Field(default=None, description="Vae submodel") + #fmt: on + +class ONNXSD1ModelLoaderInvocation(BaseInvocation): + """Loading submodels of selected model.""" + + type: Literal["sd1_model_loader_onnx"] = "sd1_model_loader_onnx" + + model_name: str = Field(default="", description="Model to load") + # TODO: precision? + + # Schema customisation + class Config(InvocationConfig): + schema_extra = { + "ui": { + "tags": ["model", "loader"], + "type_hints": { + "model_name": "model" # TODO: rename to model_name? + } + }, + } + + def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput: + + model_name = "stable-diffusion-v1-5" + base_model = BaseModelType.StableDiffusion1 + + # TODO: not found exceptions + if not context.services.model_manager.model_exists( + model_name=model_name, + base_model=BaseModelType.StableDiffusion1, + model_type=ModelType.ONNX, + ): + raise Exception(f"Unkown model name: {model_name}!") + + + return ONNXModelLoaderOutput( + unet=UNetField( + unet=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=ModelType.ONNX, + submodel=SubModelType.UNet, + ), + scheduler=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=ModelType.ONNX, + submodel=SubModelType.Scheduler, + ), + loras=[], + ), + clip=ClipField( + tokenizer=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=ModelType.ONNX, + submodel=SubModelType.Tokenizer, + ), + text_encoder=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=ModelType.ONNX, + submodel=SubModelType.TextEncoder, + ), + loras=[], + ), + vae_decoder=VaeField( + vae=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=ModelType.ONNX, + submodel=SubModelType.VaeDecoder, + ), + ), + vae_encoder=VaeField( + vae=ModelInfo( + model_name=model_name, + base_model=base_model, + model_type=ModelType.ONNX, + submodel=SubModelType.VaeEncoder, + ), + ) + ) \ No newline at end of file diff --git a/invokeai/backend/model_management/lora.py b/invokeai/backend/model_management/lora.py index c351a76590..6f64141610 100644 --- a/invokeai/backend/model_management/lora.py +++ b/invokeai/backend/model_management/lora.py @@ -11,6 +11,8 @@ from torch.utils.hooks import RemovableHandle from diffusers.models import UNet2DConditionModel from transformers import CLIPTextModel +from onnx import numpy_helper +import numpy as np from compel.embeddings_provider import BaseTextualInversionManager @@ -70,7 +72,7 @@ class LoRALayerBase: op = torch.nn.functional.linear extra_args = {} - weight = self.get_weight(module) + weight = self.get_weight() bias = self.bias if self.bias is not None else 0 scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0 @@ -81,7 +83,7 @@ class LoRALayerBase: **extra_args, ) * multiplier * scale - def get_weight(self, module: torch.nn.Module): + def get_weight(self): raise NotImplementedError() def calc_size(self) -> int: @@ -122,7 +124,7 @@ class LoRALayer(LoRALayerBase): self.rank = self.down.shape[0] - def get_weight(self, module: torch.nn.Module): + def get_weight(self): if self.mid is not None: up = self.up.reshape(up.shape[0], up.shape[1]) down = self.down.reshape(up.shape[0], up.shape[1]) @@ -185,7 +187,7 @@ class LoHALayer(LoRALayerBase): self.rank = self.w1_b.shape[0] - def get_weight(self, module: torch.nn.Module): + def get_weight(self): if self.t1 is None: weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) @@ -271,7 +273,7 @@ class LoKRLayer(LoRALayerBase): else: self.rank = None # unscaled - def get_weight(self, module: torch.nn.Module): + def get_weight(self): w1 = self.w1 if w1 is None: w1 = self.w1_a @ self.w1_b @@ -286,7 +288,7 @@ class LoKRLayer(LoRALayerBase): if len(w2.shape) == 4: w1 = w1.unsqueeze(2).unsqueeze(2) w2 = w2.contiguous() - weight = torch.kron(w1, w2).reshape(module.weight.shape) # TODO: can we remove reshape? + weight = torch.kron(w1, w2)#.reshape(module.weight.shape) # TODO: can we remove reshape? return weight @@ -676,3 +678,212 @@ class TextualInversionManager(BaseTextualInversionManager): return new_token_ids + +class ONNXModelPatcher: + + @classmethod + @contextmanager + def apply_lora_unet( + cls, + unet: OnnxRuntimeModel, + loras: List[Tuple[LoRAModel, float]], + ): + with cls.apply_lora(unet, loras, "lora_unet_"): + yield + + + @classmethod + @contextmanager + def apply_lora_text_encoder( + cls, + text_encoder: OnnxRuntimeModel, + loras: List[Tuple[LoRAModel, float]], + ): + with cls.apply_lora(text_encoder, loras, "lora_te_"): + yield + + + @classmethod + @contextmanager + def apply_lora( + cls, + model: IAIOnnxRuntimeModel, + loras: List[Tuple[LoraModel, float]], + prefix: str, + ): + from .models.base import IAIOnnxRuntimeModel + if not isinstance(model, IAIOnnxRuntimeModel): + raise Exception("Only IAIOnnxRuntimeModel models supported") + + base_model = model.proto + orig_nodes = dict() + + try: + blended_loras = dict() + + for lora, lora_weight in loras: + for layer_key, layer in lora.layers.items(): + if not layer_key.startswith(prefix): + continue + + layer_key = layer_key.replace(prefix, "") + layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight + if layer_key is blended_loras: + blended_loras[layer_key] += layer_weight + else: + blended_loras[layer_key] = layer_weight + + initializer_idx = dict() + for idx, init in enumerate(base_model.graph.initializer): + initializer_idx[init.name.replace(".", "_")] = idx + + node_idx = dict() + for idx, node in enumerate(base_model.graph.node): + node_idx[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = idx + + for layer_key, weights in blended_loras.items(): + conv_key = layer_key + "_Conv" + gemm_key = layer_key + "_Gemm" + matmul_key = layer_key + "_MatMul" + + if conv_key in node_idx or gemm_key in node_idx: + if conv_key in node_idx: + conv_node = base_model.graph.node[node_idx[conv_key]] + else: + conv_node = base_model.graph.node[node_idx[gemm_key]] + + weight_name = [n for n in conv_node.input if ".weight" in n][0] + weight_name = weight_name.replace(".", "_") + + weight_idx = initializer_idx[weight_name] + weight_node = base_model.graph.initializer[weight_idx] + + orig_weights = numpy_helper.to_array(weight_node) + + if orig_weights.shape[-2:] == (1, 1): + if weights.shape[-2:] == (1, 1): + new_weights = orig_weights.squeeze((3, 2)) + weights.squeeze((3, 2)) + else: + new_weights = orig_weights.squeeze((3, 2)) + weights + + new_weights = np.expand_dims(new_weights, (2, 3)) + else: + if orig_weights.shape != weights.shape: + new_weights = orig_weights + weights.reshape(orig_weights.shape) + else: + new_weights = orig_weights + weights + + new_node = numpy_helper.from_array(new_weights.astype(orig_weights.dtype), weight_node.name) + orig_nodes[weight_idx] = base_model.graph.initializer[weight_idx] + del base_model.graph.initializer[weight_idx] + base_model.graph.initializer.insert(weight_idx, new_node) + + elif matmul_key in node_idx: + weight_node = base_model.graph.node[node_idx[matmul_key]] + + matmul_name = [n for n in weight_node.input if "MatMul" in n][0] + + matmul_idx = initializer_idx[matmul_name] + matmul_node = base_model.graph.initializer[matmul_idx] + + orig_weights = numpy_helper.to_array(matmul_node) + + new_weights = orig_weights + weights.transpose() + + # replace the original initializer + new_node = numpy_helper.from_array(new_weights.astype(orig_weights.dtype), matmul_node.name) + orig_nodes[matmul_idx] = base_model.graph.initializer[matmul_idx] + del base_model.graph.initializer[matmul_idx] + base_model.graph.initializer.insert(matmul_idx, new_node) + + else: + # warn? err? + pass + + yield + + finally: + # restore original weights + for idx, orig_node in orig_nodes.items(): + del base_model.graph.initializer[idx] + base_model.graph.initializer.insert(idx, orig_node) + + + + @classmethod + @contextmanager + def apply_ti( + cls, + tokenizer: CLIPTokenizer, + text_encoder: IAIOnnxRuntimeModel, + ti_list: List[Any], + ) -> Tuple[CLIPTokenizer, TextualInversionManager]: + from .models.base import IAIOnnxRuntimeModel + if not isinstance(text_encoder, IAIOnnxRuntimeModel): + raise Exception("Only IAIOnnxRuntimeModel models supported") + + init_tokens_count = None + new_tokens_added = None + + try: + ti_tokenizer = copy.deepcopy(tokenizer) + ti_manager = TextualInversionManager(ti_tokenizer) + + def _get_trigger(ti, index): + trigger = ti.name + if index > 0: + trigger += f"-!pad-{i}" + return f"<{trigger}>" + + # modify tokenizer + new_tokens_added = 0 + for ti in ti_list: + for i in range(ti.embedding.shape[0]): + new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i)) + + # modify text_encoder + for i in range(len(text_encoder.proto.graph.initializer)): + if text_encoder.proto.graph.initializer[i].name == "text_model.embeddings.token_embedding.weight": + embeddings_node_idx = i + break + else: + raise Exception("text_model.embeddings.token_embedding.weight node not found") + + embeddings_node_orig = text_encoder.proto.graph.initializer[embeddings_node_idx] + base_weights = numpy_helper.to_array(embeddings_node_orig) + + embedding_weights = np.concatenate((base_weights, np.zeros((new_tokens_added, base_weights.shape[1]))), axis=0) + + for ti in ti_list: + ti_tokens = [] + for i in range(ti.embedding.shape[0]): + embedding = ti.embedding[i].detach().numpy() + trigger = _get_trigger(ti, i) + + token_id = ti_tokenizer.convert_tokens_to_ids(trigger) + if token_id == ti_tokenizer.unk_token_id: + raise RuntimeError(f"Unable to find token id for token '{trigger}'") + + if embedding_weights[token_id].shape != embedding.shape: + raise ValueError( + f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embedding_weights[token_id].shape[0]}." + ) + + embedding_weights[token_id] = embedding + ti_tokens.append(token_id) + + if len(ti_tokens) > 1: + ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:] + + + new_embeddings_node = numpy_helper.from_array(embedding_weights.astype(base_weights.dtype), embeddings_node_orig.name) + del text_encoder.proto.graph.initializer[embeddings_node_idx] + text_encoder.proto.graph.initializer.insert(embeddings_node_idx, new_embeddings_node) + + yield ti_tokenizer, ti_manager + + finally: + # restore + if embeddings_node_orig is not None: + del text_encoder.proto.graph.initializer[embeddings_node_idx] + text_encoder.proto.graph.initializer.insert(embeddings_node_idx, embeddings_node_orig) diff --git a/invokeai/backend/model_management/models/__init__.py b/invokeai/backend/model_management/models/__init__.py index 6975d45f93..0b8cdbfa0d 100644 --- a/invokeai/backend/model_management/models/__init__.py +++ b/invokeai/backend/model_management/models/__init__.py @@ -9,9 +9,12 @@ from .lora import LoRAModel from .controlnet import ControlNetModel # TODO: from .textual_inversion import TextualInversionModel +from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model + MODEL_CLASSES = { BaseModelType.StableDiffusion1: { ModelType.Pipeline: StableDiffusion1Model, + ModelType.ONNX: ONNXStableDiffusion1Model, ModelType.Vae: VaeModel, ModelType.Lora: LoRAModel, ModelType.ControlNet: ControlNetModel, @@ -19,6 +22,7 @@ MODEL_CLASSES = { }, BaseModelType.StableDiffusion2: { ModelType.Pipeline: StableDiffusion2Model, + ModelType.ONNX: ONNXStableDiffusion2Model, ModelType.Vae: VaeModel, ModelType.Lora: LoRAModel, ModelType.ControlNet: ControlNetModel, diff --git a/invokeai/backend/model_management/models/base.py b/invokeai/backend/model_management/models/base.py index ef354ecc07..0b22e380ba 100644 --- a/invokeai/backend/model_management/models/base.py +++ b/invokeai/backend/model_management/models/base.py @@ -5,19 +5,27 @@ import inspect from enum import Enum from abc import ABCMeta, abstractmethod import torch +import numpy as np import safetensors.torch -from diffusers import DiffusionPipeline, ConfigMixin +from pathlib import Path +from diffusers import DiffusionPipeline, ConfigMixin, OnnxRuntimeModel from contextlib import suppress from pydantic import BaseModel, Field from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union +import onnx +from onnx import numpy_helper +from onnx.external_data_helper import set_external_data +from onnxruntime import InferenceSession, OrtValue, SessionOptions + class BaseModelType(str, Enum): StableDiffusion1 = "sd-1" StableDiffusion2 = "sd-2" #Kandinsky2_1 = "kandinsky-2.1" class ModelType(str, Enum): + ONNX = "onnx" Pipeline = "pipeline" Vae = "vae" Lora = "lora" @@ -29,6 +37,8 @@ class SubModelType(str, Enum): TextEncoder = "text_encoder" Tokenizer = "tokenizer" Vae = "vae" + VaeDecoder = "vae_decoder" + VaeEncoder = "vae_encoder" Scheduler = "scheduler" SafetyChecker = "safety_checker" #MoVQ = "movq" @@ -240,16 +250,18 @@ class DiffusersModel(ModelBase): try: # TODO: set cache_dir to /dev/null to be sure that cache not used? model = self.child_types[child_type].from_pretrained( - self.model_path, - subfolder=child_type.value, + os.path.join(self.model_path, child_type.value), + #subfolder=child_type.value, torch_dtype=torch_dtype, variant=variant, local_files_only=True, ) break except Exception as e: - #print("====ERR LOAD====") - #print(f"{variant}: {e}") + print("====ERR LOAD====") + print(f"{variant}: {e}") + import traceback + traceback.print_exc() pass else: raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model") @@ -413,3 +425,92 @@ class SilenceWarnings(object): transformers_logging.set_verbosity(self.transformers_verbosity) diffusers_logging.set_verbosity(self.diffusers_verbosity) warnings.simplefilter('default') + +def buffer_external_data_tensors(model): + external_data = dict() + for tensor in model.graph.initializer: + name = tensor.name + + if tensor.HasField("raw_data"): + npt = numpy_helper.to_array(tensor) + orv = OrtValue.ortvalue_from_numpy(npt) + external_data[name] = orv + set_external_data(tensor, location="tmp.bin") + tensor.name = name + tensor.ClearField("raw_data") + + return (model, external_data) + +ONNX_WEIGHTS_NAME = "model.onnx" +class IAIOnnxRuntimeModel(OnnxRuntimeModel): + def __init__(self, model: tuple, **kwargs): + self.proto, self.provider, self.sess_options = model + self.session = None + self._external_data = dict() + + def __call__(self, **kwargs): + if self.session is None: + raise Exception("You should call create_session before running model") + + inputs = {k: np.array(v) for k, v in kwargs.items()} + return self.session.run(None, inputs) + + def create_session(self): + if self.session is None: + #onnx.save(self.proto, "tmp.onnx") + #onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False) + (trimmed_model, external_data) = buffer_external_data_tensors(self.proto) + sess = SessionOptions() + self._external_data.update(**external_data) + sess.add_external_initializers(list(self._external_data.keys()), list(self._external_data.values())) + self.session = InferenceSession(trimmed_model.SerializeToString(), providers=[self.provider], sess_options=sess) + #self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options) + + def release_session(self): + self.session = None + import gc + gc.collect() + + @staticmethod + def load_model(path: Union[str, Path], provider=None, sess_options=None): + """ + Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider` + + Arguments: + path (`str` or `Path`): + Directory from which to load + provider(`str`, *optional*): + Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider` + """ + if provider is None: + #logger.info("No onnxruntime provider specified, using CPUExecutionProvider") + print("No onnxruntime provider specified, using CPUExecutionProvider") + provider = "CPUExecutionProvider" + + # TODO: check that provider available? + return (onnx.load(path), provider, sess_options) + + @classmethod + def _from_pretrained( + cls, + model_id: Union[str, Path], + use_auth_token: Optional[Union[bool, str, None]] = None, + revision: Optional[Union[str, None]] = None, + force_download: bool = False, + cache_dir: Optional[str] = None, + file_name: Optional[str] = None, + provider: Optional[str] = None, + sess_options: Optional["SessionOptions"] = None, + **kwargs, + ): + model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME + # load model from local directory + if not os.path.isdir(model_id): + raise Exception(f"Model not found: {model_id}") + model = IAIOnnxRuntimeModel.load_model( + os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options + ) + + return cls(model=model, **kwargs) + + diff --git a/invokeai/backend/model_management/models/stable_diffusion_onnx.py b/invokeai/backend/model_management/models/stable_diffusion_onnx.py new file mode 100644 index 0000000000..24111bea36 --- /dev/null +++ b/invokeai/backend/model_management/models/stable_diffusion_onnx.py @@ -0,0 +1,156 @@ +import os +import json +from enum import Enum +from pydantic import Field +from pathlib import Path +from typing import Literal, Optional, Union +from .base import ( + ModelBase, + ModelConfigBase, + BaseModelType, + ModelType, + SubModelType, + ModelVariantType, + DiffusersModel, + SchedulerPredictionType, + SilenceWarnings, + read_checkpoint_meta, + classproperty, + OnnxRuntimeModel, + IAIOnnxRuntimeModel, +) +from invokeai.app.services.config import InvokeAIAppConfig + +class ONNXStableDiffusion1Model(DiffusersModel): + + class Config(ModelConfigBase): + model_format: None + variant: ModelVariantType + + + def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): + assert base_model == BaseModelType.StableDiffusion1 + assert model_type == ModelType.ONNX + super().__init__( + model_path=model_path, + base_model=BaseModelType.StableDiffusion1, + model_type=ModelType.ONNX, + ) + + for child_name, child_type in self.child_types.items(): + if child_type is OnnxRuntimeModel: + self.child_types[child_name] = IAIOnnxRuntimeModel + + # TODO: check that no optimum models provided + + @classmethod + def probe_config(cls, path: str, **kwargs): + model_format = cls.detect_format(path) + in_channels = 4 # TODO: + + if in_channels == 9: + variant = ModelVariantType.Inpaint + elif in_channels == 4: + variant = ModelVariantType.Normal + else: + raise Exception("Unkown stable diffusion 1.* model format") + + + return cls.create_config( + path=path, + model_format=model_format, + + variant=variant, + ) + + @classproperty + def save_to_config(cls) -> bool: + return True + + @classmethod + def detect_format(cls, model_path: str): + return None + + @classmethod + def convert_if_required( + cls, + model_path: str, + output_path: str, + config: ModelConfigBase, + base_model: BaseModelType, + ) -> str: + return model_path + +class ONNXStableDiffusion2Model(DiffusersModel): + + # TODO: check that configs overwriten properly + class Config(ModelConfigBase): + model_format: None + variant: ModelVariantType + prediction_type: SchedulerPredictionType + upcast_attention: bool + + + def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType): + assert base_model == BaseModelType.StableDiffusion2 + assert model_type == ModelType.ONNX + super().__init__( + model_path=model_path, + base_model=BaseModelType.StableDiffusion2, + model_type=ModelType.ONNX, + ) + + for child_name, child_type in self.child_types.items(): + if child_type is OnnxRuntimeModel: + self.child_types[child_name] = IAIOnnxRuntimeModel + # TODO: check that no optimum models provided + + @classmethod + def probe_config(cls, path: str, **kwargs): + model_format = cls.detect_format(path) + in_channels = 4 # TODO: + + if in_channels == 9: + variant = ModelVariantType.Inpaint + elif in_channels == 5: + variant = ModelVariantType.Depth + elif in_channels == 4: + variant = ModelVariantType.Normal + else: + raise Exception("Unkown stable diffusion 2.* model format") + + if variant == ModelVariantType.Normal: + prediction_type = SchedulerPredictionType.VPrediction + upcast_attention = True + + else: + prediction_type = SchedulerPredictionType.Epsilon + upcast_attention = False + + return cls.create_config( + path=path, + model_format=model_format, + + variant=variant, + prediction_type=prediction_type, + upcast_attention=upcast_attention, + ) + + @classproperty + def save_to_config(cls) -> bool: + return True + + @classmethod + def detect_format(cls, model_path: str): + return None + + @classmethod + def convert_if_required( + cls, + model_path: str, + output_path: str, + config: ModelConfigBase, + base_model: BaseModelType, + ) -> str: + return model_path +