node-based txt2img working without generate

This commit is contained in:
Lincoln Stein 2023-03-09 00:18:29 -05:00
parent 87789c1de8
commit 5d37fa6e36
8 changed files with 247 additions and 254 deletions

View File

@ -4,7 +4,7 @@ import os
from argparse import Namespace
from ...backend import Globals
from ..services.generate_initializer import get_generate
from ..services.generate_initializer import get_generator_factory
from ..services.graph import GraphExecutionState
from ..services.image_storage import DiskImageStorage
from ..services.invocation_queue import MemoryInvocationQueue
@ -47,7 +47,7 @@ class ApiDependencies:
# TODO: Use a logger
print(f">> Internet connectivity is {Globals.internet_available}")
generate = get_generate(args, config)
generator_factory = get_generator_factory(args, config)
events = FastAPIEventService(event_handler_id)
@ -61,7 +61,7 @@ class ApiDependencies:
db_location = os.path.join(output_folder, "invokeai.db")
services = InvocationServices(
generate=generate,
generator_factory=generator_factory,
events=events,
images=images,
queue=MemoryInvocationQueue(),

View File

@ -17,7 +17,7 @@ from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_gra
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
from .services.generate_initializer import get_generate
from .services.generate_initializer import get_generator_factory
from .services.graph import EdgeConnection, GraphExecutionState
from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue
@ -106,11 +106,7 @@ def invoke_cli():
args = Args()
config = args.parse_args()
generate = get_generate(args, config)
# NOTE: load model on first use, uncomment to load at startup
# TODO: Make this a config option?
# generate.load_model()
generator_factory = get_generator_factory(args, config)
events = EventServiceBase()
@ -122,7 +118,7 @@ def invoke_cli():
db_location = os.path.join(output_folder, "invokeai.db")
services = InvocationServices(
generate=generate,
generator_factory=generator_factory,
events=events,
images=DiskImageStorage(output_folder),
queue=MemoryInvocationQueue(),

View File

@ -12,9 +12,10 @@ from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
SAMPLER_NAME_VALUES = Literal[
"ddim", "plms", "k_lms", "k_dpm_2", "k_dpm_2_a", "k_euler", "k_euler_a", "k_heun"
tuple(InvokeAIGenerator.schedulers())
]
@ -57,19 +58,24 @@ class TextToImageInvocation(BaseInvocation):
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
if self.model is None or self.model == "":
self.model = context.services.generate.model_name
factory = context.services.generator_factory
if self.model:
factory.model_name = self.model
else:
self.model = factory.model_name
# Set the model (if already cached, this does nothing)
context.services.generate.set_model(self.model)
txt2img = factory.make_generator(Txt2Img)
results = context.services.generate.prompt2image(
outputs = txt2img.generate(
prompt=self.prompt,
step_callback=step_callback,
**self.dict(
exclude={"prompt"}
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generate_output = next(outputs)
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
@ -78,7 +84,7 @@ class TextToImageInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, results[0][0])
context.services.images.save(image_type, image_name, generate_output.image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
@ -115,23 +121,24 @@ class ImageToImageInvocation(TextToImageInvocation):
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
if self.model is None or self.model == "":
self.model = context.services.generate.model_name
factory = context.services.generator_factory
self.model = self.model or factory.model_name
factory.model_name = self.model
img2img = factory.make_generator(Img2Img)
# Set the model (if already cached, this does nothing)
context.services.generate.set_model(self.model)
results = context.services.generate.prompt2image(
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=step_callback,
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
generator_output = next(
img2img.generate(
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=step_callback,
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
)
result_image = results[0][0]
result_image = generator_output.image
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
@ -145,7 +152,6 @@ class ImageToImageInvocation(TextToImageInvocation):
image=ImageField(image_type=image_type, image_name=image_name)
)
class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint."""
@ -180,23 +186,24 @@ class InpaintInvocation(ImageToImageInvocation):
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
if self.model is None or self.model == "":
self.model = context.services.generate.model_name
factory = context.services.generator_factory
self.model = self.model or factory.model_name
factory.model_name = self.model
inpaint = factory.make_generator(Inpaint)
# Set the model (if already cached, this does nothing)
context.services.generate.set_model(self.model)
results = context.services.generate.prompt2image(
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=step_callback,
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
generator_output = next(
inpaint.generate(
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=step_callback,
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
)
result_image = results[0][0]
result_image = generator_output.image
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?

View File

@ -1,16 +1,17 @@
import os
import sys
import torch
import traceback
from argparse import Namespace
from omegaconf import OmegaConf
import invokeai.version
from invokeai.backend import Generate, ModelManager
from ...backend import ModelManager, InvokeAIGeneratorBasicParams, InvokeAIGeneratorFactory
from ...backend.util import choose_precision, choose_torch_device
from ...backend import Globals
# TODO: most of this code should be split into individual services as the Generate.py code is deprecated
def get_generate(args, config) -> Generate:
def get_generator_factory(args, config) -> InvokeAIGeneratorFactory:
if not args.conf:
config_file = os.path.join(Globals.root, "configs", "models.yaml")
if not os.path.exists(config_file):
@ -63,49 +64,43 @@ def get_generate(args, config) -> Generate:
print(f"{e}. Aborting.")
sys.exit(-1)
# creating a Generate object:
# creating an InvokeAIGeneratorFactory object:
try:
gen = Generate(
conf=args.conf,
model=args.model,
sampler_name=args.sampler_name,
embedding_path=embedding_path,
full_precision=args.full_precision,
precision=args.precision,
gfpgan=gfpgan,
codeformer=codeformer,
esrgan=esrgan,
free_gpu_mem=args.free_gpu_mem,
safety_checker=args.safety_checker,
device = torch.device(choose_torch_device())
precision = 'float16' if args.precision=='float16' \
else 'float32' if args.precision=='float32' \
else choose_precision(device)
model_manager = ModelManager(
OmegaConf.load(args.conf),
precision=precision,
device_type=device,
max_loaded_models=args.max_loaded_models,
)
# TO DO: initialize and pass safety checker!!!
params = InvokeAIGeneratorBasicParams(
precision=precision,
)
factory = InvokeAIGeneratorFactory(model_manager, params)
except (FileNotFoundError, TypeError, AssertionError) as e:
report_model_error(opt, e)
report_model_error(args, e)
except (IOError, KeyError) as e:
print(f"{e}. Aborting.")
sys.exit(-1)
if args.seamless:
#TODO: do something here ?
print(">> changed to seamless tiling mode")
# preload the model
try:
gen.load_model()
except KeyError:
pass
except Exception as e:
report_model_error(args, e)
# try to autoconvert new models
# autoimport new .ckpt files
if path := args.autoconvert:
gen.model_manager.autoconvert_weights(
model_manager.autoconvert_weights(
conf_path=args.conf,
weights_directory=path,
)
return gen
return factory
def load_face_restoration(opt):
try:
@ -171,85 +166,3 @@ def report_model_error(opt: Namespace, e: Exception):
# sys.argv = previous_args
# main() # would rather do a os.exec(), but doesn't exist?
# sys.exit(0)
# Temporary initializer for Generate until we migrate off of it
def old_get_generate(args, config) -> Generate:
# TODO: Remove the need for globals
from invokeai.backend.globals import Globals
# alert - setting globals here
Globals.root = os.path.expanduser(
args.root_dir or os.environ.get("INVOKEAI_ROOT") or os.path.abspath(".")
)
Globals.try_patchmatch = args.patchmatch
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers
transformers.logging.set_verbosity_error()
# Loading Face Restoration and ESRGAN Modules
gfpgan, codeformer, esrgan = None, None, None
try:
if config.restore or config.esrgan:
from ldm.invoke.restoration import Restoration
restoration = Restoration()
if config.restore:
gfpgan, codeformer = restoration.load_face_restore_models(
config.gfpgan_model_path
)
else:
print(">> Face restoration disabled")
if config.esrgan:
esrgan = restoration.load_esrgan(config.esrgan_bg_tile)
else:
print(">> Upscaling disabled")
else:
print(">> Face restoration and upscaling disabled")
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
# normalize the config directory relative to root
if not os.path.isabs(config.conf):
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
if config.embeddings:
if not os.path.isabs(config.embedding_path):
embedding_path = os.path.normpath(
os.path.join(Globals.root, config.embedding_path)
)
else:
embedding_path = None
# TODO: lazy-initialize this by wrapping it
try:
generate = Generate(
conf=config.conf,
model=config.model,
sampler_name=config.sampler_name,
embedding_path=embedding_path,
full_precision=config.full_precision,
precision=config.precision,
gfpgan=gfpgan,
codeformer=codeformer,
esrgan=esrgan,
free_gpu_mem=config.free_gpu_mem,
safety_checker=config.safety_checker,
max_loaded_models=config.max_loaded_models,
)
except (FileNotFoundError, TypeError, AssertionError):
# emergency_model_reconfigure() # TODO?
sys.exit(-1)
except (IOError, KeyError) as e:
print(f"{e}. Aborting.")
sys.exit(-1)
generate.free_gpu_mem = config.free_gpu_mem
return generate

View File

@ -1,5 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from invokeai.backend import Generate
from invokeai.backend import InvokeAIGeneratorFactory
from .events import EventServiceBase
from .image_storage import ImageStorageBase
@ -10,7 +10,7 @@ from .item_storage import ItemStorageABC
class InvocationServices:
"""Services that can be used by invocations"""
generate: Generate # TODO: wrap Generate, or split it up from model?
generator_factory: InvokeAIGeneratorFactory
events: EventServiceBase
images: ImageStorageBase
queue: InvocationQueueABC
@ -20,15 +20,15 @@ class InvocationServices:
processor: "InvocationProcessorABC"
def __init__(
self,
generate: Generate,
events: EventServiceBase,
images: ImageStorageBase,
queue: InvocationQueueABC,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
self,
generator_factory: InvokeAIGeneratorFactory,
events: EventServiceBase,
images: ImageStorageBase,
queue: InvocationQueueABC,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
):
self.generate = generate
self.generator_factory = generator_factory
self.events = events
self.images = images
self.queue = queue

View File

@ -2,6 +2,12 @@
Initialization file for invokeai.backend
"""
from .generate import Generate
from .generator import (
InvokeAIGeneratorBasicParams,
InvokeAIGeneratorFactory,
InvokeAIGenerator,
InvokeAIGeneratorOutput
)
from .model_management import ModelManager
from .args import Args
from .globals import Globals

View File

@ -4,7 +4,6 @@ including img2img, txt2img, and inpaint
"""
from __future__ import annotations
import copy
import importlib
import dataclasses
import diffusers
@ -13,7 +12,6 @@ import random
import traceback
from abc import ABCMeta, abstractmethod
from contextlib import nullcontext
from pathlib import Path
import cv2
import numpy as np
@ -22,19 +20,59 @@ from PIL import Image, ImageChops, ImageFilter
from accelerate.utils import set_seed
from diffusers import DiffusionPipeline
from tqdm import trange
from typing import List, Type, Callable
from typing import List, Type
from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler
import invokeai.assets.web as web_assets
from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker
from ..prompting.conditioning import get_uc_and_c_and_ec
from ..model_management.model_manager import ModelManager
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
downsampling = 8
CAUTION_IMG = "caution.png"
@dataclass
class InvokeAIGeneratorBasicParams:
seed: int=None
width: int=512
height: int=512
cfg_scale: int=7.5
steps: int=20
ddim_eta: float=0.0
model_name: str='stable-diffusion-1.5'
scheduler: int='ddim'
precision: str='float16'
perlin: float=0.0
threshold: int=0.0
h_symmetry_time_pct: float=None
v_symmetry_time_pct: float=None
variation_amount: float = 0.0
with_variations: list=field(default_factory=list)
safety_checker: SafetyChecker=None
@dataclass
class InvokeAIGeneratorOutput:
'''
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
operation, including the image, its seed, the model name used to generate the image
and the model hash, as well as all the generate() parameters that went into
generating the image (in .params, also available as attributes)
'''
image: Image
seed: int
model_name: str
model_hash: str
params: dict
def __getattribute__(self,name):
try:
return object.__getattribute__(self, name)
except AttributeError:
params = object.__getattribute__(self, 'params')
if name in params:
return params[name]
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
class InvokeAIGeneratorFactory(object):
def __init__(self,
@ -49,31 +87,15 @@ class InvokeAIGeneratorFactory(object):
self.params,
**keyword_args
)
@dataclass
class InvokeAIGeneratorBasicParams:
seed: int=None
width: int=512
height: int=512
cfg_scale: int=7.5
steps: int=20
ddim_eta: float=0.0
model: str='stable-diffusion-1.5'
scheduler: int='ddim'
precision: str='float16'
perlin: float=0.0
threshold: int=0.0
h_symmetry_time_pct: float=None
v_symmetry_time_pct: float=None
variation_amount: float = 0.0
with_variations: list = field(default_factory=list)
@dataclass
class InvokeAIGeneratorOutput:
image: Image
seed: int
model_name: str
model_hash: str
params: InvokeAIGeneratorBasicParams
# getter and setter shortcuts for commonly used parameters
@property
def model_name(self)->str:
return self.params.model_name
@model_name.setter
def model_name(self, model_name: str):
self.params.model_name=model_name
# we are interposing a wrapper around the original Generator classes so that
# old code that calls Generate will continue to work.
@ -93,7 +115,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def __init__(self,
model_manager: ModelManager,
params: InvokeAIGeneratorBasicParams
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
):
self.model_manager=model_manager
self.params=params
@ -105,7 +127,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
**keyword_args,
)->List[InvokeAIGeneratorOutput]:
model_name = self.params.model or self.model_manager.current_model
model_name = self.params.model_name or self.model_manager.current_model
model_info: dict = self.model_manager.get_model(model_name)
model:StableDiffusionGeneratorPipeline = model_info['model']
model_hash = model_info['hash']
@ -125,24 +147,33 @@ class InvokeAIGenerator(metaclass=ABCMeta):
self.params.variation_amount,
self.params.with_variations)
generator_args = dataclasses.asdict(self.params)
generator_args.update(keyword_args)
while True:
results = generator.generate(prompt,
conditioning=(uc, c, extra_conditioning_info),
sampler=scheduler,
**dataclasses.asdict(self.params),
**keyword_args
**generator_args,
)
output = InvokeAIGeneratorOutput(
image=results[0][0],
seed=results[0][1],
model_name = model_name,
model_hash = model_hash,
params=copy.copy(self.params)
params=generator_args,
)
if callback:
callback(output)
yield output
@classmethod
def schedulers(self)->List[str]:
'''
Return list of all the schedulers that we currently handle.
'''
return list(self.scheduler_map.keys())
def load_generator(self, model: StableDiffusionGeneratorPipeline, class_name: str):
module_name = f'invokeai.backend.generator.{class_name.lower()}'
module = importlib.import_module(module_name)
@ -220,7 +251,6 @@ class Inpaint(Img2Img):
def _generator_name(self)->str:
return 'Inpaint'
class Generator:
downsampling_factor: int
latent_channels: int
@ -240,7 +270,6 @@ class Generator:
self.with_variations = []
self.use_mps_noise = False
self.free_gpu_mem = None
self.caution_img = None
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
def get_make_image(self, prompt, **kwargs):
@ -272,7 +301,7 @@ class Generator:
perlin=0.0,
h_symmetry_time_pct=None,
v_symmetry_time_pct=None,
safety_checker: dict = None,
safety_checker: SafetyChecker=None,
free_gpu_mem: bool = False,
**kwargs,
):
@ -325,7 +354,7 @@ class Generator:
image = make_image(x_T)
if self.safety_checker is not None:
image = self.safety_check(image)
image = self.safety_checker.check(image)
results.append([image, seed])
@ -548,53 +577,6 @@ class Generator:
return v2
def safety_check(self, image: Image.Image):
"""
If the CompViz safety checker flags an NSFW image, we
blur it out.
"""
import diffusers
checker = self.safety_checker["checker"]
extractor = self.safety_checker["extractor"]
features = extractor([image], return_tensors="pt")
features.to(self.model.device)
# unfortunately checker requires the numpy version, so we have to convert back
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
diffusers.logging.set_verbosity_error()
checked_image, has_nsfw_concept = checker(
images=x_image, clip_input=features.pixel_values
)
if has_nsfw_concept[0]:
print(
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
)
return self.blur(image)
else:
return image
def blur(self, input):
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
try:
caution = self.get_caution_img()
if caution:
blurry.paste(caution, (0, 0), caution)
except FileNotFoundError:
pass
return blurry
def get_caution_img(self):
path = None
if self.caution_img:
return self.caution_img
path = Path(web_assets.__path__[0]) / CAUTION_IMG
caution = Image.open(path)
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
return self.caution_img
# this is a handy routine for debugging use. Given a generated sample,
# convert it into a PNG image and store it at the indicated path
def save_sample(self, sample, filepath):

View File

@ -0,0 +1,89 @@
'''
SafetyChecker class - checks images against the StabilityAI NSFW filter
and blurs images that contain potential NSFW content.
'''
import diffusers
import numpy as np
import torch
import traceback
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from pathlib import Path
from PIL import Image, ImageFilter
from transformers import AutoFeatureExtractor
import invokeai.assets.web as web_assets
from .globals import global_cache_dir
class SafetyChecker(object):
CAUTION_IMG = "caution.png"
def __init__(self, device: torch.device):
self.device = device
try:
print(">> Initializing NSFW checker")
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_model_path = global_cache_dir("hub")
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
safety_model_id,
local_files_only=True,
cache_dir=safety_model_path,
)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(
safety_model_id,
local_files_only=True,
cache_dir=safety_model_path,
)
self.safety_checker.to(device)
self.safety_feature_extractor.to(device)
except Exception:
print(
"** An error was encountered while installing the safety checker:"
)
print(traceback.format_exc())
else:
print(">> NSFW checker is disabled")
def check(self, image: Image.Image):
"""
Check provided image against the StabilityAI safety checker and return
"""
features = self.safety_feature_extractor([image], return_tensors="pt")
# unfortunately checker requires the numpy version, so we have to convert back
x_image = np.array(image).astype(np.float32) / 255.0
x_image = x_image[None].transpose(0, 3, 1, 2)
diffusers.logging.set_verbosity_error()
checked_image, has_nsfw_concept = self.safety_checker(
images=x_image, clip_input=features.pixel_values
)
if has_nsfw_concept[0]:
print(
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
)
return self.blur(image)
else:
return image
def blur(self, input):
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
try:
caution = self.get_caution_img()
if caution:
blurry.paste(caution, (0, 0), caution)
except FileNotFoundError:
pass
return blurry
def get_caution_img(self):
path = None
if self.caution_img:
return self.caution_img
path = Path(web_assets.__path__[0]) / self.CAUTION_IMG
caution = Image.open(path)
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
return self.caution_img