Run python black, point out that onnx is an alpha feature in the installer

This commit is contained in:
Brandon Rising 2023-07-31 16:47:48 -04:00
parent af4fd328a6
commit aeac557c41
3 changed files with 12 additions and 5 deletions

View File

@ -168,7 +168,7 @@ def graphical_accelerator():
"cuda", "cuda",
) )
nvidia_with_dml = ( nvidia_with_dml = (
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX)", "an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX) -- ALPHA",
"cuda_and_dml", "cuda_and_dml",
) )
amd = ( amd = (

View File

@ -438,7 +438,11 @@ class ModelInstall(object):
for filename in files: for filename in files:
filePath = Path(filename) filePath = Path(filename)
p = hf_download_with_resume( p = hf_download_with_resume(
repo_id, model_dir=location / filePath.parent, model_name=filePath.name, access_token=self.access_token, subfolder=filePath.parent repo_id,
model_dir=location / filePath.parent,
model_name=filePath.name,
access_token=self.access_token,
subfolder=filePath.parent,
) )
if p: if p:
paths.append(p) paths.append(p)

View File

@ -54,7 +54,9 @@ class ModelProbe(object):
} }
@classmethod @classmethod
def register_probe(cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase): def register_probe(
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase
):
cls.PROBES[format][model_type] = probe_class cls.PROBES[format][model_type] = probe_class
@classmethod @classmethod
@ -96,7 +98,7 @@ class ModelProbe(object):
if format_type == "diffusers" if format_type == "diffusers"
else cls.get_model_type_from_checkpoint(model_path, model) else cls.get_model_type_from_checkpoint(model_path, model)
) )
format_type = 'onnx' if model_type == ModelType.ONNX else format_type format_type = "onnx" if model_type == ModelType.ONNX else format_type
probe_class = cls.PROBES[format_type].get(model_type) probe_class = cls.PROBES[format_type].get(model_type)
if not probe_class: if not probe_class:
return None return None
@ -170,7 +172,7 @@ class ModelProbe(object):
if model: if model:
class_name = model.__class__.__name__ class_name = model.__class__.__name__
else: else:
if (folder_path / 'unet/model.onnx').exists(): if (folder_path / "unet/model.onnx").exists():
return ModelType.ONNX return ModelType.ONNX
if (folder_path / "learned_embeds.bin").exists(): if (folder_path / "learned_embeds.bin").exists():
return ModelType.TextualInversion return ModelType.TextualInversion
@ -474,6 +476,7 @@ class ONNXFolderProbe(FolderProbeBase):
def get_variant_type(self) -> ModelVariantType: def get_variant_type(self) -> ModelVariantType:
return ModelVariantType.Normal return ModelVariantType.Normal
class ControlNetFolderProbe(FolderProbeBase): class ControlNetFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
config_file = self.folder_path / "config.json" config_file = self.folder_path / "config.json"