mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add InvokeAIGenerator and InvokeAIGeneratorFactory classes
This commit is contained in:
parent
d232a439f7
commit
87789c1de8
@ -1,5 +1,14 @@
|
|||||||
"""
|
"""
|
||||||
Initialization file for the invokeai.generator package
|
Initialization file for the invokeai.generator package
|
||||||
"""
|
"""
|
||||||
from .base import Generator
|
from .base import (
|
||||||
|
InvokeAIGeneratorFactory,
|
||||||
|
InvokeAIGenerator,
|
||||||
|
InvokeAIGeneratorBasicParams,
|
||||||
|
InvokeAIGeneratorOutput,
|
||||||
|
Txt2Img,
|
||||||
|
Img2Img,
|
||||||
|
Inpaint,
|
||||||
|
Generator,
|
||||||
|
)
|
||||||
from .inpaint import infill_methods
|
from .inpaint import infill_methods
|
||||||
|
@ -4,9 +4,14 @@ including img2img, txt2img, and inpaint
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import importlib
|
||||||
|
import dataclasses
|
||||||
|
import diffusers
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import traceback
|
import traceback
|
||||||
|
from abc import ABCMeta, abstractmethod
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -17,13 +22,204 @@ from PIL import Image, ImageChops, ImageFilter
|
|||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
from typing import List, Type, Callable
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
|
||||||
import invokeai.assets.web as web_assets
|
import invokeai.assets.web as web_assets
|
||||||
from ..util.util import rand_perlin_2d
|
from ..util.util import rand_perlin_2d
|
||||||
|
|
||||||
|
from ..prompting.conditioning import get_uc_and_c_and_ec
|
||||||
|
from ..model_management.model_manager import ModelManager
|
||||||
|
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
downsampling = 8
|
downsampling = 8
|
||||||
CAUTION_IMG = "caution.png"
|
CAUTION_IMG = "caution.png"
|
||||||
|
|
||||||
|
class InvokeAIGeneratorFactory(object):
|
||||||
|
def __init__(self,
|
||||||
|
model_manager: ModelManager,
|
||||||
|
params: InvokeAIGeneratorBasicParams
|
||||||
|
):
|
||||||
|
self.model_manager = model_manager
|
||||||
|
self.params = params
|
||||||
|
|
||||||
|
def make_generator(self, generatorclass: Type[InvokeAIGenerator], **keyword_args)->InvokeAIGenerator:
|
||||||
|
return generatorclass(self.model_manager,
|
||||||
|
self.params,
|
||||||
|
**keyword_args
|
||||||
|
)
|
||||||
|
@dataclass
|
||||||
|
class InvokeAIGeneratorBasicParams:
|
||||||
|
seed: int=None
|
||||||
|
width: int=512
|
||||||
|
height: int=512
|
||||||
|
cfg_scale: int=7.5
|
||||||
|
steps: int=20
|
||||||
|
ddim_eta: float=0.0
|
||||||
|
model: str='stable-diffusion-1.5'
|
||||||
|
scheduler: int='ddim'
|
||||||
|
precision: str='float16'
|
||||||
|
perlin: float=0.0
|
||||||
|
threshold: int=0.0
|
||||||
|
h_symmetry_time_pct: float=None
|
||||||
|
v_symmetry_time_pct: float=None
|
||||||
|
variation_amount: float = 0.0
|
||||||
|
with_variations: list = field(default_factory=list)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InvokeAIGeneratorOutput:
|
||||||
|
image: Image
|
||||||
|
seed: int
|
||||||
|
model_name: str
|
||||||
|
model_hash: str
|
||||||
|
params: InvokeAIGeneratorBasicParams
|
||||||
|
|
||||||
|
# we are interposing a wrapper around the original Generator classes so that
|
||||||
|
# old code that calls Generate will continue to work.
|
||||||
|
class InvokeAIGenerator(metaclass=ABCMeta):
|
||||||
|
scheduler_map = dict(
|
||||||
|
ddim=diffusers.DDIMScheduler,
|
||||||
|
dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||||
|
k_dpm_2=diffusers.KDPM2DiscreteScheduler,
|
||||||
|
k_dpm_2_a=diffusers.KDPM2AncestralDiscreteScheduler,
|
||||||
|
k_dpmpp_2=diffusers.DPMSolverMultistepScheduler,
|
||||||
|
k_euler=diffusers.EulerDiscreteScheduler,
|
||||||
|
k_euler_a=diffusers.EulerAncestralDiscreteScheduler,
|
||||||
|
k_heun=diffusers.HeunDiscreteScheduler,
|
||||||
|
k_lms=diffusers.LMSDiscreteScheduler,
|
||||||
|
plms=diffusers.PNDMScheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model_manager: ModelManager,
|
||||||
|
params: InvokeAIGeneratorBasicParams
|
||||||
|
):
|
||||||
|
self.model_manager=model_manager
|
||||||
|
self.params=params
|
||||||
|
|
||||||
|
def generate(self,
|
||||||
|
prompt: str='',
|
||||||
|
callback: callable=None,
|
||||||
|
step_callback: callable=None,
|
||||||
|
**keyword_args,
|
||||||
|
)->List[InvokeAIGeneratorOutput]:
|
||||||
|
|
||||||
|
model_name = self.params.model or self.model_manager.current_model
|
||||||
|
model_info: dict = self.model_manager.get_model(model_name)
|
||||||
|
model:StableDiffusionGeneratorPipeline = model_info['model']
|
||||||
|
model_hash = model_info['hash']
|
||||||
|
scheduler: Scheduler = self.get_scheduler(
|
||||||
|
model=model,
|
||||||
|
scheduler_name=self.params.scheduler
|
||||||
|
)
|
||||||
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
||||||
|
|
||||||
|
def _wrap_results(image: Image, seed: int, **kwargs):
|
||||||
|
nonlocal results
|
||||||
|
results.append(output)
|
||||||
|
|
||||||
|
generator = self.load_generator(model, self._generator_name())
|
||||||
|
if self.params.variation_amount > 0:
|
||||||
|
generator.set_variation(self.params.seed,
|
||||||
|
self.params.variation_amount,
|
||||||
|
self.params.with_variations)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
results = generator.generate(prompt,
|
||||||
|
conditioning=(uc, c, extra_conditioning_info),
|
||||||
|
sampler=scheduler,
|
||||||
|
**dataclasses.asdict(self.params),
|
||||||
|
**keyword_args
|
||||||
|
)
|
||||||
|
output = InvokeAIGeneratorOutput(
|
||||||
|
image=results[0][0],
|
||||||
|
seed=results[0][1],
|
||||||
|
model_name = model_name,
|
||||||
|
model_hash = model_hash,
|
||||||
|
params=copy.copy(self.params)
|
||||||
|
)
|
||||||
|
if callback:
|
||||||
|
callback(output)
|
||||||
|
yield output
|
||||||
|
|
||||||
|
def load_generator(self, model: StableDiffusionGeneratorPipeline, class_name: str):
|
||||||
|
module_name = f'invokeai.backend.generator.{class_name.lower()}'
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
constructor = getattr(module, class_name)
|
||||||
|
return constructor(model, self.params.precision)
|
||||||
|
|
||||||
|
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||||
|
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
|
||||||
|
scheduler = scheduler_class.from_config(model.scheduler.config)
|
||||||
|
# hack copied over from generate.py
|
||||||
|
if not hasattr(scheduler, 'uses_inpainting_model'):
|
||||||
|
scheduler.uses_inpainting_model = lambda: False
|
||||||
|
return scheduler
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _generator_name(self)->str:
|
||||||
|
'''
|
||||||
|
In derived classes will return the name of the generator to use.
|
||||||
|
'''
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ------------------------------------
|
||||||
|
class Txt2Img(InvokeAIGenerator):
|
||||||
|
def _generator_name(self)->str:
|
||||||
|
return 'Txt2Img'
|
||||||
|
|
||||||
|
# ------------------------------------
|
||||||
|
class Img2Img(InvokeAIGenerator):
|
||||||
|
def generate(self,
|
||||||
|
init_image: Image | torch.FloatTensor,
|
||||||
|
strength: float=0.75,
|
||||||
|
**keyword_args
|
||||||
|
)->List[InvokeAIGeneratorOutput]:
|
||||||
|
return super().generate(init_image=init_image,
|
||||||
|
strength=strength,
|
||||||
|
**keyword_args
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generator_name(self)->str:
|
||||||
|
return 'Img2Img'
|
||||||
|
|
||||||
|
# ------------------------------------
|
||||||
|
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
||||||
|
class Inpaint(Img2Img):
|
||||||
|
def generate(self,
|
||||||
|
mask_image: Image | torch.FloatTensor,
|
||||||
|
# Seam settings - when 0, doesn't fill seam
|
||||||
|
seam_size: int = 0,
|
||||||
|
seam_blur: int = 0,
|
||||||
|
seam_strength: float = 0.7,
|
||||||
|
seam_steps: int = 10,
|
||||||
|
tile_size: int = 32,
|
||||||
|
inpaint_replace=False,
|
||||||
|
infill_method=None,
|
||||||
|
inpaint_width=None,
|
||||||
|
inpaint_height=None,
|
||||||
|
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||||
|
**keyword_args
|
||||||
|
)->List[InvokeAIGeneratorOutput]:
|
||||||
|
return super().generate(
|
||||||
|
mask_image=mask_image,
|
||||||
|
seam_size=seam_size,
|
||||||
|
seam_blur=seam_blur,
|
||||||
|
seam_strength=seam_strength,
|
||||||
|
seam_steps=seam_steps,
|
||||||
|
tile_size=tile_size,
|
||||||
|
inpaint_replace=inpaint_replace,
|
||||||
|
infill_method=infill_method,
|
||||||
|
inpaint_width=inpaint_width,
|
||||||
|
inpaint_height=inpaint_height,
|
||||||
|
inpaint_fill=inpaint_fill,
|
||||||
|
**keyword_args
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generator_name(self)->str:
|
||||||
|
return 'Inpaint'
|
||||||
|
|
||||||
|
|
||||||
class Generator:
|
class Generator:
|
||||||
downsampling_factor: int
|
downsampling_factor: int
|
||||||
@ -64,10 +260,10 @@ class Generator:
|
|||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
prompt,
|
prompt,
|
||||||
init_image,
|
|
||||||
width,
|
width,
|
||||||
height,
|
height,
|
||||||
sampler,
|
sampler,
|
||||||
|
init_image=None,
|
||||||
iterations=1,
|
iterations=1,
|
||||||
seed=None,
|
seed=None,
|
||||||
image_callback=None,
|
image_callback=None,
|
||||||
@ -293,16 +489,6 @@ class Generator:
|
|||||||
else:
|
else:
|
||||||
return (seed, None)
|
return (seed, None)
|
||||||
|
|
||||||
# returns a tensor filled with random numbers from a normal distribution
|
|
||||||
def get_noise(self, width, height):
|
|
||||||
"""
|
|
||||||
Returns a tensor filled with random numbers, either form a normal distribution
|
|
||||||
(txt2img) or from the latent image (img2img, inpaint)
|
|
||||||
"""
|
|
||||||
raise NotImplementedError(
|
|
||||||
"get_noise() must be implemented in a descendent class"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_perlin_noise(self, width, height):
|
def get_perlin_noise(self, width, height):
|
||||||
fixdevice = "cpu" if (self.model.device.type == "mps") else self.model.device
|
fixdevice = "cpu" if (self.model.device.type == "mps") else self.model.device
|
||||||
# limit noise to only the diffusion image channels, not the mask channels
|
# limit noise to only the diffusion image channels, not the mask channels
|
||||||
|
Loading…
x
Reference in New Issue
Block a user