ONNX Model/runtime first implementation

This commit is contained in:
Sergey Borisov
2023-06-21 02:12:21 +03:00
parent 92c86fd0b8
commit 4d337f6abc
7 changed files with 935 additions and 16 deletions

View File

@ -9,9 +9,12 @@ from .lora import LoRAModel
from .controlnet import ControlNetModel # TODO:
from .textual_inversion import TextualInversionModel
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
MODEL_CLASSES = {
BaseModelType.StableDiffusion1: {
ModelType.Pipeline: StableDiffusion1Model,
ModelType.ONNX: ONNXStableDiffusion1Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
@ -19,6 +22,7 @@ MODEL_CLASSES = {
},
BaseModelType.StableDiffusion2: {
ModelType.Pipeline: StableDiffusion2Model,
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,

View File

@ -5,19 +5,27 @@ import inspect
from enum import Enum
from abc import ABCMeta, abstractmethod
import torch
import numpy as np
import safetensors.torch
from diffusers import DiffusionPipeline, ConfigMixin
from pathlib import Path
from diffusers import DiffusionPipeline, ConfigMixin, OnnxRuntimeModel
from contextlib import suppress
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
import onnx
from onnx import numpy_helper
from onnx.external_data_helper import set_external_data
from onnxruntime import InferenceSession, OrtValue, SessionOptions
class BaseModelType(str, Enum):
StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2"
#Kandinsky2_1 = "kandinsky-2.1"
class ModelType(str, Enum):
ONNX = "onnx"
Pipeline = "pipeline"
Vae = "vae"
Lora = "lora"
@ -29,6 +37,8 @@ class SubModelType(str, Enum):
TextEncoder = "text_encoder"
Tokenizer = "tokenizer"
Vae = "vae"
VaeDecoder = "vae_decoder"
VaeEncoder = "vae_encoder"
Scheduler = "scheduler"
SafetyChecker = "safety_checker"
#MoVQ = "movq"
@ -240,16 +250,18 @@ class DiffusersModel(ModelBase):
try:
# TODO: set cache_dir to /dev/null to be sure that cache not used?
model = self.child_types[child_type].from_pretrained(
self.model_path,
subfolder=child_type.value,
os.path.join(self.model_path, child_type.value),
#subfolder=child_type.value,
torch_dtype=torch_dtype,
variant=variant,
local_files_only=True,
)
break
except Exception as e:
#print("====ERR LOAD====")
#print(f"{variant}: {e}")
print("====ERR LOAD====")
print(f"{variant}: {e}")
import traceback
traceback.print_exc()
pass
else:
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
@ -413,3 +425,92 @@ class SilenceWarnings(object):
transformers_logging.set_verbosity(self.transformers_verbosity)
diffusers_logging.set_verbosity(self.diffusers_verbosity)
warnings.simplefilter('default')
def buffer_external_data_tensors(model):
external_data = dict()
for tensor in model.graph.initializer:
name = tensor.name
if tensor.HasField("raw_data"):
npt = numpy_helper.to_array(tensor)
orv = OrtValue.ortvalue_from_numpy(npt)
external_data[name] = orv
set_external_data(tensor, location="tmp.bin")
tensor.name = name
tensor.ClearField("raw_data")
return (model, external_data)
ONNX_WEIGHTS_NAME = "model.onnx"
class IAIOnnxRuntimeModel(OnnxRuntimeModel):
def __init__(self, model: tuple, **kwargs):
self.proto, self.provider, self.sess_options = model
self.session = None
self._external_data = dict()
def __call__(self, **kwargs):
if self.session is None:
raise Exception("You should call create_session before running model")
inputs = {k: np.array(v) for k, v in kwargs.items()}
return self.session.run(None, inputs)
def create_session(self):
if self.session is None:
#onnx.save(self.proto, "tmp.onnx")
#onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
(trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
sess = SessionOptions()
self._external_data.update(**external_data)
sess.add_external_initializers(list(self._external_data.keys()), list(self._external_data.values()))
self.session = InferenceSession(trimmed_model.SerializeToString(), providers=[self.provider], sess_options=sess)
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
def release_session(self):
self.session = None
import gc
gc.collect()
@staticmethod
def load_model(path: Union[str, Path], provider=None, sess_options=None):
"""
Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
Arguments:
path (`str` or `Path`):
Directory from which to load
provider(`str`, *optional*):
Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
"""
if provider is None:
#logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
print("No onnxruntime provider specified, using CPUExecutionProvider")
provider = "CPUExecutionProvider"
# TODO: check that provider available?
return (onnx.load(path), provider, sess_options)
@classmethod
def _from_pretrained(
cls,
model_id: Union[str, Path],
use_auth_token: Optional[Union[bool, str, None]] = None,
revision: Optional[Union[str, None]] = None,
force_download: bool = False,
cache_dir: Optional[str] = None,
file_name: Optional[str] = None,
provider: Optional[str] = None,
sess_options: Optional["SessionOptions"] = None,
**kwargs,
):
model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
# load model from local directory
if not os.path.isdir(model_id):
raise Exception(f"Model not found: {model_id}")
model = IAIOnnxRuntimeModel.load_model(
os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options
)
return cls(model=model, **kwargs)

View File

@ -0,0 +1,156 @@
import os
import json
from enum import Enum
from pydantic import Field
from pathlib import Path
from typing import Literal, Optional, Union
from .base import (
ModelBase,
ModelConfigBase,
BaseModelType,
ModelType,
SubModelType,
ModelVariantType,
DiffusersModel,
SchedulerPredictionType,
SilenceWarnings,
read_checkpoint_meta,
classproperty,
OnnxRuntimeModel,
IAIOnnxRuntimeModel,
)
from invokeai.app.services.config import InvokeAIAppConfig
class ONNXStableDiffusion1Model(DiffusersModel):
class Config(ModelConfigBase):
model_format: None
variant: ModelVariantType
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion1
assert model_type == ModelType.ONNX
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion1,
model_type=ModelType.ONNX,
)
for child_name, child_type in self.child_types.items():
if child_type is OnnxRuntimeModel:
self.child_types[child_name] = IAIOnnxRuntimeModel
# TODO: check that no optimum models provided
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
in_channels = 4 # TODO:
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 1.* model format")
return cls.create_config(
path=path,
model_format=model_format,
variant=variant,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
return None
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
return model_path
class ONNXStableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
class Config(ModelConfigBase):
model_format: None
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert base_model == BaseModelType.StableDiffusion2
assert model_type == ModelType.ONNX
super().__init__(
model_path=model_path,
base_model=BaseModelType.StableDiffusion2,
model_type=ModelType.ONNX,
)
for child_name, child_type in self.child_types.items():
if child_type is OnnxRuntimeModel:
self.child_types[child_name] = IAIOnnxRuntimeModel
# TODO: check that no optimum models provided
@classmethod
def probe_config(cls, path: str, **kwargs):
model_format = cls.detect_format(path)
in_channels = 4 # TODO:
if in_channels == 9:
variant = ModelVariantType.Inpaint
elif in_channels == 5:
variant = ModelVariantType.Depth
elif in_channels == 4:
variant = ModelVariantType.Normal
else:
raise Exception("Unkown stable diffusion 2.* model format")
if variant == ModelVariantType.Normal:
prediction_type = SchedulerPredictionType.VPrediction
upcast_attention = True
else:
prediction_type = SchedulerPredictionType.Epsilon
upcast_attention = False
return cls.create_config(
path=path,
model_format=model_format,
variant=variant,
prediction_type=prediction_type,
upcast_attention=upcast_attention,
)
@classproperty
def save_to_config(cls) -> bool:
return True
@classmethod
def detect_format(cls, model_path: str):
return None
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
return model_path