From 1ab20f43c858c6cc788c2d0fb18c63c79e5caef9 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 2 Jul 2024 09:51:51 -0400 Subject: [PATCH] Tidy spandrel model probe logic, and document the reasons behind the current implementation. --- invokeai/backend/model_manager/probe.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 53da5fc152..c7267e9f1e 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Any, Dict, Literal, Optional, Union import safetensors.torch +import spandrel import torch from picklescan.scanner import scan_file_path @@ -242,15 +243,19 @@ class ModelProbe(object): return ModelType.TextualInversion # Check if the model can be loaded as a SpandrelImageToImageModel. + # This check is intentionally performed last, as it can be expensive (it requires loading the model from disk). try: - # TODO(ryand): Figure out why load_from_state_dict() doesn't work as expected. - # _ = SpandrelImageToImageModel.load_from_state_dict(ckpt) + # It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were + # explored to avoid this: + # 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta + # device. Unfortunately, some Spandrel models perform operations during initialization that are not + # supported on meta tensors. + # 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model. + # This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to + # maintain it, and the risk of false positive detections is higher. _ = SpandrelImageToImageModel.load_from_file(model_path) return ModelType.SpandrelImageToImage - except Exception as e: - # TODO(ryand): Catch a more specific exception type here if we can. - # TODO(ryand): Delete this print statement. - print(e) + except spandrel.UnsupportedModelError: pass raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")