mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
node-based txt2img working without generate
This commit is contained in:
@ -4,7 +4,7 @@ import os
|
||||
from argparse import Namespace
|
||||
|
||||
from ...backend import Globals
|
||||
from ..services.generate_initializer import get_generate
|
||||
from ..services.generate_initializer import get_generator_factory
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.image_storage import DiskImageStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
@ -47,7 +47,7 @@ class ApiDependencies:
|
||||
# TODO: Use a logger
|
||||
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||
|
||||
generate = get_generate(args, config)
|
||||
generator_factory = get_generator_factory(args, config)
|
||||
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
|
||||
@ -61,7 +61,7 @@ class ApiDependencies:
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
|
||||
services = InvocationServices(
|
||||
generate=generate,
|
||||
generator_factory=generator_factory,
|
||||
events=events,
|
||||
images=images,
|
||||
queue=MemoryInvocationQueue(),
|
||||
|
@ -17,7 +17,7 @@ from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_gra
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.generate_initializer import get_generate
|
||||
from .services.generate_initializer import get_generator_factory
|
||||
from .services.graph import EdgeConnection, GraphExecutionState
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
@ -106,11 +106,7 @@ def invoke_cli():
|
||||
args = Args()
|
||||
config = args.parse_args()
|
||||
|
||||
generate = get_generate(args, config)
|
||||
|
||||
# NOTE: load model on first use, uncomment to load at startup
|
||||
# TODO: Make this a config option?
|
||||
# generate.load_model()
|
||||
generator_factory = get_generator_factory(args, config)
|
||||
|
||||
events = EventServiceBase()
|
||||
|
||||
@ -122,7 +118,7 @@ def invoke_cli():
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
|
||||
services = InvocationServices(
|
||||
generate=generate,
|
||||
generator_factory=generator_factory,
|
||||
events=events,
|
||||
images=DiskImageStorage(output_folder),
|
||||
queue=MemoryInvocationQueue(),
|
||||
|
@ -12,9 +12,10 @@ from ..services.image_storage import ImageType
|
||||
from ..services.invocation_services import InvocationServices
|
||||
from .baseinvocation import BaseInvocation, InvocationContext
|
||||
from .image import ImageField, ImageOutput
|
||||
from ...backend.generator import Txt2Img, Img2Img, Inpaint, InvokeAIGenerator
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[
|
||||
"ddim", "plms", "k_lms", "k_dpm_2", "k_dpm_2_a", "k_euler", "k_euler_a", "k_heun"
|
||||
tuple(InvokeAIGenerator.schedulers())
|
||||
]
|
||||
|
||||
|
||||
@ -57,19 +58,24 @@ class TextToImageInvocation(BaseInvocation):
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
if self.model is None or self.model == "":
|
||||
self.model = context.services.generate.model_name
|
||||
factory = context.services.generator_factory
|
||||
if self.model:
|
||||
factory.model_name = self.model
|
||||
else:
|
||||
self.model = factory.model_name
|
||||
|
||||
# Set the model (if already cached, this does nothing)
|
||||
context.services.generate.set_model(self.model)
|
||||
txt2img = factory.make_generator(Txt2Img)
|
||||
|
||||
results = context.services.generate.prompt2image(
|
||||
outputs = txt2img.generate(
|
||||
prompt=self.prompt,
|
||||
step_callback=step_callback,
|
||||
**self.dict(
|
||||
exclude={"prompt"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
|
||||
# each time it is called. We only need the first one.
|
||||
generate_output = next(outputs)
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
@ -78,7 +84,7 @@ class TextToImageInvocation(BaseInvocation):
|
||||
image_name = context.services.images.create_name(
|
||||
context.graph_execution_state_id, self.id
|
||||
)
|
||||
context.services.images.save(image_type, image_name, results[0][0])
|
||||
context.services.images.save(image_type, image_name, generate_output.image)
|
||||
return ImageOutput(
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
@ -115,23 +121,24 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
if self.model is None or self.model == "":
|
||||
self.model = context.services.generate.model_name
|
||||
factory = context.services.generator_factory
|
||||
self.model = self.model or factory.model_name
|
||||
factory.model_name = self.model
|
||||
img2img = factory.make_generator(Img2Img)
|
||||
|
||||
# Set the model (if already cached, this does nothing)
|
||||
context.services.generate.set_model(self.model)
|
||||
|
||||
results = context.services.generate.prompt2image(
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_mask=mask,
|
||||
step_callback=step_callback,
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
generator_output = next(
|
||||
img2img.generate(
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_mask=mask,
|
||||
step_callback=step_callback,
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
)
|
||||
|
||||
result_image = results[0][0]
|
||||
result_image = generator_output.image
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
@ -145,7 +152,6 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
image=ImageField(image_type=image_type, image_name=image_name)
|
||||
)
|
||||
|
||||
|
||||
class InpaintInvocation(ImageToImageInvocation):
|
||||
"""Generates an image using inpaint."""
|
||||
|
||||
@ -180,23 +186,24 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
if self.model is None or self.model == "":
|
||||
self.model = context.services.generate.model_name
|
||||
factory = context.services.generator_factory
|
||||
self.model = self.model or factory.model_name
|
||||
factory.model_name = self.model
|
||||
inpaint = factory.make_generator(Inpaint)
|
||||
|
||||
# Set the model (if already cached, this does nothing)
|
||||
context.services.generate.set_model(self.model)
|
||||
|
||||
results = context.services.generate.prompt2image(
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_mask=mask,
|
||||
step_callback=step_callback,
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
generator_output = next(
|
||||
inpaint.generate(
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_mask=mask,
|
||||
step_callback=step_callback,
|
||||
**self.dict(
|
||||
exclude={"prompt", "image", "mask"}
|
||||
), # Shorthand for passing all of the parameters above manually
|
||||
)
|
||||
)
|
||||
|
||||
result_image = results[0][0]
|
||||
result_image = generator_output.image
|
||||
|
||||
# Results are image and seed, unwrap for now and ignore the seed
|
||||
# TODO: pre-seed?
|
||||
|
@ -1,16 +1,17 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import traceback
|
||||
from argparse import Namespace
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.version
|
||||
from invokeai.backend import Generate, ModelManager
|
||||
|
||||
from ...backend import ModelManager, InvokeAIGeneratorBasicParams, InvokeAIGeneratorFactory
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from ...backend import Globals
|
||||
|
||||
|
||||
# TODO: most of this code should be split into individual services as the Generate.py code is deprecated
|
||||
def get_generate(args, config) -> Generate:
|
||||
def get_generator_factory(args, config) -> InvokeAIGeneratorFactory:
|
||||
if not args.conf:
|
||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||
if not os.path.exists(config_file):
|
||||
@ -63,49 +64,43 @@ def get_generate(args, config) -> Generate:
|
||||
print(f"{e}. Aborting.")
|
||||
sys.exit(-1)
|
||||
|
||||
# creating a Generate object:
|
||||
# creating an InvokeAIGeneratorFactory object:
|
||||
try:
|
||||
gen = Generate(
|
||||
conf=args.conf,
|
||||
model=args.model,
|
||||
sampler_name=args.sampler_name,
|
||||
embedding_path=embedding_path,
|
||||
full_precision=args.full_precision,
|
||||
precision=args.precision,
|
||||
gfpgan=gfpgan,
|
||||
codeformer=codeformer,
|
||||
esrgan=esrgan,
|
||||
free_gpu_mem=args.free_gpu_mem,
|
||||
safety_checker=args.safety_checker,
|
||||
device = torch.device(choose_torch_device())
|
||||
precision = 'float16' if args.precision=='float16' \
|
||||
else 'float32' if args.precision=='float32' \
|
||||
else choose_precision(device)
|
||||
|
||||
model_manager = ModelManager(
|
||||
OmegaConf.load(args.conf),
|
||||
precision=precision,
|
||||
device_type=device,
|
||||
max_loaded_models=args.max_loaded_models,
|
||||
)
|
||||
# TO DO: initialize and pass safety checker!!!
|
||||
params = InvokeAIGeneratorBasicParams(
|
||||
precision=precision,
|
||||
)
|
||||
factory = InvokeAIGeneratorFactory(model_manager, params)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(opt, e)
|
||||
report_model_error(args, e)
|
||||
except (IOError, KeyError) as e:
|
||||
print(f"{e}. Aborting.")
|
||||
sys.exit(-1)
|
||||
|
||||
if args.seamless:
|
||||
#TODO: do something here ?
|
||||
print(">> changed to seamless tiling mode")
|
||||
|
||||
# preload the model
|
||||
try:
|
||||
gen.load_model()
|
||||
except KeyError:
|
||||
pass
|
||||
except Exception as e:
|
||||
report_model_error(args, e)
|
||||
|
||||
# try to autoconvert new models
|
||||
# autoimport new .ckpt files
|
||||
if path := args.autoconvert:
|
||||
gen.model_manager.autoconvert_weights(
|
||||
model_manager.autoconvert_weights(
|
||||
conf_path=args.conf,
|
||||
weights_directory=path,
|
||||
)
|
||||
|
||||
return gen
|
||||
|
||||
return factory
|
||||
|
||||
def load_face_restoration(opt):
|
||||
try:
|
||||
@ -171,85 +166,3 @@ def report_model_error(opt: Namespace, e: Exception):
|
||||
# sys.argv = previous_args
|
||||
# main() # would rather do a os.exec(), but doesn't exist?
|
||||
# sys.exit(0)
|
||||
|
||||
|
||||
# Temporary initializer for Generate until we migrate off of it
|
||||
def old_get_generate(args, config) -> Generate:
|
||||
# TODO: Remove the need for globals
|
||||
from invokeai.backend.globals import Globals
|
||||
|
||||
# alert - setting globals here
|
||||
Globals.root = os.path.expanduser(
|
||||
args.root_dir or os.environ.get("INVOKEAI_ROOT") or os.path.abspath(".")
|
||||
)
|
||||
Globals.try_patchmatch = args.patchmatch
|
||||
|
||||
print(f'>> InvokeAI runtime directory is "{Globals.root}"')
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
# when the frozen CLIP tokenizer is imported
|
||||
import transformers
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
# Loading Face Restoration and ESRGAN Modules
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
try:
|
||||
if config.restore or config.esrgan:
|
||||
from ldm.invoke.restoration import Restoration
|
||||
|
||||
restoration = Restoration()
|
||||
if config.restore:
|
||||
gfpgan, codeformer = restoration.load_face_restore_models(
|
||||
config.gfpgan_model_path
|
||||
)
|
||||
else:
|
||||
print(">> Face restoration disabled")
|
||||
if config.esrgan:
|
||||
esrgan = restoration.load_esrgan(config.esrgan_bg_tile)
|
||||
else:
|
||||
print(">> Upscaling disabled")
|
||||
else:
|
||||
print(">> Face restoration and upscaling disabled")
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
||||
|
||||
# normalize the config directory relative to root
|
||||
if not os.path.isabs(config.conf):
|
||||
config.conf = os.path.normpath(os.path.join(Globals.root, config.conf))
|
||||
|
||||
if config.embeddings:
|
||||
if not os.path.isabs(config.embedding_path):
|
||||
embedding_path = os.path.normpath(
|
||||
os.path.join(Globals.root, config.embedding_path)
|
||||
)
|
||||
else:
|
||||
embedding_path = None
|
||||
|
||||
# TODO: lazy-initialize this by wrapping it
|
||||
try:
|
||||
generate = Generate(
|
||||
conf=config.conf,
|
||||
model=config.model,
|
||||
sampler_name=config.sampler_name,
|
||||
embedding_path=embedding_path,
|
||||
full_precision=config.full_precision,
|
||||
precision=config.precision,
|
||||
gfpgan=gfpgan,
|
||||
codeformer=codeformer,
|
||||
esrgan=esrgan,
|
||||
free_gpu_mem=config.free_gpu_mem,
|
||||
safety_checker=config.safety_checker,
|
||||
max_loaded_models=config.max_loaded_models,
|
||||
)
|
||||
except (FileNotFoundError, TypeError, AssertionError):
|
||||
# emergency_model_reconfigure() # TODO?
|
||||
sys.exit(-1)
|
||||
except (IOError, KeyError) as e:
|
||||
print(f"{e}. Aborting.")
|
||||
sys.exit(-1)
|
||||
|
||||
generate.free_gpu_mem = config.free_gpu_mem
|
||||
|
||||
return generate
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
from invokeai.backend import Generate
|
||||
from invokeai.backend import InvokeAIGeneratorFactory
|
||||
|
||||
from .events import EventServiceBase
|
||||
from .image_storage import ImageStorageBase
|
||||
@ -10,7 +10,7 @@ from .item_storage import ItemStorageABC
|
||||
class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
|
||||
generate: Generate # TODO: wrap Generate, or split it up from model?
|
||||
generator_factory: InvokeAIGeneratorFactory
|
||||
events: EventServiceBase
|
||||
images: ImageStorageBase
|
||||
queue: InvocationQueueABC
|
||||
@ -20,15 +20,15 @@ class InvocationServices:
|
||||
processor: "InvocationProcessorABC"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
generate: Generate,
|
||||
events: EventServiceBase,
|
||||
images: ImageStorageBase,
|
||||
queue: InvocationQueueABC,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
self,
|
||||
generator_factory: InvokeAIGeneratorFactory,
|
||||
events: EventServiceBase,
|
||||
images: ImageStorageBase,
|
||||
queue: InvocationQueueABC,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
):
|
||||
self.generate = generate
|
||||
self.generator_factory = generator_factory
|
||||
self.events = events
|
||||
self.images = images
|
||||
self.queue = queue
|
||||
|
Reference in New Issue
Block a user