initialize InvokeAIGenerator object with model, not manager

This commit is contained in:
Lincoln Stein 2023-03-11 09:06:46 -05:00
parent 250b0ab182
commit d612f11c11
3 changed files with 18 additions and 14 deletions

View File

@ -57,8 +57,9 @@ class TextToImageInvocation(BaseInvocation):
# Handle invalid model parameter # Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache # TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now? # TODO: How to get the default model name now?
manager = context.services.model_manager # (right now uses whatever current model is set in model manager)
outputs = Txt2Img(manager).generate( model= context.services.model_manager.get_model()
outputs = Txt2Img(model).generate(
prompt=self.prompt, prompt=self.prompt,
step_callback=step_callback, step_callback=step_callback,
**self.dict( **self.dict(
@ -113,9 +114,9 @@ class ImageToImageInvocation(TextToImageInvocation):
# Handle invalid model parameter # Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache # TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now? # TODO: How to get the default model name now?
manager = context.services.model_manager model = context.services.model_manager.get_model()
generator_output = next( generator_output = next(
Img2Img(manager).generate( Img2Img(model).generate(
prompt=self.prompt, prompt=self.prompt,
init_img=image, init_img=image,
init_mask=mask, init_mask=mask,
@ -174,9 +175,9 @@ class InpaintInvocation(ImageToImageInvocation):
# Handle invalid model parameter # Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache # TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now? # TODO: How to get the default model name now?
manager = context.services.model_manager manager = context.services.model_manager.get_model()
generator_output = next( generator_output = next(
Inpaint(manager).generate( Inpaint(model).generate(
prompt=self.prompt, prompt=self.prompt,
init_img=image, init_img=image,
init_mask=mask, init_mask=mask,

View File

@ -29,14 +29,12 @@ from ..image_util import configure_model_padding
from ..util.util import rand_perlin_2d from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker from ..safety_checker import SafetyChecker
from ..prompting.conditioning import get_uc_and_c_and_ec from ..prompting.conditioning import get_uc_and_c_and_ec
from ..model_management.model_manager import ModelManager
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
downsampling = 8 downsampling = 8
@dataclass @dataclass
class InvokeAIGeneratorBasicParams: class InvokeAIGeneratorBasicParams:
model_name: str='stable-diffusion-1.5'
seed: int=None seed: int=None
width: int=512 width: int=512
height: int=512 height: int=512
@ -86,10 +84,10 @@ class InvokeAIGenerator(metaclass=ABCMeta):
) )
def __init__(self, def __init__(self,
model_manager: ModelManager, model_info: dict,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(), params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
): ):
self.model_manager=model_manager self.model_info=model_info
self.params=params self.params=params
def generate(self, def generate(self,
@ -123,8 +121,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
generator_args = dataclasses.asdict(self.params) generator_args = dataclasses.asdict(self.params)
generator_args.update(keyword_args) generator_args.update(keyword_args)
model_name = generator_args.get('model_name') or self.model_manager.current_model model_info = self.model_info
model_info: dict = self.model_manager.get_model(model_name) model_name = model_info['model_name']
model:StableDiffusionGeneratorPipeline = model_info['model'] model:StableDiffusionGeneratorPipeline = model_info['model']
model_hash = model_info['hash'] model_hash = model_info['hash']
scheduler: Scheduler = self.get_scheduler( scheduler: Scheduler = self.get_scheduler(
@ -164,7 +162,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
seed=results[0][1], seed=results[0][1],
attention_maps_images=results[0][2], attention_maps_images=results[0][2],
model_hash = model_hash, model_hash = model_hash,
params=Namespace(**generator_args), params=Namespace(model_name=model_name,**generator_args),
) )
if callback: if callback:
callback(output) callback(output)

View File

@ -88,12 +88,15 @@ class ModelManager(object):
""" """
return model_name in self.config return model_name in self.config
def get_model(self, model_name: str): def get_model(self, model_name: str=None):
""" """
Given a model named identified in models.yaml, return Given a model named identified in models.yaml, return
the model object. If in RAM will load into GPU VRAM. the model object. If in RAM will load into GPU VRAM.
If on disk, will load from there. If on disk, will load from there.
""" """
if not model_name:
return self.current_model if self.current_model else self.get_model(self.default_model())
if not self.valid_model(model_name): if not self.valid_model(model_name):
print( print(
f'** "{model_name}" is not a known model name. Please check your models.yaml file' f'** "{model_name}" is not a known model name. Please check your models.yaml file'
@ -116,6 +119,7 @@ class ModelManager(object):
else: # we're about to load a new model, so potentially offload the least recently used one else: # we're about to load a new model, so potentially offload the least recently used one
requested_model, width, height, hash = self._load_model(model_name) requested_model, width, height, hash = self._load_model(model_name)
self.models[model_name] = { self.models[model_name] = {
"model_name": model_name,
"model": requested_model, "model": requested_model,
"width": width, "width": width,
"height": height, "height": height,
@ -125,6 +129,7 @@ class ModelManager(object):
self.current_model = model_name self.current_model = model_name
self._push_newest_model(model_name) self._push_newest_model(model_name)
return { return {
"model_name": model_name,
"model": requested_model, "model": requested_model,
"width": width, "width": width,
"height": height, "height": height,