diff --git a/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py b/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py index 4241c21d24..7a57c5cf59 100644 --- a/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py +++ b/invokeai/backend/model_manager/load/model_loaders/spandrel_image_to_image.py @@ -1,6 +1,8 @@ from pathlib import Path from typing import Optional +import torch + from invokeai.backend.model_manager.config import ( AnyModel, AnyModelConfig, @@ -31,4 +33,13 @@ class SpandrelImageToImageModelLoader(ModelLoader): model_path = Path(config.path) model = SpandrelImageToImageModel.load_from_file(model_path) + torch_dtype = self._torch_dtype + if not model.supports_dtype(torch_dtype): + self._logger.warning( + f"The configured dtype ('{self._torch_dtype}') is not supported by the {model.get_model_type_name()} " + "model. Falling back to 'float32'." + ) + torch_dtype = torch.float32 + model.to(dtype=torch_dtype) + return model diff --git a/invokeai/backend/spandrel_image_to_image_model.py b/invokeai/backend/spandrel_image_to_image_model.py index 270f521604..6413ebba6b 100644 --- a/invokeai/backend/spandrel_image_to_image_model.py +++ b/invokeai/backend/spandrel_image_to_image_model.py @@ -50,6 +50,12 @@ class SpandrelImageToImageModel(RawModel): else: raise ValueError(f"Unexpected dtype '{dtype}'.") + def get_model_type_name(self) -> str: + """The model type name. Intended for logging / debugging purposes. Do not rely on this field remaining + consistent over time. + """ + return str(type(self._spandrel_model.model)) + def to( self, device: Optional[torch.device] = None,