mirror of
https://github.com/invoke-ai/InvokeAI
synced 2025-07-26 05:17:55 +00:00
revert to overload due to failing tests, use Torch futures instead
This commit is contained in:
committed by
psychedelicious
parent
5271fc1cac
commit
86719f2065
@ -1,3 +1,4 @@
|
||||
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@ -76,7 +77,15 @@ class CachedModelOnlyFullLoad:
|
||||
for k, v in self._cpu_state_dict.items():
|
||||
new_state_dict[k] = v.to(self._compute_device, copy=True)
|
||||
self._model.load_state_dict(new_state_dict, assign=True)
|
||||
self._model.to(self._compute_device)
|
||||
|
||||
check_for_gguf = self._model.state_dict().get("img_in.weight")
|
||||
if isinstance(check_for_gguf, GGMLTensor):
|
||||
old_value = torch.__future__.get_overwrite_module_params_on_conversion()
|
||||
torch.__future__.set_overwrite_module_params_on_conversion(True)
|
||||
self._model.to(self._compute_device)
|
||||
torch.__future__.set_overwrite_module_params_on_conversion(old_value)
|
||||
else:
|
||||
self._model.to(self._compute_device)
|
||||
|
||||
self._is_in_vram = True
|
||||
return self._total_bytes
|
||||
@ -92,7 +101,15 @@ class CachedModelOnlyFullLoad:
|
||||
|
||||
if self._cpu_state_dict is not None:
|
||||
self._model.load_state_dict(self._cpu_state_dict, assign=True)
|
||||
self._model.to(self._offload_device)
|
||||
|
||||
check_for_gguf = self._model.state_dict().get("img_in.weight")
|
||||
if isinstance(check_for_gguf, GGMLTensor):
|
||||
old_value = torch.__future__.get_overwrite_module_params_on_conversion()
|
||||
torch.__future__.set_overwrite_module_params_on_conversion(True)
|
||||
self._model.to(self._compute_device)
|
||||
torch.__future__.set_overwrite_module_params_on_conversion(old_value)
|
||||
else:
|
||||
self._model.to(self._compute_device)
|
||||
|
||||
self._is_in_vram = False
|
||||
return self._total_bytes
|
||||
|
@ -119,19 +119,6 @@ class GGMLTensor(torch.Tensor):
|
||||
return self.tensor_shape[dim]
|
||||
return self.tensor_shape
|
||||
|
||||
@overload
|
||||
def to(self, *args, **kwargs) -> torch.Tensor: ...
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
for func_arg in args:
|
||||
if isinstance(func_arg, torch.dtype) and func_arg != self.quantized_data.dtype:
|
||||
raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.")
|
||||
if "dtype" in kwargs.keys():
|
||||
if kwargs["dtype"] != self.quantized_data.dtype:
|
||||
raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.")
|
||||
self.quantized_data = self.quantized_data.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
@property
|
||||
def shape(self) -> torch.Size: # pyright: ignore[reportIncompatibleVariableOverride] pyright doesn't understand this for some reason.
|
||||
"""The shape of the tensor after dequantization. I.e. the shape that will be used in any math ops."""
|
||||
|
Reference in New Issue
Block a user