mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Use a ModelIdentifierField to identify the spandrel model in the UpscaleSpandrelInvocation.
This commit is contained in:
@ -16,6 +16,10 @@ class SpandrelImageToImageModel(RawModel):
|
||||
def __init__(self, spandrel_model: ImageModelDescriptor[Any]):
|
||||
self._spandrel_model = spandrel_model
|
||||
|
||||
def run(self, image_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Run the image-to-image model."""
|
||||
return self._spandrel_model(image_tensor)
|
||||
|
||||
@classmethod
|
||||
def load_from_file(cls, file_path: str | Path):
|
||||
model = ModelLoader().load_from_file(file_path)
|
||||
@ -67,3 +71,13 @@ class SpandrelImageToImageModel(RawModel):
|
||||
# TODO(ryand): spandrel.ImageModelDescriptor.to(...) does not support non_blocking. We will access the model
|
||||
# directly if we want to apply this optimization.
|
||||
self._spandrel_model.to(device=device, dtype=dtype)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
"""The device of the underlying model."""
|
||||
return self._spandrel_model.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
"""The dtype of the underlying model."""
|
||||
return self._spandrel_model.dtype
|
||||
|
Reference in New Issue
Block a user