do not use autocast for diffusers (#2349)

fixes #2345
This commit is contained in:
Kevin Turner 2023-01-17 14:26:35 -08:00 committed by GitHub
commit 3fb095de88
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 59 additions and 35 deletions

View File

@ -29,7 +29,7 @@ from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
from ldm.invoke.conditioning import get_uc_and_c_and_ec from ldm.invoke.conditioning import get_uc_and_c_and_ec
from ldm.invoke.devices import choose_torch_device, choose_precision from ldm.invoke.devices import choose_torch_device, choose_precision
from ldm.invoke.generator.inpaint import infill_methods from ldm.invoke.generator.inpaint import infill_methods
from ldm.invoke.globals import global_cache_dir from ldm.invoke.globals import global_cache_dir, Globals
from ldm.invoke.image_util import InitImageResizer from ldm.invoke.image_util import InitImageResizer
from ldm.invoke.model_manager import ModelManager from ldm.invoke.model_manager import ModelManager
from ldm.invoke.pngwriter import PngWriter from ldm.invoke.pngwriter import PngWriter
@ -201,6 +201,7 @@ class Generate:
self.precision = 'float32' self.precision = 'float32'
if self.precision == 'auto': if self.precision == 'auto':
self.precision = choose_precision(self.device) self.precision = choose_precision(self.device)
Globals.full_precision = self.precision=='float32'
# model caching system for fast switching # model caching system for fast switching
self.model_manager = ModelManager(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models) self.model_manager = ModelManager(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)

View File

@ -335,4 +335,5 @@ class CkptGenerator():
os.makedirs(dirname, exist_ok=True) os.makedirs(dirname, exist_ok=True)
image.save(filepath,'PNG') image.save(filepath,'PNG')
def torch_dtype(self)->torch.dtype:
return torch.float16 if self.precision == 'float16' else torch.float32

View File

@ -72,16 +72,18 @@ class CkptTxt2Img(CkptGenerator):
device = self.model.device device = self.model.device
if self.use_mps_noise or device.type == 'mps': if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1, x = torch.randn([1,
self.latent_channels, self.latent_channels,
height // self.downsampling_factor, height // self.downsampling_factor,
width // self.downsampling_factor], width // self.downsampling_factor],
device='cpu').to(device) dtype=self.torch_dtype(),
device='cpu').to(device)
else: else:
x = torch.randn([1, x = torch.randn([1,
self.latent_channels, self.latent_channels,
height // self.downsampling_factor, height // self.downsampling_factor,
width // self.downsampling_factor], width // self.downsampling_factor],
device=device) dtype=self.torch_dtype(),
device=device)
if self.perlin > 0.0: if self.perlin > 0.0:
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor) x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
return x return x

View File

@ -21,10 +21,19 @@ def choose_precision(device) -> str:
return 'float16' return 'float16'
return 'float32' return 'float32'
def torch_dtype(device) -> torch.dtype:
if Globals.full_precision:
return torch.float32
if choose_precision(device) == 'float16':
return torch.float16
else:
return torch.float32
def choose_autocast(precision): def choose_autocast(precision):
'''Returns an autocast context or nullcontext for the given precision string''' '''Returns an autocast context or nullcontext for the given precision string'''
# float16 currently requires autocast to avoid errors like: # float16 currently requires autocast to avoid errors like:
# 'expected scalar type Half but found Float' # 'expected scalar type Half but found Float'
print(f'DEBUG: choose_autocast() called')
if precision == 'autocast' or precision == 'float16': if precision == 'autocast' or precision == 'float16':
return autocast return autocast
return nullcontext return nullcontext

View File

@ -8,6 +8,7 @@ import os
import os.path as osp import os.path as osp
import random import random
import traceback import traceback
from contextlib import nullcontext
import cv2 import cv2
import numpy as np import numpy as np
@ -18,8 +19,6 @@ from einops import rearrange
from pytorch_lightning import seed_everything from pytorch_lightning import seed_everything
from tqdm import trange from tqdm import trange
from ldm.invoke.devices import choose_autocast
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ldm.models.diffusion.ddpm import DiffusionWrapper from ldm.models.diffusion.ddpm import DiffusionWrapper
from ldm.util import rand_perlin_2d from ldm.util import rand_perlin_2d
@ -64,7 +63,7 @@ class Generator:
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0, image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None, safety_checker:dict=None,
**kwargs): **kwargs):
scope = choose_autocast(self.precision) scope = nullcontext
self.safety_checker = safety_checker self.safety_checker = safety_checker
attention_maps_images = [] attention_maps_images = []
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image()) attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
@ -236,7 +235,8 @@ class Generator:
def get_perlin_noise(self,width,height): def get_perlin_noise(self,width,height):
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device) noise = torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
return noise
def new_seed(self): def new_seed(self):
self.seed = random.randrange(0, np.iinfo(np.uint32).max) self.seed = random.randrange(0, np.iinfo(np.uint32).max)
@ -341,3 +341,6 @@ class Generator:
image.save(filepath,'PNG') image.save(filepath,'PNG')
def torch_dtype(self)->torch.dtype:
return torch.float16 if self.precision == 'float16' else torch.float32

