Merge branch 'lstein-import-safetensors' of github.com:invoke-ai/InvokeAI into lstein-import-safetensors

This commit is contained in:
Lincoln Stein 2023-01-17 22:52:50 -05:00
commit 4a9e93463d
16 changed files with 73 additions and 37 deletions

View File

@ -119,6 +119,7 @@ jobs:
run: >
configure_invokeai.py
--yes
--default_only
--full-precision # can't use fp16 weights without a GPU
- name: Run the tests

View File

@ -217,7 +217,7 @@ manager, please follow these steps:
!!! tip
Do not move the source code repository after installation. The virtual environment directory has absolute paths in it that get confused if the directory is moved.
Do not move the source code repository after installation. The virtual environment directory has absolute paths in it that get confused if the directory is moved.
---

View File

@ -29,7 +29,7 @@ from ldm.invoke.concepts_lib import HuggingFaceConceptsLibrary
from ldm.invoke.conditioning import get_uc_and_c_and_ec
from ldm.invoke.devices import choose_torch_device, choose_precision
from ldm.invoke.generator.inpaint import infill_methods
from ldm.invoke.globals import global_cache_dir
from ldm.invoke.globals import global_cache_dir, Globals
from ldm.invoke.image_util import InitImageResizer
from ldm.invoke.model_manager import ModelManager
from ldm.invoke.pngwriter import PngWriter
@ -201,6 +201,7 @@ class Generate:
self.precision = 'float32'
if self.precision == 'auto':
self.precision = choose_precision(self.device)
Globals.full_precision = self.precision=='float32'
# model caching system for fast switching
self.model_manager = ModelManager(mconfig,self.device,self.precision,max_loaded_models=max_loaded_models)

View File

@ -335,4 +335,5 @@ class CkptGenerator():
os.makedirs(dirname, exist_ok=True)
image.save(filepath,'PNG')
def torch_dtype(self)->torch.dtype:
return torch.float16 if self.precision == 'float16' else torch.float32

View File

@ -72,16 +72,18 @@ class CkptTxt2Img(CkptGenerator):
device = self.model.device
if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device='cpu').to(device)
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
dtype=self.torch_dtype(),
device='cpu').to(device)
else:
x = torch.randn([1,
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=device)
self.latent_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
dtype=self.torch_dtype(),
device=device)
if self.perlin > 0.0:
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
return x

View File

@ -21,10 +21,19 @@ def choose_precision(device) -> str:
return 'float16'
return 'float32'
def torch_dtype(device) -> torch.dtype:
if Globals.full_precision:
return torch.float32
if choose_precision(device) == 'float16':
return torch.float16
else:
return torch.float32
def choose_autocast(precision):
'''Returns an autocast context or nullcontext for the given precision string'''
# float16 currently requires autocast to avoid errors like:
# 'expected scalar type Half but found Float'
print(f'DEBUG: choose_autocast() called')
if precision == 'autocast' or precision == 'float16':
return autocast
return nullcontext

View File

@ -8,6 +8,7 @@ import os
import os.path as osp
import random
import traceback
from contextlib import nullcontext
import cv2
import numpy as np
@ -18,8 +19,6 @@ from einops import rearrange
from pytorch_lightning import seed_everything
from tqdm import trange
from ldm.invoke.devices import choose_autocast
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ldm.models.diffusion.ddpm import DiffusionWrapper
from ldm.util import rand_perlin_2d
@ -64,7 +63,7 @@ class Generator:
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
safety_checker:dict=None,
**kwargs):
scope = choose_autocast(self.precision)
scope = nullcontext
self.safety_checker = safety_checker
attention_maps_images = []
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
@ -236,7 +235,8 @@ class Generator:
def get_perlin_noise(self,width,height):
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
noise = torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
return noise
def new_seed(self):
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
@ -341,3 +341,6 @@ class Generator:
image.save(filepath,'PNG')
def torch_dtype(self)->torch.dtype:
return torch.float16 if self.precision == 'float16' else torch.float32

View File

@ -36,10 +36,9 @@ class Txt2Img(Generator):
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
def make_image(x_T) -> PIL.Image.Image:
pipeline_output = pipeline.image_from_embeddings(
latents=torch.zeros_like(x_T),
latents=torch.zeros_like(x_T,dtype=self.torch_dtype()),
noise=x_T,
num_inference_steps=steps,
conditioning_data=conditioning_data,
@ -59,16 +58,18 @@ class Txt2Img(Generator):
input_channels = min(self.latent_channels, 4)
if self.use_mps_noise or device.type == 'mps':
x = torch.randn([1,
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device='cpu').to(device)
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
dtype=self.torch_dtype(),
device='cpu').to(device)
else:
x = torch.randn([1,
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
device=device)
input_channels,
height // self.downsampling_factor,
width // self.downsampling_factor],
dtype=self.torch_dtype(),
device=device)
if self.perlin > 0.0:
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
return x

View File

@ -90,9 +90,9 @@ class Txt2Img2Img(Generator):
def get_noise_like(self, like: torch.Tensor):
device = like.device
if device.type == 'mps':
x = torch.randn_like(like, device='cpu').to(device)
x = torch.randn_like(like, device='cpu', dtype=self.torch_dtype()).to(device)
else:
x = torch.randn_like(like, device=device)
x = torch.randn_like(like, device=device, dtype=self.torch_dtype())
if self.perlin > 0.0:
shape = like.shape
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
@ -117,10 +117,12 @@ class Txt2Img2Img(Generator):
self.latent_channels,
scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor],
device='cpu').to(device)
dtype=self.torch_dtype(),
device='cpu').to(device)
else:
return torch.randn([1,
self.latent_channels,
scaled_height // self.downsampling_factor,
scaled_width // self.downsampling_factor],
device=device)
dtype=self.torch_dtype(),
device=device)

