mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update onnx model structure, change code according
This commit is contained in:
parent
7759b3f75a
commit
6c7668aaca
@ -63,20 +63,26 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
text_encoder_info as text_encoder,\
|
text_encoder_info as text_encoder,\
|
||||||
ExitStack() as stack:
|
ExitStack() as stack:
|
||||||
|
|
||||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
#loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.clip.loras]
|
||||||
|
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.clip.loras]
|
||||||
|
|
||||||
ti_list = []
|
ti_list = []
|
||||||
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
for trigger in re.findall(r"<[a-zA-Z0-9., _-]+>", self.prompt):
|
||||||
name = trigger[1:-1]
|
name = trigger[1:-1]
|
||||||
try:
|
try:
|
||||||
ti_list.append(
|
ti_list.append(
|
||||||
stack.enter_context(
|
#stack.enter_context(
|
||||||
context.services.model_manager.get_model(
|
# context.services.model_manager.get_model(
|
||||||
model_name=name,
|
# model_name=name,
|
||||||
base_model=self.clip.text_encoder.base_model,
|
# base_model=self.clip.text_encoder.base_model,
|
||||||
model_type=ModelType.TextualInversion,
|
# model_type=ModelType.TextualInversion,
|
||||||
)
|
# )
|
||||||
)
|
#)
|
||||||
|
context.services.model_manager.get_model(
|
||||||
|
model_name=name,
|
||||||
|
base_model=self.clip.text_encoder.base_model,
|
||||||
|
model_type=ModelType.TextualInversion,
|
||||||
|
).context.model
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
#print(e)
|
#print(e)
|
||||||
@ -218,7 +224,8 @@ class ONNXTextToLatentsInvocation(BaseInvocation):
|
|||||||
with unet_info as unet,\
|
with unet_info as unet,\
|
||||||
ExitStack() as stack:
|
ExitStack() as stack:
|
||||||
|
|
||||||
loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
#loras = [(stack.enter_context(context.services.model_manager.get_model(**lora.dict(exclude={"weight"}))), lora.weight) for lora in self.unet.loras]
|
||||||
|
loras = [(context.services.model_manager.get_model(**lora.dict(exclude={"weight"})).context.model, lora.weight) for lora in self.unet.loras]
|
||||||
|
|
||||||
with ONNXModelPatcher.apply_lora_unet(unet, loras):
|
with ONNXModelPatcher.apply_lora_unet(unet, loras):
|
||||||
# TODO:
|
# TODO:
|
||||||
|
@ -12,6 +12,7 @@ from torch.utils.hooks import RemovableHandle
|
|||||||
from diffusers.models import UNet2DConditionModel
|
from diffusers.models import UNet2DConditionModel
|
||||||
from transformers import CLIPTextModel
|
from transformers import CLIPTextModel
|
||||||
from onnx import numpy_helper
|
from onnx import numpy_helper
|
||||||
|
from onnxruntime import OrtValue
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from compel.embeddings_provider import BaseTextualInversionManager
|
from compel.embeddings_provider import BaseTextualInversionManager
|
||||||
@ -718,8 +719,7 @@ class ONNXModelPatcher:
|
|||||||
if not isinstance(model, IAIOnnxRuntimeModel):
|
if not isinstance(model, IAIOnnxRuntimeModel):
|
||||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||||
|
|
||||||
base_model = model.proto
|
orig_weights = dict()
|
||||||
orig_nodes = dict()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
blended_loras = dict()
|
blended_loras = dict()
|
||||||
@ -736,68 +736,49 @@ class ONNXModelPatcher:
|
|||||||
else:
|
else:
|
||||||
blended_loras[layer_key] = layer_weight
|
blended_loras[layer_key] = layer_weight
|
||||||
|
|
||||||
initializer_idx = dict()
|
node_names = dict()
|
||||||
for idx, init in enumerate(base_model.graph.initializer):
|
for node in model.nodes.values():
|
||||||
initializer_idx[init.name.replace(".", "_")] = idx
|
node_names[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = node.name
|
||||||
|
|
||||||
node_idx = dict()
|
for layer_key, lora_weight in blended_loras.items():
|
||||||
for idx, node in enumerate(base_model.graph.node):
|
|
||||||
node_idx[node.name.replace("/", "_").replace(".", "_").lstrip("_")] = idx
|
|
||||||
|
|
||||||
for layer_key, weights in blended_loras.items():
|
|
||||||
conv_key = layer_key + "_Conv"
|
conv_key = layer_key + "_Conv"
|
||||||
gemm_key = layer_key + "_Gemm"
|
gemm_key = layer_key + "_Gemm"
|
||||||
matmul_key = layer_key + "_MatMul"
|
matmul_key = layer_key + "_MatMul"
|
||||||
|
|
||||||
if conv_key in node_idx or gemm_key in node_idx:
|
if conv_key in node_names or gemm_key in node_names:
|
||||||
if conv_key in node_idx:
|
if conv_key in node_names:
|
||||||
conv_node = base_model.graph.node[node_idx[conv_key]]
|
conv_node = model.nodes[node_names[conv_key]]
|
||||||
else:
|
else:
|
||||||
conv_node = base_model.graph.node[node_idx[gemm_key]]
|
conv_node = model.nodes[node_names[gemm_key]]
|
||||||
|
|
||||||
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
weight_name = [n for n in conv_node.input if ".weight" in n][0]
|
||||||
weight_name = weight_name.replace(".", "_")
|
orig_weight = model.tensors[weight_name]
|
||||||
|
|
||||||
weight_idx = initializer_idx[weight_name]
|
if orig_weight.shape[-2:] == (1, 1):
|
||||||
weight_node = base_model.graph.initializer[weight_idx]
|
if lora_weight.shape[-2:] == (1, 1):
|
||||||
|
new_weight = orig_weight.squeeze((3, 2)) + lora_weight.squeeze((3, 2))
|
||||||
orig_weights = numpy_helper.to_array(weight_node)
|
|
||||||
|
|
||||||
if orig_weights.shape[-2:] == (1, 1):
|
|
||||||
if weights.shape[-2:] == (1, 1):
|
|
||||||
new_weights = orig_weights.squeeze((3, 2)) + weights.squeeze((3, 2))
|
|
||||||
else:
|
else:
|
||||||
new_weights = orig_weights.squeeze((3, 2)) + weights
|
new_weight = orig_weight.squeeze((3, 2)) + lora_weight
|
||||||
|
|
||||||
new_weights = np.expand_dims(new_weights, (2, 3))
|
new_weight = np.expand_dims(new_weight, (2, 3))
|
||||||
else:
|
else:
|
||||||
if orig_weights.shape != weights.shape:
|
if orig_weight.shape != lora_weight.shape:
|
||||||
new_weights = orig_weights + weights.reshape(orig_weights.shape)
|
new_weight = orig_weight + lora_weight.reshape(orig_weight.shape)
|
||||||
else:
|
else:
|
||||||
new_weights = orig_weights + weights
|
new_weight = orig_weight + lora_weight
|
||||||
|
|
||||||
new_node = numpy_helper.from_array(new_weights.astype(orig_weights.dtype), weight_node.name)
|
orig_weights[weight_name] = orig_weight
|
||||||
orig_nodes[weight_idx] = base_model.graph.initializer[weight_idx]
|
model.tensors[weight_name] = new_weight.astype(orig_weight.dtype)
|
||||||
del base_model.graph.initializer[weight_idx]
|
|
||||||
base_model.graph.initializer.insert(weight_idx, new_node)
|
|
||||||
|
|
||||||
elif matmul_key in node_idx:
|
|
||||||
weight_node = base_model.graph.node[node_idx[matmul_key]]
|
|
||||||
|
|
||||||
|
elif matmul_key in node_names:
|
||||||
|
weight_node = model.nodes[node_names[matmul_key]]
|
||||||
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
matmul_name = [n for n in weight_node.input if "MatMul" in n][0]
|
||||||
|
|
||||||
matmul_idx = initializer_idx[matmul_name]
|
orig_weight = model.tensors[matmul_name]
|
||||||
matmul_node = base_model.graph.initializer[matmul_idx]
|
new_weight = orig_weight + lora_weight.transpose()
|
||||||
|
|
||||||
orig_weights = numpy_helper.to_array(matmul_node)
|
orig_weights[matmul_name] = orig_weight
|
||||||
|
model.tensors[matmul_name] = new_weight.astype(orig_weight.dtype)
|
||||||
new_weights = orig_weights + weights.transpose()
|
|
||||||
|
|
||||||
# replace the original initializer
|
|
||||||
new_node = numpy_helper.from_array(new_weights.astype(orig_weights.dtype), matmul_node.name)
|
|
||||||
orig_nodes[matmul_idx] = base_model.graph.initializer[matmul_idx]
|
|
||||||
del base_model.graph.initializer[matmul_idx]
|
|
||||||
base_model.graph.initializer.insert(matmul_idx, new_node)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# warn? err?
|
# warn? err?
|
||||||
@ -807,9 +788,8 @@ class ONNXModelPatcher:
|
|||||||
|
|
||||||
finally:
|
finally:
|
||||||
# restore original weights
|
# restore original weights
|
||||||
for idx, orig_node in orig_nodes.items():
|
for name, orig_weight in orig_weights.items():
|
||||||
del base_model.graph.initializer[idx]
|
model.tensors[name] = orig_weight
|
||||||
base_model.graph.initializer.insert(idx, orig_node)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -825,8 +805,7 @@ class ONNXModelPatcher:
|
|||||||
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
|
if not isinstance(text_encoder, IAIOnnxRuntimeModel):
|
||||||
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
raise Exception("Only IAIOnnxRuntimeModel models supported")
|
||||||
|
|
||||||
init_tokens_count = None
|
orig_embeddings = None
|
||||||
new_tokens_added = None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ti_tokenizer = copy.deepcopy(tokenizer)
|
ti_tokenizer = copy.deepcopy(tokenizer)
|
||||||
@ -845,17 +824,15 @@ class ONNXModelPatcher:
|
|||||||
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
new_tokens_added += ti_tokenizer.add_tokens(_get_trigger(ti, i))
|
||||||
|
|
||||||
# modify text_encoder
|
# modify text_encoder
|
||||||
for i in range(len(text_encoder.proto.graph.initializer)):
|
orig_embeddings = text_encoder.tensors["text_model.embeddings.token_embedding.weight"]
|
||||||
if text_encoder.proto.graph.initializer[i].name == "text_model.embeddings.token_embedding.weight":
|
|
||||||
embeddings_node_idx = i
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
raise Exception("text_model.embeddings.token_embedding.weight node not found")
|
|
||||||
|
|
||||||
embeddings_node_orig = text_encoder.proto.graph.initializer[embeddings_node_idx]
|
embeddings = np.concatenate(
|
||||||
base_weights = numpy_helper.to_array(embeddings_node_orig)
|
(
|
||||||
|
np.copy(orig_embeddings),
|
||||||
embedding_weights = np.concatenate((base_weights, np.zeros((new_tokens_added, base_weights.shape[1]))), axis=0)
|
np.zeros((new_tokens_added, orig_embeddings.shape[1]))
|
||||||
|
),
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
|
||||||
for ti in ti_list:
|
for ti in ti_list:
|
||||||
ti_tokens = []
|
ti_tokens = []
|
||||||
@ -867,26 +844,22 @@ class ONNXModelPatcher:
|
|||||||
if token_id == ti_tokenizer.unk_token_id:
|
if token_id == ti_tokenizer.unk_token_id:
|
||||||
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
raise RuntimeError(f"Unable to find token id for token '{trigger}'")
|
||||||
|
|
||||||
if embedding_weights[token_id].shape != embedding.shape:
|
if embeddings[token_id].shape != embedding.shape:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embedding_weights[token_id].shape[0]}."
|
f"Cannot load embedding for {trigger}. It was trained on a model with token dimension {embedding.shape[0]}, but the current model has token dimension {embeddings[token_id].shape[0]}."
|
||||||
)
|
)
|
||||||
|
|
||||||
embedding_weights[token_id] = embedding
|
embeddings[token_id] = embedding
|
||||||
ti_tokens.append(token_id)
|
ti_tokens.append(token_id)
|
||||||
|
|
||||||
if len(ti_tokens) > 1:
|
if len(ti_tokens) > 1:
|
||||||
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
ti_manager.pad_tokens[ti_tokens[0]] = ti_tokens[1:]
|
||||||
|
|
||||||
|
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = embeddings.astype(orig_embeddings.dtype)
|
||||||
new_embeddings_node = numpy_helper.from_array(embedding_weights.astype(base_weights.dtype), embeddings_node_orig.name)
|
|
||||||
del text_encoder.proto.graph.initializer[embeddings_node_idx]
|
|
||||||
text_encoder.proto.graph.initializer.insert(embeddings_node_idx, new_embeddings_node)
|
|
||||||
|
|
||||||
yield ti_tokenizer, ti_manager
|
yield ti_tokenizer, ti_manager
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
# restore
|
# restore
|
||||||
if embeddings_node_orig is not None:
|
if orig_embeddings is not None:
|
||||||
del text_encoder.proto.graph.initializer[embeddings_node_idx]
|
text_encoder.tensors["text_model.embeddings.token_embedding.weight"] = orig_embeddings
|
||||||
text_encoder.proto.graph.initializer.insert(embeddings_node_idx, embeddings_node_orig)
|
|
||||||
|
@ -426,27 +426,127 @@ class SilenceWarnings(object):
|
|||||||
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
diffusers_logging.set_verbosity(self.diffusers_verbosity)
|
||||||
warnings.simplefilter('default')
|
warnings.simplefilter('default')
|
||||||
|
|
||||||
def buffer_external_data_tensors(model):
|
|
||||||
external_data = dict()
|
|
||||||
for tensor in model.graph.initializer:
|
|
||||||
name = tensor.name
|
|
||||||
|
|
||||||
if tensor.HasField("raw_data"):
|
|
||||||
npt = numpy_helper.to_array(tensor)
|
|
||||||
orv = OrtValue.ortvalue_from_numpy(npt)
|
|
||||||
external_data[name] = orv
|
|
||||||
set_external_data(tensor, location="tmp.bin")
|
|
||||||
tensor.name = name
|
|
||||||
tensor.ClearField("raw_data")
|
|
||||||
|
|
||||||
return (model, external_data)
|
|
||||||
|
|
||||||
ONNX_WEIGHTS_NAME = "model.onnx"
|
ONNX_WEIGHTS_NAME = "model.onnx"
|
||||||
class IAIOnnxRuntimeModel(OnnxRuntimeModel):
|
class IAIOnnxRuntimeModel:
|
||||||
def __init__(self, model: tuple, **kwargs):
|
class _tensor_access:
|
||||||
self.proto, self.provider, self.sess_options = model
|
|
||||||
|
def __init__(self, model):
|
||||||
|
self.model = model
|
||||||
|
self.indexes = dict()
|
||||||
|
for idx, obj in enumerate(self.model.proto.graph.initializer):
|
||||||
|
self.indexes[obj.name] = idx
|
||||||
|
|
||||||
|
def __getitem__(self, key: str):
|
||||||
|
return self.model.data[key].numpy()
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value: np.ndarray):
|
||||||
|
new_node = numpy_helper.from_array(value)
|
||||||
|
set_external_data(new_node, location="in-memory-location")
|
||||||
|
new_node.name = key
|
||||||
|
new_node.ClearField("raw_data")
|
||||||
|
del self.model.proto.graph.initializer[self.indexes[key]]
|
||||||
|
self.model.proto.graph.initializer.insert(self.indexes[key], new_node)
|
||||||
|
self.model.data[key] = OrtValue.ortvalue_from_numpy(value)
|
||||||
|
|
||||||
|
# __delitem__
|
||||||
|
|
||||||
|
def __contains__(self, key: str):
|
||||||
|
return key in self.model.data
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
raise NotImplementedError("tensor.items")
|
||||||
|
#return [(obj.name, obj) for obj in self.raw_proto]
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self.model.data.keys()
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
raise NotImplementedError("tensor.values")
|
||||||
|
#return [obj for obj in self.raw_proto]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class _access_helper:
|
||||||
|
def __init__(self, raw_proto):
|
||||||
|
self.indexes = dict()
|
||||||
|
self.raw_proto = raw_proto
|
||||||
|
for idx, obj in enumerate(raw_proto):
|
||||||
|
self.indexes[obj.name] = idx
|
||||||
|
|
||||||
|
def __getitem__(self, key: str):
|
||||||
|
return self.raw_proto[self.indexes[key]]
|
||||||
|
|
||||||
|
def __setitem__(self, key: str, value):
|
||||||
|
index = self.indexes[key]
|
||||||
|
del self.raw_proto[index]
|
||||||
|
self.raw_proto.insert(index, value)
|
||||||
|
|
||||||
|
# __delitem__
|
||||||
|
|
||||||
|
def __contains__(self, key: str):
|
||||||
|
return key in self.indexes
|
||||||
|
|
||||||
|
def items(self):
|
||||||
|
return [(obj.name, obj) for obj in self.raw_proto]
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self.indexes.keys()
|
||||||
|
|
||||||
|
def values(self):
|
||||||
|
return [obj for obj in self.raw_proto]
|
||||||
|
|
||||||
|
def __init__(self, model_path: str, provider: Optional[str]):
|
||||||
|
self.path = model_path
|
||||||
self.session = None
|
self.session = None
|
||||||
self._external_data = dict()
|
self.provider = provider or "CPUExecutionProvider"
|
||||||
|
"""
|
||||||
|
self.data_path = self.path + "_data"
|
||||||
|
if not os.path.exists(self.data_path):
|
||||||
|
print(f"Moving model tensors to separate file: {self.data_path}")
|
||||||
|
tmp_proto = onnx.load(model_path, load_external_data=True)
|
||||||
|
onnx.save_model(tmp_proto, self.path, save_as_external_data=True, all_tensors_to_one_file=True, location=os.path.basename(self.data_path), size_threshold=1024, convert_attribute=False)
|
||||||
|
del tmp_proto
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
self.proto = onnx.load(model_path, load_external_data=False)
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.proto = onnx.load(model_path, load_external_data=True)
|
||||||
|
self.data = dict()
|
||||||
|
for tensor in self.proto.graph.initializer:
|
||||||
|
name = tensor.name
|
||||||
|
|
||||||
|
if tensor.HasField("raw_data"):
|
||||||
|
npt = numpy_helper.to_array(tensor)
|
||||||
|
orv = OrtValue.ortvalue_from_numpy(npt)
|
||||||
|
self.data[name] = orv
|
||||||
|
set_external_data(tensor, location="in-memory-location")
|
||||||
|
tensor.name = name
|
||||||
|
tensor.ClearField("raw_data")
|
||||||
|
|
||||||
|
self.nodes = self._access_helper(self.proto.graph.node)
|
||||||
|
self.initializers = self._access_helper(self.proto.graph.initializer)
|
||||||
|
|
||||||
|
self.tensors = self._tensor_access(self)
|
||||||
|
|
||||||
|
# TODO: integrate with model manager/cache
|
||||||
|
def create_session(self):
|
||||||
|
if self.session is None:
|
||||||
|
#onnx.save(self.proto, "tmp.onnx")
|
||||||
|
#onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
|
||||||
|
# TODO: something to be able to get weight when they already moved outside of model proto
|
||||||
|
#(trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
|
||||||
|
sess = SessionOptions()
|
||||||
|
#self._external_data.update(**external_data)
|
||||||
|
sess.add_external_initializers(list(self.data.keys()), list(self.data.values()))
|
||||||
|
self.session = InferenceSession(self.proto.SerializeToString(), providers=[self.provider], sess_options=sess)
|
||||||
|
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
||||||
|
|
||||||
|
def release_session(self):
|
||||||
|
self.session = None
|
||||||
|
import gc
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
def __call__(self, **kwargs):
|
def __call__(self, **kwargs):
|
||||||
if self.session is None:
|
if self.session is None:
|
||||||
@ -455,63 +555,32 @@ class IAIOnnxRuntimeModel(OnnxRuntimeModel):
|
|||||||
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
inputs = {k: np.array(v) for k, v in kwargs.items()}
|
||||||
return self.session.run(None, inputs)
|
return self.session.run(None, inputs)
|
||||||
|
|
||||||
def create_session(self):
|
# compatability with diffusers load code
|
||||||
if self.session is None:
|
|
||||||
#onnx.save(self.proto, "tmp.onnx")
|
|
||||||
#onnx.save_model(self.proto, "tmp.onnx", save_as_external_data=True, all_tensors_to_one_file=True, location="tmp.onnx_data", size_threshold=1024, convert_attribute=False)
|
|
||||||
# TODO: something to be able to get weight when they already moved outside of model proto
|
|
||||||
(trimmed_model, external_data) = buffer_external_data_tensors(self.proto)
|
|
||||||
sess = SessionOptions()
|
|
||||||
self._external_data.update(**external_data)
|
|
||||||
sess.add_external_initializers(list(self._external_data.keys()), list(self._external_data.values()))
|
|
||||||
self.session = InferenceSession(trimmed_model.SerializeToString(), providers=[self.provider], sess_options=sess)
|
|
||||||
#self.session = InferenceSession("tmp.onnx", providers=[self.provider], sess_options=self.sess_options)
|
|
||||||
|
|
||||||
def release_session(self):
|
|
||||||
self.session = None
|
|
||||||
import gc
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_model(path: Union[str, Path], provider=None, sess_options=None):
|
|
||||||
"""
|
|
||||||
Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
path (`str` or `Path`):
|
|
||||||
Directory from which to load
|
|
||||||
provider(`str`, *optional*):
|
|
||||||
Onnxruntime execution provider to use for loading the model, defaults to `CPUExecutionProvider`
|
|
||||||
"""
|
|
||||||
if provider is None:
|
|
||||||
#logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
|
|
||||||
print("No onnxruntime provider specified, using CPUExecutionProvider")
|
|
||||||
provider = "CPUExecutionProvider"
|
|
||||||
|
|
||||||
# TODO: check that provider available?
|
|
||||||
return (onnx.load(path), provider, sess_options)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_pretrained(
|
def from_pretrained(
|
||||||
cls,
|
cls,
|
||||||
model_id: Union[str, Path],
|
model_id: Union[str, Path],
|
||||||
use_auth_token: Optional[Union[bool, str, None]] = None,
|
subfolder: Union[str, Path] = None,
|
||||||
revision: Optional[Union[str, None]] = None,
|
|
||||||
force_download: bool = False,
|
|
||||||
cache_dir: Optional[str] = None,
|
|
||||||
file_name: Optional[str] = None,
|
file_name: Optional[str] = None,
|
||||||
provider: Optional[str] = None,
|
provider: Optional[str] = None,
|
||||||
sess_options: Optional["SessionOptions"] = None,
|
sess_options: Optional["SessionOptions"] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
model_file_name = file_name if file_name is not None else ONNX_WEIGHTS_NAME
|
file_name = file_name or ONNX_WEIGHTS_NAME
|
||||||
|
|
||||||
|
if os.path.isdir(model_id):
|
||||||
|
model_path = model_id
|
||||||
|
if subfolder is not None:
|
||||||
|
model_path = os.path.join(model_path, subfolder)
|
||||||
|
model_path = os.path.join(model_path, file_name)
|
||||||
|
|
||||||
|
else:
|
||||||
|
model_path = model_id
|
||||||
|
|
||||||
# load model from local directory
|
# load model from local directory
|
||||||
if not os.path.isdir(model_id):
|
if not os.path.isfile(model_path):
|
||||||
raise Exception(f"Model not found: {model_id}")
|
raise Exception(f"Model not found: {model_path}")
|
||||||
model = IAIOnnxRuntimeModel.load_model(
|
|
||||||
os.path.join(model_id, model_file_name), provider=provider, sess_options=sess_options
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(model=model, **kwargs)
|
|
||||||
|
|
||||||
|
# TODO: session options
|
||||||
|
return cls(model_path, provider=provider)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user