mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
commit
3fb095de88
@ -29,7 +29,7 @@ from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
|
|||||||
from ldm.invoke.conditioning import get_uc_and_c_and_ec
|
from ldm.invoke.conditioning import get_uc_and_c_and_ec
|
||||||
from ldm.invoke.devices import choose_torch_device, choose_precision
|
from ldm.invoke.devices import choose_torch_device, choose_precision
|
||||||
from ldm.invoke.generator.inpaint import infill_methods
|
from ldm.invoke.generator.inpaint import infill_methods
|
||||||
from ldm.invoke.globals import global_cache_dir
|
from ldm.invoke.globals import global_cache_dir, Globals
|
||||||
from ldm.invoke.image_util import InitImageResizer
|
from ldm.invoke.image_util import InitImageResizer
|
||||||
from ldm.invoke.model_manager import ModelManager
|
from ldm.invoke.model_manager import ModelManager
|
||||||
from ldm.invoke.pngwriter import PngWriter
|
from ldm.invoke.pngwriter import PngWriter
|
||||||
@ -201,6 +201,7 @@ class Generate:
|
|||||||
self.precision = 'float32'
|
self.precision = 'float32'
|
||||||
if self.precision == 'auto':
|
if self.precision == 'auto':
|
||||||
self.precision = choose_precision(self.device)
|
self.precision = choose_precision(self.device)
|
||||||
|
Globals.full_precision = self.precision=='float32'
|
||||||
|
|
||||||
# model caching system for fast switching
|
# model caching system for fast switching
|
||||||
self.model_manager = ModelManager(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)
|
self.model_manager = ModelManager(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)
|
||||||
|
@ -335,4 +335,5 @@ class CkptGenerator():
|
|||||||
os.makedirs(dirname, exist_ok=True)
|
os.makedirs(dirname, exist_ok=True)
|
||||||
image.save(filepath,'PNG')
|
image.save(filepath,'PNG')
|
||||||
|
|
||||||
|
def torch_dtype(self)->torch.dtype:
|
||||||
|
return torch.float16 if self.precision == 'float16' else torch.float32
|
||||||
|
@ -72,16 +72,18 @@ class CkptTxt2Img(CkptGenerator):
|
|||||||
device = self.model.device
|
device = self.model.device
|
||||||
if self.use_mps_noise or device.type == 'mps':
|
if self.use_mps_noise or device.type == 'mps':
|
||||||
x = torch.randn([1,
|
x = torch.randn([1,
|
||||||
self.latent_channels,
|
self.latent_channels,
|
||||||
height // self.downsampling_factor,
|
height // self.downsampling_factor,
|
||||||
width // self.downsampling_factor],
|
width // self.downsampling_factor],
|
||||||
device='cpu').to(device)
|
dtype=self.torch_dtype(),
|
||||||
|
device='cpu').to(device)
|
||||||
else:
|
else:
|
||||||
x = torch.randn([1,
|
x = torch.randn([1,
|
||||||
self.latent_channels,
|
self.latent_channels,
|
||||||
height // self.downsampling_factor,
|
height // self.downsampling_factor,
|
||||||
width // self.downsampling_factor],
|
width // self.downsampling_factor],
|
||||||
device=device)
|
dtype=self.torch_dtype(),
|
||||||
|
device=device)
|
||||||
if self.perlin > 0.0:
|
if self.perlin > 0.0:
|
||||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
||||||
return x
|
return x
|
||||||
|
@ -21,10 +21,19 @@ def choose_precision(device) -> str:
|
|||||||
return 'float16'
|
return 'float16'
|
||||||
return 'float32'
|
return 'float32'
|
||||||
|
|
||||||
|
def torch_dtype(device) -> torch.dtype:
|
||||||
|
if Globals.full_precision:
|
||||||
|
return torch.float32
|
||||||
|
if choose_precision(device) == 'float16':
|
||||||
|
return torch.float16
|
||||||
|
else:
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
def choose_autocast(precision):
|
def choose_autocast(precision):
|
||||||
'''Returns an autocast context or nullcontext for the given precision string'''
|
'''Returns an autocast context or nullcontext for the given precision string'''
|
||||||
# float16 currently requires autocast to avoid errors like:
|
# float16 currently requires autocast to avoid errors like:
|
||||||
# 'expected scalar type Half but found Float'
|
# 'expected scalar type Half but found Float'
|
||||||
|
print(f'DEBUG: choose_autocast() called')
|
||||||
if precision == 'autocast' or precision == 'float16':
|
if precision == 'autocast' or precision == 'float16':
|
||||||
return autocast
|
return autocast
|
||||||
return nullcontext
|
return nullcontext
|
||||||
|
@ -8,6 +8,7 @@ import os
|
|||||||
import os.path as osp
|
import os.path as osp
|
||||||
import random
|
import random
|
||||||
import traceback
|
import traceback
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -18,8 +19,6 @@ from einops import rearrange
|
|||||||
from pytorch_lightning import seed_everything
|
from pytorch_lightning import seed_everything
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
from ldm.invoke.devices import choose_autocast
|
|
||||||
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
|
||||||
from ldm.models.diffusion.ddpm import DiffusionWrapper
|
from ldm.models.diffusion.ddpm import DiffusionWrapper
|
||||||
from ldm.util import rand_perlin_2d
|
from ldm.util import rand_perlin_2d
|
||||||
|
|
||||||
@ -64,7 +63,7 @@ class Generator:
|
|||||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||||
safety_checker:dict=None,
|
safety_checker:dict=None,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
scope = choose_autocast(self.precision)
|
scope = nullcontext
|
||||||
self.safety_checker = safety_checker
|
self.safety_checker = safety_checker
|
||||||
attention_maps_images = []
|
attention_maps_images = []
|
||||||
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
||||||
@ -236,7 +235,8 @@ class Generator:
|
|||||||
|
|
||||||
def get_perlin_noise(self,width,height):
|
def get_perlin_noise(self,width,height):
|
||||||
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
|
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
|
||||||
return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
|
noise = torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
|
||||||
|
return noise
|
||||||
|
|
||||||
def new_seed(self):
|
def new_seed(self):
|
||||||
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||||
@ -341,3 +341,6 @@ class Generator:
|
|||||||
image.save(filepath,'PNG')
|
image.save(filepath,'PNG')
|
||||||
|
|
||||||
|
|
||||||
|
def torch_dtype(self)->torch.dtype:
|
||||||
|
return torch.float16 if self.precision == 'float16' else torch.float32
|
||||||
|
|
||||||
|
@ -36,10 +36,9 @@ class Txt2Img(Generator):
|
|||||||
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
|
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
|
||||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||||
|
|
||||||
|
|
||||||
def make_image(x_T) -> PIL.Image.Image:
|
def make_image(x_T) -> PIL.Image.Image:
|
||||||
pipeline_output = pipeline.image_from_embeddings(
|
pipeline_output = pipeline.image_from_embeddings(
|
||||||
latents=torch.zeros_like(x_T),
|
latents=torch.zeros_like(x_T,dtype=self.torch_dtype()),
|
||||||
noise=x_T,
|
noise=x_T,
|
||||||
num_inference_steps=steps,
|
num_inference_steps=steps,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
@ -59,16 +58,18 @@ class Txt2Img(Generator):
|
|||||||
input_channels = min(self.latent_channels, 4)
|
input_channels = min(self.latent_channels, 4)
|
||||||
if self.use_mps_noise or device.type == 'mps':
|
if self.use_mps_noise or device.type == 'mps':
|
||||||
x = torch.randn([1,
|
x = torch.randn([1,
|
||||||
input_channels,
|
input_channels,
|
||||||
height // self.downsampling_factor,
|
height // self.downsampling_factor,
|
||||||
width // self.downsampling_factor],
|
width // self.downsampling_factor],
|
||||||
device='cpu').to(device)
|
dtype=self.torch_dtype(),
|
||||||
|
device='cpu').to(device)
|
||||||
else:
|
else:
|
||||||
x = torch.randn([1,
|
x = torch.randn([1,
|
||||||
input_channels,
|
input_channels,
|
||||||
height // self.downsampling_factor,
|
height // self.downsampling_factor,
|
||||||
width // self.downsampling_factor],
|
width // self.downsampling_factor],
|
||||||
device=device)
|
dtype=self.torch_dtype(),
|
||||||
|
device=device)
|
||||||
if self.perlin > 0.0:
|
if self.perlin > 0.0:
|
||||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
||||||
return x
|
return x
|
||||||
|
@ -90,9 +90,9 @@ class Txt2Img2Img(Generator):
|
|||||||
def get_noise_like(self, like: torch.Tensor):
|
def get_noise_like(self, like: torch.Tensor):
|
||||||
device = like.device
|
device = like.device
|
||||||
if device.type == 'mps':
|
if device.type == 'mps':
|
||||||
x = torch.randn_like(like, device='cpu').to(device)
|
x = torch.randn_like(like, device='cpu', dtype=self.torch_dtype()).to(device)
|
||||||
else:
|
else:
|
||||||
x = torch.randn_like(like, device=device)
|
x = torch.randn_like(like, device=device, dtype=self.torch_dtype())
|
||||||
if self.perlin > 0.0:
|
if self.perlin > 0.0:
|
||||||
shape = like.shape
|
shape = like.shape
|
||||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||||
@ -117,10 +117,12 @@ class Txt2Img2Img(Generator):
|
|||||||
self.latent_channels,
|
self.latent_channels,
|
||||||
scaled_height // self.downsampling_factor,
|
scaled_height // self.downsampling_factor,
|
||||||
scaled_width // self.downsampling_factor],
|
scaled_width // self.downsampling_factor],
|
||||||
device='cpu').to(device)
|
dtype=self.torch_dtype(),
|
||||||
|
device='cpu').to(device)
|
||||||
else:
|
else:
|
||||||
return torch.randn([1,
|
return torch.randn([1,
|
||||||
self.latent_channels,
|
self.latent_channels,
|
||||||
scaled_height // self.downsampling_factor,
|
scaled_height // self.downsampling_factor,
|
||||||
scaled_width // self.downsampling_factor],
|
scaled_width // self.downsampling_factor],
|
||||||
device=device)
|
dtype=self.torch_dtype(),
|
||||||
|
device=device)
|
||||||
|
@ -43,6 +43,9 @@ Globals.always_use_cpu = False
|
|||||||
# The CLI will test connectivity at startup time.
|
# The CLI will test connectivity at startup time.
|
||||||
Globals.internet_available = True
|
Globals.internet_available = True
|
||||||
|
|
||||||
|
# whether we are forcing full precision
|
||||||
|
Globals.full_precision = False
|
||||||
|
|
||||||
def global_config_dir()->Path:
|
def global_config_dir()->Path:
|
||||||
return Path(Globals.root, Globals.config_dir)
|
return Path(Globals.root, Globals.config_dir)
|
||||||
|
|
||||||
|
@ -349,7 +349,7 @@ class ModelManager(object):
|
|||||||
|
|
||||||
if self.precision == 'float16':
|
if self.precision == 'float16':
|
||||||
print(' | Using faster float16 precision')
|
print(' | Using faster float16 precision')
|
||||||
model.to(torch.float16)
|
model = model.to(torch.float16)
|
||||||
else:
|
else:
|
||||||
print(' | Using more accurate float32 precision')
|
print(' | Using more accurate float32 precision')
|
||||||
|
|
||||||
@ -761,7 +761,7 @@ class ModelManager(object):
|
|||||||
for model in legacy_locations:
|
for model in legacy_locations:
|
||||||
source = models_dir /model
|
source = models_dir /model
|
||||||
if source.exists():
|
if source.exists():
|
||||||
print(f'DEBUG: Moving {models_dir / model} into hub')
|
print(f'** Moving {models_dir / model} into hub')
|
||||||
move(models_dir / model, hub)
|
move(models_dir / model, hub)
|
||||||
|
|
||||||
# anything else gets moved into the diffusers directory
|
# anything else gets moved into the diffusers directory
|
||||||
|
@ -7,6 +7,7 @@ import torch
|
|||||||
import diffusers
|
import diffusers
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
|
from ldm.invoke.devices import torch_dtype
|
||||||
|
|
||||||
# adapted from bloc97's CrossAttentionControl colab
|
# adapted from bloc97's CrossAttentionControl colab
|
||||||
# https://github.com/bloc97/CrossAttentionControl
|
# https://github.com/bloc97/CrossAttentionControl
|
||||||
@ -383,7 +384,7 @@ def inject_attention_function(unet, context: Context):
|
|||||||
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
|
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
|
||||||
this_attention_slice = suggested_attention_slice
|
this_attention_slice = suggested_attention_slice
|
||||||
|
|
||||||
mask = context.cross_attention_mask
|
mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device))
|
||||||
saved_mask = mask
|
saved_mask = mask
|
||||||
this_mask = 1 - mask
|
this_mask = 1 - mask
|
||||||
attention_slice = remapped_saved_attention_slice * saved_mask + \
|
attention_slice = remapped_saved_attention_slice * saved_mask + \
|
||||||
|
@ -4,7 +4,7 @@ import torch
|
|||||||
from transformers import CLIPTokenizer, CLIPTextModel
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
|
|
||||||
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
from ldm.modules.textual_inversion_manager import TextualInversionManager
|
||||||
|
from ldm.invoke.devices import torch_dtype
|
||||||
|
|
||||||
class WeightedPromptFragmentsToEmbeddingsConverter():
|
class WeightedPromptFragmentsToEmbeddingsConverter():
|
||||||
|
|
||||||
@ -207,7 +207,7 @@ class WeightedPromptFragmentsToEmbeddingsConverter():
|
|||||||
per_token_weights += [1.0] * pad_length
|
per_token_weights += [1.0] * pad_length
|
||||||
|
|
||||||
all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long, device=device)
|
all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long, device=device)
|
||||||
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32, device=device)
|
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch_dtype(self.text_encoder.device), device=device)
|
||||||
#print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}")
|
#print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}")
|
||||||
return all_token_ids_tensor, per_token_weights_tensor
|
return all_token_ids_tensor, per_token_weights_tensor
|
||||||
|
|
||||||
|
@ -111,7 +111,6 @@ class TextualInversionManager():
|
|||||||
if ti.trigger_token_id is not None:
|
if ti.trigger_token_id is not None:
|
||||||
raise ValueError(f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'")
|
raise ValueError(f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'")
|
||||||
|
|
||||||
print(f'DEBUG: Injecting token {ti.trigger_string}')
|
|
||||||
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(ti.trigger_string, ti.embedding[0])
|
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(ti.trigger_string, ti.embedding[0])
|
||||||
|
|
||||||
if ti.embedding_vector_length > 1:
|
if ti.embedding_vector_length > 1:
|
||||||
|
@ -8,6 +8,7 @@ from threading import Thread
|
|||||||
from urllib import request
|
from urllib import request
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from ldm.invoke.devices import torch_dtype
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -235,7 +236,8 @@ def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t*
|
|||||||
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device)
|
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device)
|
||||||
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device)
|
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device)
|
||||||
t = fade(grid[:shape[0], :shape[1]])
|
t = fade(grid[:shape[0], :shape[1]])
|
||||||
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device)
|
noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device)
|
||||||
|
return noise.to(dtype=torch_dtype(device))
|
||||||
|
|
||||||
def ask_user(question: str, answers: list):
|
def ask_user(question: str, answers: list):
|
||||||
from itertools import chain, repeat
|
from itertools import chain, repeat
|
||||||
|
Loading…
Reference in New Issue
Block a user