mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
node-based txt2img working without generate
This commit is contained in:
@ -2,6 +2,12 @@
|
||||
Initialization file for invokeai.backend
|
||||
"""
|
||||
from .generate import Generate
|
||||
from .generator import (
|
||||
InvokeAIGeneratorBasicParams,
|
||||
InvokeAIGeneratorFactory,
|
||||
InvokeAIGenerator,
|
||||
InvokeAIGeneratorOutput
|
||||
)
|
||||
from .model_management import ModelManager
|
||||
from .args import Args
|
||||
from .globals import Globals
|
||||
|
@ -4,7 +4,6 @@ including img2img, txt2img, and inpaint
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import importlib
|
||||
import dataclasses
|
||||
import diffusers
|
||||
@ -13,7 +12,6 @@ import random
|
||||
import traceback
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@ -22,19 +20,59 @@ from PIL import Image, ImageChops, ImageFilter
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DiffusionPipeline
|
||||
from tqdm import trange
|
||||
from typing import List, Type, Callable
|
||||
from typing import List, Type
|
||||
from dataclasses import dataclass, field
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
|
||||
import invokeai.assets.web as web_assets
|
||||
from ..util.util import rand_perlin_2d
|
||||
|
||||
from ..safety_checker import SafetyChecker
|
||||
from ..prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ..model_management.model_manager import ModelManager
|
||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
|
||||
downsampling = 8
|
||||
CAUTION_IMG = "caution.png"
|
||||
|
||||
@dataclass
|
||||
class InvokeAIGeneratorBasicParams:
|
||||
seed: int=None
|
||||
width: int=512
|
||||
height: int=512
|
||||
cfg_scale: int=7.5
|
||||
steps: int=20
|
||||
ddim_eta: float=0.0
|
||||
model_name: str='stable-diffusion-1.5'
|
||||
scheduler: int='ddim'
|
||||
precision: str='float16'
|
||||
perlin: float=0.0
|
||||
threshold: int=0.0
|
||||
h_symmetry_time_pct: float=None
|
||||
v_symmetry_time_pct: float=None
|
||||
variation_amount: float = 0.0
|
||||
with_variations: list=field(default_factory=list)
|
||||
safety_checker: SafetyChecker=None
|
||||
|
||||
@dataclass
|
||||
class InvokeAIGeneratorOutput:
|
||||
'''
|
||||
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
|
||||
operation, including the image, its seed, the model name used to generate the image
|
||||
and the model hash, as well as all the generate() parameters that went into
|
||||
generating the image (in .params, also available as attributes)
|
||||
'''
|
||||
image: Image
|
||||
seed: int
|
||||
model_name: str
|
||||
model_hash: str
|
||||
params: dict
|
||||
|
||||
def __getattribute__(self,name):
|
||||
try:
|
||||
return object.__getattribute__(self, name)
|
||||
except AttributeError:
|
||||
params = object.__getattribute__(self, 'params')
|
||||
if name in params:
|
||||
return params[name]
|
||||
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
|
||||
|
||||
class InvokeAIGeneratorFactory(object):
|
||||
def __init__(self,
|
||||
@ -49,31 +87,15 @@ class InvokeAIGeneratorFactory(object):
|
||||
self.params,
|
||||
**keyword_args
|
||||
)
|
||||
@dataclass
|
||||
class InvokeAIGeneratorBasicParams:
|
||||
seed: int=None
|
||||
width: int=512
|
||||
height: int=512
|
||||
cfg_scale: int=7.5
|
||||
steps: int=20
|
||||
ddim_eta: float=0.0
|
||||
model: str='stable-diffusion-1.5'
|
||||
scheduler: int='ddim'
|
||||
precision: str='float16'
|
||||
perlin: float=0.0
|
||||
threshold: int=0.0
|
||||
h_symmetry_time_pct: float=None
|
||||
v_symmetry_time_pct: float=None
|
||||
variation_amount: float = 0.0
|
||||
with_variations: list = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class InvokeAIGeneratorOutput:
|
||||
image: Image
|
||||
seed: int
|
||||
model_name: str
|
||||
model_hash: str
|
||||
params: InvokeAIGeneratorBasicParams
|
||||
# getter and setter shortcuts for commonly used parameters
|
||||
@property
|
||||
def model_name(self)->str:
|
||||
return self.params.model_name
|
||||
|
||||
@model_name.setter
|
||||
def model_name(self, model_name: str):
|
||||
self.params.model_name=model_name
|
||||
|
||||
# we are interposing a wrapper around the original Generator classes so that
|
||||
# old code that calls Generate will continue to work.
|
||||
@ -93,7 +115,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
|
||||
def __init__(self,
|
||||
model_manager: ModelManager,
|
||||
params: InvokeAIGeneratorBasicParams
|
||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||
):
|
||||
self.model_manager=model_manager
|
||||
self.params=params
|
||||
@ -105,7 +127,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
**keyword_args,
|
||||
)->List[InvokeAIGeneratorOutput]:
|
||||
|
||||
model_name = self.params.model or self.model_manager.current_model
|
||||
model_name = self.params.model_name or self.model_manager.current_model
|
||||
model_info: dict = self.model_manager.get_model(model_name)
|
||||
model:StableDiffusionGeneratorPipeline = model_info['model']
|
||||
model_hash = model_info['hash']
|
||||
@ -124,24 +146,33 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
generator.set_variation(self.params.seed,
|
||||
self.params.variation_amount,
|
||||
self.params.with_variations)
|
||||
|
||||
generator_args = dataclasses.asdict(self.params)
|
||||
generator_args.update(keyword_args)
|
||||
|
||||
while True:
|
||||
results = generator.generate(prompt,
|
||||
conditioning=(uc, c, extra_conditioning_info),
|
||||
sampler=scheduler,
|
||||
**dataclasses.asdict(self.params),
|
||||
**keyword_args
|
||||
**generator_args,
|
||||
)
|
||||
output = InvokeAIGeneratorOutput(
|
||||
image=results[0][0],
|
||||
seed=results[0][1],
|
||||
model_name = model_name,
|
||||
model_hash = model_hash,
|
||||
params=copy.copy(self.params)
|
||||
params=generator_args,
|
||||
)
|
||||
if callback:
|
||||
callback(output)
|
||||
yield output
|
||||
|
||||
@classmethod
|
||||
def schedulers(self)->List[str]:
|
||||
'''
|
||||
Return list of all the schedulers that we currently handle.
|
||||
'''
|
||||
return list(self.scheduler_map.keys())
|
||||
|
||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, class_name: str):
|
||||
module_name = f'invokeai.backend.generator.{class_name.lower()}'
|
||||
@ -219,8 +250,7 @@ class Inpaint(Img2Img):
|
||||
|
||||
def _generator_name(self)->str:
|
||||
return 'Inpaint'
|
||||
|
||||
|
||||
|
||||
class Generator:
|
||||
downsampling_factor: int
|
||||
latent_channels: int
|
||||
@ -240,7 +270,6 @@ class Generator:
|
||||
self.with_variations = []
|
||||
self.use_mps_noise = False
|
||||
self.free_gpu_mem = None
|
||||
self.caution_img = None
|
||||
|
||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||
def get_make_image(self, prompt, **kwargs):
|
||||
@ -272,7 +301,7 @@ class Generator:
|
||||
perlin=0.0,
|
||||
h_symmetry_time_pct=None,
|
||||
v_symmetry_time_pct=None,
|
||||
safety_checker: dict = None,
|
||||
safety_checker: SafetyChecker=None,
|
||||
free_gpu_mem: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
@ -325,7 +354,7 @@ class Generator:
|
||||
image = make_image(x_T)
|
||||
|
||||
if self.safety_checker is not None:
|
||||
image = self.safety_check(image)
|
||||
image = self.safety_checker.check(image)
|
||||
|
||||
results.append([image, seed])
|
||||
|
||||
@ -548,53 +577,6 @@ class Generator:
|
||||
|
||||
return v2
|
||||
|
||||
def safety_check(self, image: Image.Image):
|
||||
"""
|
||||
If the CompViz safety checker flags an NSFW image, we
|
||||
blur it out.
|
||||
"""
|
||||
import diffusers
|
||||
|
||||
checker = self.safety_checker["checker"]
|
||||
extractor = self.safety_checker["extractor"]
|
||||
features = extractor([image], return_tensors="pt")
|
||||
features.to(self.model.device)
|
||||
|
||||
# unfortunately checker requires the numpy version, so we have to convert back
|
||||
x_image = np.array(image).astype(np.float32) / 255.0
|
||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||
|
||||
diffusers.logging.set_verbosity_error()
|
||||
checked_image, has_nsfw_concept = checker(
|
||||
images=x_image, clip_input=features.pixel_values
|
||||
)
|
||||
if has_nsfw_concept[0]:
|
||||
print(
|
||||
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
|
||||
)
|
||||
return self.blur(image)
|
||||
else:
|
||||
return image
|
||||
|
||||
def blur(self, input):
|
||||
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||
try:
|
||||
caution = self.get_caution_img()
|
||||
if caution:
|
||||
blurry.paste(caution, (0, 0), caution)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return blurry
|
||||
|
||||
def get_caution_img(self):
|
||||
path = None
|
||||
if self.caution_img:
|
||||
return self.caution_img
|
||||
path = Path(web_assets.__path__[0]) / CAUTION_IMG
|
||||
caution = Image.open(path)
|
||||
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
||||
return self.caution_img
|
||||
|
||||
# this is a handy routine for debugging use. Given a generated sample,
|
||||
# convert it into a PNG image and store it at the indicated path
|
||||
def save_sample(self, sample, filepath):
|
||||
|
89
invokeai/backend/safety_checker.py
Normal file
89
invokeai/backend/safety_checker.py
Normal file
@ -0,0 +1,89 @@
|
||||
'''
|
||||
SafetyChecker class - checks images against the StabilityAI NSFW filter
|
||||
and blurs images that contain potential NSFW content.
|
||||
'''
|
||||
import diffusers
|
||||
import numpy as np
|
||||
import torch
|
||||
import traceback
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from pathlib import Path
|
||||
from PIL import Image, ImageFilter
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
import invokeai.assets.web as web_assets
|
||||
from .globals import global_cache_dir
|
||||
|
||||
class SafetyChecker(object):
|
||||
CAUTION_IMG = "caution.png"
|
||||
|
||||
def __init__(self, device: torch.device):
|
||||
self.device = device
|
||||
try:
|
||||
print(">> Initializing NSFW checker")
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_model_path = global_cache_dir("hub")
|
||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||
safety_model_id,
|
||||
local_files_only=True,
|
||||
cache_dir=safety_model_path,
|
||||
)
|
||||
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
safety_model_id,
|
||||
local_files_only=True,
|
||||
cache_dir=safety_model_path,
|
||||
)
|
||||
self.safety_checker.to(device)
|
||||
self.safety_feature_extractor.to(device)
|
||||
except Exception:
|
||||
print(
|
||||
"** An error was encountered while installing the safety checker:"
|
||||
)
|
||||
print(traceback.format_exc())
|
||||
else:
|
||||
print(">> NSFW checker is disabled")
|
||||
|
||||
def check(self, image: Image.Image):
|
||||
"""
|
||||
Check provided image against the StabilityAI safety checker and return
|
||||
|
||||
"""
|
||||
|
||||
features = self.safety_feature_extractor([image], return_tensors="pt")
|
||||
# unfortunately checker requires the numpy version, so we have to convert back
|
||||
x_image = np.array(image).astype(np.float32) / 255.0
|
||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||
|
||||
diffusers.logging.set_verbosity_error()
|
||||
checked_image, has_nsfw_concept = self.safety_checker(
|
||||
images=x_image, clip_input=features.pixel_values
|
||||
)
|
||||
if has_nsfw_concept[0]:
|
||||
print(
|
||||
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
|
||||
)
|
||||
return self.blur(image)
|
||||
else:
|
||||
return image
|
||||
|
||||
def blur(self, input):
|
||||
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||
try:
|
||||
caution = self.get_caution_img()
|
||||
if caution:
|
||||
blurry.paste(caution, (0, 0), caution)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return blurry
|
||||
|
||||
def get_caution_img(self):
|
||||
path = None
|
||||
if self.caution_img:
|
||||
return self.caution_img
|
||||
path = Path(web_assets.__path__[0]) / self.CAUTION_IMG
|
||||
caution = Image.open(path)
|
||||
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
||||
return self.caution_img
|
||||
|
Reference in New Issue
Block a user