add seamless tiling support

This commit is contained in:
Lincoln Stein 2023-03-11 08:33:23 -05:00
parent 675dd12b6c
commit 250b0ab182

View File

@ -4,7 +4,6 @@ including img2img, txt2img, and inpaint
""" """
from __future__ import annotations from __future__ import annotations
import importlib
import itertools import itertools
import dataclasses import dataclasses
import diffusers import diffusers
@ -22,10 +21,11 @@ from PIL import Image, ImageChops, ImageFilter
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from tqdm import trange from tqdm import trange
from typing import List, Iterator from typing import List, Iterator, Type
from dataclasses import dataclass, field from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
from ..image_util import configure_model_padding
from ..util.util import rand_perlin_2d from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker from ..safety_checker import SafetyChecker
from ..prompting.conditioning import get_uc_and_c_and_ec from ..prompting.conditioning import get_uc_and_c_and_ec
@ -47,6 +47,8 @@ class InvokeAIGeneratorBasicParams:
precision: str='float16' precision: str='float16'
perlin: float=0.0 perlin: float=0.0
threshold: int=0.0 threshold: int=0.0
seamless: bool=False
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
h_symmetry_time_pct: float=None h_symmetry_time_pct: float=None
v_symmetry_time_pct: float=None v_symmetry_time_pct: float=None
variation_amount: float = 0.0 variation_amount: float = 0.0
@ -138,6 +140,18 @@ class InvokeAIGenerator(metaclass=ABCMeta):
generator_args.get('with_variations') generator_args.get('with_variations')
) )
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
else:
configure_model_padding(model,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1) iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
for i in iteration_count: for i in iteration_count:
results = generator.generate(prompt, results = generator.generate(prompt,