node-based txt2img working without generate

This commit is contained in:
Lincoln Stein
2023-03-09 00:18:29 -05:00
parent 87789c1de8
commit 5d37fa6e36
8 changed files with 247 additions and 254 deletions

View File

@ -4,7 +4,7 @@ import os
from argparse import Namespace
from ...backend import Globals
from ..services.generate_initializer import get_generate
from ..services.generate_initializer import get_generator_factory
from ..services.graph import GraphExecutionState
from ..services.image_storage import DiskImageStorage
from ..services.invocation_queue import MemoryInvocationQueue
@ -47,7 +47,7 @@ class ApiDependencies:
# TODO: Use a logger
print(f">> Internet connectivity is {Globals.internet_available}")
generate = get_generate(args, config)
generator_factory = get_generator_factory(args, config)
events = FastAPIEventService(event_handler_id)
@ -61,7 +61,7 @@ class ApiDependencies:
db_location = os.path.join(output_folder, "invokeai.db")
services = InvocationServices(
generate=generate,
generator_factory=generator_factory,
events=events,
images=images,
queue=MemoryInvocationQueue(),

View File

@ -17,7 +17,7 @@ from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_gra
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
from .services.generate_initializer import get_generate
from .services.generate_initializer import get_generator_factory
from .services.graph import EdgeConnection, GraphExecutionState
from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue
@ -106,11 +106,7 @@ def invoke_cli():
args = Args()
config = args.parse_args()
generate = get_generate(args, config)
# NOTE: load model on first use, uncomment to load at startup
# TODO: Make this a config option?
# generate.load_model()
generator_factory = get_generator_factory(args, config)
events = EventServiceBase()
@ -122,7 +118,7 @@ def invoke_cli():
db_location = os.path.join(output_folder, "invokeai.db")
services = InvocationServices(
generate=generate,
generator_factory=generator_factory,
events=events,
images=DiskImageStorage(output_folder),
queue=MemoryInvocationQueue(),

View File

@ -12,9 +12,10 @@ from ..services.image_storage import ImageType
from ..services.invocation_services import InvocationServices
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageField, ImageOutput
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
SAMPLER_NAME_VALUES = Literal[
"ddim", "plms", "k_lms", "k_dpm_2", "k_dpm_2_a", "k_euler", "k_euler_a", "k_heun"
tuple(InvokeAIGenerator.schedulers())
]
@ -57,19 +58,24 @@ class TextToImageInvocation(BaseInvocation):
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
if self.model is None or self.model == "":
self.model = context.services.generate.model_name
factory = context.services.generator_factory
if self.model:
factory.model_name = self.model
else:
self.model = factory.model_name
# Set the model (if already cached, this does nothing)
context.services.generate.set_model(self.model)
txt2img = factory.make_generator(Txt2Img)
results = context.services.generate.prompt2image(
outputs = txt2img.generate(
prompt=self.prompt,
step_callback=step_callback,
**self.dict(
exclude={"prompt"}
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
# each time it is called. We only need the first one.
generate_output = next(outputs)
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
@ -78,7 +84,7 @@ class TextToImageInvocation(BaseInvocation):
image_name = context.services.images.create_name(
context.graph_execution_state_id, self.id
)
context.services.images.save(image_type, image_name, results[0][0])
context.services.images.save(image_type, image_name, generate_output.image)
return ImageOutput(
image=ImageField(image_type=image_type, image_name=image_name)
)
@ -115,23 +121,24 @@ class ImageToImageInvocation(TextToImageInvocation):
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
if self.model is None or self.model == "":
self.model = context.services.generate.model_name
factory = context.services.generator_factory
self.model = self.model or factory.model_name
factory.model_name = self.model
img2img = factory.make_generator(Img2Img)
# Set the model (if already cached, this does nothing)
context.services.generate.set_model(self.model)
results = context.services.generate.prompt2image(
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=step_callback,
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
generator_output = next(
img2img.generate(
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=step_callback,
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
)
result_image = results[0][0]
result_image = generator_output.image
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?
@ -145,7 +152,6 @@ class ImageToImageInvocation(TextToImageInvocation):
image=ImageField(image_type=image_type, image_name=image_name)
)
class InpaintInvocation(ImageToImageInvocation):
"""Generates an image using inpaint."""
@ -180,23 +186,24 @@ class InpaintInvocation(ImageToImageInvocation):
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
if self.model is None or self.model == "":
self.model = context.services.generate.model_name
factory = context.services.generator_factory
self.model = self.model or factory.model_name
factory.model_name = self.model
inpaint = factory.make_generator(Inpaint)
# Set the model (if already cached, this does nothing)
context.services.generate.set_model(self.model)
results = context.services.generate.prompt2image(
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=step_callback,
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
generator_output = next(
inpaint.generate(
prompt=self.prompt,
init_img=image,
init_mask=mask,
step_callback=step_callback,
**self.dict(
exclude={"prompt", "image", "mask"}
), # Shorthand for passing all of the parameters above manually
)
)
result_image = results[0][0]
result_image = generator_output.image
# Results are image and seed, unwrap for now and ignore the seed
# TODO: pre-seed?

View File

@ -1,16 +1,17 @@
import os
import sys
import torch
import traceback
from argparse import Namespace
from omegaconf import OmegaConf
import invokeai.version
from invokeai.backend import Generate, ModelManager
from ...backend import ModelManager, InvokeAIGeneratorBasicParams, InvokeAIGeneratorFactory
from ...backend.util import choose_precision, choose_torch_device
from ...backend import Globals
# TODO: most of this code should be split into individual services as the Generate.py code is deprecated
def get_generate(args, config) -> Generate:
def get_generator_factory(args, config) -> InvokeAIGeneratorFactory:
if not args.conf:
config_file = os.path.join(Globals.root, "configs", "models.yaml")
if not os.path.exists(config_file):
@ -63,49 +64,43 @@ def get_generate(args, config) -> Generate:
print(f"{e}. Aborting.")
sys.exit(-1)
# creating a Generate object:
# creating an InvokeAIGeneratorFactory object:
try:
gen = Generate(
conf=args.conf,
model=args.model,
sampler_name=args.sampler_name,
embedding_path=embedding_path,
full_precision=args.full_precision,
precision=args.precision,
gfpgan=gfpgan,
codeformer=codeformer,
esrgan=esrgan,
free_gpu_mem=args.free_gpu_mem,
safety_checker=args.safety_checker,
device = torch.device(choose_torch_device())
precision = 'float16' if args.precision=='float16' \
else 'float32' if args.precision=='float32' \
else choose_precision(device)
model_manager = ModelManager(
OmegaConf.load(args.conf),
precision=precision,
device_type=device,
max_loaded_models=args.max_loaded_models,
)
# TO DO: initialize and pass safety checker!!!
params = InvokeAIGeneratorBasicParams(
precision=precision,
)
factory = InvokeAIGeneratorFactory(model_manager, params)
except (FileNotFoundError, TypeError, AssertionError) as e:
report_model_error(opt, e)
report_model_error(args, e)
except (IOError, KeyError) as e:
print(f"{e}. Aborting.")
sys.exit(-1)
if args.seamless:
#TODO: do something here ?
print(">> changed to seamless tiling mode")
# preload the model
try:
gen.load_model()
except KeyError:
pass
except Exception as e:
report_model_error(args, e)
# try to autoconvert new models
# autoimport new .ckpt files
if path := args.autoconvert:
gen.model_manager.autoconvert_weights(
model_manager.autoconvert_weights(
conf_path=args.conf,
weights_directory=path,
)
return gen
return factory
def load_face_restoration(opt):
try:
@ -171,85 +166,3 @@ def report_model_error(opt: Namespace, e: Exception):
# sys.argv = previous_args
# main() # would rather do a os.exec(), but doesn't exist?
# sys.exit(0)
# Temporary initializer for Generate until we migrate off of it
def old_get_generate(args, config) -> Generate:
# TODO: Remove the need for globals
from invokeai.backend.globals import Globals
# alert - setting globals here
Globals.root = os.path.expanduser(
args.root_dir or os.environ.get("INVOKEAI_ROOT") or os.path.abspath(".")
)
Globals.try_patchmatch = args.patchmatch
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
# these two lines prevent a horrible warning message from appearing
# when the frozen CLIP tokenizer is imported
import transformers
transformers.logging.set_verbosity_error()
# Loading Face Restoration and ESRGAN Modules
gfpgan, codeformer, esrgan = None, None, None
try:
if config.restore or config.esrgan:
from ldm.invoke.restoration import Restoration
restoration = Restoration()
if config.restore:
gfpgan, codeformer = restoration.load_face_restore_models(
config.gfpgan_model_path
)
else:
print(">> Face restoration disabled")
if config.esrgan:
esrgan = restoration.load_esrgan(config.esrgan_bg_tile)
else:
print(">> Upscaling disabled")
else:
print(">> Face restoration and upscaling disabled")
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
# normalize the config directory relative to root
if not os.path.isabs(config.conf):
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
if config.embeddings:
if not os.path.isabs(config.embedding_path):
embedding_path = os.path.normpath(
os.path.join(Globals.root, config.embedding_path)
)
else:
embedding_path = None
# TODO: lazy-initialize this by wrapping it
try:
generate = Generate(
conf=config.conf,
model=config.model,
sampler_name=config.sampler_name,
embedding_path=embedding_path,
full_precision=config.full_precision,
precision=config.precision,
gfpgan=gfpgan,
codeformer=codeformer,
esrgan=esrgan,
free_gpu_mem=config.free_gpu_mem,
safety_checker=config.safety_checker,
max_loaded_models=config.max_loaded_models,
)
except (FileNotFoundError, TypeError, AssertionError):
# emergency_model_reconfigure() # TODO?
sys.exit(-1)
except (IOError, KeyError) as e:
print(f"{e}. Aborting.")
sys.exit(-1)
generate.free_gpu_mem = config.free_gpu_mem
return generate

View File

@ -1,5 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from invokeai.backend import Generate
from invokeai.backend import InvokeAIGeneratorFactory
from .events import EventServiceBase
from .image_storage import ImageStorageBase
@ -10,7 +10,7 @@ from .item_storage import ItemStorageABC
class InvocationServices:
"""Services that can be used by invocations"""
generate: Generate # TODO: wrap Generate, or split it up from model?
generator_factory: InvokeAIGeneratorFactory
events: EventServiceBase
images: ImageStorageBase
queue: InvocationQueueABC
@ -20,15 +20,15 @@ class InvocationServices:
processor: "InvocationProcessorABC"
def __init__(
self,
generate: Generate,
events: EventServiceBase,
images: ImageStorageBase,
queue: InvocationQueueABC,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
self,
generator_factory: InvokeAIGeneratorFactory,
events: EventServiceBase,
images: ImageStorageBase,
queue: InvocationQueueABC,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
):
self.generate = generate
self.generator_factory = generator_factory
self.events = events
self.images = images
self.queue = queue