mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add embiggen, remove complicated constructor
This commit is contained in:
parent
fe75b95464
commit
7e76eea059
@ -129,7 +129,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
scheduler_name=generator_args.get('scheduler')
|
||||
)
|
||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
||||
generator = self.load_generator(model, self._generator_name())
|
||||
gen_class = self._generator_class()
|
||||
generator = gen_class(model, self.params.precision)
|
||||
if self.params.variation_amount > 0:
|
||||
generator.set_variation(generator_args.get('seed'),
|
||||
generator_args.get('variation_amount'),
|
||||
@ -160,11 +161,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
'''
|
||||
return list(self.scheduler_map.keys())
|
||||
|
||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, class_name: str):
|
||||
module_name = f'invokeai.backend.generator.{class_name.lower()}'
|
||||
module = importlib.import_module(module_name)
|
||||
constructor = getattr(module, class_name)
|
||||
return constructor(model, self.params.precision)
|
||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||
return generator_class(model, self.params.precision)
|
||||
|
||||
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
|
||||
@ -175,17 +173,20 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
return scheduler
|
||||
|
||||
@classmethod
|
||||
def _generator_name(cls):
|
||||
def _generator_class(cls)->Type[Generator]:
|
||||
'''
|
||||
In derived classes return the name of the generator to apply.
|
||||
If you don't override will return the name of the derived
|
||||
class, which nicely parallels the generator class names.
|
||||
'''
|
||||
return cls.__name__
|
||||
return Generator
|
||||
|
||||
# ------------------------------------
|
||||
class Txt2Img(InvokeAIGenerator):
|
||||
pass
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .txt2img import Txt2Img
|
||||
return Txt2Img
|
||||
|
||||
# ------------------------------------
|
||||
class Img2Img(InvokeAIGenerator):
|
||||
@ -198,6 +199,10 @@ class Img2Img(InvokeAIGenerator):
|
||||
strength=strength,
|
||||
**keyword_args
|
||||
)
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .img2img import Img2Img
|
||||
return Img2Img
|
||||
|
||||
# ------------------------------------
|
||||
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
||||
@ -231,6 +236,29 @@ class Inpaint(Img2Img):
|
||||
inpaint_fill=inpaint_fill,
|
||||
**keyword_args
|
||||
)
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .inpaint import Inpaint
|
||||
return Inpaint
|
||||
|
||||
# ------------------------------------
|
||||
class Embiggen(Txt2Img):
|
||||
def generate(
|
||||
self,
|
||||
embiggen: list=None,
|
||||
embiggen_tiles: list = None,
|
||||
strength: float=0.75,
|
||||
**kwargs)->List[InvokeAIGeneratorOutput]:
|
||||
return super().generate(embiggen=embiggen,
|
||||
embiggen_tiles=embiggen_tiles,
|
||||
strength=strength,
|
||||
**kwargs)
|
||||
|
||||
@classmethod
|
||||
def _generator_class(cls):
|
||||
from .embiggen import Embiggen
|
||||
return Embiggen
|
||||
|
||||
|
||||
class Generator:
|
||||
downsampling_factor: int
|
||||
|
Loading…
Reference in New Issue
Block a user