mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add empty_cache() for MPS hardware.
This commit is contained in:
parent
d989c7fa34
commit
fab055995e
@ -6,6 +6,7 @@ from typing import List, Literal, Optional, Union
|
|||||||
import einops
|
import einops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from torch import mps
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from diffusers.image_processor import VaeImageProcessor
|
from diffusers.image_processor import VaeImageProcessor
|
||||||
from diffusers.models.attention_processor import (
|
from diffusers.models.attention_processor import (
|
||||||
@ -541,6 +542,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
result_latents = result_latents.to("cpu")
|
result_latents = result_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
mps.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
context.services.latents.save(name, result_latents)
|
context.services.latents.save(name, result_latents)
|
||||||
@ -612,6 +614,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
|
|
||||||
# clear memory as vae decode can request a lot
|
# clear memory as vae decode can request a lot
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
mps.empty_cache()
|
||||||
|
|
||||||
with torch.inference_mode():
|
with torch.inference_mode():
|
||||||
# copied from diffusers pipeline
|
# copied from diffusers pipeline
|
||||||
@ -624,6 +627,7 @@ class LatentsToImageInvocation(BaseInvocation):
|
|||||||
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
|
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
mps.empty_cache()
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=image,
|
image=image,
|
||||||
@ -683,6 +687,7 @@ class ResizeLatentsInvocation(BaseInvocation):
|
|||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
resized_latents = resized_latents.to("cpu")
|
resized_latents = resized_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
mps.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
# context.services.latents.set(name, resized_latents)
|
# context.services.latents.set(name, resized_latents)
|
||||||
@ -719,6 +724,7 @@ class ScaleLatentsInvocation(BaseInvocation):
|
|||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
resized_latents = resized_latents.to("cpu")
|
resized_latents = resized_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
mps.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
# context.services.latents.set(name, resized_latents)
|
# context.services.latents.set(name, resized_latents)
|
||||||
@ -875,6 +881,7 @@ class BlendLatentsInvocation(BaseInvocation):
|
|||||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||||
blended_latents = blended_latents.to("cpu")
|
blended_latents = blended_latents.to("cpu")
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
mps.empty_cache()
|
||||||
|
|
||||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||||
# context.services.latents.set(name, resized_latents)
|
# context.services.latents.set(name, resized_latents)
|
||||||
|
@ -26,6 +26,7 @@ from pathlib import Path
|
|||||||
from typing import Any, Dict, Optional, Type, Union, types
|
from typing import Any, Dict, Optional, Type, Union, types
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from torch import mps
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
@ -406,6 +407,7 @@ class ModelCache(object):
|
|||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
mps.empty_cache()
|
||||||
|
|
||||||
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
|
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
|
||||||
|
|
||||||
@ -426,6 +428,7 @@ class ModelCache(object):
|
|||||||
|
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
mps.empty_cache()
|
||||||
|
|
||||||
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
||||||
sha = hashlib.sha256()
|
sha = hashlib.sha256()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user