mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(backend): simple typing fixes
This commit is contained in:
parent
bd7e515290
commit
5fe38f7c88
@ -21,7 +21,7 @@ from PIL import Image, ImageChops, ImageFilter
|
|||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from diffusers import DiffusionPipeline
|
from diffusers import DiffusionPipeline
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
from typing import List, Iterator, Type
|
from typing import Callable, List, Iterator, Optional, Type
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from diffusers.schedulers import SchedulerMixin as Scheduler
|
from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||||
|
|
||||||
@ -35,23 +35,23 @@ downsampling = 8
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InvokeAIGeneratorBasicParams:
|
class InvokeAIGeneratorBasicParams:
|
||||||
seed: int=None
|
seed: Optional[int]=None
|
||||||
width: int=512
|
width: int=512
|
||||||
height: int=512
|
height: int=512
|
||||||
cfg_scale: int=7.5
|
cfg_scale: float=7.5
|
||||||
steps: int=20
|
steps: int=20
|
||||||
ddim_eta: float=0.0
|
ddim_eta: float=0.0
|
||||||
scheduler: int='ddim'
|
scheduler: str='ddim'
|
||||||
precision: str='float16'
|
precision: str='float16'
|
||||||
perlin: float=0.0
|
perlin: float=0.0
|
||||||
threshold: int=0.0
|
threshold: float=0.0
|
||||||
seamless: bool=False
|
seamless: bool=False
|
||||||
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
|
seamless_axes: List[str]=field(default_factory=lambda: ['x', 'y'])
|
||||||
h_symmetry_time_pct: float=None
|
h_symmetry_time_pct: Optional[float]=None
|
||||||
v_symmetry_time_pct: float=None
|
v_symmetry_time_pct: Optional[float]=None
|
||||||
variation_amount: float = 0.0
|
variation_amount: float = 0.0
|
||||||
with_variations: list=field(default_factory=list)
|
with_variations: list=field(default_factory=list)
|
||||||
safety_checker: SafetyChecker=None
|
safety_checker: Optional[SafetyChecker]=None
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InvokeAIGeneratorOutput:
|
class InvokeAIGeneratorOutput:
|
||||||
@ -61,10 +61,10 @@ class InvokeAIGeneratorOutput:
|
|||||||
and the model hash, as well as all the generate() parameters that went into
|
and the model hash, as well as all the generate() parameters that went into
|
||||||
generating the image (in .params, also available as attributes)
|
generating the image (in .params, also available as attributes)
|
||||||
'''
|
'''
|
||||||
image: Image
|
image: Image.Image
|
||||||
seed: int
|
seed: int
|
||||||
model_hash: str
|
model_hash: str
|
||||||
attention_maps_images: List[Image]
|
attention_maps_images: List[Image.Image]
|
||||||
params: Namespace
|
params: Namespace
|
||||||
|
|
||||||
# we are interposing a wrapper around the original Generator classes so that
|
# we are interposing a wrapper around the original Generator classes so that
|
||||||
@ -92,8 +92,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
prompt: str='',
|
prompt: str='',
|
||||||
callback: callable=None,
|
callback: Optional[Callable]=None,
|
||||||
step_callback: callable=None,
|
step_callback: Optional[Callable]=None,
|
||||||
iterations: int=1,
|
iterations: int=1,
|
||||||
**keyword_args,
|
**keyword_args,
|
||||||
)->Iterator[InvokeAIGeneratorOutput]:
|
)->Iterator[InvokeAIGeneratorOutput]:
|
||||||
@ -206,10 +206,10 @@ class Txt2Img(InvokeAIGenerator):
|
|||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
class Img2Img(InvokeAIGenerator):
|
class Img2Img(InvokeAIGenerator):
|
||||||
def generate(self,
|
def generate(self,
|
||||||
init_image: Image | torch.FloatTensor,
|
init_image: Image.Image | torch.FloatTensor,
|
||||||
strength: float=0.75,
|
strength: float=0.75,
|
||||||
**keyword_args
|
**keyword_args
|
||||||
)->List[InvokeAIGeneratorOutput]:
|
)->Iterator[InvokeAIGeneratorOutput]:
|
||||||
return super().generate(init_image=init_image,
|
return super().generate(init_image=init_image,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
**keyword_args
|
**keyword_args
|
||||||
@ -223,7 +223,7 @@ class Img2Img(InvokeAIGenerator):
|
|||||||
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
||||||
class Inpaint(Img2Img):
|
class Inpaint(Img2Img):
|
||||||
def generate(self,
|
def generate(self,
|
||||||
mask_image: Image | torch.FloatTensor,
|
mask_image: Image.Image | torch.FloatTensor,
|
||||||
# Seam settings - when 0, doesn't fill seam
|
# Seam settings - when 0, doesn't fill seam
|
||||||
seam_size: int = 0,
|
seam_size: int = 0,
|
||||||
seam_blur: int = 0,
|
seam_blur: int = 0,
|
||||||
@ -236,7 +236,7 @@ class Inpaint(Img2Img):
|
|||||||
inpaint_height=None,
|
inpaint_height=None,
|
||||||
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
inpaint_fill: tuple(int) = (0x7F, 0x7F, 0x7F, 0xFF),
|
||||||
**keyword_args
|
**keyword_args
|
||||||
)->List[InvokeAIGeneratorOutput]:
|
)->Iterator[InvokeAIGeneratorOutput]:
|
||||||
return super().generate(
|
return super().generate(
|
||||||
mask_image=mask_image,
|
mask_image=mask_image,
|
||||||
seam_size=seam_size,
|
seam_size=seam_size,
|
||||||
@ -263,7 +263,7 @@ class Embiggen(Txt2Img):
|
|||||||
embiggen: list=None,
|
embiggen: list=None,
|
||||||
embiggen_tiles: list = None,
|
embiggen_tiles: list = None,
|
||||||
strength: float=0.75,
|
strength: float=0.75,
|
||||||
**kwargs)->List[InvokeAIGeneratorOutput]:
|
**kwargs)->Iterator[InvokeAIGeneratorOutput]:
|
||||||
return super().generate(embiggen=embiggen,
|
return super().generate(embiggen=embiggen,
|
||||||
embiggen_tiles=embiggen_tiles,
|
embiggen_tiles=embiggen_tiles,
|
||||||
strength=strength,
|
strength=strength,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user