mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
tidy(nodes): rename sd_model_loader
to pipeline_model_loader
this is more accurate bc it can do eg kandinsky also
This commit is contained in:
parent
3722cdf5d6
commit
1bc170727b
@ -50,10 +50,10 @@ class PipelineModelField(BaseModel):
|
|||||||
base_model: BaseModelType = Field(description="Base model")
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
|
||||||
class SDModelLoaderInvocation(BaseInvocation):
|
class PipelineModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loading submodels of selected model."""
|
"""Loads a pipeline model, outputting its submodels."""
|
||||||
|
|
||||||
type: Literal["sd_model_loader"] = "sd_model_loader"
|
type: Literal["pipeline_model_loader"] = "pipeline_model_loader"
|
||||||
|
|
||||||
model: PipelineModelField = Field(description="The model to load")
|
model: PipelineModelField = Field(description="The model to load")
|
||||||
# TODO: precision?
|
# TODO: precision?
|
||||||
@ -154,211 +154,6 @@ class SDModelLoaderInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
class SD1ModelLoaderInvocation(BaseInvocation):
|
|
||||||
"""Loading submodels of selected model."""
|
|
||||||
|
|
||||||
type: Literal["sd1_model_loader"] = "sd1_model_loader"
|
|
||||||
|
|
||||||
model_name: str = Field(default="", description="Model to load")
|
|
||||||
# TODO: precision?
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["model", "loader"],
|
|
||||||
"type_hints": {
|
|
||||||
"model_name": "model" # TODO: rename to model_name?
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
|
||||||
|
|
||||||
base_model = BaseModelType.StableDiffusion1 # TODO:
|
|
||||||
|
|
||||||
# TODO: not found exceptions
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
):
|
|
||||||
raise Exception(f"Unkown model name: {self.model_name}!")
|
|
||||||
|
|
||||||
"""
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
model_type=SDModelType.Diffusers,
|
|
||||||
submodel=SDModelType.Tokenizer,
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
model_type=SDModelType.Diffusers,
|
|
||||||
submodel=SDModelType.TextEncoder,
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
model_type=SDModelType.Diffusers,
|
|
||||||
submodel=SDModelType.UNet,
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
return ModelLoaderOutput(
|
|
||||||
unet=UNetField(
|
|
||||||
unet=ModelInfo(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
submodel=SubModelType.UNet,
|
|
||||||
),
|
|
||||||
scheduler=ModelInfo(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
submodel=SubModelType.Scheduler,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
),
|
|
||||||
clip=ClipField(
|
|
||||||
tokenizer=ModelInfo(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
submodel=SubModelType.Tokenizer,
|
|
||||||
),
|
|
||||||
text_encoder=ModelInfo(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
submodel=SubModelType.TextEncoder,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
),
|
|
||||||
vae=VaeField(
|
|
||||||
vae=ModelInfo(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
submodel=SubModelType.Vae,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO: optimize(less code copy)
|
|
||||||
class SD2ModelLoaderInvocation(BaseInvocation):
|
|
||||||
"""Loading submodels of selected model."""
|
|
||||||
|
|
||||||
type: Literal["sd2_model_loader"] = "sd2_model_loader"
|
|
||||||
|
|
||||||
model_name: str = Field(default="", description="Model to load")
|
|
||||||
# TODO: precision?
|
|
||||||
|
|
||||||
# Schema customisation
|
|
||||||
class Config(InvocationConfig):
|
|
||||||
schema_extra = {
|
|
||||||
"ui": {
|
|
||||||
"tags": ["model", "loader"],
|
|
||||||
"type_hints": {
|
|
||||||
"model_name": "model" # TODO: rename to model_name?
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ModelLoaderOutput:
|
|
||||||
|
|
||||||
base_model = BaseModelType.StableDiffusion2 # TODO:
|
|
||||||
|
|
||||||
# TODO: not found exceptions
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
):
|
|
||||||
raise Exception(f"Unkown model name: {self.model_name}!")
|
|
||||||
|
|
||||||
"""
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
model_type=SDModelType.Diffusers,
|
|
||||||
submodel=SDModelType.Tokenizer,
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to find tokenizer submodel in {self.model_name}! Check if model corrupted"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
model_type=SDModelType.Diffusers,
|
|
||||||
submodel=SDModelType.TextEncoder,
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to find text_encoder submodel in {self.model_name}! Check if model corrupted"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not context.services.model_manager.model_exists(
|
|
||||||
model_name=self.model_name,
|
|
||||||
model_type=SDModelType.Diffusers,
|
|
||||||
submodel=SDModelType.UNet,
|
|
||||||
):
|
|
||||||
raise Exception(
|
|
||||||
f"Failed to find unet submodel from {self.model_name}! Check if model corrupted"
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
return ModelLoaderOutput(
|
|
||||||
unet=UNetField(
|
|
||||||
unet=ModelInfo(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
submodel=SubModelType.UNet,
|
|
||||||
),
|
|
||||||
scheduler=ModelInfo(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
submodel=SubModelType.Scheduler,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
),
|
|
||||||
clip=ClipField(
|
|
||||||
tokenizer=ModelInfo(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
submodel=SubModelType.Tokenizer,
|
|
||||||
),
|
|
||||||
text_encoder=ModelInfo(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
submodel=SubModelType.TextEncoder,
|
|
||||||
),
|
|
||||||
loras=[],
|
|
||||||
),
|
|
||||||
vae=VaeField(
|
|
||||||
vae=ModelInfo(
|
|
||||||
model_name=self.model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
model_type=ModelType.Pipeline,
|
|
||||||
submodel=SubModelType.Vae,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
class LoraLoaderOutput(BaseInvocationOutput):
|
class LoraLoaderOutput(BaseInvocationOutput):
|
||||||
"""Model loader output"""
|
"""Model loader output"""
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user