add seamless tiling support

This commit is contained in:
Lincoln Stein 2023-03-11 08:33:23 -05:00
parent 675dd12b6c
commit 250b0ab182

View File

@ -4,7 +4,6 @@ including img2img, txt2img, and inpaint
"""
from __future__ import annotations
import importlib
import itertools
import dataclasses
import diffusers
@ -22,10 +21,11 @@ from PIL import Image, ImageChops, ImageFilter
from accelerate.utils import set_seed
from diffusers import DiffusionPipeline
from tqdm import trange
from typing import List, Iterator
from typing import List, Iterator, Type
from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler
from ..image_util import configure_model_padding
from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker
from ..prompting.conditioning import get_uc_and_c_and_ec
@ -47,6 +47,8 @@ class InvokeAIGeneratorBasicParams:
precision: str='float16'
perlin: float=0.0
threshold: int=0.0
seamless: bool=False
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
h_symmetry_time_pct: float=None
v_symmetry_time_pct: float=None
variation_amount: float = 0.0
@ -138,6 +140,18 @@ class InvokeAIGenerator(metaclass=ABCMeta):
generator_args.get('with_variations')
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
else:
configure_model_padding(model,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
for i in iteration_count:
results = generator.generate(prompt,