diff --git a/invokeai/backend/util/original_weights_storage.py b/invokeai/backend/util/original_weights_storage.py index 3632c52b09..af945b086f 100644 --- a/invokeai/backend/util/original_weights_storage.py +++ b/invokeai/backend/util/original_weights_storage.py @@ -8,9 +8,13 @@ from invokeai.backend.util.devices import TorchDevice class OriginalWeightsStorage: + """A class for tracking the original weights of a model for patch/unpatch operations.""" + def __init__(self, cached_weights: Optional[Dict[str, torch.Tensor]] = None): - self._weights = {} - self._changed_weights = set() + # The original weights of the model. + self._weights: dict[str, torch.Tensor] = {} + # The keys of the weights that have been changed (via `save()`) during the lifetime of this instance. + self._changed_weights: set[str] = set() if cached_weights: self._weights.update(cached_weights)