mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
ONNX Model/runtime first implementation
This commit is contained in:
parent
92c86fd0b8
commit
4d337f6abc
@ -55,8 +55,8 @@ class LatentsOutput(BaseInvocationOutput):
|
|||||||
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
def build_latents_output(latents_name: str, latents: torch.Tensor):
|
||||||
return LatentsOutput(
|
return LatentsOutput(
|
||||||
latents=LatentsField(latents_name=latents_name),
|
latents=LatentsField(latents_name=latents_name),
|
||||||
width=latents.size()[3] * 8,
|
width=latents.shape[3] * 8,
|
||||||
height=latents.size()[2] * 8,
|
height=latents.shape[2] * 8,
|
||||||
)
|
)
|
||||||
|
|
||||||
class NoiseOutput(BaseInvocationOutput):
|
class NoiseOutput(BaseInvocationOutput):
|
||||||
|
@ -271,9 +271,13 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
def invoke(self, context: InvocationContext) -> LoraLoaderOutput:
|
||||||
|
|
||||||
|
# TODO: ui rewrite
|
||||||
|
base_model = BaseModelType.StableDiffusion1
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
if not context.services.model_manager.model_exists(
|
||||||
|
base_model=base_model,
|
||||||
model_name=self.lora_name,
|
model_name=self.lora_name,
|
||||||
model_type=SDModelType.Lora,
|
model_type=ModelType.Lora,
|
||||||
):
|
):
|
||||||
raise Exception(f"Unkown lora name: {self.lora_name}!")
|
raise Exception(f"Unkown lora name: {self.lora_name}!")
|
||||||
|
|
||||||
@ -289,8 +293,9 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
output.unet = copy.deepcopy(self.unet)
|
output.unet = copy.deepcopy(self.unet)
|
||||||
output.unet.loras.append(
|
output.unet.loras.append(
|
||||||
LoraInfo(
|
LoraInfo(
|
||||||
|
base_model=base_model,
|
||||||
model_name=self.lora_name,
|
model_name=self.lora_name,
|
||||||
model_type=SDModelType.Lora,
|
model_type=ModelType.Lora,
|
||||||
submodel=None,
|
submodel=None,
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
)
|
)
|
||||||
@ -300,8 +305,9 @@ class LoraLoaderInvocation(BaseInvocation):
|
|||||||
output.clip = copy.deepcopy(self.clip)
|
output.clip = copy.deepcopy(self.clip)
|
||||||
output.clip.loras.append(
|
output.clip.loras.append(
|
||||||
LoraInfo(
|
LoraInfo(
|
||||||
|
base_model=base_model,
|
||||||
model_name=self.lora_name,
|
model_name=self.lora_name,
|
||||||
model_type=SDModelType.Lora,
|
model_type=ModelType.Lora,
|
||||||
submodel=None,
|
submodel=None,
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
)
|
)
|
||||||
|
441
invokeai/app/invocations/onnx.py
Normal file
441
invokeai/app/invocations/onnx.py
Normal file
@ -0,0 +1,441 @@
|
|||||||
|
# Copyright (c) 2023 Borisov Sergey (https://github.com/StAlKeR7779)
|
||||||
|
|
||||||
|
from contextlib import ExitStack
|
||||||
|
from typing import List, Literal, Optional, Union
|
||||||
|
|
||||||
|
import re
|
||||||
|
import inspect
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from diffusers import ControlNetModel, DPMSolverMultistepScheduler
|
||||||
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
|
||||||
|
from ..models.image import ImageCategory, ImageField, ResourceOrigin
|
||||||
|
from ...backend.model_management.lora import ONNXModelPatcher
|
||||||
|
from .baseinvocation import (BaseInvocation, BaseInvocationOutput,
|
||||||
|
InvocationConfig, InvocationContext)
|
||||||
|
from .compel import ConditioningField
|
||||||
|
from .controlnet_image_processors import ControlField
|
||||||
|
from .image import ImageOutput
|
||||||
|
from .model import ModelInfo, UNetField, VaeField
|
||||||
|
|
||||||
|
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||||
|
|
||||||
|
from .model import ClipField
|
||||||
|
from .latent import LatentsField, LatentsOutput, build_latents_output, get_scheduler, SAMPLER_NAME_VALUES
|
||||||
|
from .compel import CompelOutput
|
||||||
|
|
||||||
|
|
||||||
|
ORT_TO_NP_TYPE = {
|
||||||
|
"tensor(bool)": np.bool_,
|
||||||
|
"tensor(int8)": np.int8,
|
||||||
|
"tensor(uint8)": np.uint8,
|
||||||
|
"tensor(int16)": np.int16,
|
||||||
|
"tensor(uint16)": np.uint16,
|
||||||
|
"tensor(int32)": np.int32,
|
||||||
|
"tensor(uint32)": np.uint32,
|
||||||
|
"tensor(int64)": np.int64,
|
||||||
|
"tensor(uint64)": np.uint64,
|
||||||
|
"tensor(float16)": np.float16,
|
||||||
|
"tensor(float)": np.float32,
|
||||||
|
"tensor(double)": np.float64,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXPromptInvocation(BaseInvocation):
|
||||||
|
type: Literal["prompt_onnx"] = "prompt_onnx"
|
||||||
|
|
||||||
|
prompt: str = Field(default="", description="Prompt")
|
||||||
|
clip: ClipField = Field(None, description="Clip to use")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> CompelOutput:
|
||||||
|
tokenizer_info = context.services.model_manager.get_model(
|
||||||
|
**self.clip.tokenizer.dict(),
|
||||||
|
)
|
||||||
|
text_encoder_info = context.services.model_manager.get_model(
|
||||||
|
**self.clip.text_encoder.dict(),
|
||||||
|
)
|
||||||
|
with tokenizer_info as orig_tokenizer,\
|
||||||
|
text_encoder_info as text_encoder,\
|
||||||
|
ExitStack() as stack:
|
||||||
|
|
||||||
|
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
|
ti_list = []
|
||||||
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||||
|
name = trigger[1:-1]
|
||||||
|
try:
|
||||||
|
ti_list.append(
|
||||||
|
stack.enter_context(
|
||||||
|
context.services.model_manager.get_model(
|
||||||
|
model_name=name,
|
||||||
|
base_model=self.clip.text_encoder.base_model,
|
||||||
|
model_type=ModelType.TextualInversion,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
#print(e)
|
||||||
|
#import traceback
|
||||||
|
#print(traceback.format_exc())
|
||||||
|
print(f"Warn: trigger: \"{trigger}\" not found")
|
||||||
|
|
||||||
|
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),\
|
||||||
|
ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager):
|
||||||
|
|
||||||
|
text_encoder.create_session()
|
||||||
|
|
||||||
|
text_inputs = tokenizer(
|
||||||
|
self.prompt,
|
||||||
|
padding="max_length",
|
||||||
|
max_length=tokenizer.model_max_length,
|
||||||
|
truncation=True,
|
||||||
|
return_tensors="np",
|
||||||
|
)
|
||||||
|
text_input_ids = text_inputs.input_ids
|
||||||
|
"""
|
||||||
|
untruncated_ids = tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
|
||||||
|
|
||||||
|
if not np.array_equal(text_input_ids, untruncated_ids):
|
||||||
|
removed_text = self.tokenizer.batch_decode(
|
||||||
|
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||||
|
)
|
||||||
|
logger.warning(
|
||||||
|
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||||
|
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt_embeds = text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
|
||||||
|
|
||||||
|
text_encoder.release_session()
|
||||||
|
|
||||||
|
conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning"
|
||||||
|
|
||||||
|
# TODO: hacky but works ;D maybe rename latents somehow?
|
||||||
|
context.services.latents.save(conditioning_name, (prompt_embeds, None))
|
||||||
|
|
||||||
|
return CompelOutput(
|
||||||
|
conditioning=ConditioningField(
|
||||||
|
conditioning_name=conditioning_name,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Text to image
|
||||||
|
class ONNXTextToLatentsInvocation(BaseInvocation):
|
||||||
|
"""Generates latents from conditionings."""
|
||||||
|
|
||||||
|
type: Literal["t2l_onnx"] = "t2l_onnx"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
# fmt: off
|
||||||
|
positive_conditioning: Optional[ConditioningField] = Field(description="Positive conditioning for generation")
|
||||||
|
negative_conditioning: Optional[ConditioningField] = Field(description="Negative conditioning for generation")
|
||||||
|
noise: Optional[LatentsField] = Field(description="The noise to use")
|
||||||
|
steps: int = Field(default=10, gt=0, description="The number of steps to use to generate the image")
|
||||||
|
cfg_scale: Union[float, List[float]] = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
|
||||||
|
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
|
||||||
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
|
#control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
|
||||||
|
#seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
|
||||||
|
#seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
@validator("cfg_scale")
|
||||||
|
def ge_one(cls, v):
|
||||||
|
"""validate that all cfg_scale values are >= 1"""
|
||||||
|
if isinstance(v, list):
|
||||||
|
for i in v:
|
||||||
|
if i < 1:
|
||||||
|
raise ValueError('cfg_scale must be greater than 1')
|
||||||
|
else:
|
||||||
|
if v < 1:
|
||||||
|
raise ValueError('cfg_scale must be greater than 1')
|
||||||
|
return v
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"tags": ["latents"],
|
||||||
|
"type_hints": {
|
||||||
|
"model": "model",
|
||||||
|
# "cfg_scale": "float",
|
||||||
|
"cfg_scale": "number"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
||||||
|
c, _ = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
||||||
|
uc, _ = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||||
|
if isinstance(c, torch.Tensor):
|
||||||
|
c = c.cpu().numpy()
|
||||||
|
if isinstance(uc, torch.Tensor):
|
||||||
|
uc = uc.cpu().numpy()
|
||||||
|
|
||||||
|
prompt_embeds = np.concatenate([uc, c])
|
||||||
|
|
||||||
|
latents = context.services.latents.get(self.noise.latents_name)
|
||||||
|
if isinstance(latents, torch.Tensor):
|
||||||
|
latents = latents.cpu().numpy()
|
||||||
|
|
||||||
|
# TODO: better execution device handling
|
||||||
|
latents = latents.astype(np.float32)
|
||||||
|
|
||||||
|
# get the initial random noise unless the user supplied it
|
||||||
|
do_classifier_free_guidance = True
|
||||||
|
#latents_dtype = prompt_embeds.dtype
|
||||||
|
#latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
|
||||||
|
#if latents.shape != latents_shape:
|
||||||
|
# raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||||
|
|
||||||
|
scheduler = get_scheduler(
|
||||||
|
context=context,
|
||||||
|
scheduler_info=self.unet.scheduler,
|
||||||
|
scheduler_name=self.scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler.set_timesteps(self.steps)
|
||||||
|
latents = latents * np.float64(scheduler.init_noise_sigma)
|
||||||
|
|
||||||
|
extra_step_kwargs = dict()
|
||||||
|
if "eta" in set(inspect.signature(scheduler.step).parameters.keys()):
|
||||||
|
extra_step_kwargs.update(
|
||||||
|
eta=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
unet_info = context.services.model_manager.get_model(**self.unet.unet.dict())
|
||||||
|
|
||||||
|
with unet_info as unet,\
|
||||||
|
ExitStack() as stack:
|
||||||
|
|
||||||
|
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||||
|
|
||||||
|
with ONNXModelPatcher.apply_lora_unet(unet, loras):
|
||||||
|
# TODO:
|
||||||
|
unet.create_session()
|
||||||
|
|
||||||
|
timestep_dtype = next(
|
||||||
|
(input.type for input in unet.session.get_inputs() if input.name == "timestep"), "tensor(float)"
|
||||||
|
)
|
||||||
|
timestep_dtype = ORT_TO_NP_TYPE[timestep_dtype]
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
for i in tqdm(range(len(scheduler.timesteps))):
|
||||||
|
t = scheduler.timesteps[i]
|
||||||
|
# expand the latents if we are doing classifier free guidance
|
||||||
|
latent_model_input = np.concatenate([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
latent_model_input = scheduler.scale_model_input(torch.from_numpy(latent_model_input), t)
|
||||||
|
latent_model_input = latent_model_input.cpu().numpy()
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
timestep = np.array([t], dtype=timestep_dtype)
|
||||||
|
noise_pred = unet(sample=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds)
|
||||||
|
noise_pred = noise_pred[0]
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
|
||||||
|
noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
scheduler_output = scheduler.step(
|
||||||
|
torch.from_numpy(noise_pred), t, torch.from_numpy(latents), **extra_step_kwargs
|
||||||
|
)
|
||||||
|
latents = scheduler_output.prev_sample.numpy()
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
#if callback is not None and i % callback_steps == 0:
|
||||||
|
# callback(i, t, latents)
|
||||||
|
|
||||||
|
unet.release_session()
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
name = f'{context.graph_execution_state_id}__{self.id}'
|
||||||
|
context.services.latents.save(name, latents)
|
||||||
|
return build_latents_output(latents_name=name, latents=latents)
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def numpy_to_pil(images):
|
||||||
|
"""
|
||||||
|
Convert a numpy image or a batch of images to a PIL image.
|
||||||
|
"""
|
||||||
|
if images.ndim == 3:
|
||||||
|
images = images[None, ...]
|
||||||
|
images = (images * 255).round().astype("uint8")
|
||||||
|
if images.shape[-1] == 1:
|
||||||
|
# special case for grayscale (single channel) images
|
||||||
|
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
||||||
|
else:
|
||||||
|
pil_images = [Image.fromarray(image) for image in images]
|
||||||
|
|
||||||
|
return pil_images
|
||||||
|
|
||||||
|
|
||||||
|
# Latent to image
|
||||||
|
class ONNXLatentsToImageInvocation(BaseInvocation):
|
||||||
|
"""Generates an image from latents."""
|
||||||
|
|
||||||
|
type: Literal["l2i_onnx"] = "l2i_onnx"
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
latents: Optional[LatentsField] = Field(description="The latents to generate an image from")
|
||||||
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
|
#tiled: bool = Field(default=False, description="Decode latents by overlaping tiles(less memory consumption)")
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"tags": ["latents", "image"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
|
latents = context.services.latents.get(self.latents.latents_name)
|
||||||
|
|
||||||
|
if self.vae.vae.submodel != SubModelType.VaeDecoder:
|
||||||
|
raise Exception(f"Expected vae_decoder, found: {self.vae.vae.model_type}")
|
||||||
|
|
||||||
|
vae_info = context.services.model_manager.get_model(
|
||||||
|
**self.vae.vae.dict(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# clear memory as vae decode can request a lot
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
with vae_info as vae:
|
||||||
|
|
||||||
|
vae.create_session()
|
||||||
|
|
||||||
|
latents = 1 / 0.18215 * latents
|
||||||
|
# image = self.vae_decoder(latent_sample=latents)[0]
|
||||||
|
# it seems likes there is a strange result for using half-precision vae decoder if batchsize>1
|
||||||
|
image = np.concatenate(
|
||||||
|
[vae(latent_sample=latents[i : i + 1])[0] for i in range(latents.shape[0])]
|
||||||
|
)
|
||||||
|
|
||||||
|
image = np.clip(image / 2 + 0.5, 0, 1)
|
||||||
|
image = image.transpose((0, 2, 3, 1))
|
||||||
|
|
||||||
|
image = VaeImageProcessor.numpy_to_pil(image)[0]
|
||||||
|
|
||||||
|
vae.release_session()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
image_dto = context.services.images.create(
|
||||||
|
image=image,
|
||||||
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
image_category=ImageCategory.GENERAL,
|
||||||
|
node_id=self.id,
|
||||||
|
session_id=context.graph_execution_state_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ImageOutput(
|
||||||
|
image=ImageField(image_name=image_dto.image_name),
|
||||||
|
width=image_dto.width,
|
||||||
|
height=image_dto.height,
|
||||||
|
)
|
||||||
|
|
||||||
|
class ONNXModelLoaderOutput(BaseInvocationOutput):
|
||||||
|
"""Model loader output"""
|
||||||
|
|
||||||
|
#fmt: off
|
||||||
|
type: Literal["model_loader_output_onnx"] = "model_loader_output_onnx"
|
||||||
|
|
||||||
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
|
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||||
|
vae_decoder: VaeField = Field(default=None, description="Vae submodel")
|
||||||
|
vae_encoder: VaeField = Field(default=None, description="Vae submodel")
|
||||||
|
#fmt: on
|
||||||
|
|
||||||
|
class ONNXSD1ModelLoaderInvocation(BaseInvocation):
|
||||||
|
"""Loading submodels of selected model."""
|
||||||
|
|
||||||
|
type: Literal["sd1_model_loader_onnx"] = "sd1_model_loader_onnx"
|
||||||
|
|
||||||
|
model_name: str = Field(default="", description="Model to load")
|
||||||
|
# TODO: precision?
|
||||||
|
|
||||||
|
# Schema customisation
|
||||||
|
class Config(InvocationConfig):
|
||||||
|
schema_extra = {
|
||||||
|
"ui": {
|
||||||
|
"tags": ["model", "loader"],
|
||||||
|
"type_hints": {
|
||||||
|
"model_name": "model" # TODO: rename to model_name?
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ONNXModelLoaderOutput:
|
||||||
|
|
||||||
|
model_name = "stable-diffusion-v1-5"
|
||||||
|
base_model = BaseModelType.StableDiffusion1
|
||||||
|
|
||||||
|
# TODO: not found exceptions
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=BaseModelType.StableDiffusion1,
|
||||||
|
model_type=ModelType.ONNX,
|
||||||
|
):
|
||||||
|
raise Exception(f"Unkown model name: {model_name}!")
|
||||||
|
|
||||||
|
|
||||||
|
return ONNXModelLoaderOutput(
|
||||||
|
unet=UNetField(
|
||||||
|
unet=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=ModelType.ONNX,
|
||||||
|
submodel=SubModelType.UNet,
|
||||||
|
),
|
||||||
|
scheduler=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=ModelType.ONNX,
|
||||||
|
submodel=SubModelType.Scheduler,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
),
|
||||||
|
clip=ClipField(
|
||||||
|
tokenizer=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=ModelType.ONNX,
|
||||||
|
submodel=SubModelType.Tokenizer,
|
||||||
|
),
|
||||||
|
text_encoder=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=ModelType.ONNX,
|
||||||
|
submodel=SubModelType.TextEncoder,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
),
|
||||||
|
vae_decoder=VaeField(
|
||||||
|
vae=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=ModelType.ONNX,
|
||||||
|
submodel=SubModelType.VaeDecoder,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
vae_encoder=VaeField(
|
||||||
|
vae=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=ModelType.ONNX,
|
||||||
|
submodel=SubModelType.VaeEncoder,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
@ -11,6 +11,8 @@ from torch.utils.hooks import RemovableHandle
|
|||||||
|
|
||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
from transformers import CLIPTextModel
|
from transformers import CLIPTextModel
|
||||||
|
from onnx import numpy_helper
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from compel.embeddings_provider import BaseTextualInversionManager
|
from compel.embeddings_provider import BaseTextualInversionManager
|
||||||
|
|
||||||
@ -70,7 +72,7 @@ class LoRALayerBase:
|
|||||||
op = torch.nn.functional.linear
|
op = torch.nn.functional.linear
|
||||||
extra_args = {}
|
extra_args = {}
|
||||||
|
|
||||||
weight = self.get_weight(module)
|
weight = self.get_weight()
|
||||||
|
|
||||||
bias = self.bias if self.bias is not None else 0
|
bias = self.bias if self.bias is not None else 0
|
||||||
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||||
@ -81,7 +83,7 @@ class LoRALayerBase:
|
|||||||
**extra_args,
|
**extra_args,
|
||||||
) * multiplier * scale
|
) * multiplier * scale
|
||||||
|
|
||||||
def get_weight(self, module: torch.nn.Module):
|
def get_weight(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
@ -122,7 +124,7 @@ class LoRALayer(LoRALayerBase):
|
|||||||
|
|
||||||
self.rank = self.down.shape[0]
|
self.rank = self.down.shape[0]
|
||||||
|
|
||||||
def get_weight(self, module: torch.nn.Module):
|
def get_weight(self):
|
||||||
if self.mid is not None:
|
if self.mid is not None:
|
||||||
up = self.up.reshape(up.shape[0], up.shape[1])
|
up = self.up.reshape(up.shape[0], up.shape[1])
|
||||||
down = self.down.reshape(up.shape[0], up.shape[1])
|
down = self.down.reshape(up.shape[0], up.shape[1])
|
||||||
@ -185,7 +187,7 @@ class LoHALayer(LoRALayerBase):
|
|||||||
|
|
||||||
self.rank = self.w1_b.shape[0]
|
self.rank = self.w1_b.shape[0]
|
||||||
|
|
||||||
def get_weight(self, module: torch.nn.Module):
|
def get_weight(self):
|
||||||
if self.t1 is None:
|
if self.t1 is None:
|
||||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||||
|
|
||||||
@ -271,7 +273,7 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
else:
|
else:
|
||||||
self.rank = None # unscaled
|
self.rank = None # unscaled
|
||||||
|
|
||||||
def get_weight(self, module: torch.nn.Module):
|
def get_weight(self):
|
||||||
w1 = self.w1
|
w1 = self.w1
|
||||||
if w1 is None:
|
if w1 is None:
|
||||||
w1 = self.w1_a @ self.w1_b
|
w1 = self.w1_a @ self.w1_b
|
||||||
@ -286,7 +288,7 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
if len(w2.shape) == 4:
|
if len(w2.shape) == 4:
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
w2 = w2.contiguous()
|
w2 = w2.contiguous()
|
||||||
weight = torch.kron(w1, w2).reshape(module.weight.shape) # TODO: can we remove reshape?
|
weight = torch.kron(w1, w2)#.reshape(module.weight.shape) # TODO: can we remove reshape?
|
||||||
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
@ -676,3 +678,212 @@ class TextualInversionManager(BaseTextualInversionManager):
|
|||||||
|
|
||||||
return new_token_ids
|
return new_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXModelPatcher:
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def apply_lora_unet(
|
||||||
|
cls,
|
||||||
|
unet: OnnxRuntimeModel,
|
||||||
|
loras: List[Tuple[LoRAModel, float]],
|
||||||
|
):
|
||||||
|
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def apply_lora_text_encoder(
|
||||||
|
cls,
|
||||||
|
text_encoder: OnnxRuntimeModel,
|
||||||
|
loras: List[Tuple[LoRAModel, float]],
|
||||||
|
):
|
||||||
|
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def apply_lora(
|
||||||
|
cls,
|
||||||
|
model: IAIOnnxRuntimeModel,
|
||||||
|
loras: List[Tuple[LoraModel, float]],
|
||||||
|
prefix: str,
|
||||||
|
):
|
||||||
|
from .models.base import IAIOnnxRuntimeModel
|
||||||
|
if not isinstance(model, IAIOnnxRuntimeModel):
|
||||||
|
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||||
|
|
||||||
|
base_model = model.proto
|
||||||
|
orig_nodes = dict()
|
||||||
|
|
||||||
|
try:
|
||||||
|
blended_loras = dict()
|
||||||
|
|
||||||
|
for lora, lora_weight in loras:
|
||||||
|
for layer_key, layer in lora.layers.items():
|
||||||
|
if not layer_key.startswith(prefix):
|
||||||
|
continue
|
||||||
|
|
||||||
|
layer_key = layer_key.replace(prefix, "")
|
||||||
|
layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight
|
||||||
|
if layer_key is blended_loras:
|
||||||
|
blended_loras[layer_key] += layer_weight
|
||||||
|
else:
|
||||||
|
blended_loras[layer_key] = layer_weight
|
||||||
|
|
||||||
|
initializer_idx = dict()
|
||||||
|
for idx, init in enumerate(base_model.graph.initializer):
|
||||||
|
initializer_idx[init.name.replace(".", "_")] = idx
|
||||||
|
|
||||||
|
node_idx = dict()
|
||||||
|
for idx, node in enumerate(base_model.graph.node):
|
||||||
|
node_idx[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = idx
|
||||||
|
|
||||||
|
for layer_key, weights in blended_loras.items():
|
||||||
|
conv_key = layer_key + "_Conv"
|
||||||
|
gemm_key = layer_key + "_Gemm"
|
||||||
|
matmul_key = layer_key + "_MatMul"
|
||||||
|
|
||||||
|
if conv_key in node_idx or gemm_key in node_idx:
|
||||||
|
if conv_key in node_idx:
|
||||||
|
conv_node = base_model.graph.node[node_idx[conv_key]]
|
||||||
|
else:
|
||||||
|
conv_node = base_model.graph.node[node_idx[gemm_key]]
|
||||||
|
|
||||||
|
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
||||||
|
weight_name = weight_name.replace(".", "_")
|
||||||
|
|
||||||
|
weight_idx = initializer_idx[weight_name]
|
||||||
|
weight_node = base_model.graph.initializer[weight_idx]
|
||||||
|
|
||||||
|
orig_weights = numpy_helper.to_array(weight_node)
|
||||||
|
|
||||||
|
if orig_weights.shape[-2:] == (1, 1):
|
||||||
|
if weights.shape[-2:] == (1, 1):
|
||||||
|
new_weights = orig_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
||||||
|
else:
|
||||||
|
new_weights = orig_weights.squeeze((3, 2)) + weights
|
||||||
|
|
||||||
|
new_weights = np.expand_dims(new_weights, (2, 3))
|
||||||
|
else:
|
||||||
|
if orig_weights.shape != weights.shape:
|
||||||
|
new_weights = orig_weights + weights.reshape(orig_weights.shape)
|
||||||
|
else:
|
||||||
|
new_weights = orig_weights + weights
|
||||||
|
|
||||||
|
new_node = numpy_helper.from_array(new_weights.astype(orig_weights.dtype), weight_node.name)
|
||||||
|
orig_nodes[weight_idx] = base_model.graph.initializer[weight_idx]
|
||||||
|
del base_model.graph.initializer[weight_idx]
|
||||||
|
base_model.graph.initializer.insert(weight_idx, new_node)
|
||||||
|
|
||||||
|
elif matmul_key in node_idx:
|
||||||
|
weight_node = base_model.graph.node[node_idx[matmul_key]]
|
||||||
|
|
||||||
|
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
||||||
|
|
||||||
|
matmul_idx = initializer_idx[matmul_name]
|
||||||
|
matmul_node = base_model.graph.initializer[matmul_idx]
|
||||||
|
|
||||||
|
orig_weights = numpy_helper.to_array(matmul_node)
|
||||||
|
|
||||||
|
new_weights = orig_weights + weights.transpose()
|
||||||
|
|
||||||
|
# replace the original initializer
|
||||||
|
new_node = numpy_helper.from_array(new_weights.astype(orig_weights.dtype), matmul_node.name)
|
||||||
|
orig_nodes[matmul_idx] = base_model.graph.initializer[matmul_idx]
|
||||||
|
del base_model.graph.initializer[matmul_idx]
|
||||||
|
base_model.graph.initializer.insert(matmul_idx, new_node)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# warn? err?
|
||||||
|
pass
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# restore original weights
|
||||||
|
for idx, orig_node in orig_nodes.items():
|
||||||
|
del base_model.graph.initializer[idx]
|
||||||
|
base_model.graph.initializer.insert(idx, orig_node)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@contextmanager
|
||||||
|
def apply_ti(
|
||||||
|
cls,
|
||||||
|
tokenizer: CLIPTokenizer,
|
||||||
|
text_encoder: IAIOnnxRuntimeModel,
|
||||||
|
ti_list: List[Any],
|
||||||
|
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||||
|
from .models.base import IAIOnnxRuntimeModel
|
||||||
|
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
|
||||||
|
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||||
|
|
||||||
|
init_tokens_count = None
|
||||||
|
new_tokens_added = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||||
|
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||||
|
|
||||||
|
def _get_trigger(ti, index):
|
||||||
|
trigger = ti.name
|
||||||
|
if index > 0:
|
||||||
|
trigger += f"-!pad-{i}"
|
||||||
|
return f"<{trigger}>"
|
||||||
|
|
||||||
|
# modify tokenizer
|
||||||
|
new_tokens_added = 0
|
||||||
|
for ti in ti_list:
|
||||||
|
for i in range(ti.embedding.shape[0]):
|
||||||
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
||||||
|
|
||||||
|
# modify text_encoder
|
||||||
|
for i in range(len(text_encoder.proto.graph.initializer)):
|
||||||
|
if text_encoder.proto.graph.initializer[i].name == "text_model.embeddings.token_embedding.weight":
|
||||||
|
embeddings_node_idx = i
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise Exception("text_model.embeddings.token_embedding.weight node not found")
|
||||||
|
|
||||||
|
embeddings_node_orig = text_encoder.proto.graph.initializer[embeddings_node_idx]
|
||||||
|
base_weights = numpy_helper.to_array(embeddings_node_orig)
|
||||||
|
|
||||||
|
embedding_weights = np.concatenate((base_weights, np.zeros((new_tokens_added, base_weights.shape[1]))), axis=0)
|
||||||
|
|
||||||
|
for ti in ti_list:
|
||||||
|
ti_tokens = []
|
||||||
|
for i in range(ti.embedding.shape[0]):
|
||||||
|
embedding = ti.embedding[i].detach().numpy()
|
||||||
|
trigger = _get_trigger(ti, i)
|
||||||
|
|
||||||
|
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||||
|
if token_id == ti_tokenizer.unk_token_id:
|
||||||
|
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
||||||
|
|
||||||
|
if embedding_weights[token_id].shape != embedding.shape:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embedding_weights[token_id].shape[0]}."
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_weights[token_id] = embedding
|
||||||
|
ti_tokens.append(token_id)
|
||||||
|
|
||||||
|
if len(ti_tokens) > 1:
|
||||||
|
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
||||||
|
|
||||||
|
|
||||||
|
new_embeddings_node = numpy_helper.from_array(embedding_weights.astype(base_weights.dtype), embeddings_node_orig.name)
|
||||||
|
del text_encoder.proto.graph.initializer[embeddings_node_idx]
|
||||||
|
text_encoder.proto.graph.initializer.insert(embeddings_node_idx, new_embeddings_node)
|
||||||
|
|
||||||
|
yield ti_tokenizer, ti_manager
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# restore
|
||||||
|
if embeddings_node_orig is not None:
|
||||||
|
del text_encoder.proto.graph.initializer[embeddings_node_idx]
|
||||||
|
text_encoder.proto.graph.initializer.insert(embeddings_node_idx, embeddings_node_orig)
|
||||||
|
@ -9,9 +9,12 @@ from .lora import LoRAModel
|
|||||||
from .controlnet import ControlNetModel # TODO:
|
from .controlnet import ControlNetModel # TODO:
|
||||||
from .textual_inversion import TextualInversionModel
|
from .textual_inversion import TextualInversionModel
|
||||||
|
|
||||||
|
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
|
||||||
|
|
||||||
MODEL_CLASSES = {
|
MODEL_CLASSES = {
|
||||||
BaseModelType.StableDiffusion1: {
|
BaseModelType.StableDiffusion1: {
|
||||||
ModelType.Pipeline: StableDiffusion1Model,
|
ModelType.Pipeline: StableDiffusion1Model,
|
||||||
|
ModelType.ONNX: ONNXStableDiffusion1Model,
|
||||||
ModelType.Vae: VaeModel,
|
ModelType.Vae: VaeModel,
|
||||||
ModelType.Lora: LoRAModel,
|
ModelType.Lora: LoRAModel,
|
||||||
ModelType.ControlNet: ControlNetModel,
|
ModelType.ControlNet: ControlNetModel,
|
||||||
@ -19,6 +22,7 @@ MODEL_CLASSES = {
|
|||||||
},
|
},
|
||||||
BaseModelType.StableDiffusion2: {
|
BaseModelType.StableDiffusion2: {
|
||||||
ModelType.Pipeline: StableDiffusion2Model,
|
ModelType.Pipeline: StableDiffusion2Model,
|
||||||
|
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||||
ModelType.Vae: VaeModel,
|
ModelType.Vae: VaeModel,
|
||||||
ModelType.Lora: LoRAModel,
|
ModelType.Lora: LoRAModel,
|
||||||
ModelType.ControlNet: ControlNetModel,
|
ModelType.ControlNet: ControlNetModel,
|
||||||
|
@ -5,19 +5,27 @@ import inspect
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
import torch
|
import torch
|
||||||
|
import numpy as np
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
from diffusers import DiffusionPipeline, ConfigMixin
|
from pathlib import Path
|
||||||
|
from diffusers import DiffusionPipeline, ConfigMixin, OnnxRuntimeModel
|
||||||
|
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||||
|
|
||||||
|
import onnx
|
||||||
|
from onnx import numpy_helper
|
||||||
|
from onnx.external_data_helper import set_external_data
|
||||||
|
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
||||||
|
|
||||||
class BaseModelType(str, Enum):
|
class BaseModelType(str, Enum):
|
||||||
StableDiffusion1 = "sd-1"
|
StableDiffusion1 = "sd-1"
|
||||||
StableDiffusion2 = "sd-2"
|
StableDiffusion2 = "sd-2"
|
||||||
#Kandinsky2_1 = "kandinsky-2.1"
|
#Kandinsky2_1 = "kandinsky-2.1"
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
class ModelType(str, Enum):
|
||||||
|
ONNX = "onnx"
|
||||||
Pipeline = "pipeline"
|
Pipeline = "pipeline"
|
||||||
Vae = "vae"
|
Vae = "vae"
|
||||||
Lora = "lora"
|
Lora = "lora"
|
||||||
@ -29,6 +37,8 @@ class SubModelType(str, Enum):
|
|||||||
TextEncoder = "text_encoder"
|
TextEncoder = "text_encoder"
|
||||||
Tokenizer = "tokenizer"
|
Tokenizer = "tokenizer"
|
||||||
Vae = "vae"
|
Vae = "vae"
|
||||||
|
VaeDecoder = "vae_decoder"
|
||||||
|
VaeEncoder = "vae_encoder"
|
||||||
Scheduler = "scheduler"
|
Scheduler = "scheduler"
|
||||||
SafetyChecker = "safety_checker"
|
SafetyChecker = "safety_checker"
|
||||||
#MoVQ = "movq"
|
#MoVQ = "movq"
|
||||||
@ -240,16 +250,18 @@ class DiffusersModel(ModelBase):
|
|||||||
try:
|
try:
|
||||||
# TODO: set cache_dir to /dev/null to be sure that cache not used?
|
# TODO: set cache_dir to /dev/null to be sure that cache not used?
|
||||||
model = self.child_types[child_type].from_pretrained(
|
model = self.child_types[child_type].from_pretrained(
|
||||||
self.model_path,
|
os.path.join(self.model_path, child_type.value),
|
||||||
subfolder=child_type.value,
|
#subfolder=child_type.value,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
variant=variant,
|
variant=variant,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
#print("====ERR LOAD====")
|
print("====ERR LOAD====")
|
||||||
#print(f"{variant}: {e}")
|
print(f"{variant}: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
|
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
|
||||||
@ -413,3 +425,92 @@ class SilenceWarnings(object):
|
|||||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||||
warnings.simplefilter('default')
|
warnings.simplefilter('default')
|
||||||
|
|
||||||
|
def buffer_external_data_tensors(model):
|
||||||
|
external_data = dict()
|
||||||
|
for tensor in model.graph.initializer:
|
||||||
|
name = tensor.name
|
||||||
|
|
||||||
|
if tensor.HasField("raw_data"):
|
||||||
|
npt = numpy_helper.to_array(tensor)
|
||||||
|
orv = OrtValue.ortvalue_from_numpy(npt)
|
||||||
|
external_data[name] = orv
|
||||||
|
set_external_data(tensor, location="tmp.bin")
|
||||||
|
tensor.name = name
|
||||||
|
tensor.ClearField("raw_data")
|
||||||
|
|
||||||
|
return (model, external_data)
|
||||||
|
|
||||||
|
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||||
|
class IAIOnnxRuntimeModel(OnnxRuntimeModel):
|
||||||
|
def __init__(self, model: tuple, **kwargs):
|
||||||
|
self.proto, self.provider, self.sess_options = model
|
||||||
|
self.session = None
|
||||||
|
self._external_data = dict()
|
||||||
|
|
||||||
|
def __call__(self, **kwargs):
|
||||||
|
if self.session is None:
|
||||||
|
raise Exception("You should call create_session before running model")
|
||||||
|
|
||||||
|
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
||||||
|
return self.session.run(None, inputs)
|
||||||
|
|
||||||
|
def create_session(self):
|
||||||
|
if self.session is None:
|
||||||
|
#onnx.save(self.proto, "tmp.onnx")
|
||||||
|
#onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
|
||||||
|
(trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
|
||||||
|
sess = SessionOptions()
|
||||||
|
self._external_data.update(**external_data)
|
||||||
|
sess.add_external_initializers(list(self._external_data.keys()), list(self._external_data.values()))
|
||||||
|
self.session = InferenceSession(trimmed_model.SerializeToString(), providers=[self.provider], sess_options=sess)
|
||||||
|
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
||||||
|
|
||||||
|
def release_session(self):
|
||||||
|
self.session = None
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def load_model(path: Union[str, Path], provider=None, sess_options=None):
|
||||||
|
"""
|
||||||
|
Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
path (`str` or `Path`):
|
||||||
|
Directory from which to load
|
||||||
|
provider(`str`, *optional*):
|
||||||
|
Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
|
||||||
|
"""
|
||||||
|
if provider is None:
|
||||||
|
#logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
|
||||||
|
print("No onnxruntime provider specified, using CPUExecutionProvider")
|
||||||
|
provider = "CPUExecutionProvider"
|
||||||
|
|
||||||
|
# TODO: check that provider available?
|
||||||
|
return (onnx.load(path), provider, sess_options)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _from_pretrained(
|
||||||
|
cls,
|
||||||
|
model_id: Union[str, Path],
|
||||||
|
use_auth_token: Optional[Union[bool, str, None]] = None,
|
||||||
|
revision: Optional[Union[str, None]] = None,
|
||||||
|
force_download: bool = False,
|
||||||
|
cache_dir: Optional[str] = None,
|
||||||
|
file_name: Optional[str] = None,
|
||||||
|
provider: Optional[str] = None,
|
||||||
|
sess_options: Optional["SessionOptions"] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
|
||||||
|
# load model from local directory
|
||||||
|
if not os.path.isdir(model_id):
|
||||||
|
raise Exception(f"Model not found: {model_id}")
|
||||||
|
model = IAIOnnxRuntimeModel.load_model(
|
||||||
|
os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(model=model, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -0,0 +1,156 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
from pydantic import Field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal, Optional, Union
|
||||||
|
from .base import (
|
||||||
|
ModelBase,
|
||||||
|
ModelConfigBase,
|
||||||
|
BaseModelType,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
ModelVariantType,
|
||||||
|
DiffusersModel,
|
||||||
|
SchedulerPredictionType,
|
||||||
|
SilenceWarnings,
|
||||||
|
read_checkpoint_meta,
|
||||||
|
classproperty,
|
||||||
|
OnnxRuntimeModel,
|
||||||
|
IAIOnnxRuntimeModel,
|
||||||
|
)
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
|
||||||
|
class ONNXStableDiffusion1Model(DiffusersModel):
|
||||||
|
|
||||||
|
class Config(ModelConfigBase):
|
||||||
|
model_format: None
|
||||||
|
variant: ModelVariantType
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
assert base_model == BaseModelType.StableDiffusion1
|
||||||
|
assert model_type == ModelType.ONNX
|
||||||
|
super().__init__(
|
||||||
|
model_path=model_path,
|
||||||
|
base_model=BaseModelType.StableDiffusion1,
|
||||||
|
model_type=ModelType.ONNX,
|
||||||
|
)
|
||||||
|
|
||||||
|
for child_name, child_type in self.child_types.items():
|
||||||
|
if child_type is OnnxRuntimeModel:
|
||||||
|
self.child_types[child_name] = IAIOnnxRuntimeModel
|
||||||
|
|
||||||
|
# TODO: check that no optimum models provided
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def probe_config(cls, path: str, **kwargs):
|
||||||
|
model_format = cls.detect_format(path)
|
||||||
|
in_channels = 4 # TODO:
|
||||||
|
|
||||||
|
if in_channels == 9:
|
||||||
|
variant = ModelVariantType.Inpaint
|
||||||
|
elif in_channels == 4:
|
||||||
|
variant = ModelVariantType.Normal
|
||||||
|
else:
|
||||||
|
raise Exception("Unkown stable diffusion 1.* model format")
|
||||||
|
|
||||||
|
|
||||||
|
return cls.create_config(
|
||||||
|
path=path,
|
||||||
|
model_format=model_format,
|
||||||
|
|
||||||
|
variant=variant,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classproperty
|
||||||
|
def save_to_config(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def detect_format(cls, model_path: str):
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_if_required(
|
||||||
|
cls,
|
||||||
|
model_path: str,
|
||||||
|
output_path: str,
|
||||||
|
config: ModelConfigBase,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
) -> str:
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
class ONNXStableDiffusion2Model(DiffusersModel):
|
||||||
|
|
||||||
|
# TODO: check that configs overwriten properly
|
||||||
|
class Config(ModelConfigBase):
|
||||||
|
model_format: None
|
||||||
|
variant: ModelVariantType
|
||||||
|
prediction_type: SchedulerPredictionType
|
||||||
|
upcast_attention: bool
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
assert base_model == BaseModelType.StableDiffusion2
|
||||||
|
assert model_type == ModelType.ONNX
|
||||||
|
super().__init__(
|
||||||
|
model_path=model_path,
|
||||||
|
base_model=BaseModelType.StableDiffusion2,
|
||||||
|
model_type=ModelType.ONNX,
|
||||||
|
)
|
||||||
|
|
||||||
|
for child_name, child_type in self.child_types.items():
|
||||||
|
if child_type is OnnxRuntimeModel:
|
||||||
|
self.child_types[child_name] = IAIOnnxRuntimeModel
|
||||||
|
# TODO: check that no optimum models provided
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def probe_config(cls, path: str, **kwargs):
|
||||||
|
model_format = cls.detect_format(path)
|
||||||
|
in_channels = 4 # TODO:
|
||||||
|
|
||||||
|
if in_channels == 9:
|
||||||
|
variant = ModelVariantType.Inpaint
|
||||||
|
elif in_channels == 5:
|
||||||
|
variant = ModelVariantType.Depth
|
||||||
|
elif in_channels == 4:
|
||||||
|
variant = ModelVariantType.Normal
|
||||||
|
else:
|
||||||
|
raise Exception("Unkown stable diffusion 2.* model format")
|
||||||
|
|
||||||
|
if variant == ModelVariantType.Normal:
|
||||||
|
prediction_type = SchedulerPredictionType.VPrediction
|
||||||
|
upcast_attention = True
|
||||||
|
|
||||||
|
else:
|
||||||
|
prediction_type = SchedulerPredictionType.Epsilon
|
||||||
|
upcast_attention = False
|
||||||
|
|
||||||
|
return cls.create_config(
|
||||||
|
path=path,
|
||||||
|
model_format=model_format,
|
||||||
|
|
||||||
|
variant=variant,
|
||||||
|
prediction_type=prediction_type,
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classproperty
|
||||||
|
def save_to_config(cls) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def detect_format(cls, model_path: str):
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_if_required(
|
||||||
|
cls,
|
||||||
|
model_path: str,
|
||||||
|
output_path: str,
|
||||||
|
config: ModelConfigBase,
|
||||||
|
base_model: BaseModelType,
|
||||||
|
) -> str:
|
||||||
|
return model_path
|
||||||
|
|
Loading…
Reference in New Issue
Block a user