mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add embiggen, remove complicated constructor
This commit is contained in:
parent
fe75b95464
commit
7e76eea059
@ -129,7 +129,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
scheduler_name=generator_args.get('scheduler')
|
scheduler_name=generator_args.get('scheduler')
|
||||||
)
|
)
|
||||||
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
|
||||||
generator = self.load_generator(model, self._generator_name())
|
gen_class = self._generator_class()
|
||||||
|
generator = gen_class(model, self.params.precision)
|
||||||
if self.params.variation_amount > 0:
|
if self.params.variation_amount > 0:
|
||||||
generator.set_variation(generator_args.get('seed'),
|
generator.set_variation(generator_args.get('seed'),
|
||||||
generator_args.get('variation_amount'),
|
generator_args.get('variation_amount'),
|
||||||
@ -160,11 +161,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
'''
|
'''
|
||||||
return list(self.scheduler_map.keys())
|
return list(self.scheduler_map.keys())
|
||||||
|
|
||||||
def load_generator(self, model: StableDiffusionGeneratorPipeline, class_name: str):
|
def load_generator(self, model: StableDiffusionGeneratorPipeline, generator_class: Type[Generator]):
|
||||||
module_name = f'invokeai.backend.generator.{class_name.lower()}'
|
return generator_class(model, self.params.precision)
|
||||||
module = importlib.import_module(module_name)
|
|
||||||
constructor = getattr(module, class_name)
|
|
||||||
return constructor(model, self.params.precision)
|
|
||||||
|
|
||||||
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
def get_scheduler(self, scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
|
||||||
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
|
scheduler_class = self.scheduler_map.get(scheduler_name,'ddim')
|
||||||
@ -175,17 +173,20 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
return scheduler
|
return scheduler
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _generator_name(cls):
|
def _generator_class(cls)->Type[Generator]:
|
||||||
'''
|
'''
|
||||||
In derived classes return the name of the generator to apply.
|
In derived classes return the name of the generator to apply.
|
||||||
If you don't override will return the name of the derived
|
If you don't override will return the name of the derived
|
||||||
class, which nicely parallels the generator class names.
|
class, which nicely parallels the generator class names.
|
||||||
'''
|
'''
|
||||||
return cls.__name__
|
return Generator
|
||||||
|
|
||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
class Txt2Img(InvokeAIGenerator):
|
class Txt2Img(InvokeAIGenerator):
|
||||||
pass
|
@classmethod
|
||||||
|
def _generator_class(cls):
|
||||||
|
from .txt2img import Txt2Img
|
||||||
|
return Txt2Img
|
||||||
|
|
||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
class Img2Img(InvokeAIGenerator):
|
class Img2Img(InvokeAIGenerator):
|
||||||
@ -198,6 +199,10 @@ class Img2Img(InvokeAIGenerator):
|
|||||||
strength=strength,
|
strength=strength,
|
||||||
**keyword_args
|
**keyword_args
|
||||||
)
|
)
|
||||||
|
@classmethod
|
||||||
|
def _generator_class(cls):
|
||||||
|
from .img2img import Img2Img
|
||||||
|
return Img2Img
|
||||||
|
|
||||||
# ------------------------------------
|
# ------------------------------------
|
||||||
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
# Takes all the arguments of Img2Img and adds the mask image and the seam/infill stuff
|
||||||
@ -231,6 +236,29 @@ class Inpaint(Img2Img):
|
|||||||
inpaint_fill=inpaint_fill,
|
inpaint_fill=inpaint_fill,
|
||||||
**keyword_args
|
**keyword_args
|
||||||
)
|
)
|
||||||
|
@classmethod
|
||||||
|
def _generator_class(cls):
|
||||||
|
from .inpaint import Inpaint
|
||||||
|
return Inpaint
|
||||||
|
|
||||||
|
# ------------------------------------
|
||||||
|
class Embiggen(Txt2Img):
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
embiggen: list=None,
|
||||||
|
embiggen_tiles: list = None,
|
||||||
|
strength: float=0.75,
|
||||||
|
**kwargs)->List[InvokeAIGeneratorOutput]:
|
||||||
|
return super().generate(embiggen=embiggen,
|
||||||
|
embiggen_tiles=embiggen_tiles,
|
||||||
|
strength=strength,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _generator_class(cls):
|
||||||
|
from .embiggen import Embiggen
|
||||||
|
return Embiggen
|
||||||
|
|
||||||
|
|
||||||
class Generator:
|
class Generator:
|
||||||
downsampling_factor: int
|
downsampling_factor: int
|
||||||
|
Loading…
Reference in New Issue
Block a user