mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'toffaletti-dream-m1' into main
This provides support for Apple M1 hardware
This commit is contained in:
commit
3ee82d8a3b
@ -52,7 +52,7 @@ dependencies:
|
|||||||
- -e git+https://github.com/huggingface/diffusers.git@v0.2.4#egg=diffusers
|
- -e git+https://github.com/huggingface/diffusers.git@v0.2.4#egg=diffusers
|
||||||
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
- -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
|
||||||
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
- -e git+https://github.com/openai/CLIP.git@main#egg=clip
|
||||||
- -e git+https://github.com/lstein/k-diffusion.git@master#egg=k-diffusion
|
- -e git+https://github.com/Birch-san/k-diffusion.git@mps#egg=k_diffusion
|
||||||
- -e .
|
- -e .
|
||||||
variables:
|
variables:
|
||||||
PYTORCH_ENABLE_MPS_FALLBACK: 1
|
PYTORCH_ENABLE_MPS_FALLBACK: 1
|
||||||
|
@ -8,4 +8,10 @@ def choose_torch_device() -> str:
|
|||||||
return 'mps'
|
return 'mps'
|
||||||
return 'cpu'
|
return 'cpu'
|
||||||
|
|
||||||
|
def choose_autocast_device(device) -> str:
|
||||||
|
'''Returns an autocast compatible device from a torch device'''
|
||||||
|
device_type = device.type # this returns 'mps' on M1
|
||||||
|
# autocast only supports cuda or cpu
|
||||||
|
if device_type not in ('cuda','cpu'):
|
||||||
|
return 'cpu'
|
||||||
|
return device_type
|
||||||
|
@ -8,6 +8,7 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
import os
|
import os
|
||||||
|
import traceback
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
@ -28,7 +29,7 @@ from ldm.models.diffusion.plms import PLMSSampler
|
|||||||
from ldm.models.diffusion.ksampler import KSampler
|
from ldm.models.diffusion.ksampler import KSampler
|
||||||
from ldm.dream.pngwriter import PngWriter
|
from ldm.dream.pngwriter import PngWriter
|
||||||
from ldm.dream.image_util import InitImageResizer
|
from ldm.dream.image_util import InitImageResizer
|
||||||
from ldm.dream.devices import choose_torch_device
|
from ldm.dream.devices import choose_autocast_device, choose_torch_device
|
||||||
|
|
||||||
"""Simplified text to image API for stable diffusion/latent diffusion
|
"""Simplified text to image API for stable diffusion/latent diffusion
|
||||||
|
|
||||||
@ -114,26 +115,28 @@ class T2I:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
iterations=1,
|
iterations=1,
|
||||||
steps=50,
|
steps=50,
|
||||||
seed=None,
|
seed=None,
|
||||||
cfg_scale=7.5,
|
cfg_scale=7.5,
|
||||||
weights='models/ldm/stable-diffusion-v1/model.ckpt',
|
weights='models/ldm/stable-diffusion-v1/model.ckpt',
|
||||||
config='configs/stable-diffusion/v1-inference.yaml',
|
config='configs/stable-diffusion/v1-inference.yaml',
|
||||||
grid=False,
|
grid=False,
|
||||||
width=512,
|
width=512,
|
||||||
height=512,
|
height=512,
|
||||||
sampler_name='k_lms',
|
sampler_name='k_lms',
|
||||||
latent_channels=4,
|
latent_channels=4,
|
||||||
downsampling_factor=8,
|
downsampling_factor=8,
|
||||||
ddim_eta=0.0, # deterministic
|
ddim_eta=0.0, # deterministic
|
||||||
precision='autocast',
|
precision='autocast',
|
||||||
full_precision=False,
|
full_precision=False,
|
||||||
strength=0.75, # default in scripts/img2img.py
|
strength=0.75, # default in scripts/img2img.py
|
||||||
embedding_path=None,
|
embedding_path=None,
|
||||||
# just to keep track of this parameter when regenerating prompt
|
device_type = 'cuda',
|
||||||
latent_diffusion_weights=False,
|
# just to keep track of this parameter when regenerating prompt
|
||||||
|
# needs to be replaced when new configuration system implemented.
|
||||||
|
latent_diffusion_weights=False,
|
||||||
):
|
):
|
||||||
self.iterations = iterations
|
self.iterations = iterations
|
||||||
self.width = width
|
self.width = width
|
||||||
@ -151,11 +154,17 @@ class T2I:
|
|||||||
self.full_precision = full_precision
|
self.full_precision = full_precision
|
||||||
self.strength = strength
|
self.strength = strength
|
||||||
self.embedding_path = embedding_path
|
self.embedding_path = embedding_path
|
||||||
|
self.device_type = device_type
|
||||||
self.model = None # empty for now
|
self.model = None # empty for now
|
||||||
self.sampler = None
|
self.sampler = None
|
||||||
self.device = None
|
self.device = None
|
||||||
self.latent_diffusion_weights = latent_diffusion_weights
|
self.latent_diffusion_weights = latent_diffusion_weights
|
||||||
|
|
||||||
|
if device_type == 'cuda' and not torch.cuda.is_available():
|
||||||
|
device_type = choose_torch_device()
|
||||||
|
print(">> cuda not available, using device", device_type)
|
||||||
|
self.device = torch.device(device_type)
|
||||||
|
|
||||||
# for VRAM usage statistics
|
# for VRAM usage statistics
|
||||||
device_type = choose_torch_device()
|
device_type = choose_torch_device()
|
||||||
self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None
|
self.session_peakmem = torch.cuda.max_memory_allocated() if device_type == 'cuda' else None
|
||||||
@ -312,8 +321,9 @@ class T2I:
|
|||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
with scope(self.device.type), self.model.ema_scope():
|
device_type = choose_autocast_device(self.device)
|
||||||
for n in trange(iterations, desc='>> Generating'):
|
with scope(device_type), self.model.ema_scope():
|
||||||
|
for n in trange(iterations, desc='Generating'):
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
image = next(images_iterator)
|
image = next(images_iterator)
|
||||||
results.append([image, seed])
|
results.append([image, seed])
|
||||||
@ -346,7 +356,7 @@ class T2I:
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(
|
print(
|
||||||
f'Error running RealESRGAN - Your image was not upscaled.\n{e}'
|
f'>> Error running RealESRGAN - Your image was not upscaled.\n{e}'
|
||||||
)
|
)
|
||||||
if image_callback is not None:
|
if image_callback is not None:
|
||||||
if save_original:
|
if save_original:
|
||||||
@ -359,11 +369,11 @@ class T2I:
|
|||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print('*interrupted*')
|
print('*interrupted*')
|
||||||
print(
|
print(
|
||||||
'Partial results will be returned; if --grid was requested, nothing will be returned.'
|
'>> Partial results will be returned; if --grid was requested, nothing will be returned.'
|
||||||
)
|
)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
print(str(e))
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
print('Are you sure your system has an adequate NVIDIA GPU?')
|
print('>> Are you sure your system has an adequate NVIDIA GPU?')
|
||||||
|
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
print('>> Usage stats:')
|
print('>> Usage stats:')
|
||||||
@ -464,7 +474,6 @@ class T2I:
|
|||||||
)
|
)
|
||||||
|
|
||||||
t_enc = int(strength * steps)
|
t_enc = int(strength * steps)
|
||||||
# print(f"target t_enc is {t_enc} steps")
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
uc, c = self._get_uc_and_c(prompt, skip_normalize)
|
uc, c = self._get_uc_and_c(prompt, skip_normalize)
|
||||||
@ -515,7 +524,7 @@ class T2I:
|
|||||||
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
||||||
if len(x_samples) != 1:
|
if len(x_samples) != 1:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f'expected to get a single image, but got {len(x_samples)}')
|
f'>> expected to get a single image, but got {len(x_samples)}')
|
||||||
x_sample = 255.0 * rearrange(
|
x_sample = 255.0 * rearrange(
|
||||||
x_samples[0].cpu().numpy(), 'c h w -> h w c'
|
x_samples[0].cpu().numpy(), 'c h w -> h w c'
|
||||||
)
|
)
|
||||||
@ -525,17 +534,12 @@ class T2I:
|
|||||||
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||||
return self.seed
|
return self.seed
|
||||||
|
|
||||||
def _get_device(self):
|
|
||||||
device_type = choose_torch_device()
|
|
||||||
return torch.device(device_type)
|
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
"""Load and initialize the model from configuration variables passed at object creation time"""
|
"""Load and initialize the model from configuration variables passed at object creation time"""
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
seed_everything(self.seed)
|
seed_everything(self.seed)
|
||||||
try:
|
try:
|
||||||
config = OmegaConf.load(self.config)
|
config = OmegaConf.load(self.config)
|
||||||
self.device = self._get_device()
|
|
||||||
model = self._load_model_from_config(config, self.weights)
|
model = self._load_model_from_config(config, self.weights)
|
||||||
if self.embedding_path is not None:
|
if self.embedding_path is not None:
|
||||||
model.embedding_manager.load(
|
model.embedding_manager.load(
|
||||||
@ -544,12 +548,10 @@ class T2I:
|
|||||||
self.model = model.to(self.device)
|
self.model = model.to(self.device)
|
||||||
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
# model.to doesn't change the cond_stage_model.device used to move the tokenizer output, so set it here
|
||||||
self.model.cond_stage_model.device = self.device
|
self.model.cond_stage_model.device = self.device
|
||||||
except AttributeError:
|
except AttributeError as e:
|
||||||
import traceback
|
print(f'>> Error loading model. {str(e)}', file=sys.stderr)
|
||||||
print(
|
|
||||||
'Error loading model. Only the CUDA backend is supported', file=sys.stderr)
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
raise SystemExit
|
raise SystemExit from e
|
||||||
|
|
||||||
self._set_sampler()
|
self._set_sampler()
|
||||||
|
|
||||||
|
@ -9,6 +9,7 @@ import sys
|
|||||||
import copy
|
import copy
|
||||||
import warnings
|
import warnings
|
||||||
import time
|
import time
|
||||||
|
from ldm.dream.devices import choose_torch_device
|
||||||
import ldm.dream.readline
|
import ldm.dream.readline
|
||||||
from ldm.dream.pngwriter import PngWriter, PromptFormatter
|
from ldm.dream.pngwriter import PngWriter, PromptFormatter
|
||||||
from ldm.dream.server import DreamServer, ThreadingDreamServer
|
from ldm.dream.server import DreamServer, ThreadingDreamServer
|
||||||
@ -60,6 +61,7 @@ def main():
|
|||||||
# this is solely for recreating the prompt
|
# this is solely for recreating the prompt
|
||||||
latent_diffusion_weights=opt.laion400m,
|
latent_diffusion_weights=opt.laion400m,
|
||||||
embedding_path=opt.embedding_path,
|
embedding_path=opt.embedding_path,
|
||||||
|
device_type=opt.device
|
||||||
)
|
)
|
||||||
|
|
||||||
# make sure the output directory exists
|
# make sure the output directory exists
|
||||||
@ -346,6 +348,8 @@ def create_argv_parser():
|
|||||||
dest='full_precision',
|
dest='full_precision',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='Use slower full precision math for calculations',
|
help='Use slower full precision math for calculations',
|
||||||
|
# MPS only functions with full precision, see https://github.com/lstein/stable-diffusion/issues/237
|
||||||
|
default=choose_torch_device() == 'mps',
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'-g',
|
'-g',
|
||||||
@ -418,6 +422,13 @@ def create_argv_parser():
|
|||||||
default='model',
|
default='model',
|
||||||
help='Indicates the Stable Diffusion model to use.',
|
help='Indicates the Stable Diffusion model to use.',
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--device',
|
||||||
|
'-d',
|
||||||
|
type=str,
|
||||||
|
default='cuda',
|
||||||
|
help="device to run stable diffusion on. defaults to cuda `torch.cuda.current_device()` if available"
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user