add support for Apple hardware using MPS acceleration

This commit is contained in:
Lincoln Stein 2022-08-31 00:33:23 -04:00
parent 1714816fe2
commit bdb0651eb2
16 changed files with 361 additions and 52 deletions

4
.gitignore vendored
View File

@ -180,4 +180,6 @@ outputs
# created from generated embeddings.
logs
testtube
checkpoints
checkpoints
# If it's a Mac
.DS_Store

228
README-Mac-MPS.md Normal file
View File

@ -0,0 +1,228 @@
# Apple Silicon Mac Users
Several people have gotten Stable Diffusion to work on Apple Silicon
Macs using Anaconda. I've gathered up most of their instructions and
put them in this fork (and readme). I haven't tested anything besides
Anaconda, and I've read about issues with things like miniforge, so if
you have an issue that isn't dealt with in this fork then head on over
to the [Apple
Silicon](https://github.com/CompVis/stable-diffusion/issues/25) issue
on GitHub (that page is so long that GitHub hides most of it by
default, so you need to find the hidden part and expand it to view the
whole thing). This fork would not have been possible without the work
done by the people on that issue.
You have to have macOS 12.3 Monterey or later. Anything earlier than that won't work.
BTW, I haven't tested any of this on Intel Macs.
How to:
```
git clone https://github.com/lstein/stable-diffusion.git
cd stable-diffusion
mkdir -p models/ldm/stable-diffusion-v1/
ln -s /path/to/ckpt/sd-v1-1.ckpt models/ldm/stable-diffusion-v1/model.ckpt
conda env create -f environment-mac.yaml
conda activate ldm
```
These instructions are identical to the main repo except I added
environment-mac.yaml because Mac doesn't have cudatoolkit.
After you follow all the instructions and run txt2img.py you might get several errors. Here's the errors I've seen and found solutions for.
### Doesn't work anymore?
We are using PyTorch nightly, which includes support for MPS. I don't
know exactly how Anaconda does updates, but I woke up one morning and
Stable Diffusion crashed and I couldn't think of anything I did that
would've changed anything the night before, when it worked. A day and
a half later I finally got it working again. I don't know what changed
overnight. PyTorch-nightly changes overnight but I'm pretty sure I
didn't manually update it. Either way, things are probably going to be
bumpy on Apple Silicon until PyTorch releases a firm version that we
can lock to.
To manually update to the latest version of PyTorch nightly (which could fix issues), run this command.
conda install pytorch torchvision torchaudio -c pytorch-nightly
## Debugging?
Tired of waiting for your renders to finish before you can see if it
works? Reduce the steps! The picture wont look like anything but if it
finishes, hey, it works! This could also help you figure out if you've
got a memory problem, because I'm betting 1 step doesn't use much
memory.
python ./scripts/txt2img.py --prompt "ocean" --ddim_steps 1
### "No module named cv2" (or some other module)
Did you remember to `conda activate ldm`? If your terminal prompt
begins with "(ldm)" then you activated it. If it begins with "(base)"
or something else you haven't.
If it says you're missing taming you need to rebuild your virtual
environment.
conda env remove -n ldm
conda env create -f environment-mac.yaml
If you have activated the ldm virtual environment and tried rebuilding
it, maybe the problem could be that I have something installed that
you don't and you'll just need to manually install it. Make sure you
activate the virtual environment so it installs there instead of
globally.
conda activate ldm
pip install *name*
You might also need to install Rust (I mention this again below).
### "The operator [name] is not current implemented for the MPS device." (sic)
Example error.
```
...
NotImplementedError: The operator 'aten::index.Tensor' is not current implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on [https://github.com/pytorch/pytorch/issues/77764](https://github.com/pytorch/pytorch/issues/77764). As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
```
Just do what it says:
export PYTORCH_ENABLE_MPS_FALLBACK=1
### "Could not build wheels for tokenizers"
I have not seen this error because I had Rust installed on my computer before I started playing with Stable Diffusion. The fix is to install Rust.
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
### How come `--seed` doesn't work?
> Completely reproducible results are not guaranteed across PyTorch
releases, individual commits, or different platforms. Furthermore,
results may not be reproducible between CPU and GPU executions, even
when using identical seeds.
[PyTorch docs](https://pytorch.org/docs/stable/notes/randomness.html)
There is an [open issue](https://github.com/pytorch/pytorch/issues/78035) (as of August 2022) in pytorch regarding gradient inconsistency. I am guessing that's what is causing this.
### libiomp5.dylib error?
OMP: Error #15: Initializing libiomp5.dylib, but found libomp.dylib already initialized.
There are several things you can do. First, you could use something
besides Anaconda like miniforge. I read a lot of things online telling
people to use something else, but I am stuck with Anaconda for other
reasons.
Or you can try this.
export KMP_DUPLICATE_LIB_OK=True
Or this (which takes forever on my computer and didn't work anyway).
conda install nomkl
This error happens with Anaconda on Macs, and
[nomkl](https://stackoverflow.com/questions/66224879/what-is-the-nomkl-python-package-used-for)
is supposed to fix the issue (it isn't a module but a fix of some
sort). [There's more
suggestions](https://stackoverflow.com/questions/53014306/error-15-initializing-libiomp5-dylib-but-found-libiomp5-dylib-already-initial),
like uninstalling tensorflow and reinstalling. I haven't tried them.
### Not enough memory.
This seems to be a common problem and is probably the underlying
problem for a lot of symptoms (listed below). The fix is to lower your
image size or to add `model.half()` right after the model is loaded. I
should probably test it out. I've read that the reason this fixes
problems is because it converts the model from 32-bit to 16-bit and
that leaves more RAM for other things. I have no idea how that would
affect the quality of the images though.
See [this issue](https://github.com/CompVis/stable-diffusion/issues/71).
### "Error: product of dimension sizes > 2**31'"
This error happens with img2img, which I haven't played with too much
yet. But I know it's because your image is too big or the resolution
isn't a multiple of 32x32. Because the stable-diffusion model was
trained on images that were 512 x 512, it's always best to use that
output size (which is the default). However, if you're using that size
and you get the above error, try 256 x 256 or 512 x 256 or something
as the source image.
BTW, 2**31-1 = [2,147,483,647](https://en.wikipedia.org/wiki/2,147,483,647#In_computing), which is also 32-bit signed [LONG_MAX](https://en.wikipedia.org/wiki/C_data_types) in C.
### I just got Rickrolled! Do I have a virus?
You don't have a virus. It's part of the project. Here's
[Rick](https://github.com/lstein/stable-diffusion/blob/main/assets/rick.jpeg)
and here's [the
code](https://github.com/lstein/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/scripts/txt2img.py#L79)
that swaps him in. It's a NSFW filter, which IMO, doesn't work very
good (and we call this "computer vision", sheesh).
Actually, this could be happening because there's not enough RAM. You could try the `model.half()` suggestion or specify smaller output images.
### My images come out black
I haven't solved this issue. I just throw away my black
images. There's a [similar
issue](https://github.com/CompVis/stable-diffusion/issues/69) on CUDA
GPU's where the images come out green. Maybe it's the same issue?
Someone in that issue says to use "--precision full", but this fork
actually disables that flag. I don't know why, someone else provided
that code and I don't know what it does. Maybe the `model.half()`
suggestion above would fix this issue too. I should probably test it.
### "view size is not compatible with input tensor's size and stride"
```
File "/opt/anaconda3/envs/ldm/lib/python3.10/site-packages/torch/nn/functional.py", line 2511, in layer_norm
return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
```
Update to the latest version of lstein/stable-diffusion. We were
patching pytorch but we found a file in stable-diffusion that we could
change instead. This is a 32-bit vs 16-bit problem.
### The processor must support the Intel bla bla bla
What? Intel? On an Apple Silicon?
Intel MKL FATAL ERROR: This system does not meet the minimum requirements for use of the Intel(R) Math Kernel Library.
The processor must support the Intel(R) Supplemental Streaming SIMD Extensions 3 (Intel(R) SSSE3) instructions.██████████████| 50/50 [02:25<00:00, 2.53s/it]
The processor must support the Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) instructions.
The processor must support the Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.
This fixed it for me:
conda clean --yes --all
### Still slow?
I changed the defaults of n_samples and n_iter to 1 so that it uses
less RAM and makes less images so it will be faster the first time you
use it. I don't actually know what n_samples does internally, but I
know it consumes a lot more RAM. The n_iter flag just loops around the
image creation code, so it shouldn't consume more RAM (it should be
faster if you're going to do multiple images because the libraries and
model will already be loaded--use a prompt file to get this speed
boost).
These flags are the default sample and iter settings in this fork/branch:
~~~~
python scripts/txt2img.py --prompt "ocean" --n_samples=1 --n_iter=1
~~~

View File

@ -387,7 +387,7 @@ Credit goes to @rinongal and the repository located at
https://github.com/rinongal/textual_inversion Please see the
repository and associated paper for details and limitations.
# Latest
# Latest Changes
- v1.13 (in process)
@ -403,9 +403,9 @@ For older changelogs, please visit **[CHANGELOGS](CHANGELOG.md)**.
# Installation
There are separate installation walkthroughs for [Linux/Mac](#linuxmac) and [Windows](#windows).
There are separate installation walkthroughs for [Linux](#linux), [Windows](#windows) and [Macintosh](#Macintosh)
## Linux/Mac
## Linux
1. You will need to install the following prerequisites if they are not already available. Use your
operating system's preferred installer
@ -580,7 +580,15 @@ python scripts\dream.py -l
python scripts\dream.py
```
10. Subsequently, to relaunch the script, first activate the Anaconda command window (step 3), enter the stable-diffusion directory (step 5, "cd \path\to\stable-diffusion"), run "conda activate ldm" (step 6b), and then launch the dream script (step 9).
10. Subsequently, to relaunch the script, first activate the Anaconda
command window (step 3), enter the stable-diffusion directory (step 5,
"cd \path\to\stable-diffusion"), run "conda activate ldm" (step 6b),
and then launch the dream script (step 9).
**Note:** Tildebyte has written an alternative ["Easy peasy Windows
install"](https://github.com/lstein/stable-diffusion/wiki/Easy-peasy-Windows-install)
which uses the Windows Powershell and pew. If you are having trouble
with Anaconda on Windows, give this a try (or try it first!)
### Updating to newer versions of the script
@ -595,11 +603,16 @@ git pull
This will bring your local copy into sync with the remote one.
## Simplified API for text to image generation
## Macintosh
See (README-Mac-MPS)[README-Mac-MPS.md] for instructions.
# Simplified API for text to image generation
For programmers who wish to incorporate stable-diffusion into other
products, this repository includes a simplified API for text to image generation, which
lets you create images from a prompt in just three lines of code:
products, this repository includes a simplified API for text to image
generation, which lets you create images from a prompt in just three
lines of code:
```
from ldm.simplet2i import T2I
@ -608,9 +621,10 @@ outputs = model.txt2img("a unicorn in manhattan")
```
Outputs is a list of lists in the format [[filename1,seed1],[filename2,seed2]...]
Please see ldm/simplet2i.py for more information.
Please see ldm/simplet2i.py for more information. A set of example scripts is
coming RSN.
## Workaround for machines with limited internet connectivity
# Workaround for machines with limited internet connectivity
My development machine is a GPU node in a high-performance compute
cluster which has no connection to the internet. During model

32
environment-mac.yaml Normal file
View File

@ -0,0 +1,32 @@
name: ldm
channels:
- apple
- conda-forge
- pytorch-nightly
- defaults
dependencies:
- python=3.10.4
- pip=22.1.2
- pytorch
- torchvision
- numpy=1.23.1
- pip:
- albumentations==0.4.6
- opencv-python==4.6.0.66
- pudb==2019.2
- imageio==2.9.0
- imageio-ffmpeg==0.4.2
- pytorch-lightning==1.4.2
- omegaconf==2.1.1
- test-tube>=0.7.5
- streamlit==1.12.0
- pillow==9.2.0
- einops==0.3.0
- torch-fidelity==0.3.0
- transformers==4.19.2
- torchmetrics==0.6.0
- kornia==0.6.0
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/lstein/k-diffusion.git@master#egg=k-diffusion
- -e .

11
ldm/dream/devices.py Normal file
View File

@ -0,0 +1,11 @@
import torch
def choose_torch_device() -> str:
'''Convenience routine for guessing which GPU device to run model on'''
if torch.cuda.is_available():
return 'cuda'
if torch.backends.mps.is_available():
return 'mps'
return 'cpu'

View File

@ -4,6 +4,7 @@ import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.dream.devices import choose_torch_device
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
@ -14,17 +15,17 @@ from ldm.modules.diffusionmodules.util import (
class DDIMSampler(object):
def __init__(self, model, schedule='linear', device='cuda', **kwargs):
def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
self.device = device or choose_torch_device()
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.device):
attr = attr.to(torch.device(self.device))
attr = attr.to(dtype=torch.float32, device=self.device)
setattr(self, name, attr)
def make_schedule(

View File

@ -2,7 +2,7 @@
import k_diffusion as K
import torch
import torch.nn as nn
from ldm.dream.devices import choose_torch_device
class CFGDenoiser(nn.Module):
def __init__(self, model):
@ -18,11 +18,11 @@ class CFGDenoiser(nn.Module):
class KSampler(object):
def __init__(self, model, schedule='lms', device='cuda', **kwargs):
def __init__(self, model, schedule='lms', device=None, **kwargs):
super().__init__()
self.model = K.external.CompVisDenoiser(model)
self.schedule = schedule
self.device = device
self.device = device or choose_torch_device()
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)

View File

@ -4,6 +4,7 @@ import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.dream.devices import choose_torch_device
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
@ -13,18 +14,17 @@ from ldm.modules.diffusionmodules.util import (
class PLMSSampler(object):
def __init__(self, model, schedule='linear', device='cuda', **kwargs):
def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
self.device = device if device else choose_torch_device()
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.device):
attr = attr.to(torch.device(self.device))
attr = attr.to(torch.float32).to(torch.device(self.device))
setattr(self, name, attr)
def make_schedule(

View File

@ -234,6 +234,7 @@ class BasicTransformerBlock(nn.Module):
)
def _forward(self, x, context=None):
x = x.contiguous() if x.device.type == 'mps' else x
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x

View File

@ -5,6 +5,7 @@ import clip
from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel
import kornia
from ldm.dream.devices import choose_torch_device
from ldm.modules.x_transformer import (
Encoder,
@ -67,7 +68,12 @@ class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
def __init__(
self, n_embed, n_layer, vocab_size, max_seq_len=77, device='cuda'
self,
n_embed,
n_layer,
vocab_size,
max_seq_len=77,
device=choose_torch_device(),
):
super().__init__()
self.device = device
@ -89,7 +95,9 @@ class TransformerEmbedder(AbstractEncoder):
class BERTTokenizer(AbstractEncoder):
"""Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device='cuda', vq_interface=True, max_length=77):
def __init__(
self, device=choose_torch_device(), vq_interface=True, max_length=77
):
super().__init__()
from transformers import (
BertTokenizerFast,
@ -145,7 +153,7 @@ class BERTEmbedder(AbstractEncoder):
n_layer,
vocab_size=30522,
max_seq_len=77,
device='cuda',
device=choose_torch_device(),
use_tokenizer=True,
embedding_dropout=0.0,
):
@ -230,7 +238,7 @@ class FrozenCLIPEmbedder(AbstractEncoder):
def __init__(
self,
version='openai/clip-vit-large-patch14',
device='cuda',
device=choose_torch_device(),
max_length=77,
):
super().__init__()
@ -455,13 +463,13 @@ class FrozenCLIPTextEmbedder(nn.Module):
def __init__(
self,
version='ViT-L/14',
device='cuda',
device=choose_torch_device(),
max_length=77,
n_repeat=1,
normalize=True,
):
super().__init__()
self.model, _ = clip.load(version, jit=False, device='cpu')
self.model, _ = clip.load(version, jit=False, device=device)
self.device = device
self.max_length = max_length
self.n_repeat = n_repeat
@ -496,7 +504,7 @@ class FrozenClipImageEmbedder(nn.Module):
self,
model,
jit=False,
device='cuda' if torch.cuda.is_available() else 'cpu',
device=choose_torch_device(),
antialias=False,
):
super().__init__()

View File

@ -28,6 +28,7 @@ from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ksampler import KSampler
from ldm.dream.pngwriter import PngWriter
from ldm.dream.image_util import InitImageResizer
from ldm.dream.devices import choose_torch_device
"""Simplified text to image API for stable diffusion/latent diffusion
@ -523,19 +524,15 @@ class T2I:
return self.seed
def _get_device(self):
if torch.cuda.is_available():
return torch.device('cuda')
elif torch.backends.mps.is_available():
return torch.device('mps')
else:
return torch.device('cpu')
device_type = choose_torch_device()
return torch.device(device_type)
def load_model(self):
"""Load and initialize the model from configuration variables passed at object creation time"""
if self.model is None:
seed_everything(self.seed)
try:
config = OmegaConf.load(self.config)
config = OmegaConf.load(self.config)
self.device = self._get_device()
model = self._load_model_from_config(config, self.weights)
if self.embedding_path is not None:

View File

@ -14,7 +14,7 @@ from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import ismap
import time
from omegaconf import OmegaConf
from ldm.dream.devices import choose_torch_device
def download_models(mode):
@ -117,7 +117,8 @@ def get_cond(mode, selected_path):
c = rearrange(c, '1 c h w -> 1 h w c')
c = 2. * c - 1.
c = c.to(torch.device("cuda"))
device = choose_torch_device()
c = c.to(device)
example["LR_image"] = c
example["image"] = c_up
@ -267,4 +268,4 @@ def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, e
log["sample"] = x_sample
log["time"] = t1 - t0
return log
return log

View File

@ -8,11 +8,11 @@ import re
import sys
import copy
import warnings
import time
import ldm.dream.readline
from ldm.dream.pngwriter import PngWriter, PromptFormatter
from ldm.dream.server import DreamServer, ThreadingDreamServer
def main():
"""Initialize command-line parsers and the diffusion model"""
arg_parser = create_argv_parser()
@ -81,7 +81,11 @@ def main():
sys.exit(-1)
# preload the model
tic = time.time()
t2i.load_model()
print(
f'model loaded in', '%4.2fs' % (time.time() - tic)
)
if not infile:
print(

View File

@ -6,7 +6,7 @@ import numpy as np
import torch
from main import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.dream.devices import choose_torch_device
def make_batch(image, mask, device):
image = np.array(Image.open(image).convert("RGB"))
@ -61,8 +61,8 @@ if __name__ == "__main__":
model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"],
strict=False)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
device = choose_torch_device()
model = model.to(device)
sampler = DDIMSampler(model)
os.makedirs(opt.outdir, exist_ok=True)

View File

@ -18,6 +18,7 @@ from pytorch_lightning import seed_everything
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.dream.devices import choose_torch_device
def chunk(it, size):
@ -40,7 +41,7 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:")
print(u)
model.cuda()
model.to(choose_torch_device())
model.eval()
return model
@ -199,7 +200,7 @@ def main():
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device = choose_torch_device()
model = model.to(device)
if opt.plms:
@ -241,8 +242,10 @@ def main():
print(f"target t_enc is {t_enc} steps")
precision_scope = autocast if opt.precision == "autocast" else nullcontext
if device.type in ['mps', 'cpu']:
precision_scope = nullcontext # have to use f32 on mps
with torch.no_grad():
with precision_scope("cuda"):
with precision_scope(device.type):
with model.ema_scope():
tic = time.time()
all_samples = list()

View File

@ -15,10 +15,10 @@ from contextlib import contextmanager, nullcontext
import k_diffusion as K
import torch.nn as nn
from ldm.util import instantiate_from_config
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.dream.devices import choose_torch_device
def chunk(it, size):
it = iter(it)
@ -40,7 +40,7 @@ def load_model_from_config(config, ckpt, verbose=False):
print("unexpected keys:")
print(u)
model.cuda()
model.to(choose_torch_device())
model.eval()
return model
@ -190,13 +190,14 @@ def main():
opt.ckpt = "models/ldm/text2img-large/model.ckpt"
opt.outdir = "outputs/txt2img-samples-laion400m"
seed_everything(opt.seed)
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
seed_everything(opt.seed)
device = torch.device(choose_torch_device())
model = model.to(device)
#for klms
model_wrap = K.external.CompVisDenoiser(model)
@ -240,11 +241,17 @@ def main():
start_code = None
if opt.fixed_code:
start_code = torch.randn([opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f], device=device)
shape = [opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f]
if device.type == 'mps':
start_code = torch.randn(shape, device='cpu').to(device)
else:
torch.randn(shape, device=device)
precision_scope = autocast if opt.precision=="autocast" else nullcontext
if device.type in ['mps', 'cpu']:
precision_scope = nullcontext # have to use f32 on mps
with torch.no_grad():
with precision_scope("cuda"):
with precision_scope(device.type):
with model.ema_scope():
tic = time.time()
all_samples = list()