do not use autocast for diffusers

- All tensors in diffusers code path are now set explicitly to
  float32 or float16, depending on the --precision flag.
- autocast is still used in the ckpt path, since it is being
  deprecated.
This commit is contained in:
Lincoln Stein 2023-01-16 19:32:06 -05:00
parent 563196bd03
commit 7e8f364d8d
11 changed files with 46 additions and 27 deletions

View File

@ -29,7 +29,7 @@ from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
from ldm.invoke.conditioning import get_uc_and_c_and_ec
from ldm.invoke.devices import choose_torch_device, choose_precision
from ldm.invoke.generator.inpaint import infill_methods
from ldm.invoke.globals import global_cache_dir
from ldm.invoke.globals import global_cache_dir, Globals
from ldm.invoke.image_util import InitImageResizer
from ldm.invoke.model_manager import ModelManager
from ldm.invoke.pngwriter import PngWriter
@ -201,6 +201,7 @@ class Generate:
self.precision = 'float32'
if self.precision == 'auto':
self.precision = choose_precision(self.device)
Globals.full_precision = self.precision=='float32'
# model caching system for fast switching
self.model_manager = ModelManager(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)

View File

@ -335,4 +335,5 @@ class CkptGenerator():
os.makedirs(dirname, exist_ok=True)
image.save(filepath,'PNG')
def torch_dtype(self)->torch.dtype:
return torch.float16 if self.precision == 'float16' else torch.float32

View File

@ -72,16 +72,18 @@ class CkptTxt2Img(CkptGenerator):
device = self.model.device
if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device='cpu').to(device)
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
dtype=self.torch_dtype(),
device='cpu').to(device)
else:
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=device)
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
dtype=self.torch_dtype(),
device=device)
if self.perlin > 0.0:
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
return x

View File

@ -17,7 +17,7 @@ from ..models.diffusion import cross_attention_control
from ..models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from ..modules.encoders.modules import WeightedFrozenCLIPEmbedder
from ..modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter
from ldm.invoke.devices import torch_dtype
def get_uc_and_c_and_ec(prompt_string, model, log_tokens=False, skip_normalize_legacy_blend=False):
@ -238,7 +238,6 @@ def _get_embeddings_and_tokens_for_prompt(model, flattened_prompt: FlattenedProm
if log_tokens:
text = " ".join(fragments)
log_tokenization(text, model, display_label=log_display_label)
return embeddings, tokens

View File

@ -21,10 +21,19 @@ def choose_precision(device) -> str:
return 'float16'
return 'float32'
def torch_dtype(device) -> torch.dtype:
if Globals.full_precision:
return torch.float32
if choose_precision(device) == 'float16':
return torch.float16
else:
return torch.float32
def choose_autocast(precision):
'''Returns an autocast context or nullcontext for the given precision string'''
# float16 currently requires autocast to avoid errors like:
# 'expected scalar type Half but found Float'
print(f'DEBUG: choose_autocast() called')
if precision == 'autocast' or precision == 'float16':
return autocast
return nullcontext

View File

@ -22,6 +22,7 @@ from ldm.invoke.devices import choose_autocast
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ldm.models.diffusion.ddpm import DiffusionWrapper
from ldm.util import rand_perlin_2d
from contextlib import nullcontext
downsampling = 8
CAUTION_IMG = 'assets/caution.png'
@ -64,7 +65,7 @@ class Generator:
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None,
**kwargs):
scope = choose_autocast(self.precision)
scope = nullcontext
self.safety_checker = safety_checker
attention_maps_images = []
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
@ -341,3 +342,6 @@ class Generator:
image.save(filepath,'PNG')
def torch_dtype(self)->torch.dtype:
return torch.float16 if self.precision == 'float16' else torch.float32

View File

@ -36,10 +36,9 @@ class Txt2Img(Generator):
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
def make_image(x_T) -> PIL.Image.Image:
pipeline_output = pipeline.image_from_embeddings(
latents=torch.zeros_like(x_T),
latents=torch.zeros_like(x_T,dtype=self.torch_dtype()),
noise=x_T,
num_inference_steps=steps,
conditioning_data=conditioning_data,
@ -59,16 +58,18 @@ class Txt2Img(Generator):
input_channels = min(self.latent_channels, 4)
if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1,
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device='cpu').to(device)
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
dtype=self.torch_dtype(),
device='cpu').to(device)
else:
x = torch.randn([1,
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=device)
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
dtype=self.torch_dtype(),
device=device)
if self.perlin > 0.0:
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
return x

View File

@ -43,6 +43,9 @@ Globals.always_use_cpu = False
# The CLI will test connectivity at startup time.
Globals.internet_available = True
# whether we are forcing full precision
Globals.full_precision = False
def global_config_dir()->Path:
return Path(Globals.root, Globals.config_dir)

View File

@ -761,7 +761,7 @@ class ModelManager(object):
for model in legacy_locations:
source = models_dir /model
if source.exists():
print(f'DEBUG: Moving {models_dir / model} into hub')
print(f'** Moving {models_dir / model} into hub')
move(models_dir / model, hub)
# anything else gets moved into the diffusers directory

View File

@ -4,7 +4,7 @@ import torch
from transformers import CLIPTokenizer, CLIPTextModel
from ldm.modules.textual_inversion_manager import TextualInversionManager
from ldm.invoke.devices import torch_dtype
class WeightedPromptFragmentsToEmbeddingsConverter():
@ -207,7 +207,7 @@ class WeightedPromptFragmentsToEmbeddingsConverter():
per_token_weights += [1.0] * pad_length
all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long, device=device)
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32, device=device)
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch_dtype(self.text_encoder.device), device=device)
#print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}")
return all_token_ids_tensor, per_token_weights_tensor

View File

@ -111,7 +111,6 @@ class TextualInversionManager():
if ti.trigger_token_id is not None:
raise ValueError(f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'")
print(f'DEBUG: Injecting token {ti.trigger_string}')
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(ti.trigger_string, ti.embedding[0])
if ti.embedding_vector_length > 1: