diff --git a/invokeai/app/api/dependencies.py b/invokeai/app/api/dependencies.py index 271a2e3be3..58e6c81492 100644 --- a/invokeai/app/api/dependencies.py +++ b/invokeai/app/api/dependencies.py @@ -4,7 +4,7 @@ import os from argparse import Namespace from ...backend import Globals -from ..services.generate_initializer import get_generate +from ..services.generate_initializer import get_generator_factory from ..services.graph import GraphExecutionState from ..services.image_storage import DiskImageStorage from ..services.invocation_queue import MemoryInvocationQueue @@ -47,7 +47,7 @@ class ApiDependencies: # TODO: Use a logger print(f">> Internet connectivity is {Globals.internet_available}") - generate = get_generate(args, config) + generator_factory = get_generator_factory(args, config) events = FastAPIEventService(event_handler_id) @@ -61,7 +61,7 @@ class ApiDependencies: db_location = os.path.join(output_folder, "invokeai.db") services = InvocationServices( - generate=generate, + generator_factory=generator_factory, events=events, images=images, queue=MemoryInvocationQueue(), diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index 721760b222..d0190903ff 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -17,7 +17,7 @@ from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_gra from .invocations import * from .invocations.baseinvocation import BaseInvocation from .services.events import EventServiceBase -from .services.generate_initializer import get_generate +from .services.generate_initializer import get_generator_factory from .services.graph import EdgeConnection, GraphExecutionState from .services.image_storage import DiskImageStorage from .services.invocation_queue import MemoryInvocationQueue @@ -106,11 +106,7 @@ def invoke_cli(): args = Args() config = args.parse_args() - generate = get_generate(args, config) - - # NOTE: load model on first use, uncomment to load at startup - # TODO: Make this a config option? - # generate.load_model() + generator_factory = get_generator_factory(args, config) events = EventServiceBase() @@ -122,7 +118,7 @@ def invoke_cli(): db_location = os.path.join(output_folder, "invokeai.db") services = InvocationServices( - generate=generate, + generator_factory=generator_factory, events=events, images=DiskImageStorage(output_folder), queue=MemoryInvocationQueue(), diff --git a/invokeai/app/invocations/generate.py b/invokeai/app/invocations/generate.py index 15c5f17438..cf2ef8aa45 100644 --- a/invokeai/app/invocations/generate.py +++ b/invokeai/app/invocations/generate.py @@ -12,9 +12,10 @@ from ..services.image_storage import ImageType from ..services.invocation_services import InvocationServices from .baseinvocation import BaseInvocation, InvocationContext from .image import ImageField, ImageOutput +from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator SAMPLER_NAME_VALUES = Literal[ - "ddim", "plms", "k_lms", "k_dpm_2", "k_dpm_2_a", "k_euler", "k_euler_a", "k_heun" + tuple(InvokeAIGenerator.schedulers()) ] @@ -57,19 +58,24 @@ class TextToImageInvocation(BaseInvocation): # Handle invalid model parameter # TODO: figure out if this can be done via a validator that uses the model_cache # TODO: How to get the default model name now? - if self.model is None or self.model == "": - self.model = context.services.generate.model_name + factory = context.services.generator_factory + if self.model: + factory.model_name = self.model + else: + self.model = factory.model_name - # Set the model (if already cached, this does nothing) - context.services.generate.set_model(self.model) + txt2img = factory.make_generator(Txt2Img) - results = context.services.generate.prompt2image( + outputs = txt2img.generate( prompt=self.prompt, step_callback=step_callback, **self.dict( exclude={"prompt"} ), # Shorthand for passing all of the parameters above manually ) + # Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object + # each time it is called. We only need the first one. + generate_output = next(outputs) # Results are image and seed, unwrap for now and ignore the seed # TODO: pre-seed? @@ -78,7 +84,7 @@ class TextToImageInvocation(BaseInvocation): image_name = context.services.images.create_name( context.graph_execution_state_id, self.id ) - context.services.images.save(image_type, image_name, results[0][0]) + context.services.images.save(image_type, image_name, generate_output.image) return ImageOutput( image=ImageField(image_type=image_type, image_name=image_name) ) @@ -115,23 +121,24 @@ class ImageToImageInvocation(TextToImageInvocation): # Handle invalid model parameter # TODO: figure out if this can be done via a validator that uses the model_cache # TODO: How to get the default model name now? - if self.model is None or self.model == "": - self.model = context.services.generate.model_name + factory = context.services.generator_factory + self.model = self.model or factory.model_name + factory.model_name = self.model + img2img = factory.make_generator(Img2Img) - # Set the model (if already cached, this does nothing) - context.services.generate.set_model(self.model) - - results = context.services.generate.prompt2image( - prompt=self.prompt, - init_img=image, - init_mask=mask, - step_callback=step_callback, - **self.dict( - exclude={"prompt", "image", "mask"} - ), # Shorthand for passing all of the parameters above manually + generator_output = next( + img2img.generate( + prompt=self.prompt, + init_img=image, + init_mask=mask, + step_callback=step_callback, + **self.dict( + exclude={"prompt", "image", "mask"} + ), # Shorthand for passing all of the parameters above manually + ) ) - result_image = results[0][0] + result_image = generator_output.image # Results are image and seed, unwrap for now and ignore the seed # TODO: pre-seed? @@ -145,7 +152,6 @@ class ImageToImageInvocation(TextToImageInvocation): image=ImageField(image_type=image_type, image_name=image_name) ) - class InpaintInvocation(ImageToImageInvocation): """Generates an image using inpaint.""" @@ -180,23 +186,24 @@ class InpaintInvocation(ImageToImageInvocation): # Handle invalid model parameter # TODO: figure out if this can be done via a validator that uses the model_cache # TODO: How to get the default model name now? - if self.model is None or self.model == "": - self.model = context.services.generate.model_name + factory = context.services.generator_factory + self.model = self.model or factory.model_name + factory.model_name = self.model + inpaint = factory.make_generator(Inpaint) - # Set the model (if already cached, this does nothing) - context.services.generate.set_model(self.model) - - results = context.services.generate.prompt2image( - prompt=self.prompt, - init_img=image, - init_mask=mask, - step_callback=step_callback, - **self.dict( - exclude={"prompt", "image", "mask"} - ), # Shorthand for passing all of the parameters above manually + generator_output = next( + inpaint.generate( + prompt=self.prompt, + init_img=image, + init_mask=mask, + step_callback=step_callback, + **self.dict( + exclude={"prompt", "image", "mask"} + ), # Shorthand for passing all of the parameters above manually + ) ) - result_image = results[0][0] + result_image = generator_output.image # Results are image and seed, unwrap for now and ignore the seed # TODO: pre-seed? diff --git a/invokeai/app/services/generate_initializer.py b/invokeai/app/services/generate_initializer.py index 9801909742..4ac5a5d706 100644 --- a/invokeai/app/services/generate_initializer.py +++ b/invokeai/app/services/generate_initializer.py @@ -1,16 +1,17 @@ import os import sys +import torch import traceback from argparse import Namespace +from omegaconf import OmegaConf import invokeai.version -from invokeai.backend import Generate, ModelManager - +from ...backend import ModelManager, InvokeAIGeneratorBasicParams, InvokeAIGeneratorFactory +from ...backend.util import choose_precision, choose_torch_device from ...backend import Globals - # TODO: most of this code should be split into individual services as the Generate.py code is deprecated -def get_generate(args, config) -> Generate: +def get_generator_factory(args, config) -> InvokeAIGeneratorFactory: if not args.conf: config_file = os.path.join(Globals.root, "configs", "models.yaml") if not os.path.exists(config_file): @@ -63,49 +64,43 @@ def get_generate(args, config) -> Generate: print(f"{e}. Aborting.") sys.exit(-1) - # creating a Generate object: + # creating an InvokeAIGeneratorFactory object: try: - gen = Generate( - conf=args.conf, - model=args.model, - sampler_name=args.sampler_name, - embedding_path=embedding_path, - full_precision=args.full_precision, - precision=args.precision, - gfpgan=gfpgan, - codeformer=codeformer, - esrgan=esrgan, - free_gpu_mem=args.free_gpu_mem, - safety_checker=args.safety_checker, + device = torch.device(choose_torch_device()) + precision = 'float16' if args.precision=='float16' \ + else 'float32' if args.precision=='float32' \ + else choose_precision(device) + + model_manager = ModelManager( + OmegaConf.load(args.conf), + precision=precision, + device_type=device, max_loaded_models=args.max_loaded_models, ) + # TO DO: initialize and pass safety checker!!! + params = InvokeAIGeneratorBasicParams( + precision=precision, + ) + factory = InvokeAIGeneratorFactory(model_manager, params) except (FileNotFoundError, TypeError, AssertionError) as e: - report_model_error(opt, e) + report_model_error(args, e) except (IOError, KeyError) as e: print(f"{e}. Aborting.") sys.exit(-1) if args.seamless: + #TODO: do something here ? print(">> changed to seamless tiling mode") - # preload the model - try: - gen.load_model() - except KeyError: - pass - except Exception as e: - report_model_error(args, e) - # try to autoconvert new models # autoimport new .ckpt files if path := args.autoconvert: - gen.model_manager.autoconvert_weights( + model_manager.autoconvert_weights( conf_path=args.conf, weights_directory=path, ) - return gen - + return factory def load_face_restoration(opt): try: @@ -171,85 +166,3 @@ def report_model_error(opt: Namespace, e: Exception): # sys.argv = previous_args # main() # would rather do a os.exec(), but doesn't exist? # sys.exit(0) - - -# Temporary initializer for Generate until we migrate off of it -def old_get_generate(args, config) -> Generate: - # TODO: Remove the need for globals - from invokeai.backend.globals import Globals - - # alert - setting globals here - Globals.root = os.path.expanduser( - args.root_dir or os.environ.get("INVOKEAI_ROOT") or os.path.abspath(".") - ) - Globals.try_patchmatch = args.patchmatch - - print(f'>> InvokeAI runtime directory is "{Globals.root}"') - - # these two lines prevent a horrible warning message from appearing - # when the frozen CLIP tokenizer is imported - import transformers - - transformers.logging.set_verbosity_error() - - # Loading Face Restoration and ESRGAN Modules - gfpgan, codeformer, esrgan = None, None, None - try: - if config.restore or config.esrgan: - from ldm.invoke.restoration import Restoration - - restoration = Restoration() - if config.restore: - gfpgan, codeformer = restoration.load_face_restore_models( - config.gfpgan_model_path - ) - else: - print(">> Face restoration disabled") - if config.esrgan: - esrgan = restoration.load_esrgan(config.esrgan_bg_tile) - else: - print(">> Upscaling disabled") - else: - print(">> Face restoration and upscaling disabled") - except (ModuleNotFoundError, ImportError): - print(traceback.format_exc(), file=sys.stderr) - print(">> You may need to install the ESRGAN and/or GFPGAN modules") - - # normalize the config directory relative to root - if not os.path.isabs(config.conf): - config.conf = os.path.normpath(os.path.join(Globals.root, config.conf)) - - if config.embeddings: - if not os.path.isabs(config.embedding_path): - embedding_path = os.path.normpath( - os.path.join(Globals.root, config.embedding_path) - ) - else: - embedding_path = None - - # TODO: lazy-initialize this by wrapping it - try: - generate = Generate( - conf=config.conf, - model=config.model, - sampler_name=config.sampler_name, - embedding_path=embedding_path, - full_precision=config.full_precision, - precision=config.precision, - gfpgan=gfpgan, - codeformer=codeformer, - esrgan=esrgan, - free_gpu_mem=config.free_gpu_mem, - safety_checker=config.safety_checker, - max_loaded_models=config.max_loaded_models, - ) - except (FileNotFoundError, TypeError, AssertionError): - # emergency_model_reconfigure() # TODO? - sys.exit(-1) - except (IOError, KeyError) as e: - print(f"{e}. Aborting.") - sys.exit(-1) - - generate.free_gpu_mem = config.free_gpu_mem - - return generate diff --git a/invokeai/app/services/invocation_services.py b/invokeai/app/services/invocation_services.py index 42cbd6c271..0177d79107 100644 --- a/invokeai/app/services/invocation_services.py +++ b/invokeai/app/services/invocation_services.py @@ -1,5 +1,5 @@ # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) -from invokeai.backend import Generate +from invokeai.backend import InvokeAIGeneratorFactory from .events import EventServiceBase from .image_storage import ImageStorageBase @@ -10,7 +10,7 @@ from .item_storage import ItemStorageABC class InvocationServices: """Services that can be used by invocations""" - generate: Generate # TODO: wrap Generate, or split it up from model? + generator_factory: InvokeAIGeneratorFactory events: EventServiceBase images: ImageStorageBase queue: InvocationQueueABC @@ -20,15 +20,15 @@ class InvocationServices: processor: "InvocationProcessorABC" def __init__( - self, - generate: Generate, - events: EventServiceBase, - images: ImageStorageBase, - queue: InvocationQueueABC, - graph_execution_manager: ItemStorageABC["GraphExecutionState"], - processor: "InvocationProcessorABC", + self, + generator_factory: InvokeAIGeneratorFactory, + events: EventServiceBase, + images: ImageStorageBase, + queue: InvocationQueueABC, + graph_execution_manager: ItemStorageABC["GraphExecutionState"], + processor: "InvocationProcessorABC", ): - self.generate = generate + self.generator_factory = generator_factory self.events = events self.images = images self.queue = queue diff --git a/invokeai/backend/__init__.py b/invokeai/backend/__init__.py index 06089369c2..75fd0b5cb4 100644 --- a/invokeai/backend/__init__.py +++ b/invokeai/backend/__init__.py @@ -2,6 +2,12 @@ Initialization file for invokeai.backend """ from .generate import Generate +from .generator import ( + InvokeAIGeneratorBasicParams, + InvokeAIGeneratorFactory, + InvokeAIGenerator, + InvokeAIGeneratorOutput +) from .model_management import ModelManager from .args import Args from .globals import Globals diff --git a/invokeai/backend/generator/base.py b/invokeai/backend/generator/base.py index 497a56b360..d6c70b4d80 100644 --- a/invokeai/backend/generator/base.py +++ b/invokeai/backend/generator/base.py @@ -4,7 +4,6 @@ including img2img, txt2img, and inpaint """ from __future__ import annotations -import copy import importlib import dataclasses import diffusers @@ -13,7 +12,6 @@ import random import traceback from abc import ABCMeta, abstractmethod from contextlib import nullcontext -from pathlib import Path import cv2 import numpy as np @@ -22,19 +20,59 @@ from PIL import Image, ImageChops, ImageFilter from accelerate.utils import set_seed from diffusers import DiffusionPipeline from tqdm import trange -from typing import List, Type, Callable +from typing import List, Type from dataclasses import dataclass, field from diffusers.schedulers import SchedulerMixin as Scheduler -import invokeai.assets.web as web_assets from ..util.util import rand_perlin_2d - +from ..safety_checker import SafetyChecker from ..prompting.conditioning import get_uc_and_c_and_ec from ..model_management.model_manager import ModelManager from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline downsampling = 8 -CAUTION_IMG = "caution.png" + +@dataclass +class InvokeAIGeneratorBasicParams: + seed: int=None + width: int=512 + height: int=512 + cfg_scale: int=7.5 + steps: int=20 + ddim_eta: float=0.0 + model_name: str='stable-diffusion-1.5' + scheduler: int='ddim' + precision: str='float16' + perlin: float=0.0 + threshold: int=0.0 + h_symmetry_time_pct: float=None + v_symmetry_time_pct: float=None + variation_amount: float = 0.0 + with_variations: list=field(default_factory=list) + safety_checker: SafetyChecker=None + +@dataclass +class InvokeAIGeneratorOutput: + ''' + InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation + operation, including the image, its seed, the model name used to generate the image + and the model hash, as well as all the generate() parameters that went into + generating the image (in .params, also available as attributes) + ''' + image: Image + seed: int + model_name: str + model_hash: str + params: dict + + def __getattribute__(self,name): + try: + return object.__getattribute__(self, name) + except AttributeError: + params = object.__getattribute__(self, 'params') + if name in params: + return params[name] + raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'") class InvokeAIGeneratorFactory(object): def __init__(self, @@ -49,31 +87,15 @@ class InvokeAIGeneratorFactory(object): self.params, **keyword_args ) -@dataclass -class InvokeAIGeneratorBasicParams: - seed: int=None - width: int=512 - height: int=512 - cfg_scale: int=7.5 - steps: int=20 - ddim_eta: float=0.0 - model: str='stable-diffusion-1.5' - scheduler: int='ddim' - precision: str='float16' - perlin: float=0.0 - threshold: int=0.0 - h_symmetry_time_pct: float=None - v_symmetry_time_pct: float=None - variation_amount: float = 0.0 - with_variations: list = field(default_factory=list) -@dataclass -class InvokeAIGeneratorOutput: - image: Image - seed: int - model_name: str - model_hash: str - params: InvokeAIGeneratorBasicParams + # getter and setter shortcuts for commonly used parameters + @property + def model_name(self)->str: + return self.params.model_name + + @model_name.setter + def model_name(self, model_name: str): + self.params.model_name=model_name # we are interposing a wrapper around the original Generator classes so that # old code that calls Generate will continue to work. @@ -93,7 +115,7 @@ class InvokeAIGenerator(metaclass=ABCMeta): def __init__(self, model_manager: ModelManager, - params: InvokeAIGeneratorBasicParams + params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(), ): self.model_manager=model_manager self.params=params @@ -105,7 +127,7 @@ class InvokeAIGenerator(metaclass=ABCMeta): **keyword_args, )->List[InvokeAIGeneratorOutput]: - model_name = self.params.model or self.model_manager.current_model + model_name = self.params.model_name or self.model_manager.current_model model_info: dict = self.model_manager.get_model(model_name) model:StableDiffusionGeneratorPipeline = model_info['model'] model_hash = model_info['hash'] @@ -124,24 +146,33 @@ class InvokeAIGenerator(metaclass=ABCMeta): generator.set_variation(self.params.seed, self.params.variation_amount, self.params.with_variations) + + generator_args = dataclasses.asdict(self.params) + generator_args.update(keyword_args) while True: results = generator.generate(prompt, conditioning=(uc, c, extra_conditioning_info), sampler=scheduler, - **dataclasses.asdict(self.params), - **keyword_args + **generator_args, ) output = InvokeAIGeneratorOutput( image=results[0][0], seed=results[0][1], model_name = model_name, model_hash = model_hash, - params=copy.copy(self.params) + params=generator_args, ) if callback: callback(output) yield output + + @classmethod + def schedulers(self)->List[str]: + ''' + Return list of all the schedulers that we currently handle. + ''' + return list(self.scheduler_map.keys()) def load_generator(self, model: StableDiffusionGeneratorPipeline, class_name: str): module_name = f'invokeai.backend.generator.{class_name.lower()}' @@ -219,8 +250,7 @@ class Inpaint(Img2Img): def _generator_name(self)->str: return 'Inpaint' - - + class Generator: downsampling_factor: int latent_channels: int @@ -240,7 +270,6 @@ class Generator: self.with_variations = [] self.use_mps_noise = False self.free_gpu_mem = None - self.caution_img = None # this is going to be overridden in img2img.py, txt2img.py and inpaint.py def get_make_image(self, prompt, **kwargs): @@ -272,7 +301,7 @@ class Generator: perlin=0.0, h_symmetry_time_pct=None, v_symmetry_time_pct=None, - safety_checker: dict = None, + safety_checker: SafetyChecker=None, free_gpu_mem: bool = False, **kwargs, ): @@ -325,7 +354,7 @@ class Generator: image = make_image(x_T) if self.safety_checker is not None: - image = self.safety_check(image) + image = self.safety_checker.check(image) results.append([image, seed]) @@ -548,53 +577,6 @@ class Generator: return v2 - def safety_check(self, image: Image.Image): - """ - If the CompViz safety checker flags an NSFW image, we - blur it out. - """ - import diffusers - - checker = self.safety_checker["checker"] - extractor = self.safety_checker["extractor"] - features = extractor([image], return_tensors="pt") - features.to(self.model.device) - - # unfortunately checker requires the numpy version, so we have to convert back - x_image = np.array(image).astype(np.float32) / 255.0 - x_image = x_image[None].transpose(0, 3, 1, 2) - - diffusers.logging.set_verbosity_error() - checked_image, has_nsfw_concept = checker( - images=x_image, clip_input=features.pixel_values - ) - if has_nsfw_concept[0]: - print( - "** An image with potential non-safe content has been detected. A blurred image will be returned. **" - ) - return self.blur(image) - else: - return image - - def blur(self, input): - blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32)) - try: - caution = self.get_caution_img() - if caution: - blurry.paste(caution, (0, 0), caution) - except FileNotFoundError: - pass - return blurry - - def get_caution_img(self): - path = None - if self.caution_img: - return self.caution_img - path = Path(web_assets.__path__[0]) / CAUTION_IMG - caution = Image.open(path) - self.caution_img = caution.resize((caution.width // 2, caution.height // 2)) - return self.caution_img - # this is a handy routine for debugging use. Given a generated sample, # convert it into a PNG image and store it at the indicated path def save_sample(self, sample, filepath): diff --git a/invokeai/backend/safety_checker.py b/invokeai/backend/safety_checker.py new file mode 100644 index 0000000000..86cf31cc13 --- /dev/null +++ b/invokeai/backend/safety_checker.py @@ -0,0 +1,89 @@ +''' +SafetyChecker class - checks images against the StabilityAI NSFW filter +and blurs images that contain potential NSFW content. +''' +import diffusers +import numpy as np +import torch +import traceback +from diffusers.pipelines.stable_diffusion.safety_checker import ( + StableDiffusionSafetyChecker, +) +from pathlib import Path +from PIL import Image, ImageFilter +from transformers import AutoFeatureExtractor + +import invokeai.assets.web as web_assets +from .globals import global_cache_dir + +class SafetyChecker(object): + CAUTION_IMG = "caution.png" + + def __init__(self, device: torch.device): + self.device = device + try: + print(">> Initializing NSFW checker") + safety_model_id = "CompVis/stable-diffusion-safety-checker" + safety_model_path = global_cache_dir("hub") + self.safety_checker = StableDiffusionSafetyChecker.from_pretrained( + safety_model_id, + local_files_only=True, + cache_dir=safety_model_path, + ) + self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained( + safety_model_id, + local_files_only=True, + cache_dir=safety_model_path, + ) + self.safety_checker.to(device) + self.safety_feature_extractor.to(device) + except Exception: + print( + "** An error was encountered while installing the safety checker:" + ) + print(traceback.format_exc()) + else: + print(">> NSFW checker is disabled") + + def check(self, image: Image.Image): + """ + Check provided image against the StabilityAI safety checker and return + + """ + + features = self.safety_feature_extractor([image], return_tensors="pt") + # unfortunately checker requires the numpy version, so we have to convert back + x_image = np.array(image).astype(np.float32) / 255.0 + x_image = x_image[None].transpose(0, 3, 1, 2) + + diffusers.logging.set_verbosity_error() + checked_image, has_nsfw_concept = self.safety_checker( + images=x_image, clip_input=features.pixel_values + ) + if has_nsfw_concept[0]: + print( + "** An image with potential non-safe content has been detected. A blurred image will be returned. **" + ) + return self.blur(image) + else: + return image + + def blur(self, input): + blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32)) + try: + caution = self.get_caution_img() + if caution: + blurry.paste(caution, (0, 0), caution) + except FileNotFoundError: + pass + return blurry + + def get_caution_img(self): + path = None + if self.caution_img: + return self.caution_img + path = Path(web_assets.__path__[0]) / self.CAUTION_IMG + caution = Image.open(path) + self.caution_img = caution.resize((caution.width // 2, caution.height // 2)) + return self.caution_img +