diff --git a/invokeai/app/invocations/mask.py b/invokeai/app/invocations/mask.py index a7f3207764..acacaedaed 100644 --- a/invokeai/app/invocations/mask.py +++ b/invokeai/app/invocations/mask.py @@ -24,9 +24,9 @@ class RectangleMaskInvocation(BaseInvocation, WithMetadata): def invoke(self, context: InvocationContext) -> MaskOutput: mask = torch.zeros((1, self.height, self.width), dtype=torch.bool) - mask[:, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width] = ( - True - ) + mask[ + :, self.y_top : self.y_top + self.rectangle_height, self.x_left : self.x_left + self.rectangle_width + ] = True mask_tensor_name = context.tensors.save(mask) return MaskOutput( diff --git a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py index 9b2444dae4..47ed5da505 100644 --- a/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py +++ b/invokeai/app/services/shared/sqlite_migrator/sqlite_migrator_common.py @@ -17,7 +17,8 @@ class MigrateCallback(Protocol): See :class:`Migration` for an example. """ - def __call__(self, cursor: sqlite3.Cursor) -> None: ... + def __call__(self, cursor: sqlite3.Cursor) -> None: + ... class MigrationError(RuntimeError): diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index f2e0c01a94..e48be7c008 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -271,7 +271,12 @@ class ModelCache(ModelCacheBase[AnyModel]): start_model_to_time = time.time() snapshot_before = self._capture_memory_snapshot() - cache_entry.model.to(target_device) + try: + cache_entry.model.to(target_device) + except torch.cuda.OutOfMemoryError as e: # blow away cache entry + self._delete_cache_entry(cache_entry) + raise e + snapshot_after = self._capture_memory_snapshot() end_model_to_time = time.time() self.logger.debug( @@ -389,8 +394,7 @@ class ModelCache(ModelCacheBase[AnyModel]): ) current_size -= cache_entry.size models_cleared += 1 - del self._cache_stack[pos] - del self._cached_models[model_key] + self._delete_cache_entry(cache_entry) del cache_entry else: @@ -417,3 +421,7 @@ class ModelCache(ModelCacheBase[AnyModel]): mps.empty_cache() self.logger.debug(f"After making room: cached_models={len(self._cached_models)}") + + def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None: + self._cache_stack.remove(cache_entry.key) + del self._cached_models[cache_entry.key]