ONNX Model/runtime first implementation

This commit is contained in:
Sergey Borisov 2023-06-21 02:12:21 +03:00
parent 92c86fd0b8
commit 4d337f6abc
7 changed files with 935 additions and 16 deletions

View File

@ -55,8 +55,8 @@ class LatentsOutput(BaseInvocationOutput):
def build_latents_output(latents_name: str, latents: torch.Tensor): def build_latents_output(latents_name: str, latents: torch.Tensor):
return LatentsOutput( return LatentsOutput(
latents=LatentsField(latents_name=latents_name), latents=LatentsField(latents_name=latents_name),
width=latents.size()[3] * 8, width=latents.shape[3] * 8,
height=latents.size()[2] * 8, height=latents.shape[2] * 8,
) )
class NoiseOutput(BaseInvocationOutput): class NoiseOutput(BaseInvocationOutput):

View File

@ -271,9 +271,13 @@ class LoraLoaderInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LoraLoaderOutput: def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
# TODO: ui rewrite
base_model = BaseModelType.StableDiffusion1
if not context.services.model_manager.model_exists( if not context.services.model_manager.model_exists(
base_model=base_model,
model_name=self.lora_name, model_name=self.lora_name,
model_type=SDModelType.Lora, model_type=ModelType.Lora,
): ):
raise Exception(f"Unkown lora name: {self.lora_name}!") raise Exception(f"Unkown lora name: {self.lora_name}!")
@ -289,8 +293,9 @@ class LoraLoaderInvocation(BaseInvocation):
output.unet = copy.deepcopy(self.unet) output.unet = copy.deepcopy(self.unet)
output.unet.loras.append( output.unet.loras.append(
LoraInfo( LoraInfo(
base_model=base_model,
model_name=self.lora_name, model_name=self.lora_name,
model_type=SDModelType.Lora, model_type=ModelType.Lora,
submodel=None, submodel=None,
weight=self.weight, weight=self.weight,
) )
@ -300,8 +305,9 @@ class LoraLoaderInvocation(BaseInvocation):
output.clip = copy.deepcopy(self.clip) output.clip = copy.deepcopy(self.clip)
output.clip.loras.append( output.clip.loras.append(
LoraInfo( LoraInfo(
base_model=base_model,
model_name=self.lora_name, model_name=self.lora_name,
model_type=SDModelType.Lora, model_type=ModelType.Lora,
submodel=None, submodel=None,
weight=self.weight, weight=self.weight,
) )

View File

@ -0,0 +1,441 @@
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
from contextlib import ExitStack
from typing import List, Literal, Optional, Union
import re
import inspect
from pydantic import BaseModel, Field, validator
import torch
import numpy as np
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import SchedulerMixin as Scheduler
from ..models.image import ImageCategory, ImageField, ResourceOrigin
from ...backend.model_management.lora import ONNXModelPatcher
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
InvocationConfig, InvocationContext)
from .compel import ConditioningField
from .controlnet_image_processors import ControlField
from .image import ImageOutput
from .model import ModelInfo, UNetField, VaeField
from invokeai.backend import BaseModelType, ModelType, SubModelType
from .model import ClipField
from .latent import LatentsField, LatentsOutput, build_latents_output, get_scheduler, SAMPLER_NAME_VALUES
from .compel import CompelOutput
ORT_TO_NP_TYPE = {
"tensor(bool)": np.bool_,
"tensor(int8)": np.int8,
"tensor(uint8)": np.uint8,
"tensor(int16)": np.int16,
"tensor(uint16)": np.uint16,
"tensor(int32)": np.int32,
"tensor(uint32)": np.uint32,
"tensor(int64)": np.int64,
"tensor(uint64)": np.uint64,
"tensor(float16)": np.float16,
"tensor(float)": np.float32,
"tensor(double)": np.float64,
}
class ONNXPromptInvocation(BaseInvocation):
type: Literal["prompt_onnx"] = "prompt_onnx"
prompt: str = Field(default="", description="Prompt")
clip: ClipField = Field(None, description="Clip to use")
def invoke(self, context: InvocationContext) -> CompelOutput:
tokenizer_info = context.services.model_manager.get_model(
**self.clip.tokenizer.dict(),
)
text_encoder_info = context.services.model_manager.get_model(
**self.clip.text_encoder.dict(),
)
with tokenizer_info as orig_tokenizer,\
text_encoder_info as text_encoder,\
ExitStack() as stack:
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
ti_list = []
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
name = trigger[1:-1]
try:
ti_list.append(
stack.enter_context(
context.services.model_manager.get_model(
model_name=name,
base_model=self.clip.text_encoder.base_model,
model_type=ModelType.TextualInversion,
)
)
)
except Exception:
#print(e)
#import traceback
#print(traceback.format_exc())
print(f"Warn: trigger: \"{trigger}\" not found")
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager):
text_encoder.create_session()
text_inputs = tokenizer(
self.prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
text_input_ids = text_inputs.input_ids
"""
untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
if not np.array_equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
"""
prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
text_encoder.release_session()
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
# TODO: hacky but works ;D maybe rename latents somehow?
context.services.latents.save(conditioning_name, (prompt_embeds, None))
return CompelOutput(
conditioning=ConditioningField(
conditioning_name=conditioning_name,
),
)
# Text to image
class ONNXTextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings."""
type: Literal["t2l_onnx"] = "t2l_onnx"
# Inputs
# fmt: off
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
noise: Optional[LatentsField] = Field(description="The noise to use")
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
unet: UNetField = Field(default=None, description="UNet submodel")
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
@validator("cfg_scale")
def ge_one(cls, v):
"""validate that all cfg_scale values are >= 1"""
if isinstance(v, list):
for i in v:
if i < 1:
raise ValueError('cfg_scale must be greater than 1')
else:
if v < 1:
raise ValueError('cfg_scale must be greater than 1')
return v
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model",
# "cfg_scale": "float",
"cfg_scale": "number"
}
},
}
def invoke(self, context: InvocationContext) -> LatentsOutput:
c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name)
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
if isinstance(c, torch.Tensor):
c = c.cpu().numpy()
if isinstance(uc, torch.Tensor):
uc = uc.cpu().numpy()
prompt_embeds = np.concatenate([uc, c])
latents = context.services.latents.get(self.noise.latents_name)
if isinstance(latents, torch.Tensor):
latents = latents.cpu().numpy()
# TODO: better execution device handling
latents = latents.astype(np.float32)
# get the initial random noise unless the user supplied it
do_classifier_free_guidance = True
#latents_dtype = prompt_embeds.dtype
#latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
#if latents.shape != latents_shape:
# raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)
scheduler.set_timesteps(self.steps)
latents = latents * np.float64(scheduler.init_noise_sigma)
extra_step_kwargs = dict()
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
extra_step_kwargs.update(
eta=0.0,
)
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
with unet_info as unet,\
ExitStack() as stack:
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
with ONNXModelPatcher.apply_lora_unet(unet, loras):
# TODO:
unet.create_session()
timestep_dtype = next(
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float)"
)
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
from tqdm import tqdm
for i in tqdm(range(len(scheduler.timesteps))):
t = scheduler.timesteps[i]
# expand the latents if we are doing classifier free guidance
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
latent_model_input = latent_model_input.cpu().numpy()
# predict the noise residual
timestep = np.array([t], dtype=timestep_dtype)
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
noise_pred = noise_pred[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
scheduler_output = scheduler.step(
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
)
latents = scheduler_output.prev_sample.numpy()
# call the callback, if provided
#if callback is not None and i % callback_steps == 0:
# callback(i, t, latents)
unet.release_session()
torch.cuda.empty_cache()
name = f'{context.graph_execution_state_id}__{self.id}'
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)
@staticmethod
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
# Latent to image
class ONNXLatentsToImageInvocation(BaseInvocation):
"""Generates an image from latents."""
type: Literal["l2i_onnx"] = "l2i_onnx"
# Inputs
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
vae: VaeField = Field(default=None, description="Vae submodel")
#tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["latents", "image"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput:
latents = context.services.latents.get(self.latents.latents_name)
if self.vae.vae.submodel != SubModelType.VaeDecoder:
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
vae_info = context.services.model_manager.get_model(
**self.vae.vae.dict(),
)
# clear memory as vae decode can request a lot
torch.cuda.empty_cache()
with vae_info as vae:
vae.create_session()
latents = 1 / 0.18215 * latents
# image = self.vae_decoder(latent_sample=latents)[0]
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
image = np.concatenate(
[vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
)
image = np.clip(image / 2 + 0.5, 0, 1)
image = image.transpose((0, 2, 3, 1))
image = VaeImageProcessor.numpy_to_pil(image)[0]
vae.release_session()
torch.cuda.empty_cache()
image_dto = context.services.images.create(
image=image,
image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL,
node_id=self.id,
session_id=context.graph_execution_state_id,
)
return ImageOutput(
image=ImageField(image_name=image_dto.image_name),
width=image_dto.width,
height=image_dto.height,
)
class ONNXModelLoaderOutput(BaseInvocationOutput):
"""Model loader output"""
#fmt: off
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae_decoder: VaeField = Field(default=None, description="Vae submodel")
vae_encoder: VaeField = Field(default=None, description="Vae submodel")
#fmt: on
class ONNXSD1ModelLoaderInvocation(BaseInvocation):
"""Loading submodels of selected model."""
type: Literal["sd1_model_loader_onnx"] = "sd1_model_loader_onnx"
model_name: str = Field(default="", description="Model to load")
# TODO: precision?
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["model", "loader"],
"type_hints": {
"model_name": "model" # TODO: rename to model_name?
}
},
}
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
model_name = "stable-diffusion-v1-5"
base_model = BaseModelType.StableDiffusion1
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=model_name,
base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.ONNX,
):
raise Exception(f"Unkown model name: {model_name}!")
return ONNXModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.Tokenizer,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.TextEncoder,
),
loras=[],
),
vae_decoder=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.VaeDecoder,
),
),
vae_encoder=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=ModelType.ONNX,
submodel=SubModelType.VaeEncoder,
),
)
)

View File

@ -11,6 +11,8 @@ from torch.utils.hooks import RemovableHandle
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
from transformers import CLIPTextModel from transformers import CLIPTextModel
from onnx import numpy_helper
import numpy as np
from compel.embeddings_provider import BaseTextualInversionManager from compel.embeddings_provider import BaseTextualInversionManager
@ -70,7 +72,7 @@ class LoRALayerBase:
op = torch.nn.functional.linear op = torch.nn.functional.linear
extra_args = {} extra_args = {}
weight = self.get_weight(module) weight = self.get_weight()
bias = self.bias if self.bias is not None else 0 bias = self.bias if self.bias is not None else 0
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0 scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
@ -81,7 +83,7 @@ class LoRALayerBase:
**extra_args, **extra_args,
) * multiplier * scale ) * multiplier * scale
def get_weight(self, module: torch.nn.Module): def get_weight(self):
raise NotImplementedError() raise NotImplementedError()
def calc_size(self) -> int: def calc_size(self) -> int:
@ -122,7 +124,7 @@ class LoRALayer(LoRALayerBase):
self.rank = self.down.shape[0] self.rank = self.down.shape[0]
def get_weight(self, module: torch.nn.Module): def get_weight(self):
if self.mid is not None: if self.mid is not None:
up = self.up.reshape(up.shape[0], up.shape[1]) up = self.up.reshape(up.shape[0], up.shape[1])
down = self.down.reshape(up.shape[0], up.shape[1]) down = self.down.reshape(up.shape[0], up.shape[1])
@ -185,7 +187,7 @@ class LoHALayer(LoRALayerBase):
self.rank = self.w1_b.shape[0] self.rank = self.w1_b.shape[0]
def get_weight(self, module: torch.nn.Module): def get_weight(self):
if self.t1 is None: if self.t1 is None:
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b) weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
@ -271,7 +273,7 @@ class LoKRLayer(LoRALayerBase):
else: else:
self.rank = None # unscaled self.rank = None # unscaled
def get_weight(self, module: torch.nn.Module): def get_weight(self):
w1 = self.w1 w1 = self.w1
if w1 is None: if w1 is None:
w1 = self.w1_a @ self.w1_b w1 = self.w1_a @ self.w1_b
@ -286,7 +288,7 @@ class LoKRLayer(LoRALayerBase):
if len(w2.shape) == 4: if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2) w1 = w1.unsqueeze(2).unsqueeze(2)
w2 = w2.contiguous() w2 = w2.contiguous()
weight = torch.kron(w1, w2).reshape(module.weight.shape) # TODO: can we remove reshape? weight = torch.kron(w1, w2)#.reshape(module.weight.shape) # TODO: can we remove reshape?
return weight return weight
@ -676,3 +678,212 @@ class TextualInversionManager(BaseTextualInversionManager):
return new_token_ids return new_token_ids
class ONNXModelPatcher:
@classmethod
@contextmanager
def apply_lora_unet(
cls,
unet: OnnxRuntimeModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(unet, loras, "lora_unet_"):
yield
@classmethod
@contextmanager
def apply_lora_text_encoder(
cls,
text_encoder: OnnxRuntimeModel,
loras: List[Tuple[LoRAModel, float]],
):
with cls.apply_lora(text_encoder, loras, "lora_te_"):
yield
@classmethod
@contextmanager
def apply_lora(
cls,
model: IAIOnnxRuntimeModel,
loras: List[Tuple[LoraModel, float]],
prefix: str,
):
from .models.base import IAIOnnxRuntimeModel
if not isinstance(model, IAIOnnxRuntimeModel):
raise Exception("Only IAIOnnxRuntimeModel models supported")
base_model = model.proto
orig_nodes = dict()
try:
blended_loras = dict()
for lora, lora_weight in loras:
for layer_key, layer in lora.layers.items():
if not layer_key.startswith(prefix):
continue
layer_key = layer_key.replace(prefix, "")
layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight
if layer_key is blended_loras:
blended_loras[layer_key] += layer_weight
else:
blended_loras[layer_key] = layer_weight
initializer_idx = dict()
for idx, init in enumerate(base_model.graph.initializer):
initializer_idx[init.name.replace(".", "_")] = idx
node_idx = dict()
for idx, node in enumerate(base_model.graph.node):
node_idx[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = idx
for layer_key, weights in blended_loras.items():
conv_key = layer_key + "_Conv"
gemm_key = layer_key + "_Gemm"
matmul_key = layer_key + "_MatMul"
if conv_key in node_idx or gemm_key in node_idx:
if conv_key in node_idx:
conv_node = base_model.graph.node[node_idx[conv_key]]
else:
conv_node = base_model.graph.node[node_idx[gemm_key]]
weight_name = [n for n in conv_node.input if ".weight" in n][0]
weight_name = weight_name.replace(".", "_")
weight_idx = initializer_idx[weight_name]
weight_node = base_model.graph.initializer[weight_idx]
orig_weights = numpy_helper.to_array(weight_node)
if orig_weights.shape[-2:] == (1, 1):
if weights.shape[-2:] == (1, 1):
new_weights = orig_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
else:
new_weights = orig_weights.squeeze((3, 2)) + weights
new_weights = np.expand_dims(new_weights, (2, 3))
else:
if orig_weights.shape != weights.shape:
new_weights = orig_weights + weights.reshape(orig_weights.shape)
else:
new_weights = orig_weights + weights
new_node = numpy_helper.from_array(new_weights.astype(orig_weights.dtype), weight_node.name)
orig_nodes[weight_idx] = base_model.graph.initializer[weight_idx]
del base_model.graph.initializer[weight_idx]
base_model.graph.initializer.insert(weight_idx, new_node)
elif matmul_key in node_idx:
weight_node = base_model.graph.node[node_idx[matmul_key]]
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
matmul_idx = initializer_idx[matmul_name]
matmul_node = base_model.graph.initializer[matmul_idx]
orig_weights = numpy_helper.to_array(matmul_node)
new_weights = orig_weights + weights.transpose()
# replace the original initializer
new_node = numpy_helper.from_array(new_weights.astype(orig_weights.dtype), matmul_node.name)
orig_nodes[matmul_idx] = base_model.graph.initializer[matmul_idx]
del base_model.graph.initializer[matmul_idx]
base_model.graph.initializer.insert(matmul_idx, new_node)
else:
# warn? err?
pass
yield
finally:
# restore original weights
for idx, orig_node in orig_nodes.items():
del base_model.graph.initializer[idx]
base_model.graph.initializer.insert(idx, orig_node)
@classmethod
@contextmanager
def apply_ti(
cls,
tokenizer: CLIPTokenizer,
text_encoder: IAIOnnxRuntimeModel,
ti_list: List[Any],
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
from .models.base import IAIOnnxRuntimeModel
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
raise Exception("Only IAIOnnxRuntimeModel models supported")
init_tokens_count = None
new_tokens_added = None
try:
ti_tokenizer = copy.deepcopy(tokenizer)
ti_manager = TextualInversionManager(ti_tokenizer)
def _get_trigger(ti, index):
trigger = ti.name
if index > 0:
trigger += f"-!pad-{i}"
return f"<{trigger}>"
# modify tokenizer
new_tokens_added = 0
for ti in ti_list:
for i in range(ti.embedding.shape[0]):
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
# modify text_encoder
for i in range(len(text_encoder.proto.graph.initializer)):
if text_encoder.proto.graph.initializer[i].name == "text_model.embeddings.token_embedding.weight":
embeddings_node_idx = i
break
else:
raise Exception("text_model.embeddings.token_embedding.weight node not found")
embeddings_node_orig = text_encoder.proto.graph.initializer[embeddings_node_idx]
base_weights = numpy_helper.to_array(embeddings_node_orig)
embedding_weights = np.concatenate((base_weights, np.zeros((new_tokens_added, base_weights.shape[1]))), axis=0)
for ti in ti_list:
ti_tokens = []
for i in range(ti.embedding.shape[0]):
embedding = ti.embedding[i].detach().numpy()
trigger = _get_trigger(ti, i)
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
if token_id == ti_tokenizer.unk_token_id:
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
if embedding_weights[token_id].shape != embedding.shape:
raise ValueError(
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embedding_weights[token_id].shape[0]}."
)
embedding_weights[token_id] = embedding
ti_tokens.append(token_id)
if len(ti_tokens) > 1:
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
new_embeddings_node = numpy_helper.from_array(embedding_weights.astype(base_weights.dtype), embeddings_node_orig.name)
del text_encoder.proto.graph.initializer[embeddings_node_idx]
text_encoder.proto.graph.initializer.insert(embeddings_node_idx, new_embeddings_node)
yield ti_tokenizer, ti_manager
finally:
# restore
if embeddings_node_orig is not None:
del text_encoder.proto.graph.initializer[embeddings_node_idx]
text_encoder.proto.graph.initializer.insert(embeddings_node_idx, embeddings_node_orig)

View File

@ -9,9 +9,12 @@ from .lora import LoRAModel
from .controlnet import ControlNetModel # TODO: from .controlnet import ControlNetModel # TODO:
from .textual_inversion import TextualInversionModel from .textual_inversion import TextualInversionModel
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
MODEL_CLASSES = { MODEL_CLASSES = {
BaseModelType.StableDiffusion1: { BaseModelType.StableDiffusion1: {
ModelType.Pipeline: StableDiffusion1Model, ModelType.Pipeline: StableDiffusion1Model,
ModelType.ONNX: ONNXStableDiffusion1Model,
ModelType.Vae: VaeModel, ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel, ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel, ModelType.ControlNet: ControlNetModel,
@ -19,6 +22,7 @@ MODEL_CLASSES = {
}, },
BaseModelType.StableDiffusion2: { BaseModelType.StableDiffusion2: {
ModelType.Pipeline: StableDiffusion2Model, ModelType.Pipeline: StableDiffusion2Model,
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.Vae: VaeModel, ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel, ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel, ModelType.ControlNet: ControlNetModel,

View File

@ -5,19 +5,27 @@ import inspect
from enum import Enum from enum import Enum
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import torch import torch
import numpy as np
import safetensors.torch import safetensors.torch
from diffusers import DiffusionPipeline, ConfigMixin from pathlib import Path
from diffusers import DiffusionPipeline, ConfigMixin, OnnxRuntimeModel
from contextlib import suppress from contextlib import suppress
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
import onnx
from onnx import numpy_helper
from onnx.external_data_helper import set_external_data
from onnxruntime import InferenceSession, OrtValue, SessionOptions
class BaseModelType(str, Enum): class BaseModelType(str, Enum):
StableDiffusion1 = "sd-1" StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2" StableDiffusion2 = "sd-2"
#Kandinsky2_1 = "kandinsky-2.1" #Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum): class ModelType(str, Enum):
ONNX = "onnx"
Pipeline = "pipeline" Pipeline = "pipeline"
Vae = "vae" Vae = "vae"
Lora = "lora" Lora = "lora"
@ -29,6 +37,8 @@ class SubModelType(str, Enum):
TextEncoder = "text_encoder" TextEncoder = "text_encoder"
Tokenizer = "tokenizer" Tokenizer = "tokenizer"
Vae = "vae" Vae = "vae"
VaeDecoder = "vae_decoder"
VaeEncoder = "vae_encoder"
Scheduler = "scheduler" Scheduler = "scheduler"
SafetyChecker = "safety_checker" SafetyChecker = "safety_checker"
#MoVQ = "movq" #MoVQ = "movq"
@ -240,16 +250,18 @@ class DiffusersModel(ModelBase):
try: try:
# TODO: set cache_dir to /dev/null to be sure that cache not used? # TODO: set cache_dir to /dev/null to be sure that cache not used?
model = self.child_types[child_type].from_pretrained( model = self.child_types[child_type].from_pretrained(
self.model_path, os.path.join(self.model_path, child_type.value),
subfolder=child_type.value, #subfolder=child_type.value,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
variant=variant, variant=variant,
local_files_only=True, local_files_only=True,
) )
break break
except Exception as e: except Exception as e:
#print("====ERR LOAD====") print("====ERR LOAD====")
#print(f"{variant}: {e}") print(f"{variant}: {e}")
import traceback
traceback.print_exc()
pass pass
else: else:
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model") raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
@ -413,3 +425,92 @@ class SilenceWarnings(object):
transformers_logging.set_verbosity(self.transformers_verbosity) transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity) diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter('default') warnings.simplefilter('default')
def buffer_external_data_tensors(model):
external_data = dict()
for tensor in model.graph.initializer:
name = tensor.name
if tensor.HasField("raw_data"):
npt = numpy_helper.to_array(tensor)
orv = OrtValue.ortvalue_from_numpy(npt)
external_data[name] = orv
set_external_data(tensor, location="tmp.bin")
tensor.name = name
tensor.ClearField("raw_data")
return (model, external_data)
ONNX_WEIGHTS_NAME = "model.onnx"
class IAIOnnxRuntimeModel(OnnxRuntimeModel):
def __init__(self, model: tuple, **kwargs):
self.proto, self.provider, self.sess_options = model
self.session = None
self._external_data = dict()
def __call__(self, **kwargs):
if self.session is None:
raise Exception("You should call create_session before running model")
inputs = {k: np.array(v) for k, v in kwargs.items()}
return self.session.run(None, inputs)
def create_session(self):
if self.session is None:
#onnx.save(self.proto, "tmp.onnx")
#onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
(trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
sess = SessionOptions()
self._external_data.update(**external_data)
sess.add_external_initializers(list(self._external_data.keys()), list(self._external_data.values()))
self.session = InferenceSession(trimmed_model.SerializeToString(), providers=[self.provider], sess_options=sess)
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
def release_session(self):
self.session = None
import gc
gc.collect()
@staticmethod
def load_model(path: Union[str, Path], provider=None, sess_options=None):
"""
Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
Arguments:
path (`str` or `Path`):
Directory from which to load
provider(`str`, *optional*):
Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
"""
if provider is None:
#logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
print("No onnxruntime provider specified, using CPUExecutionProvider")
provider = "CPUExecutionProvider"
# TODO: check that provider available?
return (onnx.load(path), provider, sess_options)
@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
use_auth_token: Optional[Union[bool, str, None]] = None,
revision: Optional[Union[str, None]] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
file_name: Optional[str] = None,
provider: Optional[str] = None,
sess_options: Optional["SessionOptions"] = None,
**kwargs,
):
model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
# load model from local directory
if not os.path.isdir(model_id):
raise Exception(f"Model not found: {model_id}")
model = IAIOnnxRuntimeModel.load_model(
os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options
)
return cls(model=model, **kwargs)

View File

@ -0,0 +1,156 @@
import os
import json
from enum import Enum
from pydantic import Field
from pathlib import Path
from typing import Literal, Optional, Union
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
ModelVariantType,
DiffusersModel,
SchedulerPredictionType,
SilenceWarnings,
read_checkpoint_meta,
classproperty,
OnnxRuntimeModel,
IAIOnnxRuntimeModel,
)
from invokeai.app.services.config import InvokeAIAppConfig
class ONNXStableDiffusion1Model(DiffusersModel):
class Config(ModelConfigBase):
model_format: None
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1
assert model_type == ModelType.ONNX
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.ONNX,
)
for child_name, child_type in self.child_types.items():
if child_type is OnnxRuntimeModel:
self.child_types[child_name] = IAIOnnxRuntimeModel
# TODO: check that no optimum models provided
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
in_channels = 4 # TODO:
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 1.* model format")
return cls.create_config(
path=path,
model_format=model_format,
variant=variant,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
return None
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
return model_path
class ONNXStableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
class Config(ModelConfigBase):
model_format: None
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion2
assert model_type == ModelType.ONNX
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion2,
model_type=ModelType.ONNX,
)
for child_name, child_type in self.child_types.items():
if child_type is OnnxRuntimeModel:
self.child_types[child_name] = IAIOnnxRuntimeModel
# TODO: check that no optimum models provided
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
in_channels = 4 # TODO:
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 5:
variant = ModelVariantType.Depth
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 2.* model format")
if variant == ModelVariantType.Normal:
prediction_type = SchedulerPredictionType.VPrediction
upcast_attention = True
else:
prediction_type = SchedulerPredictionType.Epsilon
upcast_attention = False
return cls.create_config(
path=path,
model_format=model_format,
variant=variant,
prediction_type=prediction_type,
upcast_attention=upcast_attention,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
return None
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
return model_path