mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Fix onnx installer
This commit is contained in:
parent
8935ae0ea3
commit
390ce9f249
@ -12,6 +12,7 @@ from typing import List, Dict, Callable, Union, Set
|
|||||||
import requests
|
import requests
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from diffusers import logging as dlogging
|
from diffusers import logging as dlogging
|
||||||
|
import onnx
|
||||||
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -288,8 +289,10 @@ class ModelInstall(object):
|
|||||||
|
|
||||||
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
||||||
staging = Path(staging)
|
staging = Path(staging)
|
||||||
if "model_index.json" in files:
|
if "model_index.json" in files and "unet/model.onnx" not in files:
|
||||||
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
||||||
|
elif "unet/model.onnx" in files:
|
||||||
|
location = self._download_hf_model(repo_id, files, staging)
|
||||||
else:
|
else:
|
||||||
for suffix in ["safetensors", "bin"]:
|
for suffix in ["safetensors", "bin"]:
|
||||||
if f"pytorch_lora_weights.{suffix}" in files:
|
if f"pytorch_lora_weights.{suffix}" in files:
|
||||||
@ -354,7 +357,7 @@ class ModelInstall(object):
|
|||||||
model_format=info.format,
|
model_format=info.format,
|
||||||
)
|
)
|
||||||
legacy_conf = None
|
legacy_conf = None
|
||||||
if info.model_type == ModelType.Main:
|
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
|
||||||
attributes.update(
|
attributes.update(
|
||||||
dict(
|
dict(
|
||||||
variant=info.variant_type,
|
variant=info.variant_type,
|
||||||
@ -419,8 +422,9 @@ class ModelInstall(object):
|
|||||||
location = staging / name
|
location = staging / name
|
||||||
paths = list()
|
paths = list()
|
||||||
for filename in files:
|
for filename in files:
|
||||||
|
filePath = Path(filename)
|
||||||
p = hf_download_with_resume(
|
p = hf_download_with_resume(
|
||||||
repo_id, model_dir=location, model_name=filename, access_token=self.access_token
|
repo_id, model_dir=location / filePath.parent, model_name=filePath.name, access_token=self.access_token, subfolder=filePath.parent
|
||||||
)
|
)
|
||||||
if p:
|
if p:
|
||||||
paths.append(p)
|
paths.append(p)
|
||||||
@ -468,11 +472,12 @@ def hf_download_with_resume(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
model_dest: Path = None,
|
model_dest: Path = None,
|
||||||
access_token: str = None,
|
access_token: str = None,
|
||||||
|
subfolder: str = None,
|
||||||
) -> Path:
|
) -> Path:
|
||||||
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
|
|
||||||
url = hf_hub_url(repo_id, model_name)
|
url = hf_hub_url(repo_id, model_name, subfolder=subfolder)
|
||||||
|
|
||||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||||
open_mode = "wb"
|
open_mode = "wb"
|
||||||
|
@ -41,6 +41,7 @@ class ModelProbe(object):
|
|||||||
PROBES = {
|
PROBES = {
|
||||||
"diffusers": {},
|
"diffusers": {},
|
||||||
"checkpoint": {},
|
"checkpoint": {},
|
||||||
|
"onnx": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
CLASS2TYPE = {
|
CLASS2TYPE = {
|
||||||
@ -53,7 +54,7 @@ class ModelProbe(object):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_probe(cls, format: Literal["diffusers", "checkpoint"], model_type: ModelType, probe_class: ProbeBase):
|
def register_probe(cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase):
|
||||||
cls.PROBES[format][model_type] = probe_class
|
cls.PROBES[format][model_type] = probe_class
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -95,6 +96,7 @@ class ModelProbe(object):
|
|||||||
if format_type == "diffusers"
|
if format_type == "diffusers"
|
||||||
else cls.get_model_type_from_checkpoint(model_path, model)
|
else cls.get_model_type_from_checkpoint(model_path, model)
|
||||||
)
|
)
|
||||||
|
format_type = 'onnx' if model_type == ModelType.ONNX else format_type
|
||||||
probe_class = cls.PROBES[format_type].get(model_type)
|
probe_class = cls.PROBES[format_type].get(model_type)
|
||||||
if not probe_class:
|
if not probe_class:
|
||||||
return None
|
return None
|
||||||
@ -168,6 +170,8 @@ class ModelProbe(object):
|
|||||||
if model:
|
if model:
|
||||||
class_name = model.__class__.__name__
|
class_name = model.__class__.__name__
|
||||||
else:
|
else:
|
||||||
|
if (folder_path / 'unet/model.onnx').exists():
|
||||||
|
return ModelType.ONNX
|
||||||
if (folder_path / "learned_embeds.bin").exists():
|
if (folder_path / "learned_embeds.bin").exists():
|
||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
@ -460,6 +464,16 @@ class TextualInversionFolderProbe(FolderProbeBase):
|
|||||||
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
|
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
|
||||||
|
|
||||||
|
|
||||||
|
class ONNXFolderProbe(FolderProbeBase):
|
||||||
|
def get_format(self) -> str:
|
||||||
|
return "onnx"
|
||||||
|
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
|
||||||
|
def get_variant_type(self) -> ModelVariantType:
|
||||||
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
class ControlNetFolderProbe(FolderProbeBase):
|
class ControlNetFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
config_file = self.folder_path / "config.json"
|
config_file = self.folder_path / "config.json"
|
||||||
@ -497,3 +511,4 @@ ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
|||||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||||
|
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
||||||
|
Loading…
Reference in New Issue
Block a user