mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Run python black, point out that onnx is an alpha feature in the installer
This commit is contained in:
parent
af4fd328a6
commit
aeac557c41
@ -168,7 +168,7 @@ def graphical_accelerator():
|
|||||||
"cuda",
|
"cuda",
|
||||||
)
|
)
|
||||||
nvidia_with_dml = (
|
nvidia_with_dml = (
|
||||||
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX)",
|
"an [gold1 b]NVIDIA[/] GPU (using CUDA™, and DirectML™ for ONNX) -- ALPHA",
|
||||||
"cuda_and_dml",
|
"cuda_and_dml",
|
||||||
)
|
)
|
||||||
amd = (
|
amd = (
|
||||||
|
@ -438,7 +438,11 @@ class ModelInstall(object):
|
|||||||
for filename in files:
|
for filename in files:
|
||||||
filePath = Path(filename)
|
filePath = Path(filename)
|
||||||
p = hf_download_with_resume(
|
p = hf_download_with_resume(
|
||||||
repo_id, model_dir=location / filePath.parent, model_name=filePath.name, access_token=self.access_token, subfolder=filePath.parent
|
repo_id,
|
||||||
|
model_dir=location / filePath.parent,
|
||||||
|
model_name=filePath.name,
|
||||||
|
access_token=self.access_token,
|
||||||
|
subfolder=filePath.parent,
|
||||||
)
|
)
|
||||||
if p:
|
if p:
|
||||||
paths.append(p)
|
paths.append(p)
|
||||||
|
@ -54,7 +54,9 @@ class ModelProbe(object):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_probe(cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase):
|
def register_probe(
|
||||||
|
cls, format: Literal["diffusers", "checkpoint", "onnx"], model_type: ModelType, probe_class: ProbeBase
|
||||||
|
):
|
||||||
cls.PROBES[format][model_type] = probe_class
|
cls.PROBES[format][model_type] = probe_class
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -96,7 +98,7 @@ class ModelProbe(object):
|
|||||||
if format_type == "diffusers"
|
if format_type == "diffusers"
|
||||||
else cls.get_model_type_from_checkpoint(model_path, model)
|
else cls.get_model_type_from_checkpoint(model_path, model)
|
||||||
)
|
)
|
||||||
format_type = 'onnx' if model_type == ModelType.ONNX else format_type
|
format_type = "onnx" if model_type == ModelType.ONNX else format_type
|
||||||
probe_class = cls.PROBES[format_type].get(model_type)
|
probe_class = cls.PROBES[format_type].get(model_type)
|
||||||
if not probe_class:
|
if not probe_class:
|
||||||
return None
|
return None
|
||||||
@ -170,7 +172,7 @@ class ModelProbe(object):
|
|||||||
if model:
|
if model:
|
||||||
class_name = model.__class__.__name__
|
class_name = model.__class__.__name__
|
||||||
else:
|
else:
|
||||||
if (folder_path / 'unet/model.onnx').exists():
|
if (folder_path / "unet/model.onnx").exists():
|
||||||
return ModelType.ONNX
|
return ModelType.ONNX
|
||||||
if (folder_path / "learned_embeds.bin").exists():
|
if (folder_path / "learned_embeds.bin").exists():
|
||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
@ -474,6 +476,7 @@ class ONNXFolderProbe(FolderProbeBase):
|
|||||||
def get_variant_type(self) -> ModelVariantType:
|
def get_variant_type(self) -> ModelVariantType:
|
||||||
return ModelVariantType.Normal
|
return ModelVariantType.Normal
|
||||||
|
|
||||||
|
|
||||||
class ControlNetFolderProbe(FolderProbeBase):
|
class ControlNetFolderProbe(FolderProbeBase):
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
config_file = self.folder_path / "config.json"
|
config_file = self.folder_path / "config.json"
|
||||||
|
Loading…
Reference in New Issue
Block a user