Merge branch 'toffaletti-dream-m1' into main

This provides support for Apple M1 hardware
This commit is contained in:
Lincoln Stein 2022-09-01 17:55:36 -04:00
commit 3ee82d8a3b
4 changed files with 60 additions and 41 deletions

View File

@ -52,7 +52,7 @@ dependencies:
- -e git+https://github.com/huggingface/diffusers.git@v0.2.4#egg=diffusers - -e git+https://github.com/huggingface/diffusers.git@v0.2.4#egg=diffusers
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
- -e git+https://github.com/openai/CLIP.git@main#egg=clip - -e git+https://github.com/openai/CLIP.git@main#egg=clip
- -e git+https://github.com/lstein/k-diffusion.git@master#egg=k-diffusion - -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
- -e . - -e .
variables: variables:
PYTORCH_ENABLE_MPS_FALLBACK: 1 PYTORCH_ENABLE_MPS_FALLBACK: 1

View File

@ -8,4 +8,10 @@ def choose_torch_device() -> str:
return 'mps' return 'mps'
return 'cpu' return 'cpu'
def choose_autocast_device(device) -> str:
'''Returns an autocast compatible device from a torch device'''
device_type = device.type # this returns 'mps' on M1
# autocast only supports cuda or cpu
if device_type not in ('cuda','cpu'):
return 'cpu'
return device_type

View File

