mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
First draft
This commit is contained in:
parent
704151e8e3
commit
2634f0e43a
@ -210,6 +210,31 @@ class ModelCache(object):
|
||||
|
||||
return self.ModelLocker(self, key, cache_entry.model, gpu_load, cache_entry.size)
|
||||
|
||||
def clear_one_model(self) -> bool:
|
||||
reserved = self.max_vram_cache_size * GIG
|
||||
vram_in_use = torch.cuda.memory_allocated()
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
smallest_key = None
|
||||
smallest_size = float("inf")
|
||||
for model_key, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
|
||||
if not cache_entry.locked and cache_entry.loaded:
|
||||
if cache_entry.size > 0 and cache_entry.size < smallest_size:
|
||||
smallest_key = model_key
|
||||
smallest_size = cache_entry.size
|
||||
|
||||
if smallest_key is not None:
|
||||
cache_entry = self._cached_models[smallest_key]
|
||||
self.logger.debug(f"!!!!!!!!!!!Offloading {smallest_key} from {self.execution_device} into {self.storage_device}")
|
||||
with VRAMUsage() as mem:
|
||||
cache_entry.model.to(self.storage_device)
|
||||
self.logger.debug(f"GPU VRAM freed: {(mem.vram_used/GIG):.2f} GB")
|
||||
vram_in_use += mem.vram_used # note vram_used is negative
|
||||
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM used for models; max allowed={(reserved/GIG):.2f}GB")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
return smallest_key is not None
|
||||
|
||||
class ModelLocker(object):
|
||||
def __init__(self, cache, key, model, gpu_load, size_needed):
|
||||
"""
|
||||
@ -236,17 +261,48 @@ class ModelCache(object):
|
||||
self.cache_entry.lock()
|
||||
|
||||
try:
|
||||
if self.cache.lazy_offloading:
|
||||
self.cache._offload_unlocked_models(self.size_needed)
|
||||
self.cache.logger.debug(f"Moving {self.key} into {self.cache.execution_device}")
|
||||
while True:
|
||||
try:
|
||||
with VRAMUsage() as mem:
|
||||
self.model.to(self.cache.execution_device) # move into GPU
|
||||
self.cache.logger.debug(f"GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB")
|
||||
|
||||
if self.model.device != self.cache.execution_device:
|
||||
self.cache.logger.debug(f"Moving {self.key} into {self.cache.execution_device}")
|
||||
with VRAMUsage() as mem:
|
||||
self.model.to(self.cache.execution_device) # move into GPU
|
||||
self.cache.logger.debug(f"GPU VRAM used for load: {(mem.vram_used/GIG):.2f} GB")
|
||||
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
|
||||
self.cache._print_cuda_stats()
|
||||
|
||||
self.cache.logger.debug(f"Locking {self.key} in {self.cache.execution_device}")
|
||||
self.cache._print_cuda_stats()
|
||||
def my_forward(module, cache, *args, **kwargs):
|
||||
while True:
|
||||
try:
|
||||
return module._orig_forward(*args, **kwargs)
|
||||
except:
|
||||
if not cache.clear_one_model():
|
||||
raise
|
||||
|
||||
import functools
|
||||
from diffusers.models.unet_2d_blocks import DownBlock2D, CrossAttnDownBlock2D, UpBlock2D, CrossAttnUpBlock2D
|
||||
from transformers.models.clip.modeling_clip import CLIPEncoderLayer
|
||||
from diffusers.models.unet_2d_blocks import DownEncoderBlock2D, UpDecoderBlock2D
|
||||
|
||||
for module_name, module in self.model.named_modules():
|
||||
if type(module) not in [
|
||||
DownBlock2D, CrossAttnDownBlock2D, UpBlock2D, CrossAttnUpBlock2D, # unet blocks
|
||||
CLIPEncoderLayer, # CLIPTextTransformer clip
|
||||
DownEncoderBlock2D, UpDecoderBlock2D, # vae
|
||||
]:
|
||||
continue
|
||||
# better here filter to only specific model modules
|
||||
module._orig_forward = module.forward
|
||||
module.forward = functools.partial(my_forward, module, self.cache)
|
||||
|
||||
self.model._orig_forward = self.model.forward
|
||||
self.model.forward = functools.partial(my_forward, self.model, self.cache)
|
||||
|
||||
break
|
||||
|
||||
except:
|
||||
if not self.cache.clear_one_model():
|
||||
raise
|
||||
|
||||
except:
|
||||
self.cache_entry.unlock()
|
||||
@ -264,10 +320,19 @@ class ModelCache(object):
|
||||
if not hasattr(self.model, "to"):
|
||||
return
|
||||
|
||||
if hasattr(self.model, "_orig_forward"):
|
||||
self.model.forward = self.model._orig_forward
|
||||
delattr(self.model, "_orig_forward")
|
||||
|
||||
for module_name, module in self.model.named_modules():
|
||||
if hasattr(module, "_orig_forward"):
|
||||
module.forward = module._orig_forward
|
||||
delattr(module, "_orig_forward")
|
||||
|
||||
self.cache_entry.unlock()
|
||||
if not self.cache.lazy_offloading:
|
||||
self.cache._offload_unlocked_models()
|
||||
self.cache._print_cuda_stats()
|
||||
#if not self.cache.lazy_offloading:
|
||||
# self.cache._offload_unlocked_models()
|
||||
# self.cache._print_cuda_stats()
|
||||
|
||||
# TODO: should it be called untrack_model?
|
||||
def uncache_model(self, cache_id: str):
|
||||
|
Loading…
Reference in New Issue
Block a user