View File

@ -36,10 +36,9 @@ class Txt2Img(Generator):
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None) threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)) .add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
def make_image(x_T) -> PIL.Image.Image: def make_image(x_T) -> PIL.Image.Image:
pipeline_output = pipeline.image_from_embeddings( pipeline_output = pipeline.image_from_embeddings(
latents=torch.zeros_like(x_T), latents=torch.zeros_like(x_T,dtype=self.torch_dtype()),
noise=x_T, noise=x_T,
num_inference_steps=steps, num_inference_steps=steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
@ -59,16 +58,18 @@ class Txt2Img(Generator):
input_channels = min(self.latent_channels, 4) input_channels = min(self.latent_channels, 4)
if self.use_mps_noise or device.type == 'mps': if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1, x = torch.randn([1,
input_channels, input_channels,
height // self.downsampling_factor, height // self.downsampling_factor,
width // self.downsampling_factor], width // self.downsampling_factor],
device='cpu').to(device) dtype=self.torch_dtype(),
device='cpu').to(device)
else: else:
x = torch.randn([1, x = torch.randn([1,
input_channels, input_channels,
height // self.downsampling_factor, height // self.downsampling_factor,
width // self.downsampling_factor], width // self.downsampling_factor],
device=device) dtype=self.torch_dtype(),
device=device)
if self.perlin > 0.0: if self.perlin > 0.0:
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor) x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
return x return x

View File

@ -90,9 +90,9 @@ class Txt2Img2Img(Generator):
def get_noise_like(self, like: torch.Tensor): def get_noise_like(self, like: torch.Tensor):
device = like.device device = like.device
if device.type == 'mps': if device.type == 'mps':
x = torch.randn_like(like, device='cpu').to(device) x = torch.randn_like(like, device='cpu', dtype=self.torch_dtype()).to(device)
else: else:
x = torch.randn_like(like, device=device) x = torch.randn_like(like, device=device, dtype=self.torch_dtype())
if self.perlin > 0.0: if self.perlin > 0.0:
shape = like.shape shape = like.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2]) x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
@ -117,10 +117,12 @@ class Txt2Img2Img(Generator):
self.latent_channels, self.latent_channels,
scaled_height // self.downsampling_factor, scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor], scaled_width // self.downsampling_factor],
device='cpu').to(device) dtype=self.torch_dtype(),
device='cpu').to(device)
else: else:
return torch.randn([1, return torch.randn([1,
self.latent_channels, self.latent_channels,
scaled_height // self.downsampling_factor, scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor], scaled_width // self.downsampling_factor],
device=device) dtype=self.torch_dtype(),
device=device)

View File

@ -43,6 +43,9 @@ Globals.always_use_cpu = False
# The CLI will test connectivity at startup time. # The CLI will test connectivity at startup time.
Globals.internet_available = True Globals.internet_available = True
# whether we are forcing full precision
Globals.full_precision = False
def global_config_dir()->Path: def global_config_dir()->Path:
return Path(Globals.root, Globals.config_dir) return Path(Globals.root, Globals.config_dir)

View File

@ -349,7 +349,7 @@ class ModelManager(object):
if self.precision == 'float16': if self.precision == 'float16':
print(' | Using faster float16 precision') print(' | Using faster float16 precision')
model.to(torch.float16) model = model.to(torch.float16)
else: else:
print(' | Using more accurate float32 precision') print(' | Using more accurate float32 precision')
@ -761,7 +761,7 @@ class ModelManager(object):
for model in legacy_locations: for model in legacy_locations:
source = models_dir /model source = models_dir /model
if source.exists(): if source.exists():
print(f'DEBUG: Moving {models_dir / model} into hub') print(f'** Moving {models_dir / model} into hub')
move(models_dir / model, hub) move(models_dir / model, hub)
# anything else gets moved into the diffusers directory # anything else gets moved into the diffusers directory

View File

@ -7,6 +7,7 @@ import torch
import diffusers import diffusers
from torch import nn from torch import nn
from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.models.unet_2d_condition import UNet2DConditionModel
from ldm.invoke.devices import torch_dtype
# adapted from bloc97's CrossAttentionControl colab # adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl # https://github.com/bloc97/CrossAttentionControl
@ -383,7 +384,7 @@ def inject_attention_function(unet, context: Context):
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map) remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
this_attention_slice = suggested_attention_slice this_attention_slice = suggested_attention_slice
mask = context.cross_attention_mask mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device))
saved_mask = mask saved_mask = mask
this_mask = 1 - mask this_mask = 1 - mask
attention_slice = remapped_saved_attention_slice * saved_mask + \ attention_slice = remapped_saved_attention_slice * saved_mask + \

