Add onnx models to the model manager UI

This commit is contained in:
Brandon Rising
2023-07-27 09:37:37 -04:00
parent 4d732e06de
commit 024f92f9a9
12 changed files with 465 additions and 109 deletions

View File

@ -23,7 +23,7 @@ class ModelProbeInfo(object):
variant_type: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
format: Literal['diffusers','checkpoint', 'lycoris', 'olive']
format: Literal['diffusers','checkpoint', 'lycoris', 'olive', 'onnx']
image_size: int
class ProbeBase(object):

View File

@ -21,10 +21,15 @@ from .base import (
)
from invokeai.app.services.config import InvokeAIAppConfig
class StableDiffusionOnnxModelFormat(str, Enum):
Olive = "olive"
Onnx = "onnx"
class ONNXStableDiffusion1Model(DiffusersModel):
class Config(ModelConfigBase):
model_format: None
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
variant: ModelVariantType
@ -69,7 +74,8 @@ class ONNXStableDiffusion1Model(DiffusersModel):
@classmethod
def detect_format(cls, model_path: str):
return None
# TODO: Detect onnx vs olive
return StableDiffusionOnnxModelFormat.Onnx
@classmethod
def convert_if_required(
@ -85,7 +91,7 @@ class ONNXStableDiffusion2Model(DiffusersModel):
# TODO: check that configs overwriten properly
class Config(ModelConfigBase):
model_format: None
model_format: Literal[StableDiffusionOnnxModelFormat.Onnx]
variant: ModelVariantType
prediction_type: SchedulerPredictionType
upcast_attention: bool
@ -142,7 +148,8 @@ class ONNXStableDiffusion2Model(DiffusersModel):
@classmethod
def detect_format(cls, model_path: str):
return None
# TODO: Detect onnx vs olive
return StableDiffusionOnnxModelFormat.Onnx
@classmethod
def convert_if_required(