mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
[mm] clear the cache entry for a model that got an OOM during loading
This commit is contained in:
parent
7bc77ddb40
commit
579082ac10
@ -24,9 +24,9 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata):
|
|||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> MaskOutput:
|
def invoke(self, context: InvocationContext) -> MaskOutput:
|
||||||
mask = torch.zeros((1, self.height, self.width), dtype=torch.bool)
|
mask = torch.zeros((1, self.height, self.width), dtype=torch.bool)
|
||||||
mask[:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width] = (
|
mask[
|
||||||
True
|
:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width
|
||||||
)
|
] = True
|
||||||
|
|
||||||
mask_tensor_name = context.tensors.save(mask)
|
mask_tensor_name = context.tensors.save(mask)
|
||||||
return MaskOutput(
|
return MaskOutput(
|
||||||
|
@ -17,7 +17,8 @@ class MigrateCallback(Protocol):
|
|||||||
See :class:`Migration` for an example.
|
See :class:`Migration` for an example.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __call__(self, cursor: sqlite3.Cursor) -> None: ...
|
def __call__(self, cursor: sqlite3.Cursor) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class MigrationError(RuntimeError):
|
class MigrationError(RuntimeError):
|
||||||
|
@ -271,7 +271,12 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
|
|
||||||
start_model_to_time = time.time()
|
start_model_to_time = time.time()
|
||||||
snapshot_before = self._capture_memory_snapshot()
|
snapshot_before = self._capture_memory_snapshot()
|
||||||
cache_entry.model.to(target_device)
|
try:
|
||||||
|
cache_entry.model.to(target_device)
|
||||||
|
except torch.cuda.OutOfMemoryError as e: # blow away cache entry
|
||||||
|
self._delete_cache_entry(cache_entry)
|
||||||
|
raise e
|
||||||
|
|
||||||
snapshot_after = self._capture_memory_snapshot()
|
snapshot_after = self._capture_memory_snapshot()
|
||||||
end_model_to_time = time.time()
|
end_model_to_time = time.time()
|
||||||
self.logger.debug(
|
self.logger.debug(
|
||||||
@ -389,8 +394,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
)
|
)
|
||||||
current_size -= cache_entry.size
|
current_size -= cache_entry.size
|
||||||
models_cleared += 1
|
models_cleared += 1
|
||||||
del self._cache_stack[pos]
|
self._delete_cache_entry(cache_entry)
|
||||||
del self._cached_models[model_key]
|
|
||||||
del cache_entry
|
del cache_entry
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -417,3 +421,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
mps.empty_cache()
|
mps.empty_cache()
|
||||||
|
|
||||||
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")
|
||||||
|
|
||||||
|
def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None:
|
||||||
|
self._cache_stack.remove(cache_entry.key)
|
||||||
|
del self._cached_models[cache_entry.key]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user