mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
node-based txt2img working without generate
This commit is contained in:
parent
87789c1de8
commit
5d37fa6e36
@ -4,7 +4,7 @@ import os
|
|||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
from ...backend import Globals
|
from ...backend import Globals
|
||||||
from ..services.generate_initializer import get_generate
|
from ..services.generate_initializer import get_generator_factory
|
||||||
from ..services.graph import GraphExecutionState
|
from ..services.graph import GraphExecutionState
|
||||||
from ..services.image_storage import DiskImageStorage
|
from ..services.image_storage import DiskImageStorage
|
||||||
from ..services.invocation_queue import MemoryInvocationQueue
|
from ..services.invocation_queue import MemoryInvocationQueue
|
||||||
@ -47,7 +47,7 @@ class ApiDependencies:
|
|||||||
# TODO: Use a logger
|
# TODO: Use a logger
|
||||||
print(f">> Internet connectivity is {Globals.internet_available}")
|
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||||
|
|
||||||
generate = get_generate(args, config)
|
generator_factory = get_generator_factory(args, config)
|
||||||
|
|
||||||
events = FastAPIEventService(event_handler_id)
|
events = FastAPIEventService(event_handler_id)
|
||||||
|
|
||||||
@ -61,7 +61,7 @@ class ApiDependencies:
|
|||||||
db_location = os.path.join(output_folder, "invokeai.db")
|
db_location = os.path.join(output_folder, "invokeai.db")
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
generate=generate,
|
generator_factory=generator_factory,
|
||||||
events=events,
|
events=events,
|
||||||
images=images,
|
images=images,
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
|
@ -17,7 +17,7 @@ from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_gra
|
|||||||
from .invocations import *
|
from .invocations import *
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
from .services.events import EventServiceBase
|
from .services.events import EventServiceBase
|
||||||
from .services.generate_initializer import get_generate
|
from .services.generate_initializer import get_generator_factory
|
||||||
from .services.graph import EdgeConnection, GraphExecutionState
|
from .services.graph import EdgeConnection, GraphExecutionState
|
||||||
from .services.image_storage import DiskImageStorage
|
from .services.image_storage import DiskImageStorage
|
||||||
from .services.invocation_queue import MemoryInvocationQueue
|
from .services.invocation_queue import MemoryInvocationQueue
|
||||||
@ -106,11 +106,7 @@ def invoke_cli():
|
|||||||
args = Args()
|
args = Args()
|
||||||
config = args.parse_args()
|
config = args.parse_args()
|
||||||
|
|
||||||
generate = get_generate(args, config)
|
generator_factory = get_generator_factory(args, config)
|
||||||
|
|
||||||
# NOTE: load model on first use, uncomment to load at startup
|
|
||||||
# TODO: Make this a config option?
|
|
||||||
# generate.load_model()
|
|
||||||
|
|
||||||
events = EventServiceBase()
|
events = EventServiceBase()
|
||||||
|
|
||||||
@ -122,7 +118,7 @@ def invoke_cli():
|
|||||||
db_location = os.path.join(output_folder, "invokeai.db")
|
db_location = os.path.join(output_folder, "invokeai.db")
|
||||||
|
|
||||||
services = InvocationServices(
|
services = InvocationServices(
|
||||||
generate=generate,
|
generator_factory=generator_factory,
|
||||||
events=events,
|
events=events,
|
||||||
images=DiskImageStorage(output_folder),
|
images=DiskImageStorage(output_folder),
|
||||||
queue=MemoryInvocationQueue(),
|
queue=MemoryInvocationQueue(),
|
||||||
|
@ -12,9 +12,10 @@ from ..services.image_storage import ImageType
|
|||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
from .baseinvocation import BaseInvocation, InvocationContext
|
from .baseinvocation import BaseInvocation, InvocationContext
|
||||||
from .image import ImageField, ImageOutput
|
from .image import ImageField, ImageOutput
|
||||||
|
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[
|
SAMPLER_NAME_VALUES = Literal[
|
||||||
"ddim", "plms", "k_lms", "k_dpm_2", "k_dpm_2_a", "k_euler", "k_euler_a", "k_heun"
|
tuple(InvokeAIGenerator.schedulers())
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -57,19 +58,24 @@ class TextToImageInvocation(BaseInvocation):
|
|||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
# TODO: How to get the default model name now?
|
# TODO: How to get the default model name now?
|
||||||
if self.model is None or self.model == "":
|
factory = context.services.generator_factory
|
||||||
self.model = context.services.generate.model_name
|
if self.model:
|
||||||
|
factory.model_name = self.model
|
||||||
|
else:
|
||||||
|
self.model = factory.model_name
|
||||||
|
|
||||||
# Set the model (if already cached, this does nothing)
|
txt2img = factory.make_generator(Txt2Img)
|
||||||
context.services.generate.set_model(self.model)
|
|
||||||
|
|
||||||
results = context.services.generate.prompt2image(
|
outputs = txt2img.generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
step_callback=step_callback,
|
step_callback=step_callback,
|
||||||
**self.dict(
|
**self.dict(
|
||||||
exclude={"prompt"}
|
exclude={"prompt"}
|
||||||
), # Shorthand for passing all of the parameters above manually
|
), # Shorthand for passing all of the parameters above manually
|
||||||
)
|
)
|
||||||
|
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||||
|
# each time it is called. We only need the first one.
|
||||||
|
generate_output = next(outputs)
|
||||||
|
|
||||||
# Results are image and seed, unwrap for now and ignore the seed
|
# Results are image and seed, unwrap for now and ignore the seed
|
||||||
# TODO: pre-seed?
|
# TODO: pre-seed?
|
||||||
@ -78,7 +84,7 @@ class TextToImageInvocation(BaseInvocation):
|
|||||||
image_name = context.services.images.create_name(
|
image_name = context.services.images.create_name(
|
||||||
context.graph_execution_state_id, self.id
|
context.graph_execution_state_id, self.id
|
||||||
)
|
)
|
||||||
context.services.images.save(image_type, image_name, results[0][0])
|
context.services.images.save(image_type, image_name, generate_output.image)
|
||||||
return ImageOutput(
|
return ImageOutput(
|
||||||
image=ImageField(image_type=image_type, image_name=image_name)
|
image=ImageField(image_type=image_type, image_name=image_name)
|
||||||
)
|
)
|
||||||
@ -115,23 +121,24 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
# TODO: How to get the default model name now?
|
# TODO: How to get the default model name now?
|
||||||
if self.model is None or self.model == "":
|
factory = context.services.generator_factory
|
||||||
self.model = context.services.generate.model_name
|
self.model = self.model or factory.model_name
|
||||||
|
factory.model_name = self.model
|
||||||
|
img2img = factory.make_generator(Img2Img)
|
||||||
|
|
||||||
# Set the model (if already cached, this does nothing)
|
generator_output = next(
|
||||||
context.services.generate.set_model(self.model)
|
img2img.generate(
|
||||||
|
prompt=self.prompt,
|
||||||
results = context.services.generate.prompt2image(
|
init_img=image,
|
||||||
prompt=self.prompt,
|
init_mask=mask,
|
||||||
init_img=image,
|
step_callback=step_callback,
|
||||||
init_mask=mask,
|
**self.dict(
|
||||||
step_callback=step_callback,
|
exclude={"prompt", "image", "mask"}
|
||||||
**self.dict(
|
), # Shorthand for passing all of the parameters above manually
|
||||||
exclude={"prompt", "image", "mask"}
|
)
|
||||||
), # Shorthand for passing all of the parameters above manually
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result_image = results[0][0]
|
result_image = generator_output.image
|
||||||
|
|
||||||
# Results are image and seed, unwrap for now and ignore the seed
|
# Results are image and seed, unwrap for now and ignore the seed
|
||||||
# TODO: pre-seed?
|
# TODO: pre-seed?
|
||||||
@ -145,7 +152,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
image=ImageField(image_type=image_type, image_name=image_name)
|
image=ImageField(image_type=image_type, image_name=image_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InpaintInvocation(ImageToImageInvocation):
|
class InpaintInvocation(ImageToImageInvocation):
|
||||||
"""Generates an image using inpaint."""
|
"""Generates an image using inpaint."""
|
||||||
|
|
||||||
@ -180,23 +186,24 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
# TODO: How to get the default model name now?
|
# TODO: How to get the default model name now?
|
||||||
if self.model is None or self.model == "":
|
factory = context.services.generator_factory
|
||||||
self.model = context.services.generate.model_name
|
self.model = self.model or factory.model_name
|
||||||
|
factory.model_name = self.model
|
||||||
|
inpaint = factory.make_generator(Inpaint)
|
||||||
|
|
||||||
# Set the model (if already cached, this does nothing)
|
generator_output = next(
|
||||||
context.services.generate.set_model(self.model)
|
inpaint.generate(
|
||||||
|
prompt=self.prompt,
|
||||||
results = context.services.generate.prompt2image(
|
init_img=image,
|
||||||
prompt=self.prompt,
|
init_mask=mask,
|
||||||
init_img=image,
|
step_callback=step_callback,
|
||||||
init_mask=mask,
|
**self.dict(
|
||||||
step_callback=step_callback,
|
exclude={"prompt", "image", "mask"}
|
||||||
**self.dict(
|
), # Shorthand for passing all of the parameters above manually
|
||||||
exclude={"prompt", "image", "mask"}
|
)
|
||||||
), # Shorthand for passing all of the parameters above manually
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result_image = results[0][0]
|
result_image = generator_output.image
|
||||||
|
|
||||||
# Results are image and seed, unwrap for now and ignore the seed
|
# Results are image and seed, unwrap for now and ignore the seed
|
||||||
# TODO: pre-seed?
|
# TODO: pre-seed?
|
||||||
|
@ -1,16 +1,17 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import torch
|
||||||
import traceback
|
import traceback
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
import invokeai.version
|
import invokeai.version
|
||||||
from invokeai.backend import Generate, ModelManager
|
from ...backend import ModelManager, InvokeAIGeneratorBasicParams, InvokeAIGeneratorFactory
|
||||||
|
from ...backend.util import choose_precision, choose_torch_device
|
||||||
from ...backend import Globals
|
from ...backend import Globals
|
||||||
|
|
||||||
|
|
||||||
# TODO: most of this code should be split into individual services as the Generate.py code is deprecated
|
# TODO: most of this code should be split into individual services as the Generate.py code is deprecated
|
||||||
def get_generate(args, config) -> Generate:
|
def get_generator_factory(args, config) -> InvokeAIGeneratorFactory:
|
||||||
if not args.conf:
|
if not args.conf:
|
||||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||||
if not os.path.exists(config_file):
|
if not os.path.exists(config_file):
|
||||||
@ -63,49 +64,43 @@ def get_generate(args, config) -> Generate:
|
|||||||
print(f"{e}. Aborting.")
|
print(f"{e}. Aborting.")
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
# creating a Generate object:
|
# creating an InvokeAIGeneratorFactory object:
|
||||||
try:
|
try:
|
||||||
gen = Generate(
|
device = torch.device(choose_torch_device())
|
||||||
conf=args.conf,
|
precision = 'float16' if args.precision=='float16' \
|
||||||
model=args.model,
|
else 'float32' if args.precision=='float32' \
|
||||||
sampler_name=args.sampler_name,
|
else choose_precision(device)
|
||||||
embedding_path=embedding_path,
|
|
||||||
full_precision=args.full_precision,
|
model_manager = ModelManager(
|
||||||
precision=args.precision,
|
OmegaConf.load(args.conf),
|
||||||
gfpgan=gfpgan,
|
precision=precision,
|
||||||
codeformer=codeformer,
|
device_type=device,
|
||||||
esrgan=esrgan,
|
|
||||||
free_gpu_mem=args.free_gpu_mem,
|
|
||||||
safety_checker=args.safety_checker,
|
|
||||||
max_loaded_models=args.max_loaded_models,
|
max_loaded_models=args.max_loaded_models,
|
||||||
)
|
)
|
||||||
|
# TO DO: initialize and pass safety checker!!!
|
||||||
|
params = InvokeAIGeneratorBasicParams(
|
||||||
|
precision=precision,
|
||||||
|
)
|
||||||
|
factory = InvokeAIGeneratorFactory(model_manager, params)
|
||||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||||
report_model_error(opt, e)
|
report_model_error(args, e)
|
||||||
except (IOError, KeyError) as e:
|
except (IOError, KeyError) as e:
|
||||||
print(f"{e}. Aborting.")
|
print(f"{e}. Aborting.")
|
||||||
sys.exit(-1)
|
sys.exit(-1)
|
||||||
|
|
||||||
if args.seamless:
|
if args.seamless:
|
||||||
|
#TODO: do something here ?
|
||||||
print(">> changed to seamless tiling mode")
|
print(">> changed to seamless tiling mode")
|
||||||
|
|
||||||
# preload the model
|
|
||||||
try:
|
|
||||||
gen.load_model()
|
|
||||||
except KeyError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
report_model_error(args, e)
|
|
||||||
|
|
||||||
# try to autoconvert new models
|
# try to autoconvert new models
|
||||||
# autoimport new .ckpt files
|
# autoimport new .ckpt files
|
||||||
if path := args.autoconvert:
|
if path := args.autoconvert:
|
||||||
gen.model_manager.autoconvert_weights(
|
model_manager.autoconvert_weights(
|
||||||
conf_path=args.conf,
|
conf_path=args.conf,
|
||||||
weights_directory=path,
|
weights_directory=path,
|
||||||
)
|
)
|
||||||
|
|
||||||
return gen
|
return factory
|
||||||
|
|
||||||
|
|
||||||
def load_face_restoration(opt):
|
def load_face_restoration(opt):
|
||||||
try:
|
try:
|
||||||
@ -171,85 +166,3 @@ def report_model_error(opt: Namespace, e: Exception):
|
|||||||
# sys.argv = previous_args
|
# sys.argv = previous_args
|
||||||
# main() # would rather do a os.exec(), but doesn't exist?
|
# main() # would rather do a os.exec(), but doesn't exist?
|
||||||
# sys.exit(0)
|
# sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
# Temporary initializer for Generate until we migrate off of it
|
|
||||||
def old_get_generate(args, config) -> Generate:
|
|
||||||
# TODO: Remove the need for globals
|
|
||||||
from invokeai.backend.globals import Globals
|
|
||||||
|
|
||||||
# alert - setting globals here
|
|
||||||
Globals.root = os.path.expanduser(
|
|
||||||
args.root_dir or os.environ.get("INVOKEAI_ROOT") or os.path.abspath(".")
|
|
||||||
)
|
|
||||||
Globals.try_patchmatch = args.patchmatch
|
|
||||||
|
|
||||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
|
||||||
|
|
||||||
# these two lines prevent a horrible warning message from appearing
|
|
||||||
# when the frozen CLIP tokenizer is imported
|
|
||||||
import transformers
|
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
|
||||||
|
|
||||||
# Loading Face Restoration and ESRGAN Modules
|
|
||||||
gfpgan, codeformer, esrgan = None, None, None
|
|
||||||
try:
|
|
||||||
if config.restore or config.esrgan:
|
|
||||||
from ldm.invoke.restoration import Restoration
|
|
||||||
|
|
||||||
restoration = Restoration()
|
|
||||||
if config.restore:
|
|
||||||
gfpgan, codeformer = restoration.load_face_restore_models(
|
|
||||||
config.gfpgan_model_path
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
print(">> Face restoration disabled")
|
|
||||||
if config.esrgan:
|
|
||||||
esrgan = restoration.load_esrgan(config.esrgan_bg_tile)
|
|
||||||
else:
|
|
||||||
print(">> Upscaling disabled")
|
|
||||||
else:
|
|
||||||
print(">> Face restoration and upscaling disabled")
|
|
||||||
except (ModuleNotFoundError, ImportError):
|
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
|
||||||
|
|
||||||
# normalize the config directory relative to root
|
|
||||||
if not os.path.isabs(config.conf):
|
|
||||||
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
|
|
||||||
|
|
||||||
if config.embeddings:
|
|
||||||
if not os.path.isabs(config.embedding_path):
|
|
||||||
embedding_path = os.path.normpath(
|
|
||||||
os.path.join(Globals.root, config.embedding_path)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
embedding_path = None
|
|
||||||
|
|
||||||
# TODO: lazy-initialize this by wrapping it
|
|
||||||
try:
|
|
||||||
generate = Generate(
|
|
||||||
conf=config.conf,
|
|
||||||
model=config.model,
|
|
||||||
sampler_name=config.sampler_name,
|
|
||||||
embedding_path=embedding_path,
|
|
||||||
full_precision=config.full_precision,
|
|
||||||
precision=config.precision,
|
|
||||||
gfpgan=gfpgan,
|
|
||||||
codeformer=codeformer,
|
|
||||||
esrgan=esrgan,
|
|
||||||
free_gpu_mem=config.free_gpu_mem,
|
|
||||||
safety_checker=config.safety_checker,
|
|
||||||
max_loaded_models=config.max_loaded_models,
|
|
||||||
)
|
|
||||||
except (FileNotFoundError, TypeError, AssertionError):
|
|
||||||
# emergency_model_reconfigure() # TODO?
|
|
||||||
sys.exit(-1)
|
|
||||||
except (IOError, KeyError) as e:
|
|
||||||
print(f"{e}. Aborting.")
|
|
||||||
sys.exit(-1)
|
|
||||||
|
|
||||||
generate.free_gpu_mem = config.free_gpu_mem
|
|
||||||
|
|
||||||
return generate
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
from invokeai.backend import Generate
|
from invokeai.backend import InvokeAIGeneratorFactory
|
||||||
|
|
||||||
from .events import EventServiceBase
|
from .events import EventServiceBase
|
||||||
from .image_storage import ImageStorageBase
|
from .image_storage import ImageStorageBase
|
||||||
@ -10,7 +10,7 @@ from .item_storage import ItemStorageABC
|
|||||||
class InvocationServices:
|
class InvocationServices:
|
||||||
"""Services that can be used by invocations"""
|
"""Services that can be used by invocations"""
|
||||||
|
|
||||||
generate: Generate # TODO: wrap Generate, or split it up from model?
|
generator_factory: InvokeAIGeneratorFactory
|
||||||
events: EventServiceBase
|
events: EventServiceBase
|
||||||
images: ImageStorageBase
|
images: ImageStorageBase
|
||||||
queue: InvocationQueueABC
|
queue: InvocationQueueABC
|
||||||
@ -20,15 +20,15 @@ class InvocationServices:
|
|||||||
processor: "InvocationProcessorABC"
|
processor: "InvocationProcessorABC"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
generate: Generate,
|
generator_factory: InvokeAIGeneratorFactory,
|
||||||
events: EventServiceBase,
|
events: EventServiceBase,
|
||||||
images: ImageStorageBase,
|
images: ImageStorageBase,
|
||||||
queue: InvocationQueueABC,
|
queue: InvocationQueueABC,
|
||||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||||
processor: "InvocationProcessorABC",
|
processor: "InvocationProcessorABC",
|
||||||
):
|
):
|
||||||
self.generate = generate
|
self.generator_factory = generator_factory
|
||||||
self.events = events
|
self.events = events
|
||||||
self.images = images
|
self.images = images
|
||||||
self.queue = queue
|
self.queue = queue
|
||||||
|
@ -2,6 +2,12 @@
|
|||||||
Initialization file for invokeai.backend
|
Initialization file for invokeai.backend
|
||||||
"""
|
"""
|
||||||
from .generate import Generate
|
from .generate import Generate
|
||||||
|
from .generator import (
|
||||||
|
InvokeAIGeneratorBasicParams,
|
||||||
|
InvokeAIGeneratorFactory,
|
||||||
|
InvokeAIGenerator,
|
||||||
|
InvokeAIGeneratorOutput
|
||||||
|
)
|
||||||
from .model_management import ModelManager
|
from .model_management import ModelManager
|
||||||
from .args import Args
|
from .args import Args
|
||||||
from .globals import Globals
|
from .globals import Globals
|
||||||
|
@ -4,7 +4,6 @@ including img2img, txt2img, and inpaint
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
|
||||||
import importlib
|
import importlib
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import diffusers
|
import diffusers
|
||||||
@ -13,7 +12,6 @@ import random
|
|||||||
import traceback
|
import traceback
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -22,19 +20,59 @@ from PIL import Image, ImageChops, ImageFilter
|
|||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
from typing import List, Type, Callable
|
from typing import List, Type
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
|
||||||
import invokeai.assets.web as web_assets
|
|
||||||
from ..util.util import rand_perlin_2d
|
from ..util.util import rand_perlin_2d
|
||||||
|
from ..safety_checker import SafetyChecker
|
||||||
from ..prompting.conditioning import get_uc_and_c_and_ec
|
from ..prompting.conditioning import get_uc_and_c_and_ec
|
||||||
from ..model_management.model_manager import ModelManager
|
from ..model_management.model_manager import ModelManager
|
||||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
downsampling = 8
|
downsampling = 8
|
||||||
CAUTION_IMG = "caution.png"
|
|
||||||
|
@dataclass
|
||||||
|
class InvokeAIGeneratorBasicParams:
|
||||||
|
seed: int=None
|
||||||
|
width: int=512
|
||||||
|
height: int=512
|
||||||
|
cfg_scale: int=7.5
|
||||||
|
steps: int=20
|
||||||
|
ddim_eta: float=0.0
|
||||||
|
model_name: str='stable-diffusion-1.5'
|
||||||
|
scheduler: int='ddim'
|
||||||
|
precision: str='float16'
|
||||||
|
perlin: float=0.0
|
||||||
|
threshold: int=0.0
|
||||||
|
h_symmetry_time_pct: float=None
|
||||||
|
v_symmetry_time_pct: float=None
|
||||||
|
variation_amount: float = 0.0
|
||||||
|
with_variations: list=field(default_factory=list)
|
||||||
|
safety_checker: SafetyChecker=None
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InvokeAIGeneratorOutput:
|
||||||
|
'''
|
||||||
|
InvokeAIGeneratorOutput is a dataclass that contains the outputs of a generation
|
||||||
|
operation, including the image, its seed, the model name used to generate the image
|
||||||
|
and the model hash, as well as all the generate() parameters that went into
|
||||||
|
generating the image (in .params, also available as attributes)
|
||||||
|
'''
|
||||||
|
image: Image
|
||||||
|
seed: int
|
||||||
|
model_name: str
|
||||||
|
model_hash: str
|
||||||
|
params: dict
|
||||||
|
|
||||||
|
def __getattribute__(self,name):
|
||||||
|
try:
|
||||||
|
return object.__getattribute__(self, name)
|
||||||
|
except AttributeError:
|
||||||
|
params = object.__getattribute__(self, 'params')
|
||||||
|
if name in params:
|
||||||
|
return params[name]
|
||||||
|
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
|
||||||
|
|
||||||
class InvokeAIGeneratorFactory(object):
|
class InvokeAIGeneratorFactory(object):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -49,31 +87,15 @@ class InvokeAIGeneratorFactory(object):
|
|||||||
self.params,
|
self.params,
|
||||||
**keyword_args
|
**keyword_args
|
||||||
)
|
)
|
||||||
@dataclass
|
|
||||||
class InvokeAIGeneratorBasicParams:
|
|
||||||
seed: int=None
|
|
||||||
width: int=512
|
|
||||||
height: int=512
|
|
||||||
cfg_scale: int=7.5
|
|
||||||
steps: int=20
|
|
||||||
ddim_eta: float=0.0
|
|
||||||
model: str='stable-diffusion-1.5'
|
|
||||||
scheduler: int='ddim'
|
|
||||||
precision: str='float16'
|
|
||||||
perlin: float=0.0
|
|
||||||
threshold: int=0.0
|
|
||||||
h_symmetry_time_pct: float=None
|
|
||||||
v_symmetry_time_pct: float=None
|
|
||||||
variation_amount: float = 0.0
|
|
||||||
with_variations: list = field(default_factory=list)
|
|
||||||
|
|
||||||
@dataclass
|
# getter and setter shortcuts for commonly used parameters
|
||||||
class InvokeAIGeneratorOutput:
|
@property
|
||||||
image: Image
|
def model_name(self)->str:
|
||||||
seed: int
|
return self.params.model_name
|
||||||
model_name: str
|
|
||||||
model_hash: str
|
@model_name.setter
|
||||||
params: InvokeAIGeneratorBasicParams
|
def model_name(self, model_name: str):
|
||||||
|
self.params.model_name=model_name
|
||||||
|
|
||||||
# we are interposing a wrapper around the original Generator classes so that
|
# we are interposing a wrapper around the original Generator classes so that
|
||||||
# old code that calls Generate will continue to work.
|
# old code that calls Generate will continue to work.
|
||||||
@ -93,7 +115,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_manager: ModelManager,
|
model_manager: ModelManager,
|
||||||
params: InvokeAIGeneratorBasicParams
|
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||||
):
|
):
|
||||||
self.model_manager=model_manager
|
self.model_manager=model_manager
|
||||||
self.params=params
|
self.params=params
|
||||||
@ -105,7 +127,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
**keyword_args,
|
**keyword_args,
|
||||||
)->List[InvokeAIGeneratorOutput]:
|
)->List[InvokeAIGeneratorOutput]:
|
||||||
|
|
||||||
model_name = self.params.model or self.model_manager.current_model
|
model_name = self.params.model_name or self.model_manager.current_model
|
||||||
model_info: dict = self.model_manager.get_model(model_name)
|
model_info: dict = self.model_manager.get_model(model_name)
|
||||||
model:StableDiffusionGeneratorPipeline = model_info['model']
|
model:StableDiffusionGeneratorPipeline = model_info['model']
|
||||||
model_hash = model_info['hash']
|
model_hash = model_info['hash']
|
||||||
@ -125,24 +147,33 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
self.params.variation_amount,
|
self.params.variation_amount,
|
||||||
self.params.with_variations)
|
self.params.with_variations)
|
||||||
|
|
||||||
|
generator_args = dataclasses.asdict(self.params)
|
||||||
|
generator_args.update(keyword_args)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
results = generator.generate(prompt,
|
results = generator.generate(prompt,
|
||||||
conditioning=(uc, c, extra_conditioning_info),
|
conditioning=(uc, c, extra_conditioning_info),
|
||||||
sampler=scheduler,
|
sampler=scheduler,
|
||||||
**dataclasses.asdict(self.params),
|
**generator_args,
|
||||||
**keyword_args
|
|
||||||
)
|
)
|
||||||
output = InvokeAIGeneratorOutput(
|
output = InvokeAIGeneratorOutput(
|
||||||
image=results[0][0],
|
image=results[0][0],
|
||||||
seed=results[0][1],
|
seed=results[0][1],
|
||||||
model_name = model_name,
|
model_name = model_name,
|
||||||
model_hash = model_hash,
|
model_hash = model_hash,
|
||||||
params=copy.copy(self.params)
|
params=generator_args,
|
||||||
)
|
)
|
||||||
if callback:
|
if callback:
|
||||||
callback(output)
|
callback(output)
|
||||||
yield output
|
yield output
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def schedulers(self)->List[str]:
|
||||||
|
'''
|
||||||
|
Return list of all the schedulers that we currently handle.
|
||||||
|
'''
|
||||||
|
return list(self.scheduler_map.keys())
|
||||||
|
|
||||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, class_name: str):
|
def load_generator(self, model: StableDiffusionGeneratorPipeline, class_name: str):
|
||||||
module_name = f'invokeai.backend.generator.{class_name.lower()}'
|
module_name = f'invokeai.backend.generator.{class_name.lower()}'
|
||||||
module = importlib.import_module(module_name)
|
module = importlib.import_module(module_name)
|
||||||
@ -220,7 +251,6 @@ class Inpaint(Img2Img):
|
|||||||
def _generator_name(self)->str:
|
def _generator_name(self)->str:
|
||||||
return 'Inpaint'
|
return 'Inpaint'
|
||||||
|
|
||||||
|
|
||||||
class Generator:
|
class Generator:
|
||||||
downsampling_factor: int
|
downsampling_factor: int
|
||||||
latent_channels: int
|
latent_channels: int
|
||||||
@ -240,7 +270,6 @@ class Generator:
|
|||||||
self.with_variations = []
|
self.with_variations = []
|
||||||
self.use_mps_noise = False
|
self.use_mps_noise = False
|
||||||
self.free_gpu_mem = None
|
self.free_gpu_mem = None
|
||||||
self.caution_img = None
|
|
||||||
|
|
||||||
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
|
||||||
def get_make_image(self, prompt, **kwargs):
|
def get_make_image(self, prompt, **kwargs):
|
||||||
@ -272,7 +301,7 @@ class Generator:
|
|||||||
perlin=0.0,
|
perlin=0.0,
|
||||||
h_symmetry_time_pct=None,
|
h_symmetry_time_pct=None,
|
||||||
v_symmetry_time_pct=None,
|
v_symmetry_time_pct=None,
|
||||||
safety_checker: dict = None,
|
safety_checker: SafetyChecker=None,
|
||||||
free_gpu_mem: bool = False,
|
free_gpu_mem: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@ -325,7 +354,7 @@ class Generator:
|
|||||||
image = make_image(x_T)
|
image = make_image(x_T)
|
||||||
|
|
||||||
if self.safety_checker is not None:
|
if self.safety_checker is not None:
|
||||||
image = self.safety_check(image)
|
image = self.safety_checker.check(image)
|
||||||
|
|
||||||
results.append([image, seed])
|
results.append([image, seed])
|
||||||
|
|
||||||
@ -548,53 +577,6 @@ class Generator:
|
|||||||
|
|
||||||
return v2
|
return v2
|
||||||
|
|
||||||
def safety_check(self, image: Image.Image):
|
|
||||||
"""
|
|
||||||
If the CompViz safety checker flags an NSFW image, we
|
|
||||||
blur it out.
|
|
||||||
"""
|
|
||||||
import diffusers
|
|
||||||
|
|
||||||
checker = self.safety_checker["checker"]
|
|
||||||
extractor = self.safety_checker["extractor"]
|
|
||||||
features = extractor([image], return_tensors="pt")
|
|
||||||
features.to(self.model.device)
|
|
||||||
|
|
||||||
# unfortunately checker requires the numpy version, so we have to convert back
|
|
||||||
x_image = np.array(image).astype(np.float32) / 255.0
|
|
||||||
x_image = x_image[None].transpose(0, 3, 1, 2)
|
|
||||||
|
|
||||||
diffusers.logging.set_verbosity_error()
|
|
||||||
checked_image, has_nsfw_concept = checker(
|
|
||||||
images=x_image, clip_input=features.pixel_values
|
|
||||||
)
|
|
||||||
if has_nsfw_concept[0]:
|
|
||||||
print(
|
|
||||||
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
|
|
||||||
)
|
|
||||||
return self.blur(image)
|
|
||||||
else:
|
|
||||||
return image
|
|
||||||
|
|
||||||
def blur(self, input):
|
|
||||||
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
|
||||||
try:
|
|
||||||
caution = self.get_caution_img()
|
|
||||||
if caution:
|
|
||||||
blurry.paste(caution, (0, 0), caution)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
return blurry
|
|
||||||
|
|
||||||
def get_caution_img(self):
|
|
||||||
path = None
|
|
||||||
if self.caution_img:
|
|
||||||
return self.caution_img
|
|
||||||
path = Path(web_assets.__path__[0]) / CAUTION_IMG
|
|
||||||
caution = Image.open(path)
|
|
||||||
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
|
||||||
return self.caution_img
|
|
||||||
|
|
||||||
# this is a handy routine for debugging use. Given a generated sample,
|
# this is a handy routine for debugging use. Given a generated sample,
|
||||||
# convert it into a PNG image and store it at the indicated path
|
# convert it into a PNG image and store it at the indicated path
|
||||||
def save_sample(self, sample, filepath):
|
def save_sample(self, sample, filepath):
|
||||||
|
89
invokeai/backend/safety_checker.py
Normal file
89
invokeai/backend/safety_checker.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
'''
|
||||||
|
SafetyChecker class - checks images against the StabilityAI NSFW filter
|
||||||
|
and blurs images that contain potential NSFW content.
|
||||||
|
'''
|
||||||
|
import diffusers
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import traceback
|
||||||
|
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||||
|
StableDiffusionSafetyChecker,
|
||||||
|
)
|
||||||
|
from pathlib import Path
|
||||||
|
from PIL import Image, ImageFilter
|
||||||
|
from transformers import AutoFeatureExtractor
|
||||||
|
|
||||||
|
import invokeai.assets.web as web_assets
|
||||||
|
from .globals import global_cache_dir
|
||||||
|
|
||||||
|
class SafetyChecker(object):
|
||||||
|
CAUTION_IMG = "caution.png"
|
||||||
|
|
||||||
|
def __init__(self, device: torch.device):
|
||||||
|
self.device = device
|
||||||
|
try:
|
||||||
|
print(">> Initializing NSFW checker")
|
||||||
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
|
safety_model_path = global_cache_dir("hub")
|
||||||
|
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(
|
||||||
|
safety_model_id,
|
||||||
|
local_files_only=True,
|
||||||
|
cache_dir=safety_model_path,
|
||||||
|
)
|
||||||
|
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||||
|
safety_model_id,
|
||||||
|
local_files_only=True,
|
||||||
|
cache_dir=safety_model_path,
|
||||||
|
)
|
||||||
|
self.safety_checker.to(device)
|
||||||
|
self.safety_feature_extractor.to(device)
|
||||||
|
except Exception:
|
||||||
|
print(
|
||||||
|
"** An error was encountered while installing the safety checker:"
|
||||||
|
)
|
||||||
|
print(traceback.format_exc())
|
||||||
|
else:
|
||||||
|
print(">> NSFW checker is disabled")
|
||||||
|
|
||||||
|
def check(self, image: Image.Image):
|
||||||
|
"""
|
||||||
|
Check provided image against the StabilityAI safety checker and return
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
features = self.safety_feature_extractor([image], return_tensors="pt")
|
||||||
|
# unfortunately checker requires the numpy version, so we have to convert back
|
||||||
|
x_image = np.array(image).astype(np.float32) / 255.0
|
||||||
|
x_image = x_image[None].transpose(0, 3, 1, 2)
|
||||||
|
|
||||||
|
diffusers.logging.set_verbosity_error()
|
||||||
|
checked_image, has_nsfw_concept = self.safety_checker(
|
||||||
|
images=x_image, clip_input=features.pixel_values
|
||||||
|
)
|
||||||
|
if has_nsfw_concept[0]:
|
||||||
|
print(
|
||||||
|
"** An image with potential non-safe content has been detected. A blurred image will be returned. **"
|
||||||
|
)
|
||||||
|
return self.blur(image)
|
||||||
|
else:
|
||||||
|
return image
|
||||||
|
|
||||||
|
def blur(self, input):
|
||||||
|
blurry = input.filter(filter=ImageFilter.GaussianBlur(radius=32))
|
||||||
|
try:
|
||||||
|
caution = self.get_caution_img()
|
||||||
|
if caution:
|
||||||
|
blurry.paste(caution, (0, 0), caution)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
return blurry
|
||||||
|
|
||||||
|
def get_caution_img(self):
|
||||||
|
path = None
|
||||||
|
if self.caution_img:
|
||||||
|
return self.caution_img
|
||||||
|
path = Path(web_assets.__path__[0]) / self.CAUTION_IMG
|
||||||
|
caution = Image.open(path)
|
||||||
|
self.caution_img = caution.resize((caution.width // 2, caution.height // 2))
|
||||||
|
return self.caution_img
|
||||||
|
|
Loading…
Reference in New Issue
Block a user