From 7e76eea059a3a30569e882b1a33a37a812c58c5c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 11 Mar 2023 07:50:39 -0500 Subject: [PATCH] add embiggen, remove complicated constructor --- invokeai/backend/generator/base.py | 46 ++++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/invokeai/backend/generator/base.py b/invokeai/backend/generator/base.py index db1afa0f88..e2ff81beb7 100644 --- a/invokeai/backend/generator/base.py +++ b/invokeai/backend/generator/base.py @@ -129,7 +129,8 @@ class InvokeAIGenerator(metaclass=ABCMeta): scheduler_name=generator_args.get('scheduler') ) uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model) - generator = self.load_generator(model, self._generator_name()) + gen_class = self._generator_class() + generator = gen_class(model, self.params.precision) if self.params.variation_amount > 0: generator.set_variation(generator_args.get('seed'), generator_args.get('variation_amount'), @@ -160,11 +161,8 @@ class InvokeAIGenerator(metaclass=ABCMeta): ''' return list(self.scheduler_map.keys()) - def load_generator(self, model: StableDiffusionGeneratorPipeline, class_name: str): - module_name = f'invokeai.backend.generator.{class_name.lower()}' - module = importlib.import_module(module_name) - constructor = getattr(module, class_name) - return constructor(model, self.params.precision) + def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]): + return generator_class(model, self.params.precision) def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler: scheduler_class = self.scheduler_map.get(scheduler_name,'ddim') @@ -175,17 +173,20 @@ class InvokeAIGenerator(metaclass=ABCMeta): return scheduler @classmethod - def _generator_name(cls): + def _generator_class(cls)->Type[Generator]: ''' In derived classes return the name of the generator to apply. If you don't override will return the name of the derived class, which nicely parallels the generator class names. ''' - return cls.__name__ + return Generator # ------------------------------------ class Txt2Img(InvokeAIGenerator): - pass + @classmethod + def _generator_class(cls): + from .txt2img import Txt2Img + return Txt2Img # ------------------------------------ class Img2Img(InvokeAIGenerator): @@ -198,6 +199,10 @@ class Img2Img(InvokeAIGenerator): strength=strength, **keyword_args ) + @classmethod + def _generator_class(cls): + from .img2img import Img2Img + return Img2Img # ------------------------------------ # Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff @@ -231,6 +236,29 @@ class Inpaint(Img2Img): inpaint_fill=inpaint_fill, **keyword_args ) + @classmethod + def _generator_class(cls): + from .inpaint import Inpaint + return Inpaint + +# ------------------------------------ +class Embiggen(Txt2Img): + def generate( + self, + embiggen: list=None, + embiggen_tiles: list = None, + strength: float=0.75, + **kwargs)->List[InvokeAIGeneratorOutput]: + return super().generate(embiggen=embiggen, + embiggen_tiles=embiggen_tiles, + strength=strength, + **kwargs) + + @classmethod + def _generator_class(cls): + from .embiggen import Embiggen + return Embiggen + class Generator: downsampling_factor: int