mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' into inpaint-model
This commit is contained in:
@ -7,25 +7,28 @@ import numpy as np
|
||||
import random
|
||||
import os
|
||||
from tqdm import tqdm, trange
|
||||
from PIL import Image
|
||||
from PIL import Image, ImageFilter
|
||||
from einops import rearrange, repeat
|
||||
from pytorch_lightning import seed_everything
|
||||
from ldm.invoke.devices import choose_autocast
|
||||
from ldm.util import rand_perlin_2d
|
||||
|
||||
downsampling = 8
|
||||
CAUTION_IMG = 'assets/caution.png'
|
||||
|
||||
class Generator():
|
||||
def __init__(self, model, precision):
|
||||
self.model = model
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
self.latent_channels = model.channels
|
||||
self.model = model
|
||||
self.precision = precision
|
||||
self.seed = None
|
||||
self.latent_channels = model.channels
|
||||
self.downsampling_factor = downsampling # BUG: should come from model or config
|
||||
self.perlin = 0.0
|
||||
self.threshold = 0
|
||||
self.variation_amount = 0
|
||||
self.with_variations = []
|
||||
self.safety_checker = None
|
||||
self.perlin = 0.0
|
||||
self.threshold = 0
|
||||
self.variation_amount = 0
|
||||
self.with_variations = []
|
||||
self.use_mps_noise = False
|
||||
|
||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||
def get_make_image(self,prompt,**kwargs):
|
||||
@ -42,9 +45,11 @@ class Generator():
|
||||
|
||||
def generate(self,prompt,init_image,width,height,sampler, iterations=1,seed=None,
|
||||
image_callback=None, step_callback=None, threshold=0.0, perlin=0.0,
|
||||
safety_checker:dict=None,
|
||||
**kwargs):
|
||||
scope = choose_autocast(self.precision)
|
||||
make_image = self.get_make_image(
|
||||
self.safety_checker = safety_checker
|
||||
make_image = self.get_make_image(
|
||||
prompt,
|
||||
sampler = sampler,
|
||||
init_image = init_image,
|
||||
@ -79,10 +84,17 @@ class Generator():
|
||||
except:
|
||||
pass
|
||||
image = make_image(x_T)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
image = self.safety_check(image)
|
||||
|
||||
results.append([image, seed])
|
||||
|
||||
if image_callback is not None:
|
||||
image_callback(image, seed, first_seed=first_seed)
|
||||
|
||||
seed = self.new_seed()
|
||||
|
||||
return results
|
||||
|
||||
def sample_to_image(self,samples)->Image.Image:
|
||||
@ -171,6 +183,39 @@ class Generator():
|
||||
|
||||
return v2
|
||||
|
||||
def safety_check(self,image:Image.Image):
|
||||
'''
|
||||
If the CompViz safety checker flags an NSFW image, we
|
||||
blur it out.
|
||||
'''
|
||||
import diffusers
|
||||
|
||||
checker = self.safety_checker['checker']
|
||||
extractor = self.safety_checker['extractor']
|
||||
features = extractor([image], return_tensors="pt")
|
||||
|
||||
# unfortunately checker requires the numpy version, so we have to convert back
|
||||
x_image = np.array(image).astype(np.float32) / 255.0
|
||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||
|
||||
diffusers.logging.set_verbosity_error()
|
||||
checked_image, has_nsfw_concept = checker(images=x_image, clip_input=features.pixel_values)
|
||||
if has_nsfw_concept[0]:
|
||||
print('** An image with potential non-safe content has been detected. A blurred image will be returned. **')
|
||||
return self.blur(image)
|
||||
else:
|
||||
return image
|
||||
|
||||
def blur(self,input):
|
||||
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||
try:
|
||||
caution = Image.open(CAUTION_IMG)
|
||||
caution = caution.resize((caution.width // 2, caution.height //2))
|
||||
blurry.paste(caution,(0,0),caution)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return blurry
|
||||
|
||||
# this is a handy routine for debugging use. Given a generated sample,
|
||||
# convert it into a PNG image and store it at the indicated path
|
||||
def save_sample(self, sample, filepath):
|
||||
|
@ -59,7 +59,7 @@ class Txt2Img(Generator):
|
||||
# returns a tensor filled with random numbers from a normal distribution
|
||||
def get_noise(self,width,height):
|
||||
device = self.model.device
|
||||
if device.type == 'mps':
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
x = torch.randn([1,
|
||||
self.latent_channels,
|
||||
height // self.downsampling_factor,
|
||||
|
@ -118,7 +118,7 @@ class Txt2Img2Img(Generator):
|
||||
scaled_height = height
|
||||
|
||||
device = self.model.device
|
||||
if device.type == 'mps':
|
||||
if self.use_mps_noise or device.type == 'mps':
|
||||
return torch.randn([1,
|
||||
self.latent_channels,
|
||||
scaled_height // self.downsampling_factor,
|
||||
|
Reference in New Issue
Block a user