remove factory pattern

Factory pattern is now removed. Typical usage of the InvokeAIGenerator is now:

```
from invokeai.backend.generator import (
    InvokeAIGeneratorBasicParams,
    Txt2Img,
    Img2Img,
    Inpaint,
)
    params = InvokeAIGeneratorBasicParams(
        model_name = 'stable-diffusion-1.5',
        steps = 30,
        scheduler = 'k_lms',
        cfg_scale = 8.0,
        height = 640,
        width = 640
        )
    print ('=== TXT2IMG TEST ===')
    txt2img = Txt2Img(manager, params)
    outputs = txt2img.generate(prompt='banana sushi', iterations=2)

    for i in outputs:
        print(f'image={output.image}, seed={output.seed}, model={output.params.model_name}, hash={output.model_hash}, steps={output.params.steps}')
```

The `params` argument is optional, so if you wish to accept default
parameters and selectively override them, just do this:

```
    outputs = Txt2Img(manager).generate(prompt='banana sushi',
                                        steps=50,
					scheduler='k_heun',
					model_name='stable-diffusion-2.1'
					)
```
This commit is contained in:
Lincoln Stein
2023-03-10 19:33:04 -05:00
parent c11e823ff3
commit 95954188b2
9 changed files with 44 additions and 104 deletions

View File

@ -4,7 +4,6 @@ Initialization file for invokeai.backend
from .generate import Generate
from .generator import (
InvokeAIGeneratorBasicParams,
InvokeAIGeneratorFactory,
InvokeAIGenerator,
InvokeAIGeneratorOutput,
Txt2Img,

View File

@ -2,7 +2,6 @@
Initialization file for the invokeai.generator package
"""
from .base import (
InvokeAIGeneratorFactory,
InvokeAIGenerator,
InvokeAIGeneratorBasicParams,
InvokeAIGeneratorOutput,

View File

@ -11,7 +11,8 @@ import diffusers
import os
import random
import traceback
from abc import ABCMeta, abstractmethod
from abc import ABCMeta
from argparse import Namespace
from contextlib import nullcontext
import cv2
@ -21,7 +22,7 @@ from PIL import Image, ImageChops, ImageFilter
from accelerate.utils import set_seed
from diffusers import DiffusionPipeline
from tqdm import trange
from typing import List, Type, Iterator
from typing import List, Iterator
from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler
@ -35,13 +36,13 @@ downsampling = 8
@dataclass
class InvokeAIGeneratorBasicParams:
model_name: str='stable-diffusion-1.5'
seed: int=None
width: int=512
height: int=512
cfg_scale: int=7.5
steps: int=20
ddim_eta: float=0.0
model_name: str='stable-diffusion-1.5'
scheduler: int='ddim'
precision: str='float16'
perlin: float=0.0
@ -62,41 +63,8 @@ class InvokeAIGeneratorOutput:
'''
image: Image
seed: int
model_name: str
model_hash: str
params: dict
def __getattribute__(self,name):
try:
return object.__getattribute__(self, name)
except AttributeError:
params = object.__getattribute__(self, 'params')
if name in params:
return params[name]
raise AttributeError(f"'{self.__class__.__name__}' has no attribute '{name}'")
class InvokeAIGeneratorFactory(object):
def __init__(self,
model_manager: ModelManager,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
):
self.model_manager = model_manager
self.params = params
def make_generator(self, generatorclass: Type[InvokeAIGenerator], **keyword_args)->InvokeAIGenerator:
return generatorclass(self.model_manager,
self.params,
**keyword_args
)
# getter and setter shortcuts for commonly used parameters
@property
def model_name(self)->str:
return self.params.model_name
@model_name.setter
def model_name(self, model_name: str):
self.params.model_name=model_name
params: Namespace
# we are interposing a wrapper around the original Generator classes so that
# old code that calls Generate will continue to work.
@ -116,7 +84,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def __init__(self,
model_manager: ModelManager,
params: InvokeAIGeneratorBasicParams,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
):
self.model_manager=model_manager
self.params=params
@ -149,23 +117,24 @@ class InvokeAIGenerator(metaclass=ABCMeta):
print(o.image, o.seed)
'''
model_name = self.params.model_name or self.model_manager.current_model
generator_args = dataclasses.asdict(self.params)
generator_args.update(keyword_args)
model_name = generator_args.get('model_name') or self.model_manager.current_model
model_info: dict = self.model_manager.get_model(model_name)
model:StableDiffusionGeneratorPipeline = model_info['model']
model_hash = model_info['hash']
scheduler: Scheduler = self.get_scheduler(
model=model,
scheduler_name=self.params.scheduler
scheduler_name=generator_args.get('scheduler')
)
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
generator = self.load_generator(model, self._generator_name())
if self.params.variation_amount > 0:
generator.set_variation(self.params.seed,
self.params.variation_amount,
self.params.with_variations)
generator_args = dataclasses.asdict(self.params)
generator_args.update(keyword_args)
generator.set_variation(generator_args.get('seed'),
generator_args.get('variation_amount'),
generator_args.get('with_variations')
)
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
for i in iteration_count:
@ -177,9 +146,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
output = InvokeAIGeneratorOutput(
image=results[0][0],
seed=results[0][1],
model_name = model_name,
model_hash = model_hash,
params=generator_args,
params=Namespace(**generator_args),
)
if callback:
callback(output)
@ -205,18 +173,19 @@ class InvokeAIGenerator(metaclass=ABCMeta):
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
return scheduler
@abstractmethod
def _generator_name(self)->str:
@classmethod
def _generator_name(cls):
'''
In derived classes will return the name of the generator to use.
In derived classes return the name of the generator to apply.
If you don't override will return the name of the derived
class, which nicely parallels the generator class names.
'''
pass
return cls.__name__
# ------------------------------------
class Txt2Img(InvokeAIGenerator):
def _generator_name(self)->str:
return 'Txt2Img'
pass
# ------------------------------------
class Img2Img(InvokeAIGenerator):
@ -230,9 +199,6 @@ class Img2Img(InvokeAIGenerator):
**keyword_args
)
def _generator_name(self)->str:
return 'Img2Img'
# ------------------------------------
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
class Inpaint(Img2Img):
@ -266,9 +232,6 @@ class Inpaint(Img2Img):
**keyword_args
)
def _generator_name(self)->str:
return 'Inpaint'
class Generator:
downsampling_factor: int
latent_channels: int

View File

@ -34,7 +34,7 @@ from picklescan.scanner import scan_file_path
from invokeai.backend.globals import Globals, global_cache_dir
from ..stable_diffusion import StableDiffusionGeneratorPipeline
from ..util import CUDA_DEVICE, ask_user, download_with_resume
from ..util import CUDA_DEVICE, CPU_DEVICE, ask_user, download_with_resume
class SDLegacyType(Enum):
V1 = 1