revert to overload due to failing tests, use Torch futures instead

This commit is contained in:
David Burnett
2025-04-29 11:03:31 +01:00
committed by psychedelicious
parent 5271fc1cac
commit 86719f2065
2 changed files with 19 additions and 15 deletions

View File

@ -1,3 +1,4 @@
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from typing import Any
import torch
@ -76,7 +77,15 @@ class CachedModelOnlyFullLoad:
for k, v in self._cpu_state_dict.items():
new_state_dict[k] = v.to(self._compute_device, copy=True)
self._model.load_state_dict(new_state_dict, assign=True)
self._model.to(self._compute_device)
check_for_gguf = self._model.state_dict().get("img_in.weight")
if isinstance(check_for_gguf, GGMLTensor):
old_value = torch.__future__.get_overwrite_module_params_on_conversion()
torch.__future__.set_overwrite_module_params_on_conversion(True)
self._model.to(self._compute_device)
torch.__future__.set_overwrite_module_params_on_conversion(old_value)
else:
self._model.to(self._compute_device)
self._is_in_vram = True
return self._total_bytes
@ -92,7 +101,15 @@ class CachedModelOnlyFullLoad:
if self._cpu_state_dict is not None:
self._model.load_state_dict(self._cpu_state_dict, assign=True)
self._model.to(self._offload_device)
check_for_gguf = self._model.state_dict().get("img_in.weight")
if isinstance(check_for_gguf, GGMLTensor):
old_value = torch.__future__.get_overwrite_module_params_on_conversion()
torch.__future__.set_overwrite_module_params_on_conversion(True)
self._model.to(self._compute_device)
torch.__future__.set_overwrite_module_params_on_conversion(old_value)
else:
self._model.to(self._compute_device)
self._is_in_vram = False
return self._total_bytes

View File

@ -119,19 +119,6 @@ class GGMLTensor(torch.Tensor):
return self.tensor_shape[dim]
return self.tensor_shape
@overload
def to(self, *args, **kwargs) -> torch.Tensor: ...
def to(self, *args, **kwargs):
for func_arg in args:
if isinstance(func_arg, torch.dtype) and func_arg != self.quantized_data.dtype:
raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.")
if "dtype" in kwargs.keys():
if kwargs["dtype"] != self.quantized_data.dtype:
raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.")
self.quantized_data = self.quantized_data.to(*args, **kwargs)
return self
@property
def shape(self) -> torch.Size: # pyright: ignore[reportIncompatibleVariableOverride] pyright doesn't understand this for some reason.
"""The shape of the tensor after dequantization. I.e. the shape that will be used in any math ops."""