mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add specific exception for model probe failures
This commit is contained in:
parent
af1c1ab51f
commit
1353bf98b3
@ -14,6 +14,7 @@ from invokeai.backend.model_management.models import (
|
||||
OPENAPI_MODEL_CONFIGS,
|
||||
SchedulerPredictionType,
|
||||
ModelNotFoundException,
|
||||
InvalidModelException,
|
||||
)
|
||||
from invokeai.backend.model_management import MergeInterpolationMethod
|
||||
|
||||
@ -168,6 +169,9 @@ async def import_model(
|
||||
except ModelNotFoundException as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except InvalidModelException as e:
|
||||
log.error(str(e))
|
||||
raise HTTPException(status_code=415)
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
@ -12,6 +12,7 @@ from picklescan.scanner import scan_file_path
|
||||
from .models import (
|
||||
BaseModelType, ModelType, ModelVariantType,
|
||||
SchedulerPredictionType, SilenceWarnings,
|
||||
InvalidModelException
|
||||
)
|
||||
from .models.base import read_checkpoint_meta
|
||||
|
||||
@ -61,7 +62,7 @@ class ModelProbe(object):
|
||||
elif isinstance(model,(dict,ModelMixin,ConfigMixin)):
|
||||
return cls.probe(model_path=None, model=model, prediction_type_helper=prediction_type_helper)
|
||||
else:
|
||||
raise ValueError("model parameter {model} is neither a Path, nor a model")
|
||||
raise InvalidModelException("model parameter {model} is neither a Path, nor a model")
|
||||
|
||||
@classmethod
|
||||
def probe(cls,
|
||||
@ -141,7 +142,7 @@ class ModelProbe(object):
|
||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||
return ModelType.TextualInversion
|
||||
|
||||
raise ValueError(f"Unable to determine model type for {model_path}")
|
||||
raise InvalidModelException(f"Unable to determine model type for {model_path}")
|
||||
|
||||
@classmethod
|
||||
def get_model_type_from_folder(cls, folder_path: Path, model: ModelMixin)->ModelType:
|
||||
@ -171,7 +172,7 @@ class ModelProbe(object):
|
||||
return type
|
||||
|
||||
# give up
|
||||
raise ValueError(f"Unable to determine model type for {folder_path}")
|
||||
raise InvalidModelException(f"Unable to determine model type for {folder_path}")
|
||||
|
||||
@classmethod
|
||||
def _scan_and_load_checkpoint(cls,model_path: Path)->dict:
|
||||
@ -240,7 +241,7 @@ class CheckpointProbeBase(ProbeBase):
|
||||
elif in_channels == 4:
|
||||
return ModelVariantType.Normal
|
||||
else:
|
||||
raise ValueError(f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}")
|
||||
raise InvalidModelException(f"Cannot determine variant type (in_channels={in_channels}) at {self.checkpoint_path}")
|
||||
|
||||
class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
@ -254,7 +255,7 @@ class PipelineCheckpointProbe(CheckpointProbeBase):
|
||||
# TODO: Verify that this is correct! Need an XL checkpoint file for this.
|
||||
if key_name in state_dict and state_dict[key_name].shape[-1] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
raise ValueError("Cannot determine base type")
|
||||
raise InvalidModelException("Cannot determine base type")
|
||||
|
||||
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
||||
type = self.get_base_type()
|
||||
@ -335,7 +336,7 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif self.checkpoint_path and self.helper:
|
||||
return self.helper(self.checkpoint_path)
|
||||
raise ValueError("Unable to determine base type for {self.checkpoint_path}")
|
||||
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
|
||||
|
||||
########################################################
|
||||
# classes for probing folders
|
||||
@ -371,7 +372,7 @@ class PipelineFolderProbe(FolderProbeBase):
|
||||
elif unet_conf['cross_attention_dim'] == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise ValueError(f'Unknown base model for {self.folder_path}')
|
||||
raise InvalidModelException(f'Unknown base model for {self.folder_path}')
|
||||
|
||||
def get_scheduler_prediction_type(self)->SchedulerPredictionType:
|
||||
if self.model:
|
||||
@ -428,7 +429,7 @@ class ControlNetFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self)->BaseModelType:
|
||||
config_file = self.folder_path / 'config.json'
|
||||
if not config_file.exists():
|
||||
raise ValueError(f"Cannot determine base type for {self.folder_path}")
|
||||
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
|
||||
with open(config_file,'r') as file:
|
||||
config = json.load(file)
|
||||
# no obvious way to distinguish between sd2-base and sd2-768
|
||||
@ -445,7 +446,7 @@ class LoRAFolderProbe(FolderProbeBase):
|
||||
model_file = base_file
|
||||
break
|
||||
if not model_file:
|
||||
raise ValueError('Unknown LoRA format encountered')
|
||||
raise InvalidModelException('Unknown LoRA format encountered')
|
||||
return LoRACheckpointProbe(model_file,None).get_base_type()
|
||||
|
||||
############## register probe classes ######
|
||||
|
Loading…
Reference in New Issue
Block a user