From 26a7b7b66d9d959aec74f9066fa733802b719cef Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Thu, 17 Aug 2023 20:01:21 -0700 Subject: [PATCH] feat(model_probe): provide more clues when we fail to load a model. --- .../backend/model_management/model_probe.py | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/invokeai/backend/model_management/model_probe.py b/invokeai/backend/model_management/model_probe.py index f157fb177a..6fe0eb1714 100644 --- a/invokeai/backend/model_management/model_probe.py +++ b/invokeai/backend/model_management/model_probe.py @@ -1,12 +1,11 @@ import json -import torch -import safetensors.torch - from dataclasses import dataclass - -from diffusers import ModelMixin, ConfigMixin from pathlib import Path from typing import Callable, Literal, Union, Dict, Optional + +import safetensors.torch +import torch +from diffusers import ModelMixin, ConfigMixin from picklescan.scanner import scan_file_path from .models import ( @@ -17,8 +16,8 @@ from .models import ( SilenceWarnings, InvalidModelException, ) -from .util import lora_token_vector_length from .models.base import read_checkpoint_meta +from .util import lora_token_vector_length @dataclass @@ -171,6 +170,7 @@ class ModelProbe(object): Get the model type of a hugging-face style folder. """ class_name = None + error_hint = None if model: class_name = model.__class__.__name__ else: @@ -190,12 +190,18 @@ class ModelProbe(object): with open(config_path, "r") as file: conf = json.load(file) class_name = conf["_class_name"] + else: + error_hint = f"No model_index.json or config.json found in {folder_path}." if class_name and (type := cls.CLASS2TYPE.get(class_name)): return type + else: + error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]" # give up - raise InvalidModelException(f"Unable to determine model type for {folder_path}") + raise InvalidModelException( + f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "") + ) @classmethod def _scan_and_load_checkpoint(cls, model_path: Path) -> dict: