From 1f6699ac431b6bdbc8f2154bc0287c4cf51786c0 Mon Sep 17 00:00:00 2001
From: Ryan Dick <ryanjdick3@gmail.com>
Date: Fri, 29 Sep 2023 10:28:12 -0400
Subject: [PATCH] Consolidate all model.to(...) calls in the model cache to use
 a utility function with better logging.

---
 .../backend/model_management/model_cache.py   | 95 ++++++++-----------
 1 file changed, 42 insertions(+), 53 deletions(-)

diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py
index bcd61a77f8..b4db4cf5ea 100644
--- a/invokeai/backend/model_management/model_cache.py
+++ b/invokeai/backend/model_management/model_cache.py
@@ -265,6 +265,44 @@ class ModelCache(object):
 
         return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
 
+    def _move_model_to_device(self, key, target_device):
+        cache_entry = self._cached_models[key]
+
+        source_device = cache_entry.model.device
+        if source_device == target_device:
+            return
+
+        start_model_to_time = time.time()
+        snapshot_before = MemorySnapshot.capture()
+        cache_entry.model.to(target_device)
+        snapshot_after = MemorySnapshot.capture()
+        end_model_to_time = time.time()
+        self.logger.debug(
+            f"Moved model '{key}' from {source_device} to"
+            f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
+            f" Estimated model size: {(cache_entry.size/GIG):.2f} GB."
+            f" {get_pretty_snapshot_diff(snapshot_before, snapshot_after)}."
+        )
+
+        # If the estimated model size does not match the change in VRAM, log a warning.
+        if (
+            snapshot_before.vram is not None
+            and snapshot_after.vram is not None
+            and not math.isclose(
+                abs(snapshot_before.vram - snapshot_after.vram),
+                cache_entry.size,
+                rel_tol=0.1,
+                abs_tol=10 * MB,
+            )
+        ):
+            self.logger.warning(
+                f"Moving model '{key}' from {source_device} to"
+                f" {target_device} caused an unexpected change in VRAM usage. The model's"
+                " estimated size may be incorrect. Estimated model size:"
+                f" {(cache_entry.size/GIG):.2f} GB."
+                f" {get_pretty_snapshot_diff(snapshot_before, snapshot_after)}."
+            )
+
     class ModelLocker(object):
         def __init__(self, cache, key, model, gpu_load, size_needed):
             """
@@ -294,32 +332,7 @@ class ModelCache(object):
                     if self.cache.lazy_offloading:
                         self.cache._offload_unlocked_models(self.size_needed)
 
-                    if self.model.device != self.cache.execution_device:
-                        start_model_to_time = time.time()
-                        snapshot_before = MemorySnapshot.capture()
-                        self.model.to(self.cache.execution_device)  # move into GPU
-                        snapshot_after = MemorySnapshot.capture()
-                        end_model_to_time = time.time()
-                        self.cache.logger.debug(
-                            f"Moved model '{self.key}' from {self.cache.storage_device} to"
-                            f" {self.cache.execution_device} in {(end_model_to_time-start_model_to_time):.2f}s."
-                            f" Estimated model size: {(self.cache_entry.size/GIG):.2f} GB."
-                            f" {get_pretty_snapshot_diff(snapshot_before, snapshot_after)}."
-                        )
-
-                        if not math.isclose(
-                            abs((snapshot_before.vram or 0) - (snapshot_after.vram or 0)),
-                            self.cache_entry.size,
-                            rel_tol=0.1,
-                            abs_tol=10 * MB,
-                        ):
-                            self.cache.logger.warning(
-                                f"Moving '{self.key}' from {self.cache.storage_device} to"
-                                f" {self.cache.execution_device} caused an unexpected change in VRAM usage. The model's"
-                                " estimated size may be incorrect. Estimated model size:"
-                                f" {(self.cache_entry.size/GIG):.2f} GB."
-                                f" {get_pretty_snapshot_diff(snapshot_before, snapshot_after)}."
-                            )
+                    self.cache._move_model_to_device(self.key, self.cache.execution_device)
 
                     self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
                     self.cache._print_cuda_stats()
@@ -332,7 +345,7 @@ class ModelCache(object):
             # in the event that the caller wants the model in RAM, we
             # move it into CPU if it is in GPU and not locked
             elif self.cache_entry.loaded and not self.cache_entry.locked:
-                self.model.to(self.cache.storage_device)
+                self.cache._move_model_to_device(self.key, self.cache.storage_device)
 
             return self.model
 
@@ -472,33 +485,9 @@ class ModelCache(object):
             if vram_in_use <= reserved:
                 break
             if not cache_entry.locked and cache_entry.loaded:
-                start_model_to_time = time.time()
-                snapshot_before = MemorySnapshot.capture()
-                cache_entry.model.to(self.storage_device)
-                snapshot_after = MemorySnapshot.capture()
-                end_model_to_time = time.time()
-                self.logger.debug(
-                    f"Moved model '{model_key}' from {self.execution_device} to {self.storage_device} in"
-                    f" {(end_model_to_time-start_model_to_time):.2f}s. Estimated model size:"
-                    f" {(cache_entry.size/GIG):.2f} GB."
-                    f" {get_pretty_snapshot_diff(snapshot_before, snapshot_after)}."
-                )
+                self._move_model_to_device(model_key, self.storage_device)
 
-                if not math.isclose(
-                    abs((snapshot_before.vram or 0) - (snapshot_after.vram or 0)),
-                    cache_entry.size,
-                    rel_tol=0.1,
-                    abs_tol=10 * MB,
-                ):
-                    self.logger.warning(
-                        f"Moving '{model_key}' from {self.execution_device} to"
-                        f" {self.storage_device} caused an unexpected change in VRAM usage. The model's"
-                        " estimated size may be incorrect. Estimated model size:"
-                        f" {(cache_entry.size/GIG):.2f} GB."
-                        f" {get_pretty_snapshot_diff(snapshot_before, snapshot_after)}."
-                    )
-
-                vram_in_use = snapshot_after.vram or 0
+                vram_in_use = torch.cuda.memory_allocated()
                 self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
 
         gc.collect()