add embiggen, remove complicated constructor

This commit is contained in:
Lincoln Stein 2023-03-11 07:50:39 -05:00
parent fe75b95464
commit 7e76eea059

View File

@ -129,7 +129,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
scheduler_name=generator_args.get('scheduler') scheduler_name=generator_args.get('scheduler')
) )
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model) uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
generator = self.load_generator(model, self._generator_name()) gen_class = self._generator_class()
generator = gen_class(model, self.params.precision)
if self.params.variation_amount > 0: if self.params.variation_amount > 0:
generator.set_variation(generator_args.get('seed'), generator.set_variation(generator_args.get('seed'),
generator_args.get('variation_amount'), generator_args.get('variation_amount'),
@ -160,11 +161,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
''' '''
return list(self.scheduler_map.keys()) return list(self.scheduler_map.keys())
def load_generator(self, model: StableDiffusionGeneratorPipeline, class_name: str): def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
module_name = f'invokeai.backend.generator.{class_name.lower()}' return generator_class(model, self.params.precision)
module = importlib.import_module(module_name)
constructor = getattr(module, class_name)
return constructor(model, self.params.precision)
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler: def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim') scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
@ -175,17 +173,20 @@ class InvokeAIGenerator(metaclass=ABCMeta):
return scheduler return scheduler
@classmethod @classmethod
def _generator_name(cls): def _generator_class(cls)->Type[Generator]:
''' '''
In derived classes return the name of the generator to apply. In derived classes return the name of the generator to apply.
If you don't override will return the name of the derived If you don't override will return the name of the derived
class, which nicely parallels the generator class names. class, which nicely parallels the generator class names.
''' '''
return cls.__name__ return Generator
# ------------------------------------ # ------------------------------------
class Txt2Img(InvokeAIGenerator): class Txt2Img(InvokeAIGenerator):
pass @classmethod
def _generator_class(cls):
from .txt2img import Txt2Img
return Txt2Img
# ------------------------------------ # ------------------------------------
class Img2Img(InvokeAIGenerator): class Img2Img(InvokeAIGenerator):
@ -198,6 +199,10 @@ class Img2Img(InvokeAIGenerator):
strength=strength, strength=strength,
**keyword_args **keyword_args
) )
@classmethod
def _generator_class(cls):
from .img2img import Img2Img
return Img2Img
# ------------------------------------ # ------------------------------------
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff # Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
@ -231,6 +236,29 @@ class Inpaint(Img2Img):
inpaint_fill=inpaint_fill, inpaint_fill=inpaint_fill,
**keyword_args **keyword_args
) )
@classmethod
def _generator_class(cls):
from .inpaint import Inpaint
return Inpaint
# ------------------------------------
class Embiggen(Txt2Img):
def generate(
self,
embiggen: list=None,
embiggen_tiles: list = None,
strength: float=0.75,
**kwargs)->List[InvokeAIGeneratorOutput]:
return super().generate(embiggen=embiggen,
embiggen_tiles=embiggen_tiles,
strength=strength,
**kwargs)
@classmethod
def _generator_class(cls):
from .embiggen import Embiggen
return Embiggen
class Generator: class Generator:
downsampling_factor: int downsampling_factor: int