mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Set the dtype correctly for SpandrelImageToImageModels when they are loaded.
This commit is contained in:
parent
59ce9cf41c
commit
2a1514272f
@ -1,6 +1,8 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
AnyModel,
|
AnyModel,
|
||||||
AnyModelConfig,
|
AnyModelConfig,
|
||||||
@ -31,4 +33,13 @@ class SpandrelImageToImageModelLoader(ModelLoader):
|
|||||||
model_path = Path(config.path)
|
model_path = Path(config.path)
|
||||||
model = SpandrelImageToImageModel.load_from_file(model_path)
|
model = SpandrelImageToImageModel.load_from_file(model_path)
|
||||||
|
|
||||||
|
torch_dtype = self._torch_dtype
|
||||||
|
if not model.supports_dtype(torch_dtype):
|
||||||
|
self._logger.warning(
|
||||||
|
f"The configured dtype ('{self._torch_dtype}') is not supported by the {model.get_model_type_name()} "
|
||||||
|
"model. Falling back to 'float32'."
|
||||||
|
)
|
||||||
|
torch_dtype = torch.float32
|
||||||
|
model.to(dtype=torch_dtype)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -50,6 +50,12 @@ class SpandrelImageToImageModel(RawModel):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected dtype '{dtype}'.")
|
raise ValueError(f"Unexpected dtype '{dtype}'.")
|
||||||
|
|
||||||
|
def get_model_type_name(self) -> str:
|
||||||
|
"""The model type name. Intended for logging / debugging purposes. Do not rely on this field remaining
|
||||||
|
consistent over time.
|
||||||
|
"""
|
||||||
|
return str(type(self._spandrel_model.model))
|
||||||
|
|
||||||
def to(
|
def to(
|
||||||
self,
|
self,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
|
Loading…
Reference in New Issue
Block a user