Add some docs to OriginalWeightsStorage and fix type hints.

This commit is contained in:
Ryan Dick 2024-07-31 15:08:24 -04:00
parent 86f705bf48
commit 0bb7ed44f6

View File

@ -8,9 +8,13 @@ from invokeai.backend.util.devices import TorchDevice
class OriginalWeightsStorage:
"""A class for tracking the original weights of a model for patch/unpatch operations."""
def __init__(self, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
self._weights = {}
self._changed_weights = set()
# The original weights of the model.
self._weights: dict[str, torch.Tensor] = {}
# The keys of the weights that have been changed (via `save()`) during the lifetime of this instance.
self._changed_weights: set[str] = set()
if cached_weights:
self._weights.update(cached_weights)