mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Remove device and dtype members from LoRAModelRaw, they can too easily get out-of-sync with the underlying layer states.
This commit is contained in:
parent
2ba5b44ec4
commit
545c811bf1
@ -440,33 +440,19 @@ class IA3Layer(LoRALayerBase):
|
|||||||
class LoRAModelRaw: # (torch.nn.Module):
|
class LoRAModelRaw: # (torch.nn.Module):
|
||||||
_name: str
|
_name: str
|
||||||
layers: Dict[str, LoRALayer]
|
layers: Dict[str, LoRALayer]
|
||||||
_device: torch.device
|
|
||||||
_dtype: torch.dtype
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
layers: Dict[str, LoRALayer],
|
layers: Dict[str, LoRALayer],
|
||||||
device: torch.device,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
):
|
):
|
||||||
self._name = name
|
self._name = name
|
||||||
self._device = device or torch.cpu
|
|
||||||
self._dtype = dtype or torch.float32
|
|
||||||
self.layers = layers
|
self.layers = layers
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self):
|
def name(self):
|
||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
return self._device
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self._dtype
|
|
||||||
|
|
||||||
def to(
|
def to(
|
||||||
self,
|
self,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
@ -475,8 +461,6 @@ class LoRAModelRaw: # (torch.nn.Module):
|
|||||||
# TODO: try revert if exception?
|
# TODO: try revert if exception?
|
||||||
for key, layer in self.layers.items():
|
for key, layer in self.layers.items():
|
||||||
layer.to(device=device, dtype=dtype)
|
layer.to(device=device, dtype=dtype)
|
||||||
self._device = device
|
|
||||||
self._dtype = dtype
|
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
model_size = 0
|
model_size = 0
|
||||||
@ -557,8 +541,6 @@ class LoRAModelRaw: # (torch.nn.Module):
|
|||||||
file_path = Path(file_path)
|
file_path = Path(file_path)
|
||||||
|
|
||||||
model = cls(
|
model = cls(
|
||||||
device=device,
|
|
||||||
dtype=dtype,
|
|
||||||
name=file_path.stem, # TODO:
|
name=file_path.stem, # TODO:
|
||||||
layers=dict(),
|
layers=dict(),
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user