remove factory pattern

Factory pattern is now removed. Typical usage of the InvokeAIGenerator is now:

```
from invokeai.backend.generator import (
    InvokeAIGeneratorBasicParams,
    Txt2Img,
    Img2Img,
    Inpaint,
)
    params = InvokeAIGeneratorBasicParams(
        model_name = 'stable-diffusion-1.5',
        steps = 30,
        scheduler = 'k_lms',
        cfg_scale = 8.0,
        height = 640,
        width = 640
        )
    print ('=== TXT2IMG TEST ===')
    txt2img = Txt2Img(manager, params)
    outputs = txt2img.generate(prompt='banana sushi', iterations=2)

    for i in outputs:
        print(f'image={output.image}, seed={output.seed}, model={output.params.model_name}, hash={output.model_hash}, steps={output.params.steps}')
```

The `params` argument is optional, so if you wish to accept default
parameters and selectively override them, just do this:

```
    outputs = Txt2Img(manager).generate(prompt='banana sushi',
                                        steps=50,
					scheduler='k_heun',
					model_name='stable-diffusion-2.1'
					)
```
This commit is contained in:
Lincoln Stein 2023-03-10 19:33:04 -05:00
parent c11e823ff3
commit 95954188b2
9 changed files with 44 additions and 104 deletions

View File

@ -4,7 +4,7 @@ import os
from argparse import Namespace
from ...backend import Globals
from ..services.generate_initializer import get_generator_factory
from ..services.generate_initializer import get_model_manager
from ..services.graph import GraphExecutionState
from ..services.image_storage import DiskImageStorage
from ..services.invocation_queue import MemoryInvocationQueue
@ -47,7 +47,7 @@ class ApiDependencies:
# TODO: Use a logger
print(f">> Internet connectivity is {Globals.internet_available}")
generator_factory = get_generator_factory(args, config)
model_manager = get_model_manager(args, config)
events = FastAPIEventService(event_handler_id)

View File

@ -17,7 +17,7 @@ from .cli.commands import BaseCommand, CliContext, ExitCli, add_parsers, get_gra
from .invocations import *
from .invocations.baseinvocation import BaseInvocation
from .services.events import EventServiceBase
from .services.generate_initializer import get_generator_factory
from .services.generate_initializer import get_model_manager
from .services.graph import EdgeConnection, GraphExecutionState
from .services.image_storage import DiskImageStorage
from .services.invocation_queue import MemoryInvocationQueue
@ -129,7 +129,7 @@ def invoke_cli():
args = Args()
config = args.parse_args()
generator_factory = get_generator_factory(args, config)
model_manager = get_model_manager(args, config)
events = EventServiceBase()
@ -141,7 +141,7 @@ def invoke_cli():
db_location = os.path.join(output_folder, "invokeai.db")
services = InvocationServices(
generator_factory=generator_factory,
model_manager=model_manager,
events=events,
images=DiskImageStorage(output_folder),
queue=MemoryInvocationQueue(),

View File

@ -18,7 +18,6 @@ SAMPLER_NAME_VALUES = Literal[
tuple(InvokeAIGenerator.schedulers())
]
# Text to image
class TextToImageInvocation(BaseInvocation):
"""Generates an image using text2img."""
@ -58,15 +57,8 @@ class TextToImageInvocation(BaseInvocation):
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
factory = context.services.generator_factory
if self.model:
factory.model_name = self.model
else:
self.model = factory.model_name
txt2img = factory.make_generator(Txt2Img)
outputs = txt2img.generate(
manager = context.services.model_manager
outputs = Txt2Img(manager).generate(
prompt=self.prompt,
step_callback=step_callback,
**self.dict(
@ -121,13 +113,9 @@ class ImageToImageInvocation(TextToImageInvocation):
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
factory = context.services.generator_factory
self.model = self.model or factory.model_name
factory.model_name = self.model
img2img = factory.make_generator(Img2Img)
manager = context.services.model_manager
generator_output = next(
img2img.generate(
Img2Img(manager).generate(
prompt=self.prompt,
init_img=image,
init_mask=mask,
@ -186,13 +174,9 @@ class InpaintInvocation(ImageToImageInvocation):
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
factory = context.services.generator_factory
self.model = self.model or factory.model_name
factory.model_name = self.model
inpaint = factory.make_generator(Inpaint)
manager = context.services.model_manager
generator_output = next(
inpaint.generate(
Inpaint(manager).generate(
prompt=self.prompt,
init_img=image,
init_mask=mask,

View File

@ -6,12 +6,12 @@ from argparse import Namespace
from omegaconf import OmegaConf
import invokeai.version
from ...backend import ModelManager, InvokeAIGeneratorBasicParams, InvokeAIGeneratorFactory
from ...backend import ModelManager
from ...backend.util import choose_precision, choose_torch_device
from ...backend import Globals
# TODO: most of this code should be split into individual services as the Generate.py code is deprecated
def get_generator_factory(args, config) -> InvokeAIGeneratorFactory:
def get_model_manager(args, config) -> ModelManager:
if not args.conf:
config_file = os.path.join(Globals.root, "configs", "models.yaml")
if not os.path.exists(config_file):
@ -64,7 +64,7 @@ def get_generator_factory(args, config) -> InvokeAIGeneratorFactory:
print(f"{e}. Aborting.")
sys.exit(-1)
# creating an InvokeAIGeneratorFactory object:
# creating the model manager
try:
device = torch.device(choose_torch_device())
precision = 'float16' if args.precision=='float16' \
@ -77,11 +77,6 @@ def get_generator_factory(args, config) -> InvokeAIGeneratorFactory:
device_type=device,
max_loaded_models=args.max_loaded_models,
)
# TO DO: initialize and pass safety checker!!!
params = InvokeAIGeneratorBasicParams(
precision=precision,
)
factory = InvokeAIGeneratorFactory(model_manager, params)
except (FileNotFoundError, TypeError, AssertionError) as e:
report_model_error(args, e)
except (IOError, KeyError) as e:
@ -100,7 +95,7 @@ def get_generator_factory(args, config) -> InvokeAIGeneratorFactory:
weights_directory=path,
)
return factory
return model_manager
def load_face_restoration(opt):
try:

View File

@ -1,5 +1,5 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654)
from invokeai.backend import InvokeAIGeneratorFactory
from invokeai.backend import ModelManager
from .events import EventServiceBase
from .image_storage import ImageStorageBase
@ -10,7 +10,7 @@ from .item_storage import ItemStorageABC
class InvocationServices:
"""Services that can be used by invocations"""
generator_factory: InvokeAIGeneratorFactory
model_manager: ModelManager
events: EventServiceBase
images: ImageStorageBase
queue: InvocationQueueABC
@ -21,14 +21,14 @@ class InvocationServices:
def __init__(
self,
generator_factory: InvokeAIGeneratorFactory,
model_manager: ModelManager,
events: EventServiceBase,
images: ImageStorageBase,
queue: InvocationQueueABC,
graph_execution_manager: ItemStorageABC["GraphExecutionState"],
processor: "InvocationProcessorABC",
):
self.generator_factory = generator_factory
self.model_manager = model_manager
self.events = events
self.images = images
self.queue = queue

View File

@ -4,7 +4,6 @@ Initialization file for invokeai.backend
from .generate import Generate
from .generator import (
InvokeAIGeneratorBasicParams,
InvokeAIGeneratorFactory,
InvokeAIGenerator,
InvokeAIGeneratorOutput,
Txt2Img,

View File

@ -2,7 +2,6 @@
Initialization file for the invokeai.generator package
"""
from .base import (
InvokeAIGeneratorFactory,
InvokeAIGenerator,
InvokeAIGeneratorBasicParams,
InvokeAIGeneratorOutput,

View File

@ -11,7 +11,8 @@ import diffusers
import os
import random
import traceback
from abc import ABCMeta, abstractmethod
from abc import ABCMeta
from argparse import Namespace
from contextlib import nullcontext
import cv2
@ -21,7 +22,7 @@ from PIL import Image, ImageChops, ImageFilter
from accelerate.utils import set_seed
from diffusers import DiffusionPipeline
from tqdm import trange
from typing import List, Type, Iterator
from typing import List, Iterator
from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler
@ -35,13 +36,13 @@ downsampling = 8
@dataclass
class InvokeAIGeneratorBasicParams:
model_name: str='stable-diffusion-1.5'
seed: int=None
width: int=512
height: int=512
cfg_scale: int=7.5
steps: int=20
ddim_eta: float=0.0
model_name: str='stable-diffusion-1.5'
scheduler: int='ddim'
precision: str='float16'
perlin: float=0.0
@ -62,41 +63,8 @@ class InvokeAIGeneratorOutput:
'''
image: Image
seed: int
model_name: str
model_hash: str
params: dict
def __getattribute__(self,name):
try:
return object.__getattribute__(self, name)
except AttributeError:
params = object.__getattribute__(self, 'params')
if name in params:
return params[name]
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
class InvokeAIGeneratorFactory(object):
def __init__(self,
model_manager: ModelManager,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
):
self.model_manager = model_manager
self.params = params
def make_generator(self, generatorclass: Type[InvokeAIGenerator], **keyword_args)->InvokeAIGenerator:
return generatorclass(self.model_manager,
self.params,
**keyword_args
)
# getter and setter shortcuts for commonly used parameters
@property
def model_name(self)->str:
return self.params.model_name
@model_name.setter
def model_name(self, model_name: str):
self.params.model_name=model_name
params: Namespace
# we are interposing a wrapper around the original Generator classes so that
# old code that calls Generate will continue to work.
@ -116,7 +84,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def __init__(self,
model_manager: ModelManager,
params: InvokeAIGeneratorBasicParams,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
):
self.model_manager=model_manager
self.params=params
@ -149,23 +117,24 @@ class InvokeAIGenerator(metaclass=ABCMeta):
print(o.image, o.seed)
'''
model_name = self.params.model_name or self.model_manager.current_model
generator_args = dataclasses.asdict(self.params)
generator_args.update(keyword_args)
model_name = generator_args.get('model_name') or self.model_manager.current_model
model_info: dict = self.model_manager.get_model(model_name)
model:StableDiffusionGeneratorPipeline = model_info['model']
model_hash = model_info['hash']
scheduler: Scheduler = self.get_scheduler(
model=model,
scheduler_name=self.params.scheduler
scheduler_name=generator_args.get('scheduler')
)
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
generator = self.load_generator(model, self._generator_name())
if self.params.variation_amount > 0:
generator.set_variation(self.params.seed,
self.params.variation_amount,
self.params.with_variations)
generator_args = dataclasses.asdict(self.params)
generator_args.update(keyword_args)
generator.set_variation(generator_args.get('seed'),
generator_args.get('variation_amount'),
generator_args.get('with_variations')
)
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
for i in iteration_count:
@ -177,9 +146,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
output = InvokeAIGeneratorOutput(
image=results[0][0],
seed=results[0][1],
model_name = model_name,
model_hash = model_hash,
params=generator_args,
params=Namespace(**generator_args),
)
if callback:
callback(output)
@ -205,18 +173,19 @@ class InvokeAIGenerator(metaclass=ABCMeta):
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
return scheduler
@abstractmethod
def _generator_name(self)->str:
@classmethod
def _generator_name(cls):
'''
In derived classes will return the name of the generator to use.
In derived classes return the name of the generator to apply.
If you don't override will return the name of the derived
class, which nicely parallels the generator class names.
'''
pass
return cls.__name__
# ------------------------------------
class Txt2Img(InvokeAIGenerator):
def _generator_name(self)->str:
return 'Txt2Img'
pass
# ------------------------------------
class Img2Img(InvokeAIGenerator):
@ -230,9 +199,6 @@ class Img2Img(InvokeAIGenerator):
**keyword_args
)
def _generator_name(self)->str:
return 'Img2Img'
# ------------------------------------
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
class Inpaint(Img2Img):
@ -266,9 +232,6 @@ class Inpaint(Img2Img):
**keyword_args
)
def _generator_name(self)->str:
return 'Inpaint'
class Generator:
downsampling_factor: int
latent_channels: int

View File

@ -34,7 +34,7 @@ from picklescan.scanner import scan_file_path
from invokeai.backend.globals import Globals, global_cache_dir
from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ..util import CUDA_DEVICE, ask_user, download_with_resume
from ..util import CUDA_DEVICE, CPU_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum):
V1 = 1