add check for state_dict, required to load TI's

This commit is contained in:
David Burnett
2025-04-29 11:43:28 +01:00
committed by psychedelicious
parent 73ab4b8895
commit 8abcc99ced

View File

@ -78,7 +78,8 @@ class CachedModelOnlyFullLoad:
new_state_dict[k] = v.to(self._compute_device, copy=True)
self._model.load_state_dict(new_state_dict, assign=True)
check_for_gguf = self._model.state_dict().get("img_in.weight")
check_for_gguf = hasattr(self._model, 'state_dict') and self._model.state_dict().get("img_in.weight")
if isinstance(check_for_gguf, GGMLTensor):
old_value = torch.__future__.get_overwrite_module_params_on_conversion()
torch.__future__.set_overwrite_module_params_on_conversion(True)
@ -102,7 +103,7 @@ class CachedModelOnlyFullLoad:
if self._cpu_state_dict is not None:
self._model.load_state_dict(self._cpu_state_dict, assign=True)
check_for_gguf = self._model.state_dict().get("img_in.weight")
check_for_gguf = hasattr(self._model, 'state_dict') and self._model.state_dict().get("img_in.weight")
if isinstance(check_for_gguf, GGMLTensor):
old_value = torch.__future__.get_overwrite_module_params_on_conversion()
torch.__future__.set_overwrite_module_params_on_conversion(True)