add specific exception for model probe failures

This commit is contained in:
Lincoln Stein 2023-07-17 23:08:39 -04:00
parent af1c1ab51f
commit 1353bf98b3
2 changed files with 14 additions and 9 deletions

View File

@ -14,6 +14,7 @@ from invokeai.backend.model_management.models import (
OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType,
ModelNotFoundException,
InvalidModelException,
)
from invokeai.backend.model_management import MergeInterpolationMethod
@ -168,6 +169,9 @@ async def import_model(
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except InvalidModelException as e:
log.error(str(e))
raise HTTPException(status_code=415)
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))

View File

@ -12,6 +12,7 @@ from picklescan.scanner import scan_file_path
from .models import (
BaseModelType, ModelType, ModelVariantType,
SchedulerPredictionType, SilenceWarnings,
InvalidModelException
)
from .models.base import read_checkpoint_meta
@ -61,7 +62,7 @@ class ModelProbe(object):
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
else:
raise ValueError("model parameter {model} is neither a Path, nor a model")
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
@classmethod
def probe(cls,
@ -141,7 +142,7 @@ class ModelProbe(object):
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
return ModelType.TextualInversion
raise ValueError(f"Unable to determine model type for {model_path}")
raise InvalidModelException(f"Unable to determine model type for {model_path}")
@classmethod
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
@ -171,7 +172,7 @@ class ModelProbe(object):
return type
# give up
raise ValueError(f"Unable to determine model type for {folder_path}")
raise InvalidModelException(f"Unable to determine model type for {folder_path}")
@classmethod
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
@ -240,7 +241,7 @@ class CheckpointProbeBase(ProbeBase):
elif in_channels == 4:
return ModelVariantType.Normal
else:
raise ValueError(f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}")
raise InvalidModelException(f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}")
class PipelineCheckpointProbe(CheckpointProbeBase):
def get_base_type(self)->BaseModelType:
@ -254,7 +255,7 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
# TODO: Verify that this is correct! Need an XL checkpoint file for this.
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
return BaseModelType.StableDiffusionXL
raise ValueError("Cannot determine base type")
raise InvalidModelException("Cannot determine base type")
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
type = self.get_base_type()
@ -335,7 +336,7 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
return BaseModelType.StableDiffusion2
elif self.checkpoint_path and self.helper:
return self.helper(self.checkpoint_path)
raise ValueError("Unable to determine base type for {self.checkpoint_path}")
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
########################################################
# classes for probing folders
@ -371,7 +372,7 @@ class PipelineFolderProbe(FolderProbeBase):
elif unet_conf['cross_attention_dim'] == 2048:
return BaseModelType.StableDiffusionXL
else:
raise ValueError(f'Unknown base model for {self.folder_path}')
raise InvalidModelException(f'Unknown base model for {self.folder_path}')
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
if self.model:
@ -428,7 +429,7 @@ class ControlNetFolderProbe(FolderProbeBase):
def get_base_type(self)->BaseModelType:
config_file = self.folder_path / 'config.json'
if not config_file.exists():
raise ValueError(f"Cannot determine base type for {self.folder_path}")
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file,'r') as file:
config = json.load(file)
# no obvious way to distinguish between sd2-base and sd2-768
@ -445,7 +446,7 @@ class LoRAFolderProbe(FolderProbeBase):
model_file = base_file
break
if not model_file:
raise ValueError('Unknown LoRA format encountered')
raise InvalidModelException('Unknown LoRA format encountered')
return LoRACheckpointProbe(model_file,None).get_base_type()
############## register probe classes ######