@ -8,6 +8,7 @@ import torch
import numpy as np import numpy as np
import random import random
import os import os
import traceback
from omegaconf import OmegaConf from omegaconf import OmegaConf
from PIL import Image from PIL import Image
from tqdm import tqdm, trange from tqdm import tqdm, trange
@ -28,7 +29,7 @@ from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ksampler import KSampler from ldm.models.diffusion.ksampler import KSampler
from ldm.dream.pngwriter import PngWriter from ldm.dream.pngwriter import PngWriter
from ldm.dream.image_util import InitImageResizer from ldm.dream.image_util import InitImageResizer
from ldm.dream.devices import choose_torch_device from ldm.dream.devices import choose_autocast_device, choose_torch_device
"""Simplified text to image API for stable diffusion/latent diffusion """Simplified text to image API for stable diffusion/latent diffusion
@ -114,26 +115,28 @@ class T2I:
""" """
def __init__( def __init__(
self, self,
iterations=1, iterations=1,
steps=50, steps=50,
seed=None, seed=None,
cfg_scale=7.5, cfg_scale=7.5,
weights='models/ldm/stable-diffusion-v1/model.ckpt', weights='models/ldm/stable-diffusion-v1/model.ckpt',
config='configs/stable-diffusion/v1-inference.yaml', config='configs/stable-diffusion/v1-inference.yaml',
grid=False, grid=False,
width=512, width=512,
height=512, height=512,
sampler_name='k_lms', sampler_name='k_lms',
latent_channels=4, latent_channels=4,
downsampling_factor=8, downsampling_factor=8,
ddim_eta=0.0, # deterministic ddim_eta=0.0, # deterministic
precision='autocast', precision='autocast',
full_precision=False, full_precision=False,
strength=0.75, # default in scripts/img2img.py strength=0.75, # default in scripts/img2img.py
embedding_path=None, embedding_path=None,
# just to keep track of this parameter when regenerating prompt device_type = 'cuda',
latent_diffusion_weights=False, # just to keep track of this parameter when regenerating prompt
# needs to be replaced when new configuration system implemented.
latent_diffusion_weights=False,
): ):
self.iterations = iterations self.iterations = iterations
self.width = width self.width = width
@ -151,11 +154,17 @@ class T2I:
self.full_precision = full_precision self.full_precision = full_precision
self.strength = strength self.strength = strength
self.embedding_path = embedding_path self.embedding_path = embedding_path
self.device_type = device_type
self.model = None # empty for now self.model = None # empty for now
self.sampler = None self.sampler = None
self.device = None self.device = None
self.latent_diffusion_weights = latent_diffusion_weights self.latent_diffusion_weights = latent_diffusion_weights
if device_type == 'cuda' and not torch.cuda.is_available():
device_type = choose_torch_device()
print(">> cuda not available, using device", device_type)
self.device = torch.device(device_type)
# for VRAM usage statistics # for VRAM usage statistics
device_type = choose_torch_device() device_type = choose_torch_device()
self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None
@ -312,8 +321,9 @@ class T2I:
callback=step_callback, callback=step_callback,
) )
with scope(self.device.type), self.model.ema_scope(): device_type = choose_autocast_device(self.device)
for n in trange(iterations, desc='>> Generating'): with scope(device_type), self.model.ema_scope():
for n in trange(iterations, desc='Generating'):
seed_everything(seed) seed_everything(seed)
image = next(images_iterator) image = next(images_iterator)
results.append([image, seed]) results.append([image, seed])
@ -346,7 +356,7 @@ class T2I:
) )
except Exception as e: except Exception as e:
print( print(
f'Error running RealESRGAN - Your image was not upscaled.\n{e}' f'>> Error running RealESRGAN - Your image was not upscaled.\n{e}'
) )
if image_callback is not None: if image_callback is not None:
if save_original: if save_original:
@ -359,11 +369,11 @@ class T2I:
except KeyboardInterrupt: except KeyboardInterrupt:
print('*interrupted*') print('*interrupted*')
print( print(
'Partial results will be returned; if --grid was requested, nothing will be returned.' '>> Partial results will be returned; if --grid was requested, nothing will be returned.'
) )
except RuntimeError as e: except RuntimeError as e:
print(str(e)) print(traceback.format_exc(), file=sys.stderr)
print('Are you sure your system has an adequate NVIDIA GPU?') print('>> Are you sure your system has an adequate NVIDIA GPU?')
toc = time.time() toc = time.time()
print('>> Usage stats:') print('>> Usage stats:')
@ -464,7 +474,6 @@ class T2I:
) )
t_enc = int(strength * steps) t_enc = int(strength * steps)
# print(f"target t_enc is {t_enc} steps")
while True: while True:
uc, c = self._get_uc_and_c(prompt, skip_normalize) uc, c = self._get_uc_and_c(prompt, skip_normalize)
@ -515,7 +524,7 @@ class T2I:
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if len(x_samples) != 1: if len(x_samples) != 1:
raise Exception( raise Exception(
f'expected to get a single image, but got {len(x_samples)}') f'>> expected to get a single image, but got {len(x_samples)}')
x_sample = 255.0 * rearrange( x_sample = 255.0 * rearrange(
x_samples[0].cpu().numpy(), 'c h w -> h w c' x_samples[0].cpu().numpy(), 'c h w -> h w c'
) )
@ -525,17 +534,12 @@ class T2I:
self.seed = random.randrange(0, np.iinfo(np.uint32).max) self.seed = random.randrange(0, np.iinfo(np.uint32).max)
return self.seed return self.seed
def _get_device(self):
device_type = choose_torch_device()
return torch.device(device_type)
def load_model(self): def load_model(self):
"""Load and initialize the model from configuration variables passed at object creation time""" """Load and initialize the model from configuration variables passed at object creation time"""
if self.model is None: if self.model is None:
seed_everything(self.seed) seed_everything(self.seed)
try: try:
config = OmegaConf.load(self.config) config = OmegaConf.load(self.config)
self.device = self._get_device()
model = self._load_model_from_config(config, self.weights) model = self._load_model_from_config(config, self.weights)
if self.embedding_path is not None: if self.embedding_path is not None:
model.embedding_manager.load( model.embedding_manager.load(
@ -544,12 +548,10 @@ class T2I:
self.model = model.to(self.device) self.model = model.to(self.device)
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here # model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
self.model.cond_stage_model.device = self.device self.model.cond_stage_model.device = self.device
except AttributeError: except AttributeError as e:
import traceback print(f'>> Error loading model. {str(e)}', file=sys.stderr)
print(
'Error loading model. Only the CUDA backend is supported', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
raise SystemExit raise SystemExit from e
self._set_sampler() self._set_sampler()

View File

@ -9,6 +9,7 @@ import sys
import copy import copy
import warnings import warnings
import time import time
from ldm.dream.devices import choose_torch_device
import ldm.dream.readline import ldm.dream.readline
from ldm.dream.pngwriter import PngWriter, PromptFormatter from ldm.dream.pngwriter import PngWriter, PromptFormatter
from ldm.dream.server import DreamServer, ThreadingDreamServer from ldm.dream.server import DreamServer, ThreadingDreamServer
@ -60,6 +61,7 @@ def main():
# this is solely for recreating the prompt # this is solely for recreating the prompt
latent_diffusion_weights=opt.laion400m, latent_diffusion_weights=opt.laion400m,
embedding_path=opt.embedding_path, embedding_path=opt.embedding_path,
device_type=opt.device
) )
# make sure the output directory exists # make sure the output directory exists
@ -346,6 +348,8 @@ def create_argv_parser():
dest='full_precision', dest='full_precision',
action='store_true', action='store_true',
help='Use slower full precision math for calculations', help='Use slower full precision math for calculations',
# MPS only functions with full precision, see https://github.com/lstein/stable-diffusion/issues/237
default=choose_torch_device() == 'mps',
) )
parser.add_argument( parser.add_argument(
'-g', '-g',
@ -418,6 +422,13 @@ def create_argv_parser():
default='model', default='model',
help='Indicates the Stable Diffusion model to use.', help='Indicates the Stable Diffusion model to use.',
) )
parser.add_argument(
'--device',
'-d',
type=str,
default='cuda',
help="device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if available"
)
return parser return parser