diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 56d0d8f954..43d51e6b96 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -10,7 +10,7 @@ import json import re from abc import ABC, abstractmethod from pathlib import Path -from typing import Callable, Dict, Optional +from typing import Callable, Dict, Optional, Type import safetensors.torch import torch @@ -99,7 +99,7 @@ class ModelProbe(ModelProbeBase): } @classmethod - def register_probe(cls, format: ModelFormat, model_type: ModelType, probe_class: ProbeBase): + def register_probe(cls, format: ModelFormat, model_type: ModelType, probe_class: Type[ProbeBase]): """ Register a probe subclass to use when interrogating a model. @@ -648,20 +648,20 @@ class CLIPVisionFolderProbe(FolderProbeBase): ############## register probe classes ###### -ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) -ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe) +ModelProbe.register_probe(ModelFormat("diffusers"), ModelType.Main, PipelineFolderProbe) +ModelProbe.register_probe(ModelFormat("diffusers"), ModelType.Vae, VaeFolderProbe) +ModelProbe.register_probe(ModelFormat("diffusers"), ModelType.Lora, LoRAFolderProbe) +ModelProbe.register_probe(ModelFormat("diffusers"), ModelType.TextualInversion, TextualInversionFolderProbe) +ModelProbe.register_probe(ModelFormat("diffusers"), ModelType.ControlNet, ControlNetFolderProbe) +ModelProbe.register_probe(ModelFormat("diffusers"), ModelType.IPAdapter, IPAdapterFolderProbe) +ModelProbe.register_probe(ModelFormat("diffusers"), ModelType.CLIPVision, CLIPVisionFolderProbe) -ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) -ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe) +ModelProbe.register_probe(ModelFormat("checkpoint"), ModelType.Main, PipelineCheckpointProbe) +ModelProbe.register_probe(ModelFormat("checkpoint"), ModelType.Vae, VaeCheckpointProbe) +ModelProbe.register_probe(ModelFormat("checkpoint"), ModelType.Lora, LoRACheckpointProbe) +ModelProbe.register_probe(ModelFormat("checkpoint"), ModelType.TextualInversion, TextualInversionCheckpointProbe) +ModelProbe.register_probe(ModelFormat("checkpoint"), ModelType.ControlNet, ControlNetCheckpointProbe) +ModelProbe.register_probe(ModelFormat("checkpoint"), ModelType.IPAdapter, IPAdapterCheckpointProbe) +ModelProbe.register_probe(ModelFormat("checkpoint"), ModelType.CLIPVision, CLIPVisionCheckpointProbe) -ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe) +ModelProbe.register_probe(ModelFormat("onnx"), ModelType.ONNX, ONNXFolderProbe)