fix(backend): simple typing fixes

This commit is contained in:
psychedelicious 2023-03-26 16:54:46 +11:00
parent bd7e515290
commit 5fe38f7c88

View File

@ -21,7 +21,7 @@ from PIL import Image, ImageChops, ImageFilter
from accelerate.utils import set_seed
from diffusers import DiffusionPipeline
from tqdm import trange
from typing import List, Iterator, Type
from typing import Callable, List, Iterator, Optional, Type
from dataclasses import dataclass, field
from diffusers.schedulers import SchedulerMixin as Scheduler
@ -35,23 +35,23 @@ downsampling = 8
@dataclass
class InvokeAIGeneratorBasicParams:
seed: int=None
seed: Optional[int]=None
width: int=512
height: int=512
cfg_scale: int=7.5
cfg_scale: float=7.5
steps: int=20
ddim_eta: float=0.0
scheduler: int='ddim'
scheduler: str='ddim'
precision: str='float16'
perlin: float=0.0
threshold: int=0.0
threshold: float=0.0
seamless: bool=False
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
h_symmetry_time_pct: float=None
v_symmetry_time_pct: float=None
h_symmetry_time_pct: Optional[float]=None
v_symmetry_time_pct: Optional[float]=None
variation_amount: float = 0.0
with_variations: list=field(default_factory=list)
safety_checker: SafetyChecker=None
safety_checker: Optional[SafetyChecker]=None
@dataclass
class InvokeAIGeneratorOutput:
@ -61,10 +61,10 @@ class InvokeAIGeneratorOutput:
and the model hash, as well as all the generate() parameters that went into
generating the image (in .params, also available as attributes)
'''
image: Image
image: Image.Image
seed: int
model_hash: str
attention_maps_images: List[Image]
attention_maps_images: List[Image.Image]
params: Namespace
# we are interposing a wrapper around the original Generator classes so that
@ -92,8 +92,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
def generate(self,
prompt: str='',
callback: callable=None,
step_callback: callable=None,
callback: Optional[Callable]=None,
step_callback: Optional[Callable]=None,
iterations: int=1,
**keyword_args,
)->Iterator[InvokeAIGeneratorOutput]:
@ -206,10 +206,10 @@ class Txt2Img(InvokeAIGenerator):
# ------------------------------------
class Img2Img(InvokeAIGenerator):
def generate(self,
init_image: Image | torch.FloatTensor,
init_image: Image.Image | torch.FloatTensor,
strength: float=0.75,
**keyword_args
)->List[InvokeAIGeneratorOutput]:
)->Iterator[InvokeAIGeneratorOutput]:
return super().generate(init_image=init_image,
strength=strength,
**keyword_args
@ -223,7 +223,7 @@ class Img2Img(InvokeAIGenerator):
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
class Inpaint(Img2Img):
def generate(self,
mask_image: Image | torch.FloatTensor,
mask_image: Image.Image | torch.FloatTensor,
# Seam settings - when 0, doesn't fill seam
seam_size: int = 0,
seam_blur: int = 0,
@ -236,7 +236,7 @@ class Inpaint(Img2Img):
inpaint_height=None,
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
**keyword_args
)->List[InvokeAIGeneratorOutput]:
)->Iterator[InvokeAIGeneratorOutput]:
return super().generate(
mask_image=mask_image,
seam_size=seam_size,
@ -263,7 +263,7 @@ class Embiggen(Txt2Img):
embiggen: list=None,
embiggen_tiles: list = None,
strength: float=0.75,
**kwargs)->List[InvokeAIGeneratorOutput]:
**kwargs)->Iterator[InvokeAIGeneratorOutput]:
return super().generate(embiggen=embiggen,
embiggen_tiles=embiggen_tiles,
strength=strength,