From b7296000e4a2146e5e83d53e9a39391d68f4ab2a Mon Sep 17 00:00:00 2001 From: Ryan <36354352+gogurtenjoyer@users.noreply.github.com> Date: Mon, 11 Sep 2023 00:44:43 -0400 Subject: [PATCH] made MPS calls conditional on MPS actually being the chosen device with backend available --- invokeai/app/invocations/latent.py | 22 +++++++++++++------ .../backend/model_management/model_cache.py | 10 +++++++-- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 78f8f624f2..385ddc5df8 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -6,7 +6,6 @@ from typing import List, Literal, Optional, Union import einops import numpy as np import torch -from torch import mps import torchvision.transforms as T from diffusers.image_processor import VaeImageProcessor from diffusers.models.attention_processor import ( @@ -64,6 +63,9 @@ from .compel import ConditioningField from .controlnet_image_processors import ControlField from .model import ModelInfo, UNetField, VaeField +if choose_torch_device() == torch.device("mps"): + from torch import mps + DEFAULT_PRECISION = choose_precision(choose_torch_device()) @@ -542,7 +544,8 @@ class DenoiseLatentsInvocation(BaseInvocation): # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 result_latents = result_latents.to("cpu") torch.cuda.empty_cache() - mps.empty_cache() + if choose_torch_device() == torch.device("mps"): + mps.empty_cache() name = f"{context.graph_execution_state_id}__{self.id}" context.services.latents.save(name, result_latents) @@ -614,7 +617,8 @@ class LatentsToImageInvocation(BaseInvocation): # clear memory as vae decode can request a lot torch.cuda.empty_cache() - mps.empty_cache() + if choose_torch_device() == torch.device("mps"): + mps.empty_cache() with torch.inference_mode(): # copied from diffusers pipeline @@ -627,7 +631,8 @@ class LatentsToImageInvocation(BaseInvocation): image = VaeImageProcessor.numpy_to_pil(np_image)[0] torch.cuda.empty_cache() - mps.empty_cache() + if choose_torch_device() == torch.device("mps"): + mps.empty_cache() image_dto = context.services.images.create( image=image, @@ -687,7 +692,8 @@ class ResizeLatentsInvocation(BaseInvocation): # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 resized_latents = resized_latents.to("cpu") torch.cuda.empty_cache() - mps.empty_cache() + if device == torch.device("mps"): + mps.empty_cache() name = f"{context.graph_execution_state_id}__{self.id}" # context.services.latents.set(name, resized_latents) @@ -724,7 +730,8 @@ class ScaleLatentsInvocation(BaseInvocation): # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 resized_latents = resized_latents.to("cpu") torch.cuda.empty_cache() - mps.empty_cache() + if device == torch.device("mps"): + mps.empty_cache() name = f"{context.graph_execution_state_id}__{self.id}" # context.services.latents.set(name, resized_latents) @@ -881,7 +888,8 @@ class BlendLatentsInvocation(BaseInvocation): # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 blended_latents = blended_latents.to("cpu") torch.cuda.empty_cache() - mps.empty_cache() + if device == torch.device("mps"): + mps.empty_cache() name = f"{context.graph_execution_state_id}__{self.id}" # context.services.latents.set(name, resized_latents) diff --git a/invokeai/backend/model_management/model_cache.py b/invokeai/backend/model_management/model_cache.py index 2db46e9f64..d2850e21ec 100644 --- a/invokeai/backend/model_management/model_cache.py +++ b/invokeai/backend/model_management/model_cache.py @@ -30,8 +30,12 @@ from torch import mps import invokeai.backend.util.logging as logger +from ..util.devices import choose_torch_device from .models import BaseModelType, ModelBase, ModelType, SubModelType +if choose_torch_device() == torch.device("mps"): + from torch import mps + # Maximum size of the cache, in gigs # Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously DEFAULT_MAX_CACHE_SIZE = 6.0 @@ -407,7 +411,8 @@ class ModelCache(object): gc.collect() torch.cuda.empty_cache() - mps.empty_cache() + if choose_torch_device() == torch.device("mps"): + mps.empty_cache() self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}") @@ -428,7 +433,8 @@ class ModelCache(object): gc.collect() torch.cuda.empty_cache() - mps.empty_cache() + if choose_torch_device() == torch.device("mps"): + mps.empty_cache() def _local_model_hash(self, model_path: Union[str, Path]) -> str: sha = hashlib.sha256()