made MPS calls conditional on MPS actually being the chosen device with backend available

This commit is contained in:
Ryan 2023-09-11 00:44:43 -04:00 committed by Kent Keirsey
parent fab055995e
commit b7296000e4
2 changed files with 23 additions and 9 deletions

View File

@ -6,7 +6,6 @@ from typing import List, Literal, Optional, Union
import einops import einops
import numpy as np import numpy as np
import torch import torch
from torch import mps
import torchvision.transforms as T import torchvision.transforms as T
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
@ -64,6 +63,9 @@ from .compel import ConditioningField
from .controlnet_image_processors import ControlField from .controlnet_image_processors import ControlField
from .model import ModelInfo, UNetField, VaeField from .model import ModelInfo, UNetField, VaeField
if choose_torch_device() == torch.device("mps"):
from torch import mps
DEFAULT_PRECISION = choose_precision(choose_torch_device()) DEFAULT_PRECISION = choose_precision(choose_torch_device())
@ -542,6 +544,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu") result_latents = result_latents.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache() mps.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
@ -614,6 +617,7 @@ class LatentsToImageInvocation(BaseInvocation):
# clear memory as vae decode can request a lot # clear memory as vae decode can request a lot
torch.cuda.empty_cache() torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache() mps.empty_cache()
with torch.inference_mode(): with torch.inference_mode():
@ -627,6 +631,7 @@ class LatentsToImageInvocation(BaseInvocation):
image = VaeImageProcessor.numpy_to_pil(np_image)[0] image = VaeImageProcessor.numpy_to_pil(np_image)[0]
torch.cuda.empty_cache() torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache() mps.empty_cache()
image_dto = context.services.images.create( image_dto = context.services.images.create(
@ -687,6 +692,7 @@ class ResizeLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu") resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache() mps.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
@ -724,6 +730,7 @@ class ScaleLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu") resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache() mps.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"
@ -881,6 +888,7 @@ class BlendLatentsInvocation(BaseInvocation):
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
blended_latents = blended_latents.to("cpu") blended_latents = blended_latents.to("cpu")
torch.cuda.empty_cache() torch.cuda.empty_cache()
if device == torch.device("mps"):
mps.empty_cache() mps.empty_cache()
name = f"{context.graph_execution_state_id}__{self.id}" name = f"{context.graph_execution_state_id}__{self.id}"

View File

@ -30,8 +30,12 @@ from torch import mps
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from ..util.devices import choose_torch_device
from .models import BaseModelType, ModelBase, ModelType, SubModelType from .models import BaseModelType, ModelBase, ModelType, SubModelType
if choose_torch_device() == torch.device("mps"):
from torch import mps
# Maximum size of the cache, in gigs # Maximum size of the cache, in gigs
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously # Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
DEFAULT_MAX_CACHE_SIZE = 6.0 DEFAULT_MAX_CACHE_SIZE = 6.0
@ -407,6 +411,7 @@ class ModelCache(object):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache() mps.empty_cache()
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}") self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
@ -428,6 +433,7 @@ class ModelCache(object):
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache() mps.empty_cache()
def _local_model_hash(self, model_path: Union[str, Path]) -> str: def _local_model_hash(self, model_path: Union[str, Path]) -> str: