mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
remove factory pattern
Factory pattern is now removed. Typical usage of the InvokeAIGenerator is now: ``` from invokeai.backend.generator import ( InvokeAIGeneratorBasicParams, Txt2Img, Img2Img, Inpaint, ) params = InvokeAIGeneratorBasicParams( model_name = 'stable-diffusion-1.5', steps = 30, scheduler = 'k_lms', cfg_scale = 8.0, height = 640, width = 640 ) print ('=== TXT2IMG TEST ===') txt2img = Txt2Img(manager, params) outputs = txt2img.generate(prompt='banana sushi', iterations=2) for i in outputs: print(f'image={output.image}, seed={output.seed}, model={output.params.model_name}, hash={output.model_hash}, steps={output.params.steps}') ``` The `params` argument is optional, so if you wish to accept default parameters and selectively override them, just do this: ``` outputs = Txt2Img(manager).generate(prompt='banana sushi', steps=50, scheduler='k_heun', model_name='stable-diffusion-2.1' ) ```
This commit is contained in:
parent
c11e823ff3
commit
95954188b2
@ -4,7 +4,7 @@ import os
|
||||
from argparse import Namespace
|
||||
|
||||
from ...backend import Globals
|
||||
from ..services.generate_initializer import get_generator_factory
|
||||
from ..services.generate_initializer import get_model_manager
|
||||
from ..services.graph import GraphExecutionState
|
||||
from ..services.image_storage import DiskImageStorage
|
||||
from ..services.invocation_queue import MemoryInvocationQueue
|
||||
@ -47,7 +47,7 @@ class ApiDependencies:
|
||||
# TODO: Use a logger
|
||||
print(f">> Internet connectivity is {Globals.internet_available}")
|
||||
|
||||
generator_factory = get_generator_factory(args, config)
|
||||
model_manager = get_model_manager(args, config)
|
||||
|
||||
events = FastAPIEventService(event_handler_id)
|
||||
|
||||
|
@ -17,7 +17,7 @@ from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_gra
|
||||
from .invocations import *
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.events import EventServiceBase
|
||||
from .services.generate_initializer import get_generator_factory
|
||||
from .services.generate_initializer import get_model_manager
|
||||
from .services.graph import EdgeConnection, GraphExecutionState
|
||||
from .services.image_storage import DiskImageStorage
|
||||
from .services.invocation_queue import MemoryInvocationQueue
|
||||
@ -129,7 +129,7 @@ def invoke_cli():
|
||||
args = Args()
|
||||
config = args.parse_args()
|
||||
|
||||
generator_factory = get_generator_factory(args, config)
|
||||
model_manager = get_model_manager(args, config)
|
||||
|
||||
events = EventServiceBase()
|
||||
|
||||
@ -141,7 +141,7 @@ def invoke_cli():
|
||||
db_location = os.path.join(output_folder, "invokeai.db")
|
||||
|
||||
services = InvocationServices(
|
||||
generator_factory=generator_factory,
|
||||
model_manager=model_manager,
|
||||
events=events,
|
||||
images=DiskImageStorage(output_folder),
|
||||
queue=MemoryInvocationQueue(),
|
||||
|
@ -18,7 +18,6 @@ SAMPLER_NAME_VALUES = Literal[
|
||||
tuple(InvokeAIGenerator.schedulers())
|
||||
]
|
||||
|
||||
|
||||
# Text to image
|
||||
class TextToImageInvocation(BaseInvocation):
|
||||
"""Generates an image using text2img."""
|
||||
@ -58,15 +57,8 @@ class TextToImageInvocation(BaseInvocation):
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
factory = context.services.generator_factory
|
||||
if self.model:
|
||||
factory.model_name = self.model
|
||||
else:
|
||||
self.model = factory.model_name
|
||||
|
||||
txt2img = factory.make_generator(Txt2Img)
|
||||
|
||||
outputs = txt2img.generate(
|
||||
manager = context.services.model_manager
|
||||
outputs = Txt2Img(manager).generate(
|
||||
prompt=self.prompt,
|
||||
step_callback=step_callback,
|
||||
**self.dict(
|
||||
@ -121,13 +113,9 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
factory = context.services.generator_factory
|
||||
self.model = self.model or factory.model_name
|
||||
factory.model_name = self.model
|
||||
img2img = factory.make_generator(Img2Img)
|
||||
|
||||
manager = context.services.model_manager
|
||||
generator_output = next(
|
||||
img2img.generate(
|
||||
Img2Img(manager).generate(
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_mask=mask,
|
||||
@ -186,13 +174,9 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
factory = context.services.generator_factory
|
||||
self.model = self.model or factory.model_name
|
||||
factory.model_name = self.model
|
||||
inpaint = factory.make_generator(Inpaint)
|
||||
|
||||
manager = context.services.model_manager
|
||||
generator_output = next(
|
||||
inpaint.generate(
|
||||
Inpaint(manager).generate(
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_mask=mask,
|
||||
|
@ -6,12 +6,12 @@ from argparse import Namespace
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import invokeai.version
|
||||
from ...backend import ModelManager, InvokeAIGeneratorBasicParams, InvokeAIGeneratorFactory
|
||||
from ...backend import ModelManager
|
||||
from ...backend.util import choose_precision, choose_torch_device
|
||||
from ...backend import Globals
|
||||
|
||||
# TODO: most of this code should be split into individual services as the Generate.py code is deprecated
|
||||
def get_generator_factory(args, config) -> InvokeAIGeneratorFactory:
|
||||
def get_model_manager(args, config) -> ModelManager:
|
||||
if not args.conf:
|
||||
config_file = os.path.join(Globals.root, "configs", "models.yaml")
|
||||
if not os.path.exists(config_file):
|
||||
@ -64,7 +64,7 @@ def get_generator_factory(args, config) -> InvokeAIGeneratorFactory:
|
||||
print(f"{e}. Aborting.")
|
||||
sys.exit(-1)
|
||||
|
||||
# creating an InvokeAIGeneratorFactory object:
|
||||
# creating the model manager
|
||||
try:
|
||||
device = torch.device(choose_torch_device())
|
||||
precision = 'float16' if args.precision=='float16' \
|
||||
@ -77,11 +77,6 @@ def get_generator_factory(args, config) -> InvokeAIGeneratorFactory:
|
||||
device_type=device,
|
||||
max_loaded_models=args.max_loaded_models,
|
||||
)
|
||||
# TO DO: initialize and pass safety checker!!!
|
||||
params = InvokeAIGeneratorBasicParams(
|
||||
precision=precision,
|
||||
)
|
||||
factory = InvokeAIGeneratorFactory(model_manager, params)
|
||||
except (FileNotFoundError, TypeError, AssertionError) as e:
|
||||
report_model_error(args, e)
|
||||
except (IOError, KeyError) as e:
|
||||
@ -100,7 +95,7 @@ def get_generator_factory(args, config) -> InvokeAIGeneratorFactory:
|
||||
weights_directory=path,
|
||||
)
|
||||
|
||||
return factory
|
||||
return model_manager
|
||||
|
||||
def load_face_restoration(opt):
|
||||
try:
|
||||
|
@ -1,5 +1,5 @@
|
||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
|
||||
from invokeai.backend import InvokeAIGeneratorFactory
|
||||
from invokeai.backend import ModelManager
|
||||
|
||||
from .events import EventServiceBase
|
||||
from .image_storage import ImageStorageBase
|
||||
@ -10,7 +10,7 @@ from .item_storage import ItemStorageABC
|
||||
class InvocationServices:
|
||||
"""Services that can be used by invocations"""
|
||||
|
||||
generator_factory: InvokeAIGeneratorFactory
|
||||
model_manager: ModelManager
|
||||
events: EventServiceBase
|
||||
images: ImageStorageBase
|
||||
queue: InvocationQueueABC
|
||||
@ -21,14 +21,14 @@ class InvocationServices:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
generator_factory: InvokeAIGeneratorFactory,
|
||||
model_manager: ModelManager,
|
||||
events: EventServiceBase,
|
||||
images: ImageStorageBase,
|
||||
queue: InvocationQueueABC,
|
||||
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
|
||||
processor: "InvocationProcessorABC",
|
||||
):
|
||||
self.generator_factory = generator_factory
|
||||
self.model_manager = model_manager
|
||||
self.events = events
|
||||
self.images = images
|
||||
self.queue = queue
|
||||
|
@ -4,7 +4,6 @@ Initialization file for invokeai.backend
|
||||
from .generate import Generate
|
||||
from .generator import (
|
||||
InvokeAIGeneratorBasicParams,
|
||||
InvokeAIGeneratorFactory,
|
||||
InvokeAIGenerator,
|
||||
InvokeAIGeneratorOutput,
|
||||
Txt2Img,
|
||||
|
@ -2,7 +2,6 @@
|
||||
Initialization file for the invokeai.generator package
|
||||
"""
|
||||
from .base import (
|
||||
InvokeAIGeneratorFactory,
|
||||
InvokeAIGenerator,
|
||||
InvokeAIGeneratorBasicParams,
|
||||
InvokeAIGeneratorOutput,
|
||||
|
@ -11,7 +11,8 @@ import diffusers
|
||||
import os
|
||||
import random
|
||||
import traceback
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from abc import ABCMeta
|
||||
from argparse import Namespace
|
||||
from contextlib import nullcontext
|
||||
|
||||
import cv2
|
||||
@ -21,7 +22,7 @@ from PIL import Image, ImageChops, ImageFilter
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import DiffusionPipeline
|
||||
from tqdm import trange
|
||||
from typing import List, Type, Iterator
|
||||
from typing import List, Iterator
|
||||
from dataclasses import dataclass, field
|
||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
|
||||
@ -35,13 +36,13 @@ downsampling = 8
|
||||
|
||||
@dataclass
|
||||
class InvokeAIGeneratorBasicParams:
|
||||
model_name: str='stable-diffusion-1.5'
|
||||
seed: int=None
|
||||
width: int=512
|
||||
height: int=512
|
||||
cfg_scale: int=7.5
|
||||
steps: int=20
|
||||
ddim_eta: float=0.0
|
||||
model_name: str='stable-diffusion-1.5'
|
||||
scheduler: int='ddim'
|
||||
precision: str='float16'
|
||||
perlin: float=0.0
|
||||
@ -62,41 +63,8 @@ class InvokeAIGeneratorOutput:
|
||||
'''
|
||||
image: Image
|
||||
seed: int
|
||||
model_name: str
|
||||
model_hash: str
|
||||
params: dict
|
||||
|
||||
def __getattribute__(self,name):
|
||||
try:
|
||||
return object.__getattribute__(self, name)
|
||||
except AttributeError:
|
||||
params = object.__getattribute__(self, 'params')
|
||||
if name in params:
|
||||
return params[name]
|
||||
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
|
||||
|
||||
class InvokeAIGeneratorFactory(object):
|
||||
def __init__(self,
|
||||
model_manager: ModelManager,
|
||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||
):
|
||||
self.model_manager = model_manager
|
||||
self.params = params
|
||||
|
||||
def make_generator(self, generatorclass: Type[InvokeAIGenerator], **keyword_args)->InvokeAIGenerator:
|
||||
return generatorclass(self.model_manager,
|
||||
self.params,
|
||||
**keyword_args
|
||||
)
|
||||
|
||||
# getter and setter shortcuts for commonly used parameters
|
||||
@property
|
||||
def model_name(self)->str:
|
||||
return self.params.model_name
|
||||
|
||||
@model_name.setter
|
||||
def model_name(self, model_name: str):
|
||||
self.params.model_name=model_name
|
||||
params: Namespace
|
||||
|
||||
# we are interposing a wrapper around the original Generator classes so that
|
||||
# old code that calls Generate will continue to work.
|
||||
@ -116,7 +84,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
|
||||
def __init__(self,
|
||||
model_manager: ModelManager,
|
||||
params: InvokeAIGeneratorBasicParams,
|
||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||
):
|
||||
self.model_manager=model_manager
|
||||
self.params=params
|
||||
@ -149,23 +117,24 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
print(o.image, o.seed)
|
||||
|
||||
'''
|
||||
model_name = self.params.model_name or self.model_manager.current_model
|
||||
generator_args = dataclasses.asdict(self.params)
|
||||
generator_args.update(keyword_args)
|
||||
|
||||
model_name = generator_args.get('model_name') or self.model_manager.current_model
|
||||
model_info: dict = self.model_manager.get_model(model_name)
|
||||
model:StableDiffusionGeneratorPipeline = model_info['model']
|
||||
model_hash = model_info['hash']
|
||||
scheduler: Scheduler = self.get_scheduler(
|
||||
model=model,
|
||||
scheduler_name=self.params.scheduler
|
||||
scheduler_name=generator_args.get('scheduler')
|
||||
)
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
||||
generator = self.load_generator(model, self._generator_name())
|
||||
if self.params.variation_amount > 0:
|
||||
generator.set_variation(self.params.seed,
|
||||
self.params.variation_amount,
|
||||
self.params.with_variations)
|
||||
|
||||
generator_args = dataclasses.asdict(self.params)
|
||||
generator_args.update(keyword_args)
|
||||
generator.set_variation(generator_args.get('seed'),
|
||||
generator_args.get('variation_amount'),
|
||||
generator_args.get('with_variations')
|
||||
)
|
||||
|
||||
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
|
||||
for i in iteration_count:
|
||||
@ -177,9 +146,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
output = InvokeAIGeneratorOutput(
|
||||
image=results[0][0],
|
||||
seed=results[0][1],
|
||||
model_name = model_name,
|
||||
model_hash = model_hash,
|
||||
params=generator_args,
|
||||
params=Namespace(**generator_args),
|
||||
)
|
||||
if callback:
|
||||
callback(output)
|
||||
@ -205,18 +173,19 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||
scheduler.uses_inpainting_model = lambda: False
|
||||
return scheduler
|
||||
|
||||
@abstractmethod
|
||||
def _generator_name(self)->str:
|
||||
|
||||
@classmethod
|
||||
def _generator_name(cls):
|
||||
'''
|
||||
In derived classes will return the name of the generator to use.
|
||||
In derived classes return the name of the generator to apply.
|
||||
If you don't override will return the name of the derived
|
||||
class, which nicely parallels the generator class names.
|
||||
'''
|
||||
pass
|
||||
return cls.__name__
|
||||
|
||||
# ------------------------------------
|
||||
class Txt2Img(InvokeAIGenerator):
|
||||
def _generator_name(self)->str:
|
||||
return 'Txt2Img'
|
||||
pass
|
||||
|
||||
# ------------------------------------
|
||||
class Img2Img(InvokeAIGenerator):
|
||||
@ -230,9 +199,6 @@ class Img2Img(InvokeAIGenerator):
|
||||
**keyword_args
|
||||
)
|
||||
|
||||
def _generator_name(self)->str:
|
||||
return 'Img2Img'
|
||||
|
||||
# ------------------------------------
|
||||
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
||||
class Inpaint(Img2Img):
|
||||
@ -266,9 +232,6 @@ class Inpaint(Img2Img):
|
||||
**keyword_args
|
||||
)
|
||||
|
||||
def _generator_name(self)->str:
|
||||
return 'Inpaint'
|
||||
|
||||
class Generator:
|
||||
downsampling_factor: int
|
||||
latent_channels: int
|
||||
|
@ -34,7 +34,7 @@ from picklescan.scanner import scan_file_path
|
||||
from invokeai.backend.globals import Globals, global_cache_dir
|
||||
|
||||
from ..stable_diffusion import StableDiffusionGeneratorPipeline
|
||||
from ..util import CUDA_DEVICE, ask_user, download_with_resume
|
||||
from ..util import CUDA_DEVICE, CPU_DEVICE, ask_user, download_with_resume
|
||||
|
||||
class SDLegacyType(Enum):
|
||||
V1 = 1
|
||||
|
Loading…
Reference in New Issue
Block a user