mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'lstein-import-safetensors' of github.com:invoke-ai/InvokeAI into lstein-import-safetensors
This commit is contained in:
@ -335,4 +335,5 @@ class CkptGenerator():
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
image.save(filepath,'PNG')
|
||||
|
||||
|
||||
def torch_dtype(self)->torch.dtype:
|
||||
return torch.float16 if self.precision == 'float16' else torch.float32
|
||||
|
@ -72,16 +72,18 @@ class CkptTxt2Img(CkptGenerator):
|
||||
device = self.model.device
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
x = torch.randn([1,
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
device='cpu').to(device)
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
dtype=self.torch_dtype(),
|
||||
device='cpu').to(device)
|
||||
else:
|
||||
x = torch.randn([1,
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
device=device)
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
dtype=self.torch_dtype(),
|
||||
device=device)
|
||||
if self.perlin > 0.0:
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
||||
return x
|
||||
|
@ -21,10 +21,19 @@ def choose_precision(device) -> str:
|
||||
return 'float16'
|
||||
return 'float32'
|
||||
|
||||
def torch_dtype(device) -> torch.dtype:
|
||||
if Globals.full_precision:
|
||||
return torch.float32
|
||||
if choose_precision(device) == 'float16':
|
||||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
def choose_autocast(precision):
|
||||
'''Returns an autocast context or nullcontext for the given precision string'''
|
||||
# float16 currently requires autocast to avoid errors like:
|
||||
# 'expected scalar type Half but found Float'
|
||||
print(f'DEBUG: choose_autocast() called')
|
||||
if precision == 'autocast' or precision == 'float16':
|
||||
return autocast
|
||||
return nullcontext
|
||||
|
@ -8,6 +8,7 @@ import os
|
||||
import os.path as osp
|
||||
import random
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@ -18,8 +19,6 @@ from einops import rearrange
|
||||
from pytorch_lightning import seed_everything
|
||||
from tqdm import trange
|
||||
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.models.diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from ldm.models.diffusion.ddpm import DiffusionWrapper
|
||||
from ldm.util import rand_perlin_2d
|
||||
|
||||
@ -64,7 +63,7 @@ class Generator:
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
safety_checker:dict=None,
|
||||
**kwargs):
|
||||
scope = choose_autocast(self.precision)
|
||||
scope = nullcontext
|
||||
self.safety_checker = safety_checker
|
||||
attention_maps_images = []
|
||||
attention_maps_callback = lambda saver: attention_maps_images.append(saver.get_stacked_maps_image())
|
||||
@ -236,7 +235,8 @@ class Generator:
|
||||
|
||||
def get_perlin_noise(self,width,height):
|
||||
fixdevice = 'cpu' if (self.model.device.type == 'mps') else self.model.device
|
||||
return torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
|
||||
noise = torch.stack([rand_perlin_2d((height, width), (8, 8), device = self.model.device).to(fixdevice) for _ in range(self.latent_channels)], dim=0).to(self.model.device)
|
||||
return noise
|
||||
|
||||
def new_seed(self):
|
||||
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
|
||||
@ -341,3 +341,6 @@ class Generator:
|
||||
image.save(filepath,'PNG')
|
||||
|
||||
|
||||
def torch_dtype(self)->torch.dtype:
|
||||
return torch.float16 if self.precision == 'float16' else torch.float32
|
||||
|
||||
|
@ -36,10 +36,9 @@ class Txt2Img(Generator):
|
||||
threshold = ThresholdSettings(threshold, warmup=0.2) if threshold else None)
|
||||
.add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta))
|
||||
|
||||
|
||||
def make_image(x_T) -> PIL.Image.Image:
|
||||
pipeline_output = pipeline.image_from_embeddings(
|
||||
latents=torch.zeros_like(x_T),
|
||||
latents=torch.zeros_like(x_T,dtype=self.torch_dtype()),
|
||||
noise=x_T,
|
||||
num_inference_steps=steps,
|
||||
conditioning_data=conditioning_data,
|
||||
@ -59,16 +58,18 @@ class Txt2Img(Generator):
|
||||
input_channels = min(self.latent_channels, 4)
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
x = torch.randn([1,
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
device='cpu').to(device)
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
dtype=self.torch_dtype(),
|
||||
device='cpu').to(device)
|
||||
else:
|
||||
x = torch.randn([1,
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
device=device)
|
||||
input_channels,
|
||||
height // self.downsampling_factor,
|
||||
width // self.downsampling_factor],
|
||||
dtype=self.torch_dtype(),
|
||||
device=device)
|
||||
if self.perlin > 0.0:
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width // self.downsampling_factor, height // self.downsampling_factor)
|
||||
return x
|
||||
|
@ -90,9 +90,9 @@ class Txt2Img2Img(Generator):
|
||||
def get_noise_like(self, like: torch.Tensor):
|
||||
device = like.device
|
||||
if device.type == 'mps':
|
||||
x = torch.randn_like(like, device='cpu').to(device)
|
||||
x = torch.randn_like(like, device='cpu', dtype=self.torch_dtype()).to(device)
|
||||
else:
|
||||
x = torch.randn_like(like, device=device)
|
||||
x = torch.randn_like(like, device=device, dtype=self.torch_dtype())
|
||||
if self.perlin > 0.0:
|
||||
shape = like.shape
|
||||
x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(shape[3], shape[2])
|
||||
@ -117,10 +117,12 @@ class Txt2Img2Img(Generator):
|
||||
self.latent_channels,
|
||||
scaled_height // self.downsampling_factor,
|
||||
scaled_width // self.downsampling_factor],
|
||||
device='cpu').to(device)
|
||||
dtype=self.torch_dtype(),
|
||||
device='cpu').to(device)
|
||||
else:
|
||||
return torch.randn([1,
|
||||
self.latent_channels,
|
||||
scaled_height // self.downsampling_factor,
|
||||
scaled_width // self.downsampling_factor],
|
||||
device=device)
|
||||
dtype=self.torch_dtype(),
|
||||
device=device)
|
||||
|
@ -43,6 +43,9 @@ Globals.always_use_cpu = False
|
||||
# The CLI will test connectivity at startup time.
|
||||
Globals.internet_available = True
|
||||
|
||||
# whether we are forcing full precision
|
||||
Globals.full_precision = False
|
||||
|
||||
def global_config_dir()->Path:
|
||||
return Path(Globals.root, Globals.config_dir)
|
||||
|
||||
|
@ -349,7 +349,7 @@ class ModelManager(object):
|
||||
|
||||
if self.precision == 'float16':
|
||||
print(' | Using faster float16 precision')
|
||||
model.to(torch.float16)
|
||||
model = model.to(torch.float16)
|
||||
else:
|
||||
print(' | Using more accurate float32 precision')
|
||||
|
||||
@ -763,7 +763,7 @@ class ModelManager(object):
|
||||
for model in legacy_locations:
|
||||
source = models_dir /model
|
||||
if source.exists():
|
||||
print(f'DEBUG: Moving {models_dir / model} into hub')
|
||||
print(f'** Moving {models_dir / model} into hub')
|
||||
move(models_dir / model, hub)
|
||||
|
||||
# anything else gets moved into the diffusers directory
|
||||
|
Reference in New Issue
Block a user