fix import ordering, remove code I reverted that the resync added back

This commit is contained in:
David Burnett
2025-04-29 12:10:57 +01:00
committed by psychedelicious
parent 99e154d773
commit 6c0bd7d150
2 changed files with 4 additions and 17 deletions

View File

@ -1,8 +1,9 @@
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
from typing import Any
import torch
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
class CachedModelOnlyFullLoad:
"""A wrapper around a PyTorch model to handle full loads and unloads between the CPU and the compute device.
@ -78,8 +79,7 @@ class CachedModelOnlyFullLoad:
new_state_dict[k] = v.to(self._compute_device, copy=True)
self._model.load_state_dict(new_state_dict, assign=True)
check_for_gguf = hasattr(self._model, 'state_dict') and self._model.state_dict().get("img_in.weight")
check_for_gguf = hasattr(self._model, "state_dict") and self._model.state_dict().get("img_in.weight")
if isinstance(check_for_gguf, GGMLTensor):
old_value = torch.__future__.get_overwrite_module_params_on_conversion()
torch.__future__.set_overwrite_module_params_on_conversion(True)
@ -103,7 +103,7 @@ class CachedModelOnlyFullLoad:
if self._cpu_state_dict is not None:
self._model.load_state_dict(self._cpu_state_dict, assign=True)
check_for_gguf = hasattr(self._model, 'state_dict') and self._model.state_dict().get("img_in.weight")
check_for_gguf = hasattr(self._model, "state_dict") and self._model.state_dict().get("img_in.weight")
if isinstance(check_for_gguf, GGMLTensor):
old_value = torch.__future__.get_overwrite_module_params_on_conversion()
torch.__future__.set_overwrite_module_params_on_conversion(True)

View File

@ -119,19 +119,6 @@ class GGMLTensor(torch.Tensor):
return self.tensor_shape[dim]
return self.tensor_shape
@overload
def to(self, *args, **kwargs) -> torch.Tensor: ...
def to(self, *args, **kwargs):
for func_arg in args:
if isinstance(func_arg, torch.dtype) and func_arg != self.quantized_data.dtype:
raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.")
if "dtype" in kwargs.keys():
if kwargs["dtype"] != self.quantized_data.dtype:
raise ValueError("Operation changed the dtype of GGMLTensor unexpectedly.")
self.quantized_data = self.quantized_data.to(*args, **kwargs)
return self
@property
def shape(self) -> torch.Size: # pyright: ignore[reportIncompatibleVariableOverride] pyright doesn't understand this for some reason.
"""The shape of the tensor after dequantization. I.e. the shape that will be used in any math ops."""