node-based txt2img working without generate

This commit is contained in:
Lincoln Stein
2023-03-09 00:18:29 -05:00
parent 87789c1de8
commit 5d37fa6e36
8 changed files with 247 additions and 254 deletions

View File

@ -2,6 +2,12 @@
Initialization file for invokeai.backend
"""
from .generate import Generate
from .generator import (
InvokeAIGeneratorBasicParams,
InvokeAIGeneratorFactory,
InvokeAIGenerator,
InvokeAIGeneratorOutput
)
from .model_management import ModelManager
from .args import Args
from .globals import Globals

View File

@ -4,7 +4,6 @@ including img2img, txt2img, and inpaint
"""
from __future__ import annotations
import copy
import importlib
import dataclasses
import diffusers
@ -13,7 +12,6 @@ import random
import traceback
from abc import ABCMeta, abstractmethod
from contextlib import nullcontext
from pathlib import Path
import cv2
import numpy as np
@ -22,19 +20,59 @@ from PIL import Image, ImageChops, ImageFilter
from accelerate.utils import set_seed
from diffusers import DiffusionPipeline
from tqdm import trange
from typing import List, Type, Callable
from typing import List, Type
from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler
import invokeai.assets.web as web_assets
from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker
from ..prompting.conditioning import get_uc_and_c_and_ec
from ..model_management.model_manager import ModelManager
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
downsampling = 8
CAUTION_IMG = "caution.png"
@dataclass
class InvokeAIGeneratorBasicParams:
seed: int=None
width: int=512
height: int=512
cfg_scale: int=7.5
steps: int=20
ddim_eta: float=0.0
model_name: str='stable-diffusion-1.5'
scheduler: int='ddim'
precision: str='float16'
perlin: float=0.0
threshold: int=0.0
h_symmetry_time_pct: float=None
v_symmetry_time_pct: float=None
variation_amount: float = 0.0
with_variations: list=field(default_factory=list)
safety_checker: SafetyChecker=None
@dataclass
class InvokeAIGeneratorOutput:
'''
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
operation, including the image, its seed, the model name used to generate the image
and the model hash, as well as all the generate() parameters that went into
generating the image (in .params, also available as attributes)
'''
image: Image
seed: int
model_name: str
model_hash: str
params: dict
def __getattribute__(self,name):
try:
return object.__getattribute__(self, name)
except AttributeError:
params = object.__getattribute__(self, 'params')
if name in params:
return params[name]
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
class InvokeAIGeneratorFactory(object):
def __init__(self,
@ -49,31 +87,15 @@ class InvokeAIGeneratorFactory(object):
self.params,
**keyword_args
)
@dataclass
class InvokeAIGeneratorBasicParams:
seed: int=None
width: int=512
height: int=512
cfg_scale: int=7.5
steps: int=20
ddim_eta: float=0.0
model: str='stable-diffusion-1.5'
scheduler: int='ddim'
precision: str='float16'
perlin: float=0.0
threshold: int=0.0
h_symmetry_time_pct: float=None
v_symmetry_time_pct: float=None
variation_amount: float = 0.0
with_variations: list = field(default_factory=list)
@dataclass
class InvokeAIGeneratorOutput:
image: Image
seed: int
model_name: str
model_hash: str
params: InvokeAIGeneratorBasicParams
# getter and setter shortcuts for commonly used parameters
@property
def model_name(self)->str:
return self.params.model_name
@model_name.setter
def model_name(self, model_name: str):
self.params.model_name=model_name
# we are interposing a wrapper around the original Generator classes so that
# old code that calls Generate will continue to work.
@ -93,7 +115,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def __init__(self,
model_manager: ModelManager,
params: InvokeAIGeneratorBasicParams
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
):
self.model_manager=model_manager
self.params=params
@ -105,7 +127,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
**keyword_args,
)->List[InvokeAIGeneratorOutput]:
model_name = self.params.model or self.model_manager.current_model
model_name = self.params.model_name or self.model_manager.current_model
model_info: dict = self.model_manager.get_model(model_name)
model:StableDiffusionGeneratorPipeline = model_info['model']
model_hash = model_info['hash']
@ -124,24 +146,33 @@ class InvokeAIGenerator(metaclass=ABCMeta):
generator.set_variation(self.params.seed,
self.params.variation_amount,
self.params.with_variations)
generator_args = dataclasses.asdict(self.params)
generator_args.update(keyword_args)
while True:
results = generator.generate(prompt,
conditioning=(uc, c, extra_conditioning_info),
sampler=scheduler,
**dataclasses.asdict(self.params),
**keyword_args
**generator_args,
)
output = InvokeAIGeneratorOutput(
image=results[0][0],
seed=results[0][1],
model_name = model_name,
model_hash = model_hash,
params=copy.copy(self.params)
params=generator_args,
)
if callback:
callback(output)
yield output
@classmethod
def schedulers(self)->List[str]:
'''
Return list of all the schedulers that we currently handle.
'''
return list(self.scheduler_map.keys())
def load_generator(self, model: StableDiffusionGeneratorPipeline, class_name: str):
module_name = f'invokeai.backend.generator.{class_name.lower()}'
@ -219,8 +250,7 @@ class Inpaint(Img2Img):
def _generator_name(self)->str:
return 'Inpaint'
class Generator:
downsampling_factor: int
latent_channels: int
@ -240,7 +270,6 @@ class Generator:
self.with_variations = []
self.use_mps_noise = False
self.free_gpu_mem = None
self.caution_img = None
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self, prompt, **kwargs):
@ -272,7 +301,7 @@ class Generator:
perlin=0.0,
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
safety_checker: dict = None,
safety_checker: SafetyChecker=None,
free_gpu_mem: bool = False,
**kwargs,
):
@ -325,7 +354,7 @@ class Generator:
image = make_image(x_T)
if self.safety_checker is not None:
image = self.safety_check(image)
image = self.safety_checker.check(image)
results.append([image, seed])
@ -548,53 +577,6 @@ class Generator:
return v2
def safety_check(self, image: Image.Image):
"""
If the CompViz safety checker flags an NSFW image, we
blur it out.
"""
import diffusers
checker = self.safety_checker["checker"]
extractor = self.safety_checker["extractor"]
features = extractor([image], return_tensors="pt")
features.to(self.model.device)
# unfortunately checker requires the numpy version, so we have to convert back
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
diffusers.logging.set_verbosity_error()
checked_image, has_nsfw_concept = checker(
images=x_image, clip_input=features.pixel_values
)
if has_nsfw_concept[0]:
print(
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
)
return self.blur(image)
else:
return image
def blur(self, input):
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
try:
caution = self.get_caution_img()
if caution:
blurry.paste(caution, (0, 0), caution)
except FileNotFoundError:
pass
return blurry
def get_caution_img(self):
path = None
if self.caution_img:
return self.caution_img
path = Path(web_assets.__path__[0]) / CAUTION_IMG
caution = Image.open(path)
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
return self.caution_img
# this is a handy routine for debugging use. Given a generated sample,
# convert it into a PNG image and store it at the indicated path
def save_sample(self, sample, filepath):

