Add empty_cache() for MPS hardware.

This commit is contained in:
Ryan
2023-09-09 22:54:20 -04:00
committed by Kent Keirsey
parent d989c7fa34
commit fab055995e
2 changed files with 10 additions and 0 deletions

View File

@ -6,6 +6,7 @@ from typing import List, Literal, Optional, Union
import einops import einops
import numpy as np import numpy as np
import torch import torch
from torch import mps
import torchvision.transforms as T import torchvision.transforms as T
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
@ -541,6 +542,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu") result_latents = result_latents.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
mps.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
context.services.latents.save(name, result_latents) context.services.latents.save(name, result_latents)
@ -612,6 +614,7 @@ class LatentsToImageInvocation(BaseInvocation):
# clear memory as vae decode can request a lot # clear memory as vae decode can request a lot
torch.cuda.empty_cache() torch.cuda.empty_cache()
mps.empty_cache()
with torch.inference_mode(): with torch.inference_mode():
# copied from diffusers pipeline # copied from diffusers pipeline
@ -624,6 +627,7 @@ class LatentsToImageInvocation(BaseInvocation):
image = VaeImageProcessor.numpy_to_pil(np_image)[0] image = VaeImageProcessor.numpy_to_pil(np_image)[0]
torch.cuda.empty_cache() torch.cuda.empty_cache()
mps.empty_cache()
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=image, image=image,
@ -683,6 +687,7 @@ class ResizeLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu") resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
mps.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, resized_latents) # context.services.latents.set(name, resized_latents)
@ -719,6 +724,7 @@ class ScaleLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu") resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
mps.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, resized_latents) # context.services.latents.set(name, resized_latents)
@ -875,6 +881,7 @@ class BlendLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu") blended_latents = blended_latents.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
mps.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, resized_latents) # context.services.latents.set(name, resized_latents)

View File

@ -26,6 +26,7 @@ from pathlib import Path
from typing import Any, Dict, Optional, Type, Union, types from typing import Any, Dict, Optional, Type, Union, types
import torch import torch
from torch import mps
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
@ -406,6 +407,7 @@ class ModelCache(object):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
mps.empty_cache()
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}") self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
@ -426,6 +428,7 @@ class ModelCache(object):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
mps.empty_cache()
def _local_model_hash(self, model_path: Union[str, Path]) -> str: def _local_model_hash(self, model_path: Union[str, Path]) -> str:
sha = hashlib.sha256() sha = hashlib.sha256()