Remove device and dtype members from LoRAModelRaw, they can too easily get out-of-sync with the underlying layer states.

This commit is contained in:
Ryan Dick 2023-10-31 15:15:03 -04:00 committed by Kent Keirsey
parent 2ba5b44ec4
commit 545c811bf1

View File

@ -440,33 +440,19 @@ class IA3Layer(LoRALayerBase):
class LoRAModelRaw: # (torch.nn.Module):
_name: str
layers: Dict[str, LoRALayer]
_device: torch.device
_dtype: torch.dtype
def __init__(
self,
name: str,
layers: Dict[str, LoRALayer],
device: torch.device,
dtype: torch.dtype,
):
self._name = name
self._device = device or torch.cpu
self._dtype = dtype or torch.float32
self.layers = layers
@property
def name(self):
return self._name
@property
def device(self):
return self._device
@property
def dtype(self):
return self._dtype
def to(
self,
device: Optional[torch.device] = None,
@ -475,8 +461,6 @@ class LoRAModelRaw: # (torch.nn.Module):
# TODO: try revert if exception?
for key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
self._device = device
self._dtype = dtype
def calc_size(self) -> int:
model_size = 0
@ -557,8 +541,6 @@ class LoRAModelRaw: # (torch.nn.Module):
file_path = Path(file_path)
model = cls(
device=device,
dtype=dtype,
name=file_path.stem, # TODO:
layers=dict(),
)