View File

@ -43,6 +43,9 @@ Globals.always_use_cpu = False
# The CLI will test connectivity at startup time.
Globals.internet_available = True
# whether we are forcing full precision
Globals.full_precision = False
def global_config_dir()->Path:
return Path(Globals.root, Globals.config_dir)

View File

@ -349,7 +349,7 @@ class ModelManager(object):
if self.precision == 'float16':
print(' | Using faster float16 precision')
model.to(torch.float16)
model = model.to(torch.float16)
else:
print(' | Using more accurate float32 precision')
@ -763,7 +763,7 @@ class ModelManager(object):
for model in legacy_locations:
source = models_dir /model
if source.exists():
print(f'DEBUG: Moving {models_dir / model} into hub')
print(f'** Moving {models_dir / model} into hub')
move(models_dir / model, hub)
# anything else gets moved into the diffusers directory

View File

@ -7,6 +7,7 @@ import torch
import diffusers
from torch import nn
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from ldm.invoke.devices import torch_dtype
# adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl
@ -383,7 +384,7 @@ def inject_attention_function(unet, context: Context):
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
this_attention_slice = suggested_attention_slice
mask = context.cross_attention_mask
mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device))
saved_mask = mask
this_mask = 1 - mask
attention_slice = remapped_saved_attention_slice * saved_mask + \

View File

@ -4,7 +4,7 @@ import torch
from transformers import CLIPTokenizer, CLIPTextModel
from ldm.modules.textual_inversion_manager import TextualInversionManager
from ldm.invoke.devices import torch_dtype
class WeightedPromptFragmentsToEmbeddingsConverter():
@ -207,7 +207,7 @@ class WeightedPromptFragmentsToEmbeddingsConverter():
per_token_weights += [1.0] * pad_length
all_token_ids_tensor = torch.tensor(all_token_ids, dtype=torch.long, device=device)
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch.float32, device=device)
per_token_weights_tensor = torch.tensor(per_token_weights, dtype=torch_dtype(self.text_encoder.device), device=device)
#print(f"assembled all_token_ids_tensor with shape {all_token_ids_tensor.shape}")
return all_token_ids_tensor, per_token_weights_tensor

View File

@ -111,7 +111,6 @@ class TextualInversionManager():
if ti.trigger_token_id is not None:
raise ValueError(f"Tokens already injected for textual inversion with trigger '{ti.trigger_string}'")
print(f'DEBUG: Injecting token {ti.trigger_string}')
trigger_token_id = self._get_or_create_token_id_and_assign_embedding(ti.trigger_string, ti.embedding[0])
if ti.embedding_vector_length > 1:

View File

@ -8,6 +8,7 @@ from threading import Thread
from urllib import request
from tqdm import tqdm
from pathlib import Path
from ldm.invoke.devices import torch_dtype
import numpy as np
import torch
@ -235,7 +236,8 @@ def rand_perlin_2d(shape, res, device, fade = lambda t: 6*t**5 - 15*t**4 + 10*t*
n01 = dot(tile_grads([0, -1],[1, None]), [0, -1]).to(device)
n11 = dot(tile_grads([1, None], [1, None]), [-1,-1]).to(device)
t = fade(grid[:shape[0], :shape[1]])
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device)
noise = math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]).to(device)
return noise.to(dtype=torch_dtype(device))
def ask_user(question: str, answers: list):
from itertools import chain, repeat

View File

@ -197,6 +197,14 @@ def recommended_datasets()->dict:
datasets[ds]=True
return datasets
#---------------------------------------------
def default_dataset()->dict:
datasets = dict()
for ds in Datasets.keys():
if Datasets[ds].get('default',False):
datasets[ds]=True
return datasets
#---------------------------------------------
def all_datasets()->dict:
datasets = dict()
@ -646,7 +654,7 @@ def download_weights(opt:dict) -> Union[str, None]:
precision = 'float32' if opt.full_precision else choose_precision(torch.device(choose_torch_device()))
if opt.yes_to_all:
models = recommended_datasets()
models = default_dataset() if opt.default_only else recommended_datasets()
access_token = authenticate(opt.yes_to_all)
if len(models)>0:
successfully_downloaded = download_weight_datasets(models, access_token, precision=precision)
@ -808,6 +816,9 @@ def main():
dest='yes_to_all',
action='store_true',
help='answer "yes" to all prompts')
parser.add_argument('--default_only',
action='store_true',
help='when --yes specified, only install the default model')
parser.add_argument('--config_file',
'-c',
dest='config_file',