Fix onnx installer

This commit is contained in:
Brandon Rising 2023-07-28 16:54:03 -04:00
parent 8935ae0ea3
commit 390ce9f249
2 changed files with 25 additions and 5 deletions

View File

@ -12,6 +12,7 @@ from typing import List, Dict, Callable, Union, Set
import requests import requests
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers import logging as dlogging from diffusers import logging as dlogging
import onnx
from huggingface_hub import hf_hub_url, HfFolder, HfApi from huggingface_hub import hf_hub_url, HfFolder, HfApi
from omegaconf import OmegaConf from omegaconf import OmegaConf
from tqdm import tqdm from tqdm import tqdm
@ -288,8 +289,10 @@ class ModelInstall(object):
with TemporaryDirectory(dir=self.config.models_path) as staging: with TemporaryDirectory(dir=self.config.models_path) as staging:
staging = Path(staging) staging = Path(staging)
if "model_index.json" in files: if "model_index.json" in files and "unet/model.onnx" not in files:
location = self._download_hf_pipeline(repo_id, staging) # pipeline location = self._download_hf_pipeline(repo_id, staging) # pipeline
elif "unet/model.onnx" in files:
location = self._download_hf_model(repo_id, files, staging)
else: else:
for suffix in ["safetensors", "bin"]: for suffix in ["safetensors", "bin"]:
if f"pytorch_lora_weights.{suffix}" in files: if f"pytorch_lora_weights.{suffix}" in files:
@ -354,7 +357,7 @@ class ModelInstall(object):
model_format=info.format, model_format=info.format,
) )
legacy_conf = None legacy_conf = None
if info.model_type == ModelType.Main: if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
attributes.update( attributes.update(
dict( dict(
variant=info.variant_type, variant=info.variant_type,
@ -419,8 +422,9 @@ class ModelInstall(object):
location = staging / name location = staging / name
paths = list() paths = list()
for filename in files: for filename in files:
filePath = Path(filename)
p = hf_download_with_resume( p = hf_download_with_resume(
repo_id, model_dir=location, model_name=filename, access_token=self.access_token repo_id, model_dir=location / filePath.parent, model_name=filePath.name, access_token=self.access_token, subfolder=filePath.parent
) )
if p: if p:
paths.append(p) paths.append(p)
@ -468,11 +472,12 @@ def hf_download_with_resume(
model_name: str, model_name: str,
model_dest: Path = None, model_dest: Path = None,
access_token: str = None, access_token: str = None,
subfolder: str = None,
) -> Path: ) -> Path:
model_dest = model_dest or Path(os.path.join(model_dir, model_name)) model_dest = model_dest or Path(os.path.join(model_dir, model_name))
os.makedirs(model_dir, exist_ok=True) os.makedirs(model_dir, exist_ok=True)
url = hf_hub_url(repo_id, model_name) url = hf_hub_url(repo_id, model_name, subfolder=subfolder)
header = {"Authorization": f"Bearer {access_token}"} if access_token else {} header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
open_mode = "wb" open_mode = "wb"

View File

@ -41,6 +41,7 @@ class ModelProbe(object):
PROBES = { PROBES = {
"diffusers": {}, "diffusers": {},
"checkpoint": {}, "checkpoint": {},
"onnx": {},
} }
CLASS2TYPE = { CLASS2TYPE = {
@ -53,7 +54,7 @@ class ModelProbe(object):
} }
@classmethod @classmethod
def register_probe(cls, format: Literal["diffusers", "checkpoint"], model_type: ModelType, probe_class: ProbeBase): def register_probe(cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase):
cls.PROBES[format][model_type] = probe_class cls.PROBES[format][model_type] = probe_class
@classmethod @classmethod
@ -95,6 +96,7 @@ class ModelProbe(object):
if format_type == "diffusers" if format_type == "diffusers"
else cls.get_model_type_from_checkpoint(model_path, model) else cls.get_model_type_from_checkpoint(model_path, model)
) )
format_type = 'onnx' if model_type == ModelType.ONNX else format_type
probe_class = cls.PROBES[format_type].get(model_type) probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class: if not probe_class:
return None return None
@ -168,6 +170,8 @@ class ModelProbe(object):
if model: if model:
class_name = model.__class__.__name__ class_name = model.__class__.__name__
else: else:
if (folder_path / 'unet/model.onnx').exists():
return ModelType.ONNX
if (folder_path / "learned_embeds.bin").exists(): if (folder_path / "learned_embeds.bin").exists():
return ModelType.TextualInversion return ModelType.TextualInversion
@ -460,6 +464,16 @@ class TextualInversionFolderProbe(FolderProbeBase):
return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type() return TextualInversionCheckpointProbe(None, checkpoint=checkpoint).get_base_type()
class ONNXFolderProbe(FolderProbeBase):
def get_format(self) -> str:
return "onnx"
def get_base_type(self) -> BaseModelType:
return BaseModelType.StableDiffusion1
def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal
class ControlNetFolderProbe(FolderProbeBase): class ControlNetFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
config_file = self.folder_path / "config.json" config_file = self.folder_path / "config.json"
@ -497,3 +511,4 @@ ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)