fix one more type mismatch in probe module

This commit is contained in:
Lincoln Stein
2023-09-29 00:44:50 -04:00
parent 2f16a2c35d
commit 3b832f1db2

View File

@ -10,7 +10,7 @@ import json
import re
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Callable, Dict, Optional
from typing import Callable, Dict, Optional, Type
import safetensors.torch
import torch
@ -99,7 +99,7 @@ class ModelProbe(ModelProbeBase):
}
@classmethod
def register_probe(cls, format: ModelFormat, model_type: ModelType, probe_class: ProbeBase):
def register_probe(cls, format: ModelFormat, model_type: ModelType, probe_class: Type[ProbeBase]):
"""
Register a probe subclass to use when interrogating a model.
@ -648,20 +648,20 @@ class CLIPVisionFolderProbe(FolderProbeBase):
############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe(ModelFormat("diffusers"), ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe(ModelFormat("diffusers"), ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe(ModelFormat("diffusers"), ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe(ModelFormat("diffusers"), ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe(ModelFormat("diffusers"), ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe(ModelFormat("diffusers"), ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe(ModelFormat("diffusers"), ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe(ModelFormat("checkpoint"), ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe(ModelFormat("checkpoint"), ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe(ModelFormat("checkpoint"), ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe(ModelFormat("checkpoint"), ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe(ModelFormat("checkpoint"), ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe(ModelFormat("checkpoint"), ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe(ModelFormat("checkpoint"), ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
ModelProbe.register_probe(ModelFormat("onnx"), ModelType.ONNX, ONNXFolderProbe)