mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
made MPS calls conditional on MPS actually being the chosen device with backend available
This commit is contained in:
parent
fab055995e
commit
b7296000e4
@ -6,7 +6,6 @@ from typing import List, Literal, Optional, Union
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import mps
|
||||
import torchvision.transforms as T
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models.attention_processor import (
|
||||
@ -64,6 +63,9 @@ from .compel import ConditioningField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
|
||||
|
||||
@ -542,7 +544,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
result_latents = result_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
mps.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
context.services.latents.save(name, result_latents)
|
||||
@ -614,7 +617,8 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
|
||||
# clear memory as vae decode can request a lot
|
||||
torch.cuda.empty_cache()
|
||||
mps.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
with torch.inference_mode():
|
||||
# copied from diffusers pipeline
|
||||
@ -627,7 +631,8 @@ class LatentsToImageInvocation(BaseInvocation):
|
||||
image = VaeImageProcessor.numpy_to_pil(np_image)[0]
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
mps.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
image_dto = context.services.images.create(
|
||||
image=image,
|
||||
@ -687,7 +692,8 @@ class ResizeLatentsInvocation(BaseInvocation):
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
resized_latents = resized_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
mps.empty_cache()
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
@ -724,7 +730,8 @@ class ScaleLatentsInvocation(BaseInvocation):
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
resized_latents = resized_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
mps.empty_cache()
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
@ -881,7 +888,8 @@ class BlendLatentsInvocation(BaseInvocation):
|
||||
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
||||
blended_latents = blended_latents.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
mps.empty_cache()
|
||||
if device == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
name = f"{context.graph_execution_state_id}__{self.id}"
|
||||
# context.services.latents.set(name, resized_latents)
|
||||
|
@ -30,8 +30,12 @@ from torch import mps
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ..util.devices import choose_torch_device
|
||||
from .models import BaseModelType, ModelBase, ModelType, SubModelType
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
||||
# Maximum size of the cache, in gigs
|
||||
# Default is roughly enough to hold three fp16 diffusers models in RAM simultaneously
|
||||
DEFAULT_MAX_CACHE_SIZE = 6.0
|
||||
@ -407,7 +411,8 @@ class ModelCache(object):
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
mps.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
|
||||
|
||||
@ -428,7 +433,8 @@ class ModelCache(object):
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
mps.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
mps.empty_cache()
|
||||
|
||||
def _local_model_hash(self, model_path: Union[str, Path]) -> str:
|
||||
sha = hashlib.sha256()
|
||||
|
Loading…
Reference in New Issue
Block a user