mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[mm] clear the cache entry for a model that got an OOM during loading
This commit is contained in:
parent
7bc77ddb40
commit
579082ac10
@ -24,9 +24,9 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
||||
|
||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||
mask = torch.zeros((1, self.height, self.width), dtype=torch.bool)
|
||||
mask[:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width] = (
|
||||
True
|
||||
)
|
||||
mask[
|
||||
:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width
|
||||
] = True
|
||||
|
||||
mask_tensor_name = context.tensors.save(mask)
|
||||
return MaskOutput(
|
||||
|
@ -17,7 +17,8 @@ class MigrateCallback(Protocol):
|
||||
See :class:`Migration` for an example.
|
||||
"""
|
||||
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
|
||||
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||
...
|
||||
|
||||
|
||||
class MigrationError(RuntimeError):
|
||||
|
@ -271,7 +271,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
|
||||
start_model_to_time = time.time()
|
||||
snapshot_before = self._capture_memory_snapshot()
|
||||
cache_entry.model.to(target_device)
|
||||
try:
|
||||
cache_entry.model.to(target_device)
|
||||
except torch.cuda.OutOfMemoryError as e: # blow away cache entry
|
||||
self._delete_cache_entry(cache_entry)
|
||||
raise e
|
||||
|
||||
snapshot_after = self._capture_memory_snapshot()
|
||||
end_model_to_time = time.time()
|
||||
self.logger.debug(
|
||||
@ -389,8 +394,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
)
|
||||
current_size -= cache_entry.size
|
||||
models_cleared += 1
|
||||
del self._cache_stack[pos]
|
||||
del self._cached_models[model_key]
|
||||
self._delete_cache_entry(cache_entry)
|
||||
del cache_entry
|
||||
|
||||
else:
|
||||
@ -417,3 +421,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
||||
mps.empty_cache()
|
||||
|
||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||
|
||||
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||
self._cache_stack.remove(cache_entry.key)
|
||||
del self._cached_models[cache_entry.key]
|
||||
|
Loading…
Reference in New Issue
Block a user