diff --git a/invokeai/backend/generator/base.py b/invokeai/backend/generator/base.py index e30b77ec33..ee56077fa8 100644 --- a/invokeai/backend/generator/base.py +++ b/invokeai/backend/generator/base.py @@ -21,7 +21,7 @@ from PIL import Image, ImageChops, ImageFilter from accelerate.utils import set_seed from diffusers import DiffusionPipeline from tqdm import trange -from typing import List, Iterator, Type +from typing import Callable, List, Iterator, Optional, Type from dataclasses import dataclass, field from diffusers.schedulers import SchedulerMixin as Scheduler @@ -35,23 +35,23 @@ downsampling = 8 @dataclass class InvokeAIGeneratorBasicParams: - seed: int=None + seed: Optional[int]=None width: int=512 height: int=512 - cfg_scale: int=7.5 + cfg_scale: float=7.5 steps: int=20 ddim_eta: float=0.0 - scheduler: int='ddim' + scheduler: str='ddim' precision: str='float16' perlin: float=0.0 - threshold: int=0.0 + threshold: float=0.0 seamless: bool=False seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y']) - h_symmetry_time_pct: float=None - v_symmetry_time_pct: float=None + h_symmetry_time_pct: Optional[float]=None + v_symmetry_time_pct: Optional[float]=None variation_amount: float = 0.0 with_variations: list=field(default_factory=list) - safety_checker: SafetyChecker=None + safety_checker: Optional[SafetyChecker]=None @dataclass class InvokeAIGeneratorOutput: @@ -61,10 +61,10 @@ class InvokeAIGeneratorOutput: and the model hash, as well as all the generate() parameters that went into generating the image (in .params, also available as attributes) ''' - image: Image + image: Image.Image seed: int model_hash: str - attention_maps_images: List[Image] + attention_maps_images: List[Image.Image] params: Namespace # we are interposing a wrapper around the original Generator classes so that @@ -92,8 +92,8 @@ class InvokeAIGenerator(metaclass=ABCMeta): def generate(self, prompt: str='', - callback: callable=None, - step_callback: callable=None, + callback: Optional[Callable]=None, + step_callback: Optional[Callable]=None, iterations: int=1, **keyword_args, )->Iterator[InvokeAIGeneratorOutput]: @@ -206,10 +206,10 @@ class Txt2Img(InvokeAIGenerator): # ------------------------------------ class Img2Img(InvokeAIGenerator): def generate(self, - init_image: Image | torch.FloatTensor, + init_image: Image.Image | torch.FloatTensor, strength: float=0.75, **keyword_args - )->List[InvokeAIGeneratorOutput]: + )->Iterator[InvokeAIGeneratorOutput]: return super().generate(init_image=init_image, strength=strength, **keyword_args @@ -223,7 +223,7 @@ class Img2Img(InvokeAIGenerator): # Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff class Inpaint(Img2Img): def generate(self, - mask_image: Image | torch.FloatTensor, + mask_image: Image.Image | torch.FloatTensor, # Seam settings - when 0, doesn't fill seam seam_size: int = 0, seam_blur: int = 0, @@ -236,7 +236,7 @@ class Inpaint(Img2Img): inpaint_height=None, inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF), **keyword_args - )->List[InvokeAIGeneratorOutput]: + )->Iterator[InvokeAIGeneratorOutput]: return super().generate( mask_image=mask_image, seam_size=seam_size, @@ -263,7 +263,7 @@ class Embiggen(Txt2Img): embiggen: list=None, embiggen_tiles: list = None, strength: float=0.75, - **kwargs)->List[InvokeAIGeneratorOutput]: + **kwargs)->Iterator[InvokeAIGeneratorOutput]: return super().generate(embiggen=embiggen, embiggen_tiles=embiggen_tiles, strength=strength,