add embiggen, remove complicated constructor

This commit is contained in:
Lincoln Stein 2023-03-11 07:50:39 -05:00
parent fe75b95464
commit 7e76eea059

View File

@ -129,7 +129,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
scheduler_name=generator_args.get('scheduler')
)
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
generator = self.load_generator(model, self._generator_name())
gen_class = self._generator_class()
generator = gen_class(model, self.params.precision)
if self.params.variation_amount > 0:
generator.set_variation(generator_args.get('seed'),
generator_args.get('variation_amount'),
@ -160,11 +161,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
'''
return list(self.scheduler_map.keys())
def load_generator(self, model: StableDiffusionGeneratorPipeline, class_name: str):
module_name = f'invokeai.backend.generator.{class_name.lower()}'
module = importlib.import_module(module_name)
constructor = getattr(module, class_name)
return constructor(model, self.params.precision)
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
return generator_class(model, self.params.precision)
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
@ -175,17 +173,20 @@ class InvokeAIGenerator(metaclass=ABCMeta):
return scheduler
@classmethod
def _generator_name(cls):
def _generator_class(cls)->Type[Generator]:
'''
In derived classes return the name of the generator to apply.
If you don't override will return the name of the derived
class, which nicely parallels the generator class names.
'''
return cls.__name__
return Generator
# ------------------------------------
class Txt2Img(InvokeAIGenerator):
pass
@classmethod
def _generator_class(cls):
from .txt2img import Txt2Img
return Txt2Img
# ------------------------------------
class Img2Img(InvokeAIGenerator):
@ -198,6 +199,10 @@ class Img2Img(InvokeAIGenerator):
strength=strength,
**keyword_args
)
@classmethod
def _generator_class(cls):
from .img2img import Img2Img
return Img2Img
# ------------------------------------
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
@ -231,6 +236,29 @@ class Inpaint(Img2Img):
inpaint_fill=inpaint_fill,
**keyword_args
)
@classmethod
def _generator_class(cls):
from .inpaint import Inpaint
return Inpaint
# ------------------------------------
class Embiggen(Txt2Img):
def generate(
self,
embiggen: list=None,
embiggen_tiles: list = None,
strength: float=0.75,
**kwargs)->List[InvokeAIGeneratorOutput]:
return super().generate(embiggen=embiggen,
embiggen_tiles=embiggen_tiles,
strength=strength,
**kwargs)
@classmethod
def _generator_class(cls):
from .embiggen import Embiggen
return Embiggen
class Generator:
downsampling_factor: int