View File

@ -4,7 +4,7 @@ import torch
from transformers import CLIPTokenizer, CLIPTextModel from transformers import CLIPTokenizer, CLIPTextModel
from ldm.modules.textual_inversion_manager import TextualInversionManager from ldm.modules.textual_inversion_manager import TextualInversionManager
from ldm.invoke.devices import torch_dtype
class WeightedPromptFragmentsToEmbeddingsConverter(): class WeightedPromptFragmentsToEmbeddingsConverter():
@ -207,7 +207,7 @@ class WeightedPromptFragmentsToEmbeddingsConverter():
per_token_weights += [1.0] * pad_length per_token_weights += [1.0] * pad_length
all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long, device=device) all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long, device=device)
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32, device=device) per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch_dtype(self.text_encoder.device), device=device)
#print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}") #print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}")
return all_token_ids_tensor, per_token_weights_tensor return all_token_ids_tensor, per_token_weights_tensor

View File

@ -111,7 +111,6 @@ class TextualInversionManager():
if ti.trigger_token_id is not None: if ti.trigger_token_id is not None:
raise ValueError(f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'") raise ValueError(f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'")
print(f'DEBUG: Injecting token {ti.trigger_string}')
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(ti.trigger_string, ti.embedding[0]) trigger_token_id = self._get_or_create_token_id_and_assign_embedding(ti.trigger_string, ti.embedding[0])
if ti.embedding_vector_length > 1: if ti.embedding_vector_length > 1:

View File

@ -8,6 +8,7 @@ from threading import Thread
from urllib import request from urllib import request
from tqdm import tqdm from tqdm import tqdm
from pathlib import Path from pathlib import Path
from ldm.invoke.devices import torch_dtype
import numpy as np import numpy as np
import torch import torch
@ -235,7 +236,8 @@ def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t*
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device) n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device)
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device) n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device)
t = fade(grid[:shape[0], :shape[1]]) t = fade(grid[:shape[0], :shape[1]])
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device) noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device)
return noise.to(dtype=torch_dtype(device))
def ask_user(question: str, answers: list): def ask_user(question: str, answers: list):
from itertools import chain, repeat from itertools import chain, repeat