diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 5c3f1c6e8f..14713eb964 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -79,6 +79,7 @@ class SubModelType(str, Enum): Tokenizer = "tokenizer" Tokenizer2 = "tokenizer_2" Tokenizer3 = "tokenizer_3" + Transformer = "transformer" VAE = "vae" VAEDecoder = "vae_decoder" VAEEncoder = "vae_encoder" diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index 012fd42d55..bdddba86ac 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -73,6 +73,7 @@ class CacheRecord(Generic[T]): device: torch.device state_dict: Optional[Dict[str, torch.Tensor]] size: int + is_quantized: bool = False loaded: bool = False _locks: int = 0 diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py index 335a15a5c8..a071570c22 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_default.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_default.py @@ -60,9 +60,7 @@ class ModelCache(ModelCacheBase[AnyModel]): execution_device: torch.device = torch.device("cuda"), storage_device: torch.device = torch.device("cpu"), precision: torch.dtype = torch.float16, - sequential_offload: bool = False, lazy_offloading: bool = True, - sha_chunksize: int = 16777216, log_memory_usage: bool = False, logger: Optional[Logger] = None, ): @@ -74,7 +72,6 @@ class ModelCache(ModelCacheBase[AnyModel]): :param storage_device: Torch device to save inactive model in [torch.device('cpu')] :param precision: Precision for loaded models [torch.float16] :param lazy_offloading: Keep model in VRAM until another model needs to be loaded - :param sequential_offload: Conserve VRAM by loading and unloading each stage of the pipeline sequentially :param log_memory_usage: If True, a memory snapshot will be captured before and after every model cache operation, and the result will be logged (at debug level). There is a time cost to capturing the memory snapshots, so it is recommended to disable this feature unless you are actively inspecting the model cache's @@ -163,8 +160,16 @@ class ModelCache(ModelCacheBase[AnyModel]): size = calc_model_size_by_data(model) self.make_room(size) + is_quantized = hasattr(model, "is_quantized") and model.is_quantized state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None - cache_record = CacheRecord(key=key, model=model, device=self.storage_device, state_dict=state_dict, size=size) + cache_record = CacheRecord( + key=key, + model=model, + device=self._storage_device, + is_quantized=is_quantized, + state_dict=state_dict, + size=size, + ) self._cached_models[key] = cache_record self._cache_stack.append(key) @@ -230,19 +235,26 @@ class ModelCache(ModelCacheBase[AnyModel]): reserved = self._max_vram_cache_size * GIG vram_in_use = torch.cuda.memory_allocated() + size_required self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM needed for models; max allowed={(reserved/GIG):.2f}GB") + delete_it = False for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size): if vram_in_use <= reserved: break if not cache_entry.loaded: continue if not cache_entry.locked: - self.move_model_to_device(cache_entry, self.storage_device) - cache_entry.loaded = False + if cache_entry.is_quantized: + self._delete_cache_entry(cache_entry) + delete_it = True + else: + self.move_model_to_device(cache_entry, self.storage_device) + cache_entry.loaded = False vram_in_use = torch.cuda.memory_allocated() + size_required self.logger.debug( f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB" ) + if delete_it: + del cache_entry TorchDevice.empty_cache() def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None: @@ -265,6 +277,9 @@ class ModelCache(ModelCacheBase[AnyModel]): if not hasattr(cache_entry.model, "to"): return + if cache_entry.is_quantized: # can't move quantized models around + return + # This roundabout method for moving the model around is done to avoid # the cost of moving the model from RAM to VRAM and then back from VRAM to RAM. # When moving to VRAM, we copy (not move) each element of the state dict from @@ -407,3 +422,5 @@ class ModelCache(ModelCacheBase[AnyModel]): def _delete_cache_entry(self, cache_entry: CacheRecord[AnyModel]) -> None: self._cache_stack.remove(cache_entry.key) del self._cached_models[cache_entry.key] + gc.collect() + TorchDevice.empty_cache() diff --git a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py index 3ca7a5b2e4..5e0cb508cf 100644 --- a/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py +++ b/invokeai/backend/model_manager/load/model_loaders/stable_diffusion.py @@ -36,9 +36,11 @@ VARIANT_TO_IN_CHANNEL_MAP = { class StableDiffusionDiffusersModel(GenericDiffusersLoader): """Class to load main models.""" + # note - will be removed for load_single_file() model_base_to_model_type = { BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder", BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder", + BaseModelType.StableDiffusion3: "SD3", # non-functional, for completeness only BaseModelType.StableDiffusionXL: "SDXL", BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner", } @@ -65,7 +67,10 @@ class StableDiffusionDiffusersModel(GenericDiffusersLoader): if variant and "no file named" in str( e ): # try without the variant, just in case user's preferences changed - result = load_class.from_pretrained(model_path, torch_dtype=self._torch_dtype) + result = load_class.from_pretrained( + model_path, + torch_dtype=self._torch_dtype, + ) else: raise e diff --git a/pyproject.toml b/pyproject.toml index bf983a0c8b..d60ffbee86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ classifiers = [ dependencies = [ # Core generation dependencies, pinned for reproducible builds. "accelerate==0.30.1", + "bitsandbytes", "clip_anytorch==2.6.0", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "compel==2.0.2", "controlnet-aux==0.0.7",