fix(backend): simple typing fixes

This commit is contained in:
psychedelicious 2023-03-26 16:54:46 +11:00
parent bd7e515290
commit 5fe38f7c88

View File

@ -21,7 +21,7 @@ from PIL import Image, ImageChops, ImageFilter
from accelerate.utils import set_seed from accelerate.utils import set_seed
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from tqdm import trange from tqdm import trange
from typing import List, Iterator, Type from typing import Callable, List, Iterator, Optional, Type
from dataclasses import dataclass, field from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler from diffusers.schedulers import SchedulerMixin as Scheduler
@ -35,23 +35,23 @@ downsampling = 8
@dataclass @dataclass
class InvokeAIGeneratorBasicParams: class InvokeAIGeneratorBasicParams:
seed: int=None seed: Optional[int]=None
width: int=512 width: int=512
height: int=512 height: int=512
cfg_scale: int=7.5 cfg_scale: float=7.5
steps: int=20 steps: int=20
ddim_eta: float=0.0 ddim_eta: float=0.0
scheduler: int='ddim' scheduler: str='ddim'
precision: str='float16' precision: str='float16'
perlin: float=0.0 perlin: float=0.0
threshold: int=0.0 threshold: float=0.0
seamless: bool=False seamless: bool=False
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y']) seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
h_symmetry_time_pct: float=None h_symmetry_time_pct: Optional[float]=None
v_symmetry_time_pct: float=None v_symmetry_time_pct: Optional[float]=None
variation_amount: float = 0.0 variation_amount: float = 0.0
with_variations: list=field(default_factory=list) with_variations: list=field(default_factory=list)
safety_checker: SafetyChecker=None safety_checker: Optional[SafetyChecker]=None
@dataclass @dataclass
class InvokeAIGeneratorOutput: class InvokeAIGeneratorOutput:
@ -61,10 +61,10 @@ class InvokeAIGeneratorOutput:
and the model hash, as well as all the generate() parameters that went into and the model hash, as well as all the generate() parameters that went into
generating the image (in .params, also available as attributes) generating the image (in .params, also available as attributes)
''' '''
image: Image image: Image.Image
seed: int seed: int
model_hash: str model_hash: str
attention_maps_images: List[Image] attention_maps_images: List[Image.Image]
params: Namespace params: Namespace
# we are interposing a wrapper around the original Generator classes so that # we are interposing a wrapper around the original Generator classes so that
@ -92,8 +92,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def generate(self, def generate(self,
prompt: str='', prompt: str='',
callback: callable=None, callback: Optional[Callable]=None,
step_callback: callable=None, step_callback: Optional[Callable]=None,
iterations: int=1, iterations: int=1,
**keyword_args, **keyword_args,
)->Iterator[InvokeAIGeneratorOutput]: )->Iterator[InvokeAIGeneratorOutput]:
@ -206,10 +206,10 @@ class Txt2Img(InvokeAIGenerator):
# ------------------------------------ # ------------------------------------
class Img2Img(InvokeAIGenerator): class Img2Img(InvokeAIGenerator):
def generate(self, def generate(self,
init_image: Image | torch.FloatTensor, init_image: Image.Image | torch.FloatTensor,
strength: float=0.75, strength: float=0.75,
**keyword_args **keyword_args
)->List[InvokeAIGeneratorOutput]: )->Iterator[InvokeAIGeneratorOutput]:
return super().generate(init_image=init_image, return super().generate(init_image=init_image,
strength=strength, strength=strength,
**keyword_args **keyword_args
@ -223,7 +223,7 @@ class Img2Img(InvokeAIGenerator):
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff # Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
class Inpaint(Img2Img): class Inpaint(Img2Img):
def generate(self, def generate(self,
mask_image: Image | torch.FloatTensor, mask_image: Image.Image | torch.FloatTensor,
# Seam settings - when 0, doesn't fill seam # Seam settings - when 0, doesn't fill seam
seam_size: int = 0, seam_size: int = 0,
seam_blur: int = 0, seam_blur: int = 0,
@ -236,7 +236,7 @@ class Inpaint(Img2Img):
inpaint_height=None, inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF), inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
**keyword_args **keyword_args
)->List[InvokeAIGeneratorOutput]: )->Iterator[InvokeAIGeneratorOutput]:
return super().generate( return super().generate(
mask_image=mask_image, mask_image=mask_image,
seam_size=seam_size, seam_size=seam_size,
@ -263,7 +263,7 @@ class Embiggen(Txt2Img):
embiggen: list=None, embiggen: list=None,
embiggen_tiles: list = None, embiggen_tiles: list = None,
strength: float=0.75, strength: float=0.75,
**kwargs)->List[InvokeAIGeneratorOutput]: **kwargs)->Iterator[InvokeAIGeneratorOutput]:
return super().generate(embiggen=embiggen, return super().generate(embiggen=embiggen,
embiggen_tiles=embiggen_tiles, embiggen_tiles=embiggen_tiles,
strength=strength, strength=strength,