From 390ce9f2498b71e1cc4062d2470ae489f41627d7 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Fri, 28 Jul 2023 16:54:03 -0400 Subject: [PATCH] Fix onnx installer --- .../backend/install/model_install_backend.py | 13 +++++++++---- .../backend/model_management/model_probe.py | 17 ++++++++++++++++- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/invokeai/backend/install/model_install_backend.py b/invokeai/backend/install/model_install_backend.py index b3ab88b5dd..eb70db2fee 100644 --- a/invokeai/backend/install/model_install_backend.py +++ b/invokeai/backend/install/model_install_backend.py @@ -12,6 +12,7 @@ from typing import List, Dict, Callable, Union, Set import requests from diffusers import DiffusionPipeline from diffusers import logging as dlogging +import onnx from huggingface_hub import hf_hub_url, HfFolder, HfApi from omegaconf import OmegaConf from tqdm import tqdm @@ -288,8 +289,10 @@ class ModelInstall(object): with TemporaryDirectory(dir=self.config.models_path) as staging: staging = Path(staging) - if "model_index.json" in files: + if "model_index.json" in files and "unet/model.onnx" not in files: location = self._download_hf_pipeline(repo_id, staging) # pipeline + elif "unet/model.onnx" in files: + location = self._download_hf_model(repo_id, files, staging) else: for suffix in ["safetensors", "bin"]: if f"pytorch_lora_weights.{suffix}" in files: @@ -354,7 +357,7 @@ class ModelInstall(object): model_format=info.format, ) legacy_conf = None - if info.model_type == ModelType.Main: + if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX: attributes.update( dict( variant=info.variant_type, @@ -419,8 +422,9 @@ class ModelInstall(object): location = staging / name paths = list() for filename in files: + filePath = Path(filename) p = hf_download_with_resume( - repo_id, model_dir=location, model_name=filename, access_token=self.access_token + repo_id, model_dir=location / filePath.parent, model_name=filePath.name, access_token=self.access_token, subfolder=filePath.parent ) if p: paths.append(p) @@ -468,11 +472,12 @@ def hf_download_with_resume( model_name: str, model_dest: Path = None, access_token: str = None, + subfolder: str = None, ) -> Path: model_dest = model_dest or Path(os.path.join(model_dir, model_name)) os.makedirs(model_dir, exist_ok=True) - url = hf_hub_url(repo_id, model_name) + url = hf_hub_url(repo_id, model_name, subfolder=subfolder) header = {"Authorization": f"Bearer {access_token}"} if access_token else {} open_mode = "wb" diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index ee14d8ba93..7da722df77 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -41,6 +41,7 @@ class ModelProbe(object): PROBES = { "diffusers": {}, "checkpoint": {}, + "onnx": {}, } CLASS2TYPE = { @@ -53,7 +54,7 @@ class ModelProbe(object): } @classmethod - def register_probe(cls, format: Literal["diffusers", "checkpoint"], model_type: ModelType, probe_class: ProbeBase): + def register_probe(cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase): cls.PROBES[format][model_type] = probe_class @classmethod @@ -95,6 +96,7 @@ class ModelProbe(object): if format_type == "diffusers" else cls.get_model_type_from_checkpoint(model_path, model) ) + format_type = 'onnx' if model_type == ModelType.ONNX else format_type probe_class = cls.PROBES[format_type].get(model_type) if not probe_class: return None @@ -168,6 +170,8 @@ class ModelProbe(object): if model: class_name = model.__class__.__name__ else: + if (folder_path / 'unet/model.onnx').exists(): + return ModelType.ONNX if (folder_path / "learned_embeds.bin").exists(): return ModelType.TextualInversion @@ -460,6 +464,16 @@ class TextualInversionFolderProbe(FolderProbeBase): return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type() +class ONNXFolderProbe(FolderProbeBase): + def get_format(self) -> str: + return "onnx" + + def get_base_type(self) -> BaseModelType: + return BaseModelType.StableDiffusion1 + + def get_variant_type(self) -> ModelVariantType: + return ModelVariantType.Normal + class ControlNetFolderProbe(FolderProbeBase): def get_base_type(self) -> BaseModelType: config_file = self.folder_path / "config.json" @@ -497,3 +511,4 @@ ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) +ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)