mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
ONNX Model/runtime first implementation
This commit is contained in:
@ -11,6 +11,8 @@ from torch.utils.hooks import RemovableHandle
|
||||
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from onnx import numpy_helper
|
||||
import numpy as np
|
||||
|
||||
from compel.embeddings_provider import BaseTextualInversionManager
|
||||
|
||||
@ -70,7 +72,7 @@ class LoRALayerBase:
|
||||
op = torch.nn.functional.linear
|
||||
extra_args = {}
|
||||
|
||||
weight = self.get_weight(module)
|
||||
weight = self.get_weight()
|
||||
|
||||
bias = self.bias if self.bias is not None else 0
|
||||
scale = self.alpha / self.rank if (self.alpha and self.rank) else 1.0
|
||||
@ -81,7 +83,7 @@ class LoRALayerBase:
|
||||
**extra_args,
|
||||
) * multiplier * scale
|
||||
|
||||
def get_weight(self, module: torch.nn.Module):
|
||||
def get_weight(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def calc_size(self) -> int:
|
||||
@ -122,7 +124,7 @@ class LoRALayer(LoRALayerBase):
|
||||
|
||||
self.rank = self.down.shape[0]
|
||||
|
||||
def get_weight(self, module: torch.nn.Module):
|
||||
def get_weight(self):
|
||||
if self.mid is not None:
|
||||
up = self.up.reshape(up.shape[0], up.shape[1])
|
||||
down = self.down.reshape(up.shape[0], up.shape[1])
|
||||
@ -185,7 +187,7 @@ class LoHALayer(LoRALayerBase):
|
||||
|
||||
self.rank = self.w1_b.shape[0]
|
||||
|
||||
def get_weight(self, module: torch.nn.Module):
|
||||
def get_weight(self):
|
||||
if self.t1 is None:
|
||||
weight = (self.w1_a @ self.w1_b) * (self.w2_a @ self.w2_b)
|
||||
|
||||
@ -271,7 +273,7 @@ class LoKRLayer(LoRALayerBase):
|
||||
else:
|
||||
self.rank = None # unscaled
|
||||
|
||||
def get_weight(self, module: torch.nn.Module):
|
||||
def get_weight(self):
|
||||
w1 = self.w1
|
||||
if w1 is None:
|
||||
w1 = self.w1_a @ self.w1_b
|
||||
@ -286,7 +288,7 @@ class LoKRLayer(LoRALayerBase):
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
w2 = w2.contiguous()
|
||||
weight = torch.kron(w1, w2).reshape(module.weight.shape) # TODO: can we remove reshape?
|
||||
weight = torch.kron(w1, w2)#.reshape(module.weight.shape) # TODO: can we remove reshape?
|
||||
|
||||
return weight
|
||||
|
||||
@ -676,3 +678,212 @@ class TextualInversionManager(BaseTextualInversionManager):
|
||||
|
||||
return new_token_ids
|
||||
|
||||
|
||||
class ONNXModelPatcher:
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_unet(
|
||||
cls,
|
||||
unet: OnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
):
|
||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
||||
yield
|
||||
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora_text_encoder(
|
||||
cls,
|
||||
text_encoder: OnnxRuntimeModel,
|
||||
loras: List[Tuple[LoRAModel, float]],
|
||||
):
|
||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
||||
yield
|
||||
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_lora(
|
||||
cls,
|
||||
model: IAIOnnxRuntimeModel,
|
||||
loras: List[Tuple[LoraModel, float]],
|
||||
prefix: str,
|
||||
):
|
||||
from .models.base import IAIOnnxRuntimeModel
|
||||
if not isinstance(model, IAIOnnxRuntimeModel):
|
||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||
|
||||
base_model = model.proto
|
||||
orig_nodes = dict()
|
||||
|
||||
try:
|
||||
blended_loras = dict()
|
||||
|
||||
for lora, lora_weight in loras:
|
||||
for layer_key, layer in lora.layers.items():
|
||||
if not layer_key.startswith(prefix):
|
||||
continue
|
||||
|
||||
layer_key = layer_key.replace(prefix, "")
|
||||
layer_weight = layer.get_weight().detach().cpu().numpy() * lora_weight
|
||||
if layer_key is blended_loras:
|
||||
blended_loras[layer_key] += layer_weight
|
||||
else:
|
||||
blended_loras[layer_key] = layer_weight
|
||||
|
||||
initializer_idx = dict()
|
||||
for idx, init in enumerate(base_model.graph.initializer):
|
||||
initializer_idx[init.name.replace(".", "_")] = idx
|
||||
|
||||
node_idx = dict()
|
||||
for idx, node in enumerate(base_model.graph.node):
|
||||
node_idx[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = idx
|
||||
|
||||
for layer_key, weights in blended_loras.items():
|
||||
conv_key = layer_key + "_Conv"
|
||||
gemm_key = layer_key + "_Gemm"
|
||||
matmul_key = layer_key + "_MatMul"
|
||||
|
||||
if conv_key in node_idx or gemm_key in node_idx:
|
||||
if conv_key in node_idx:
|
||||
conv_node = base_model.graph.node[node_idx[conv_key]]
|
||||
else:
|
||||
conv_node = base_model.graph.node[node_idx[gemm_key]]
|
||||
|
||||
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
||||
weight_name = weight_name.replace(".", "_")
|
||||
|
||||
weight_idx = initializer_idx[weight_name]
|
||||
weight_node = base_model.graph.initializer[weight_idx]
|
||||
|
||||
orig_weights = numpy_helper.to_array(weight_node)
|
||||
|
||||
if orig_weights.shape[-2:] == (1, 1):
|
||||
if weights.shape[-2:] == (1, 1):
|
||||
new_weights = orig_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
||||
else:
|
||||
new_weights = orig_weights.squeeze((3, 2)) + weights
|
||||
|
||||
new_weights = np.expand_dims(new_weights, (2, 3))
|
||||
else:
|
||||
if orig_weights.shape != weights.shape:
|
||||
new_weights = orig_weights + weights.reshape(orig_weights.shape)
|
||||
else:
|
||||
new_weights = orig_weights + weights
|
||||
|
||||
new_node = numpy_helper.from_array(new_weights.astype(orig_weights.dtype), weight_node.name)
|
||||
orig_nodes[weight_idx] = base_model.graph.initializer[weight_idx]
|
||||
del base_model.graph.initializer[weight_idx]
|
||||
base_model.graph.initializer.insert(weight_idx, new_node)
|
||||
|
||||
elif matmul_key in node_idx:
|
||||
weight_node = base_model.graph.node[node_idx[matmul_key]]
|
||||
|
||||
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
||||
|
||||
matmul_idx = initializer_idx[matmul_name]
|
||||
matmul_node = base_model.graph.initializer[matmul_idx]
|
||||
|
||||
orig_weights = numpy_helper.to_array(matmul_node)
|
||||
|
||||
new_weights = orig_weights + weights.transpose()
|
||||
|
||||
# replace the original initializer
|
||||
new_node = numpy_helper.from_array(new_weights.astype(orig_weights.dtype), matmul_node.name)
|
||||
orig_nodes[matmul_idx] = base_model.graph.initializer[matmul_idx]
|
||||
del base_model.graph.initializer[matmul_idx]
|
||||
base_model.graph.initializer.insert(matmul_idx, new_node)
|
||||
|
||||
else:
|
||||
# warn? err?
|
||||
pass
|
||||
|
||||
yield
|
||||
|
||||
finally:
|
||||
# restore original weights
|
||||
for idx, orig_node in orig_nodes.items():
|
||||
del base_model.graph.initializer[idx]
|
||||
base_model.graph.initializer.insert(idx, orig_node)
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def apply_ti(
|
||||
cls,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_encoder: IAIOnnxRuntimeModel,
|
||||
ti_list: List[Any],
|
||||
) -> Tuple[CLIPTokenizer, TextualInversionManager]:
|
||||
from .models.base import IAIOnnxRuntimeModel
|
||||
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
|
||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||
|
||||
init_tokens_count = None
|
||||
new_tokens_added = None
|
||||
|
||||
try:
|
||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||
ti_manager = TextualInversionManager(ti_tokenizer)
|
||||
|
||||
def _get_trigger(ti, index):
|
||||
trigger = ti.name
|
||||
if index > 0:
|
||||
trigger += f"-!pad-{i}"
|
||||
return f"<{trigger}>"
|
||||
|
||||
# modify tokenizer
|
||||
new_tokens_added = 0
|
||||
for ti in ti_list:
|
||||
for i in range(ti.embedding.shape[0]):
|
||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
||||
|
||||
# modify text_encoder
|
||||
for i in range(len(text_encoder.proto.graph.initializer)):
|
||||
if text_encoder.proto.graph.initializer[i].name == "text_model.embeddings.token_embedding.weight":
|
||||
embeddings_node_idx = i
|
||||
break
|
||||
else:
|
||||
raise Exception("text_model.embeddings.token_embedding.weight node not found")
|
||||
|
||||
embeddings_node_orig = text_encoder.proto.graph.initializer[embeddings_node_idx]
|
||||
base_weights = numpy_helper.to_array(embeddings_node_orig)
|
||||
|
||||
embedding_weights = np.concatenate((base_weights, np.zeros((new_tokens_added, base_weights.shape[1]))), axis=0)
|
||||
|
||||
for ti in ti_list:
|
||||
ti_tokens = []
|
||||
for i in range(ti.embedding.shape[0]):
|
||||
embedding = ti.embedding[i].detach().numpy()
|
||||
trigger = _get_trigger(ti, i)
|
||||
|
||||
token_id = ti_tokenizer.convert_tokens_to_ids(trigger)
|
||||
if token_id == ti_tokenizer.unk_token_id:
|
||||
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
||||
|
||||
if embedding_weights[token_id].shape != embedding.shape:
|
||||
raise ValueError(
|
||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embedding_weights[token_id].shape[0]}."
|
||||
)
|
||||
|
||||
embedding_weights[token_id] = embedding
|
||||
ti_tokens.append(token_id)
|
||||
|
||||
if len(ti_tokens) > 1:
|
||||
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
||||
|
||||
|
||||
new_embeddings_node = numpy_helper.from_array(embedding_weights.astype(base_weights.dtype), embeddings_node_orig.name)
|
||||
del text_encoder.proto.graph.initializer[embeddings_node_idx]
|
||||
text_encoder.proto.graph.initializer.insert(embeddings_node_idx, new_embeddings_node)
|
||||
|
||||
yield ti_tokenizer, ti_manager
|
||||
|
||||
finally:
|
||||
# restore
|
||||
if embeddings_node_orig is not None:
|
||||
del text_encoder.proto.graph.initializer[embeddings_node_idx]
|
||||
text_encoder.proto.graph.initializer.insert(embeddings_node_idx, embeddings_node_orig)
|
||||
|
@ -9,9 +9,12 @@ from .lora import LoRAModel
|
||||
from .controlnet import ControlNetModel # TODO:
|
||||
from .textual_inversion import TextualInversionModel
|
||||
|
||||
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
|
||||
|
||||
MODEL_CLASSES = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
ModelType.Pipeline: StableDiffusion1Model,
|
||||
ModelType.ONNX: ONNXStableDiffusion1Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
@ -19,6 +22,7 @@ MODEL_CLASSES = {
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelType.Pipeline: StableDiffusion2Model,
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
|
@ -5,19 +5,27 @@ import inspect
|
||||
from enum import Enum
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import torch
|
||||
import numpy as np
|
||||
import safetensors.torch
|
||||
from diffusers import DiffusionPipeline, ConfigMixin
|
||||
from pathlib import Path
|
||||
from diffusers import DiffusionPipeline, ConfigMixin, OnnxRuntimeModel
|
||||
|
||||
from contextlib import suppress
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||
|
||||
import onnx
|
||||
from onnx import numpy_helper
|
||||
from onnx.external_data_helper import set_external_data
|
||||
from onnxruntime import InferenceSession, OrtValue, SessionOptions
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
#Kandinsky2_1 = "kandinsky-2.1"
|
||||
|
||||
class ModelType(str, Enum):
|
||||
ONNX = "onnx"
|
||||
Pipeline = "pipeline"
|
||||
Vae = "vae"
|
||||
Lora = "lora"
|
||||
@ -29,6 +37,8 @@ class SubModelType(str, Enum):
|
||||
TextEncoder = "text_encoder"
|
||||
Tokenizer = "tokenizer"
|
||||
Vae = "vae"
|
||||
VaeDecoder = "vae_decoder"
|
||||
VaeEncoder = "vae_encoder"
|
||||
Scheduler = "scheduler"
|
||||
SafetyChecker = "safety_checker"
|
||||
#MoVQ = "movq"
|
||||
@ -240,16 +250,18 @@ class DiffusersModel(ModelBase):
|
||||
try:
|
||||
# TODO: set cache_dir to /dev/null to be sure that cache not used?
|
||||
model = self.child_types[child_type].from_pretrained(
|
||||
self.model_path,
|
||||
subfolder=child_type.value,
|
||||
os.path.join(self.model_path, child_type.value),
|
||||
#subfolder=child_type.value,
|
||||
torch_dtype=torch_dtype,
|
||||
variant=variant,
|
||||
local_files_only=True,
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
#print("====ERR LOAD====")
|
||||
#print(f"{variant}: {e}")
|
||||
print("====ERR LOAD====")
|
||||
print(f"{variant}: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
pass
|
||||
else:
|
||||
raise Exception(f"Failed to load {self.base_model}:{self.model_type}:{child_type} model")
|
||||
@ -413,3 +425,92 @@ class SilenceWarnings(object):
|
||||
transformers_logging.set_verbosity(self.transformers_verbosity)
|
||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||
warnings.simplefilter('default')
|
||||
|
||||
def buffer_external_data_tensors(model):
|
||||
external_data = dict()
|
||||
for tensor in model.graph.initializer:
|
||||
name = tensor.name
|
||||
|
||||
if tensor.HasField("raw_data"):
|
||||
npt = numpy_helper.to_array(tensor)
|
||||
orv = OrtValue.ortvalue_from_numpy(npt)
|
||||
external_data[name] = orv
|
||||
set_external_data(tensor, location="tmp.bin")
|
||||
tensor.name = name
|
||||
tensor.ClearField("raw_data")
|
||||
|
||||
return (model, external_data)
|
||||
|
||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||
class IAIOnnxRuntimeModel(OnnxRuntimeModel):
|
||||
def __init__(self, model: tuple, **kwargs):
|
||||
self.proto, self.provider, self.sess_options = model
|
||||
self.session = None
|
||||
self._external_data = dict()
|
||||
|
||||
def __call__(self, **kwargs):
|
||||
if self.session is None:
|
||||
raise Exception("You should call create_session before running model")
|
||||
|
||||
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
||||
return self.session.run(None, inputs)
|
||||
|
||||
def create_session(self):
|
||||
if self.session is None:
|
||||
#onnx.save(self.proto, "tmp.onnx")
|
||||
#onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
|
||||
(trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
|
||||
sess = SessionOptions()
|
||||
self._external_data.update(**external_data)
|
||||
sess.add_external_initializers(list(self._external_data.keys()), list(self._external_data.values()))
|
||||
self.session = InferenceSession(trimmed_model.SerializeToString(), providers=[self.provider], sess_options=sess)
|
||||
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
||||
|
||||
def release_session(self):
|
||||
self.session = None
|
||||
import gc
|
||||
gc.collect()
|
||||
|
||||
@staticmethod
|
||||
def load_model(path: Union[str, Path], provider=None, sess_options=None):
|
||||
"""
|
||||
Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
|
||||
|
||||
Arguments:
|
||||
path (`str` or `Path`):
|
||||
Directory from which to load
|
||||
provider(`str`, *optional*):
|
||||
Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
|
||||
"""
|
||||
if provider is None:
|
||||
#logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
|
||||
print("No onnxruntime provider specified, using CPUExecutionProvider")
|
||||
provider = "CPUExecutionProvider"
|
||||
|
||||
# TODO: check that provider available?
|
||||
return (onnx.load(path), provider, sess_options)
|
||||
|
||||
@classmethod
|
||||
def _from_pretrained(
|
||||
cls,
|
||||
model_id: Union[str, Path],
|
||||
use_auth_token: Optional[Union[bool, str, None]] = None,
|
||||
revision: Optional[Union[str, None]] = None,
|
||||
force_download: bool = False,
|
||||
cache_dir: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
sess_options: Optional["SessionOptions"] = None,
|
||||
**kwargs,
|
||||
):
|
||||
model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
|
||||
# load model from local directory
|
||||
if not os.path.isdir(model_id):
|
||||
raise Exception(f"Model not found: {model_id}")
|
||||
model = IAIOnnxRuntimeModel.load_model(
|
||||
os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options
|
||||
)
|
||||
|
||||
return cls(model=model, **kwargs)
|
||||
|
||||
|
||||
|
@ -0,0 +1,156 @@
|
||||
import os
|
||||
import json
|
||||
from enum import Enum
|
||||
from pydantic import Field
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
from .base import (
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelVariantType,
|
||||
DiffusersModel,
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
read_checkpoint_meta,
|
||||
classproperty,
|
||||
OnnxRuntimeModel,
|
||||
IAIOnnxRuntimeModel,
|
||||
)
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
class ONNXStableDiffusion1Model(DiffusersModel):
|
||||
|
||||
class Config(ModelConfigBase):
|
||||
model_format: None
|
||||
variant: ModelVariantType
|
||||
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion1
|
||||
assert model_type == ModelType.ONNX
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion1,
|
||||
model_type=ModelType.ONNX,
|
||||
)
|
||||
|
||||
for child_name, child_type in self.child_types.items():
|
||||
if child_type is OnnxRuntimeModel:
|
||||
self.child_types[child_name] = IAIOnnxRuntimeModel
|
||||
|
||||
# TODO: check that no optimum models provided
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
in_channels = 4 # TODO:
|
||||
|
||||
if in_channels == 9:
|
||||
variant = ModelVariantType.Inpaint
|
||||
elif in_channels == 4:
|
||||
variant = ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Unkown stable diffusion 1.* model format")
|
||||
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=model_format,
|
||||
|
||||
variant=variant,
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
return model_path
|
||||
|
||||
class ONNXStableDiffusion2Model(DiffusersModel):
|
||||
|
||||
# TODO: check that configs overwriten properly
|
||||
class Config(ModelConfigBase):
|
||||
model_format: None
|
||||
variant: ModelVariantType
|
||||
prediction_type: SchedulerPredictionType
|
||||
upcast_attention: bool
|
||||
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert base_model == BaseModelType.StableDiffusion2
|
||||
assert model_type == ModelType.ONNX
|
||||
super().__init__(
|
||||
model_path=model_path,
|
||||
base_model=BaseModelType.StableDiffusion2,
|
||||
model_type=ModelType.ONNX,
|
||||
)
|
||||
|
||||
for child_name, child_type in self.child_types.items():
|
||||
if child_type is OnnxRuntimeModel:
|
||||
self.child_types[child_name] = IAIOnnxRuntimeModel
|
||||
# TODO: check that no optimum models provided
|
||||
|
||||
@classmethod
|
||||
def probe_config(cls, path: str, **kwargs):
|
||||
model_format = cls.detect_format(path)
|
||||
in_channels = 4 # TODO:
|
||||
|
||||
if in_channels == 9:
|
||||
variant = ModelVariantType.Inpaint
|
||||
elif in_channels == 5:
|
||||
variant = ModelVariantType.Depth
|
||||
elif in_channels == 4:
|
||||
variant = ModelVariantType.Normal
|
||||
else:
|
||||
raise Exception("Unkown stable diffusion 2.* model format")
|
||||
|
||||
if variant == ModelVariantType.Normal:
|
||||
prediction_type = SchedulerPredictionType.VPrediction
|
||||
upcast_attention = True
|
||||
|
||||
else:
|
||||
prediction_type = SchedulerPredictionType.Epsilon
|
||||
upcast_attention = False
|
||||
|
||||
return cls.create_config(
|
||||
path=path,
|
||||
model_format=model_format,
|
||||
|
||||
variant=variant,
|
||||
prediction_type=prediction_type,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, model_path: str):
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
return model_path
|
||||
|
Reference in New Issue
Block a user