mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Set the dtype correctly for SpandrelImageToImageModels when they are loaded.
This commit is contained in:
parent
59ce9cf41c
commit
2a1514272f
@ -1,6 +1,8 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
@ -31,4 +33,13 @@ class SpandrelImageToImageModelLoader(ModelLoader):
|
||||
model_path = Path(config.path)
|
||||
model = SpandrelImageToImageModel.load_from_file(model_path)
|
||||
|
||||
torch_dtype = self._torch_dtype
|
||||
if not model.supports_dtype(torch_dtype):
|
||||
self._logger.warning(
|
||||
f"The configured dtype ('{self._torch_dtype}') is not supported by the {model.get_model_type_name()} "
|
||||
"model. Falling back to 'float32'."
|
||||
)
|
||||
torch_dtype = torch.float32
|
||||
model.to(dtype=torch_dtype)
|
||||
|
||||
return model
|
||||
|
@ -50,6 +50,12 @@ class SpandrelImageToImageModel(RawModel):
|
||||
else:
|
||||
raise ValueError(f"Unexpected dtype '{dtype}'.")
|
||||
|
||||
def get_model_type_name(self) -> str:
|
||||
"""The model type name. Intended for logging / debugging purposes. Do not rely on this field remaining
|
||||
consistent over time.
|
||||
"""
|
||||
return str(type(self._spandrel_model.model))
|
||||
|
||||
def to(
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
|
Loading…
Reference in New Issue
Block a user