Run python black, point out that onnx is an alpha feature in the installer

This commit is contained in:
Brandon Rising 2023-07-31 16:47:48 -04:00
parent af4fd328a6
commit aeac557c41
3 changed files with 12 additions and 5 deletions

View File

@ -168,7 +168,7 @@ def graphical_accelerator():
"cuda",
)
nvidia_with_dml = (
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX)",
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX) -- ALPHA",
"cuda_and_dml",
)
amd = (

View File

@ -438,7 +438,11 @@ class ModelInstall(object):
for filename in files:
filePath = Path(filename)
p = hf_download_with_resume(
repo_id, model_dir=location / filePath.parent, model_name=filePath.name, access_token=self.access_token, subfolder=filePath.parent
repo_id,
model_dir=location / filePath.parent,
model_name=filePath.name,
access_token=self.access_token,
subfolder=filePath.parent,
)
if p:
paths.append(p)

View File

@ -54,7 +54,9 @@ class ModelProbe(object):
}
@classmethod
def register_probe(cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase):
def register_probe(
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase
):
cls.PROBES[format][model_type] = probe_class
@classmethod
@ -96,7 +98,7 @@ class ModelProbe(object):
if format_type == "diffusers"
else cls.get_model_type_from_checkpoint(model_path, model)
)
format_type = 'onnx' if model_type == ModelType.ONNX else format_type
format_type = "onnx" if model_type == ModelType.ONNX else format_type
probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class:
return None
@ -170,7 +172,7 @@ class ModelProbe(object):
if model:
class_name = model.__class__.__name__
else:
if (folder_path / 'unet/model.onnx').exists():
if (folder_path / "unet/model.onnx").exists():
return ModelType.ONNX
if (folder_path / "learned_embeds.bin").exists():
return ModelType.TextualInversion
@ -474,6 +476,7 @@ class ONNXFolderProbe(FolderProbeBase):
def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal
class ControlNetFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.folder_path / "config.json"