initialize InvokeAIGenerator object with model, not manager

This commit is contained in:
Lincoln Stein 2023-03-11 09:06:46 -05:00
parent 250b0ab182
commit d612f11c11
3 changed files with 18 additions and 14 deletions

View File

@ -57,8 +57,9 @@ class TextToImageInvocation(BaseInvocation):
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
manager = context.services.model_manager
outputs = Txt2Img(manager).generate(
# (right now uses whatever current model is set in model manager)
model= context.services.model_manager.get_model()
outputs = Txt2Img(model).generate(
prompt=self.prompt,
step_callback=step_callback,
**self.dict(
@ -113,9 +114,9 @@ class ImageToImageInvocation(TextToImageInvocation):
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
manager = context.services.model_manager
model = context.services.model_manager.get_model()
generator_output = next(
Img2Img(manager).generate(
Img2Img(model).generate(
prompt=self.prompt,
init_img=image,
init_mask=mask,
@ -174,9 +175,9 @@ class InpaintInvocation(ImageToImageInvocation):
# Handle invalid model parameter
# TODO: figure out if this can be done via a validator that uses the model_cache
# TODO: How to get the default model name now?
manager = context.services.model_manager
manager = context.services.model_manager.get_model()
generator_output = next(
Inpaint(manager).generate(
Inpaint(model).generate(
prompt=self.prompt,
init_img=image,
init_mask=mask,

View File

@ -29,14 +29,12 @@ from ..image_util import configure_model_padding
from ..util.util import rand_perlin_2d
from ..safety_checker import SafetyChecker
from ..prompting.conditioning import get_uc_and_c_and_ec
from ..model_management.model_manager import ModelManager
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
downsampling = 8
@dataclass
class InvokeAIGeneratorBasicParams:
model_name: str='stable-diffusion-1.5'
seed: int=None
width: int=512
height: int=512
@ -86,10 +84,10 @@ class InvokeAIGenerator(metaclass=ABCMeta):
)
def __init__(self,
model_manager: ModelManager,
model_info: dict,
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
):
self.model_manager=model_manager
self.model_info=model_info
self.params=params
def generate(self,
@ -123,8 +121,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
generator_args = dataclasses.asdict(self.params)
generator_args.update(keyword_args)
model_name = generator_args.get('model_name') or self.model_manager.current_model
model_info: dict = self.model_manager.get_model(model_name)
model_info = self.model_info
model_name = model_info['model_name']
model:StableDiffusionGeneratorPipeline = model_info['model']
model_hash = model_info['hash']
scheduler: Scheduler = self.get_scheduler(
@ -164,7 +162,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
seed=results[0][1],
attention_maps_images=results[0][2],
model_hash = model_hash,
params=Namespace(**generator_args),
params=Namespace(model_name=model_name,**generator_args),
)
if callback:
callback(output)

View File

@ -88,12 +88,15 @@ class ModelManager(object):
"""
return model_name in self.config
def get_model(self, model_name: str):
def get_model(self, model_name: str=None):
"""
Given a model named identified in models.yaml, return
the model object. If in RAM will load into GPU VRAM.
If on disk, will load from there.
"""
if not model_name:
return self.current_model if self.current_model else self.get_model(self.default_model())
if not self.valid_model(model_name):
print(
f'** "{model_name}" is not a known model name. Please check your models.yaml file'
@ -116,6 +119,7 @@ class ModelManager(object):
else: # we're about to load a new model, so potentially offload the least recently used one
requested_model, width, height, hash = self._load_model(model_name)
self.models[model_name] = {
"model_name": model_name,
"model": requested_model,
"width": width,
"height": height,
@ -125,6 +129,7 @@ class ModelManager(object):
self.current_model = model_name
self._push_newest_model(model_name)
return {
"model_name": model_name,
"model": requested_model,
"width": width,
"height": height,