Set the dtype correctly for SpandrelImageToImageModels when they are loaded.

This commit is contained in:
Ryan Dick 2024-06-28 15:22:39 -04:00
parent 59ce9cf41c
commit 2a1514272f
2 changed files with 17 additions and 0 deletions

View File

@ -1,6 +1,8 @@
from pathlib import Path
from typing import Optional
import torch
from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
@ -31,4 +33,13 @@ class SpandrelImageToImageModelLoader(ModelLoader):
model_path = Path(config.path)
model = SpandrelImageToImageModel.load_from_file(model_path)
torch_dtype = self._torch_dtype
if not model.supports_dtype(torch_dtype):
self._logger.warning(
f"The configured dtype ('{self._torch_dtype}') is not supported by the {model.get_model_type_name()} "
"model. Falling back to 'float32'."
)
torch_dtype = torch.float32
model.to(dtype=torch_dtype)
return model

View File

@ -50,6 +50,12 @@ class SpandrelImageToImageModel(RawModel):
else:
raise ValueError(f"Unexpected dtype '{dtype}'.")
def get_model_type_name(self) -> str:
"""The model type name. Intended for logging / debugging purposes. Do not rely on this field remaining
consistent over time.
"""
return str(type(self._spandrel_model.model))
def to(
self,
device: Optional[torch.device] = None,