mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Tidy spandrel model probe logic, and document the reasons behind the current implementation.
This commit is contained in:
parent
9328c17ded
commit
1ab20f43c8
@ -4,6 +4,7 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, Literal, Optional, Union
|
from typing import Any, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
|
import spandrel
|
||||||
import torch
|
import torch
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
@ -242,15 +243,19 @@ class ModelProbe(object):
|
|||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
# Check if the model can be loaded as a SpandrelImageToImageModel.
|
# Check if the model can be loaded as a SpandrelImageToImageModel.
|
||||||
|
# This check is intentionally performed last, as it can be expensive (it requires loading the model from disk).
|
||||||
try:
|
try:
|
||||||
# TODO(ryand): Figure out why load_from_state_dict() doesn't work as expected.
|
# It would be nice to avoid having to load the Spandrel model from disk here. A couple of options were
|
||||||
# _ = SpandrelImageToImageModel.load_from_state_dict(ckpt)
|
# explored to avoid this:
|
||||||
|
# 1. Call `SpandrelImageToImageModel.load_from_state_dict(ckpt)`, where `ckpt` is a state_dict on the meta
|
||||||
|
# device. Unfortunately, some Spandrel models perform operations during initialization that are not
|
||||||
|
# supported on meta tensors.
|
||||||
|
# 2. Spandrel has internal logic to determine a model's type from its state_dict before loading the model.
|
||||||
|
# This logic is not exposed in spandrel's public API. We could copy the logic here, but then we have to
|
||||||
|
# maintain it, and the risk of false positive detections is higher.
|
||||||
_ = SpandrelImageToImageModel.load_from_file(model_path)
|
_ = SpandrelImageToImageModel.load_from_file(model_path)
|
||||||
return ModelType.SpandrelImageToImage
|
return ModelType.SpandrelImageToImage
|
||||||
except Exception as e:
|
except spandrel.UnsupportedModelError:
|
||||||
# TODO(ryand): Catch a more specific exception type here if we can.
|
|
||||||
# TODO(ryand): Delete this print statement.
|
|
||||||
print(e)
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
|
raise InvalidModelConfigException(f"Unable to determine model type for {model_path}")
|
||||||
|
Loading…
Reference in New Issue
Block a user