mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(mm): handle integer state dict keys in probe
It's possible for a model's state dict to have integer keys, though we do not actually support such models. As part of probing, we call `key.startswith(...)` on the state dict keys. This raises an `AttributeError` for integer keys. This logic is in `invokeai/backend/model_manager/probe.py:get_model_type_from_checkpoint` To fix this, we can cast the keys to strings first. The models w/ integer keys will still fail to be probed, but we'll get a `InvalidModelConfigException` instead of `AttributeError`. Closes #6044
This commit is contained in:
@ -28,7 +28,7 @@ from .config import (
|
|||||||
)
|
)
|
||||||
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
from .util.model_util import lora_token_vector_length, read_checkpoint_meta
|
||||||
|
|
||||||
CkptType = Dict[str, Any]
|
CkptType = Dict[str | int, Any]
|
||||||
|
|
||||||
LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = {
|
LEGACY_CONFIGS: Dict[BaseModelType, Dict[ModelVariantType, Union[str, Dict[SchedulerPredictionType, str]]]] = {
|
||||||
BaseModelType.StableDiffusion1: {
|
BaseModelType.StableDiffusion1: {
|
||||||
@ -219,7 +219,7 @@ class ModelProbe(object):
|
|||||||
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
ckpt = checkpoint if checkpoint else read_checkpoint_meta(model_path, scan=True)
|
||||||
ckpt = ckpt.get("state_dict", ckpt)
|
ckpt = ckpt.get("state_dict", ckpt)
|
||||||
|
|
||||||
for key in ckpt.keys():
|
for key in [str(k) for k in ckpt.keys()]:
|
||||||
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
if any(key.startswith(v) for v in {"cond_stage_model.", "first_stage_model.", "model.diffusion_model."}):
|
||||||
return ModelType.Main
|
return ModelType.Main
|
||||||
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
elif any(key.startswith(v) for v in {"encoder.conv_in", "decoder.conv_in"}):
|
||||||
|
Reference in New Issue
Block a user