mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
initialize InvokeAIGenerator object with model, not manager
This commit is contained in:
parent
250b0ab182
commit
d612f11c11
@ -57,8 +57,9 @@ class TextToImageInvocation(BaseInvocation):
|
|||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
# TODO: How to get the default model name now?
|
# TODO: How to get the default model name now?
|
||||||
manager = context.services.model_manager
|
# (right now uses whatever current model is set in model manager)
|
||||||
outputs = Txt2Img(manager).generate(
|
model= context.services.model_manager.get_model()
|
||||||
|
outputs = Txt2Img(model).generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
step_callback=step_callback,
|
step_callback=step_callback,
|
||||||
**self.dict(
|
**self.dict(
|
||||||
@ -113,9 +114,9 @@ class ImageToImageInvocation(TextToImageInvocation):
|
|||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
# TODO: How to get the default model name now?
|
# TODO: How to get the default model name now?
|
||||||
manager = context.services.model_manager
|
model = context.services.model_manager.get_model()
|
||||||
generator_output = next(
|
generator_output = next(
|
||||||
Img2Img(manager).generate(
|
Img2Img(model).generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
init_img=image,
|
init_img=image,
|
||||||
init_mask=mask,
|
init_mask=mask,
|
||||||
@ -174,9 +175,9 @@ class InpaintInvocation(ImageToImageInvocation):
|
|||||||
# Handle invalid model parameter
|
# Handle invalid model parameter
|
||||||
# TODO: figure out if this can be done via a validator that uses the model_cache
|
# TODO: figure out if this can be done via a validator that uses the model_cache
|
||||||
# TODO: How to get the default model name now?
|
# TODO: How to get the default model name now?
|
||||||
manager = context.services.model_manager
|
manager = context.services.model_manager.get_model()
|
||||||
generator_output = next(
|
generator_output = next(
|
||||||
Inpaint(manager).generate(
|
Inpaint(model).generate(
|
||||||
prompt=self.prompt,
|
prompt=self.prompt,
|
||||||
init_img=image,
|
init_img=image,
|
||||||
init_mask=mask,
|
init_mask=mask,
|
||||||
|
@ -29,14 +29,12 @@ from ..image_util import configure_model_padding
|
|||||||
from ..util.util import rand_perlin_2d
|
from ..util.util import rand_perlin_2d
|
||||||
from ..safety_checker import SafetyChecker
|
from ..safety_checker import SafetyChecker
|
||||||
from ..prompting.conditioning import get_uc_and_c_and_ec
|
from ..prompting.conditioning import get_uc_and_c_and_ec
|
||||||
from ..model_management.model_manager import ModelManager
|
|
||||||
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
from ..stable_diffusion.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
|
|
||||||
downsampling = 8
|
downsampling = 8
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InvokeAIGeneratorBasicParams:
|
class InvokeAIGeneratorBasicParams:
|
||||||
model_name: str='stable-diffusion-1.5'
|
|
||||||
seed: int=None
|
seed: int=None
|
||||||
width: int=512
|
width: int=512
|
||||||
height: int=512
|
height: int=512
|
||||||
@ -86,10 +84,10 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_manager: ModelManager,
|
model_info: dict,
|
||||||
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
params: InvokeAIGeneratorBasicParams=InvokeAIGeneratorBasicParams(),
|
||||||
):
|
):
|
||||||
self.model_manager=model_manager
|
self.model_info=model_info
|
||||||
self.params=params
|
self.params=params
|
||||||
|
|
||||||
def generate(self,
|
def generate(self,
|
||||||
@ -123,8 +121,8 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
generator_args = dataclasses.asdict(self.params)
|
generator_args = dataclasses.asdict(self.params)
|
||||||
generator_args.update(keyword_args)
|
generator_args.update(keyword_args)
|
||||||
|
|
||||||
model_name = generator_args.get('model_name') or self.model_manager.current_model
|
model_info = self.model_info
|
||||||
model_info: dict = self.model_manager.get_model(model_name)
|
model_name = model_info['model_name']
|
||||||
model:StableDiffusionGeneratorPipeline = model_info['model']
|
model:StableDiffusionGeneratorPipeline = model_info['model']
|
||||||
model_hash = model_info['hash']
|
model_hash = model_info['hash']
|
||||||
scheduler: Scheduler = self.get_scheduler(
|
scheduler: Scheduler = self.get_scheduler(
|
||||||
@ -164,7 +162,7 @@ class InvokeAIGenerator(metaclass=ABCMeta):
|
|||||||
seed=results[0][1],
|
seed=results[0][1],
|
||||||
attention_maps_images=results[0][2],
|
attention_maps_images=results[0][2],
|
||||||
model_hash = model_hash,
|
model_hash = model_hash,
|
||||||
params=Namespace(**generator_args),
|
params=Namespace(model_name=model_name,**generator_args),
|
||||||
)
|
)
|
||||||
if callback:
|
if callback:
|
||||||
callback(output)
|
callback(output)
|
||||||
|
@ -88,12 +88,15 @@ class ModelManager(object):
|
|||||||
"""
|
"""
|
||||||
return model_name in self.config
|
return model_name in self.config
|
||||||
|
|
||||||
def get_model(self, model_name: str):
|
def get_model(self, model_name: str=None):
|
||||||
"""
|
"""
|
||||||
Given a model named identified in models.yaml, return
|
Given a model named identified in models.yaml, return
|
||||||
the model object. If in RAM will load into GPU VRAM.
|
the model object. If in RAM will load into GPU VRAM.
|
||||||
If on disk, will load from there.
|
If on disk, will load from there.
|
||||||
"""
|
"""
|
||||||
|
if not model_name:
|
||||||
|
return self.current_model if self.current_model else self.get_model(self.default_model())
|
||||||
|
|
||||||
if not self.valid_model(model_name):
|
if not self.valid_model(model_name):
|
||||||
print(
|
print(
|
||||||
f'** "{model_name}" is not a known model name. Please check your models.yaml file'
|
f'** "{model_name}" is not a known model name. Please check your models.yaml file'
|
||||||
@ -116,6 +119,7 @@ class ModelManager(object):
|
|||||||
else: # we're about to load a new model, so potentially offload the least recently used one
|
else: # we're about to load a new model, so potentially offload the least recently used one
|
||||||
requested_model, width, height, hash = self._load_model(model_name)
|
requested_model, width, height, hash = self._load_model(model_name)
|
||||||
self.models[model_name] = {
|
self.models[model_name] = {
|
||||||
|
"model_name": model_name,
|
||||||
"model": requested_model,
|
"model": requested_model,
|
||||||
"width": width,
|
"width": width,
|
||||||
"height": height,
|
"height": height,
|
||||||
@ -125,6 +129,7 @@ class ModelManager(object):
|
|||||||
self.current_model = model_name
|
self.current_model = model_name
|
||||||
self._push_newest_model(model_name)
|
self._push_newest_model(model_name)
|
||||||
return {
|
return {
|
||||||
|
"model_name": model_name,
|
||||||
"model": requested_model,
|
"model": requested_model,
|
||||||
"width": width,
|
"width": width,
|
||||||
"height": height,
|
"height": height,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user