feat(model_probe): provide more clues when we fail to load a model.

This commit is contained in:
Kevin Turner 2023-08-17 20:01:21 -07:00
parent 8611ffe32d
commit 26a7b7b66d

View File

@ -1,12 +1,11 @@
import json import json
import torch
import safetensors.torch
from dataclasses import dataclass from dataclasses import dataclass
from diffusers import ModelMixin, ConfigMixin
from pathlib import Path from pathlib import Path
from typing import Callable, Literal, Union, Dict, Optional from typing import Callable, Literal, Union, Dict, Optional
import safetensors.torch
import torch
from diffusers import ModelMixin, ConfigMixin
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
from .models import ( from .models import (
@ -17,8 +16,8 @@ from .models import (
SilenceWarnings, SilenceWarnings,
InvalidModelException, InvalidModelException,
) )
from .util import lora_token_vector_length
from .models.base import read_checkpoint_meta from .models.base import read_checkpoint_meta
from .util import lora_token_vector_length
@dataclass @dataclass
@ -171,6 +170,7 @@ class ModelProbe(object):
Get the model type of a hugging-face style folder. Get the model type of a hugging-face style folder.
""" """
class_name = None class_name = None
error_hint = None
if model: if model:
class_name = model.__class__.__name__ class_name = model.__class__.__name__
else: else:
@ -190,12 +190,18 @@ class ModelProbe(object):
with open(config_path, "r") as file: with open(config_path, "r") as file:
conf = json.load(file) conf = json.load(file)
class_name = conf["_class_name"] class_name = conf["_class_name"]
else:
error_hint = f"No model_index.json or config.json found in {folder_path}."
if class_name and (type := cls.CLASS2TYPE.get(class_name)): if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type return type
else:
error_hint = f"class {class_name} is not one of the supported classes [{', '.join(cls.CLASS2TYPE.keys())}]"
# give up # give up
raise InvalidModelException(f"Unable to determine model type for {folder_path}") raise InvalidModelException(
f"Unable to determine model type for {folder_path}" + (f"; {error_hint}" if error_hint else "")
)
@classmethod @classmethod
def _scan_and_load_checkpoint(cls, model_path: Path) -> dict: def _scan_and_load_checkpoint(cls, model_path: Path) -> dict: