mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
initialize InvokeAIGenerator object with model, not manager
This commit is contained in:
parent
250b0ab182
commit
d612f11c11
@ -57,8 +57,9 @@ class TextToImageInvocation(BaseInvocation):
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
manager = context.services.model_manager
|
||||
outputs = Txt2Img(manager).generate(
|
||||
# (right now uses whatever current model is set in model manager)
|
||||
model= context.services.model_manager.get_model()
|
||||
outputs = Txt2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
step_callback=step_callback,
|
||||
**self.dict(
|
||||
@ -113,9 +114,9 @@ class ImageToImageInvocation(TextToImageInvocation):
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
manager = context.services.model_manager
|
||||
model = context.services.model_manager.get_model()
|
||||
generator_output = next(
|
||||
Img2Img(manager).generate(
|
||||
Img2Img(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_mask=mask,
|
||||
@ -174,9 +175,9 @@ class InpaintInvocation(ImageToImageInvocation):
|
||||
# Handle invalid model parameter
|
||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||
# TODO: How to get the default model name now?
|
||||
manager = context.services.model_manager
|
||||
manager = context.services.model_manager.get_model()
|
||||
generator_output = next(
|
||||
Inpaint(manager).generate(
|
||||
Inpaint(model).generate(
|
||||
prompt=self.prompt,
|
||||
init_img=image,
|
||||
init_mask=mask,
|
||||
|
@ -29,14 +29,12 @@ from ..image_util import configure_model_padding
|
||||
from ..util.util import rand_perlin_2d
|
||||
from ..safety_checker import SafetyChecker
|
||||
from ..prompting.conditioning import get_uc_and_c_and_ec
|
||||
from ..model_management.model_manager import ModelManager
|
||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
|
||||
downsampling = 8
|
||||
|
||||
@dataclass
|
||||
class InvokeAIGeneratorBasicParams:
|
||||
model_name: str='stable-diffusion-1.5'
|
||||
seed: int=None
|
||||
width: int=512
|
||||
height: int=512
|
||||
@ -86,10 +84,10 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
model_manager: ModelManager,
|
||||
model_info: dict,
|
||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||
):
|
||||
self.model_manager=model_manager
|
||||
self.model_info=model_info
|
||||
self.params=params
|
||||
|
||||
def generate(self,
|
||||
@ -123,8 +121,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
generator_args = dataclasses.asdict(self.params)
|
||||
generator_args.update(keyword_args)
|
||||
|
||||
model_name = generator_args.get('model_name') or self.model_manager.current_model
|
||||
model_info: dict = self.model_manager.get_model(model_name)
|
||||
model_info = self.model_info
|
||||
model_name = model_info['model_name']
|
||||
model:StableDiffusionGeneratorPipeline = model_info['model']
|
||||
model_hash = model_info['hash']
|
||||
scheduler: Scheduler = self.get_scheduler(
|
||||
@ -164,7 +162,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
||||
seed=results[0][1],
|
||||
attention_maps_images=results[0][2],
|
||||
model_hash = model_hash,
|
||||
params=Namespace(**generator_args),
|
||||
params=Namespace(model_name=model_name,**generator_args),
|
||||
)
|
||||
if callback:
|
||||
callback(output)
|
||||
|
@ -88,12 +88,15 @@ class ModelManager(object):
|
||||
"""
|
||||
return model_name in self.config
|
||||
|
||||
def get_model(self, model_name: str):
|
||||
def get_model(self, model_name: str=None):
|
||||
"""
|
||||
Given a model named identified in models.yaml, return
|
||||
the model object. If in RAM will load into GPU VRAM.
|
||||
If on disk, will load from there.
|
||||
"""
|
||||
if not model_name:
|
||||
return self.current_model if self.current_model else self.get_model(self.default_model())
|
||||
|
||||
if not self.valid_model(model_name):
|
||||
print(
|
||||
f'** "{model_name}" is not a known model name. Please check your models.yaml file'
|
||||
@ -116,6 +119,7 @@ class ModelManager(object):
|
||||
else: # we're about to load a new model, so potentially offload the least recently used one
|
||||
requested_model, width, height, hash = self._load_model(model_name)
|
||||
self.models[model_name] = {
|
||||
"model_name": model_name,
|
||||
"model": requested_model,
|
||||
"width": width,
|
||||
"height": height,
|
||||
@ -125,6 +129,7 @@ class ModelManager(object):
|
||||
self.current_model = model_name
|
||||
self._push_newest_model(model_name)
|
||||
return {
|
||||
"model_name": model_name,
|
||||
"model": requested_model,
|
||||
"width": width,
|
||||
"height": height,
|
||||
|
Loading…
Reference in New Issue
Block a user