View File

@ -0,0 +1,89 @@
'''
SafetyChecker class - checks images against the StabilityAI NSFW filter
and blurs images that contain potential NSFW content.
'''
import diffusers
import numpy as np
import torch
import traceback
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from pathlib import Path
from PIL import Image, ImageFilter
from transformers import AutoFeatureExtractor
import invokeai.assets.web as web_assets
from .globals import global_cache_dir
class SafetyChecker(object):
CAUTION_IMG = "caution.png"
def __init__(self, device: torch.device):
self.device = device
try:
print(">> Initializing NSFW checker")
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_model_path = global_cache_dir("hub")
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
safety_model_id,
local_files_only=True,
cache_dir=safety_model_path,
)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(
safety_model_id,
local_files_only=True,
cache_dir=safety_model_path,
)
self.safety_checker.to(device)
self.safety_feature_extractor.to(device)
except Exception:
print(
"** An error was encountered while installing the safety checker:"
)
print(traceback.format_exc())
else:
print(">> NSFW checker is disabled")
def check(self, image: Image.Image):
"""
Check provided image against the StabilityAI safety checker and return
"""
features = self.safety_feature_extractor([image], return_tensors="pt")
# unfortunately checker requires the numpy version, so we have to convert back
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
diffusers.logging.set_verbosity_error()
checked_image, has_nsfw_concept = self.safety_checker(
images=x_image, clip_input=features.pixel_values
)
if has_nsfw_concept[0]:
print(
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
)
return self.blur(image)
else:
return image
def blur(self, input):
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
try:
caution = self.get_caution_img()
if caution:
blurry.paste(caution, (0, 0), caution)
except FileNotFoundError:
pass
return blurry
def get_caution_img(self):
path = None
if self.caution_img:
return self.caution_img
path = Path(web_assets.__path__[0]) / self.CAUTION_IMG
caution = Image.open(path)
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
return self.caution_img