mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refiner only has clip2 not clip
This commit is contained in:
parent
0a2964d8c0
commit
6534288b75
@ -21,17 +21,20 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
|
|||||||
|
|
||||||
unet: UNetField = Field(default=None, description="UNet submodel")
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||||
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels (SDXL only)")
|
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||||
vae: VaeField = Field(default=None, description="Vae submodel")
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class SDXLRefinerModelLoaderOutput(SDXLModelLoaderOutput):
|
class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
|
||||||
"""SDXL refiner model loader output"""
|
"""SDXL refiner model loader output"""
|
||||||
# fmt: off
|
# fmt: off
|
||||||
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
|
||||||
|
unet: UNetField = Field(default=None, description="UNet submodel")
|
||||||
|
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
|
||||||
|
vae: VaeField = Field(default=None, description="Vae submodel")
|
||||||
|
# fmt: on
|
||||||
#fmt: on
|
#fmt: on
|
||||||
|
|
||||||
|
|
||||||
class SDXLModelLoaderInvocation(BaseInvocation):
|
class SDXLModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl base model, outputting its submodels."""
|
"""Loads an sdxl base model, outputting its submodels."""
|
||||||
|
|
||||||
@ -50,10 +53,6 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _output_class(cls):
|
|
||||||
return SDXLModelLoaderOutput
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
|
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
|
||||||
base_model = self.model.base_model
|
base_model = self.model.base_model
|
||||||
model_name = self.model.model_name
|
model_name = self.model.model_name
|
||||||
@ -67,7 +66,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
):
|
):
|
||||||
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||||
|
|
||||||
return self._output_class(
|
return SDXLModelLoaderOutput(
|
||||||
unet=UNetField(
|
unet=UNetField(
|
||||||
unet=ModelInfo(
|
unet=ModelInfo(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
@ -125,7 +124,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
class SDXLRefinerModelLoaderInvocation(SDXLModelLoaderInvocation):
|
class SDXLRefinerModelLoaderInvocation(BaseInvocation):
|
||||||
"""Loads an sdxl refiner model, outputting its submodels."""
|
"""Loads an sdxl refiner model, outputting its submodels."""
|
||||||
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
|
||||||
|
|
||||||
@ -139,10 +138,61 @@ class SDXLRefinerModelLoaderInvocation(SDXLModelLoaderInvocation):
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
|
||||||
def _output_class(cls):
|
base_model = self.model.base_model
|
||||||
return SDXLRefinerModelLoaderOutput
|
model_name = self.model.model_name
|
||||||
|
model_type = ModelType.Main
|
||||||
|
|
||||||
|
# TODO: not found exceptions
|
||||||
|
if not context.services.model_manager.model_exists(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
):
|
||||||
|
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
|
||||||
|
|
||||||
|
return SDXLRefinerModelLoaderOutput(
|
||||||
|
unet=UNetField(
|
||||||
|
unet=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.UNet,
|
||||||
|
),
|
||||||
|
scheduler=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.Scheduler,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
),
|
||||||
|
clip2=ClipField(
|
||||||
|
tokenizer=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.Tokenizer2,
|
||||||
|
),
|
||||||
|
text_encoder=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.TextEncoder2,
|
||||||
|
),
|
||||||
|
loras=[],
|
||||||
|
skipped_layers=0,
|
||||||
|
),
|
||||||
|
vae=VaeField(
|
||||||
|
vae=ModelInfo(
|
||||||
|
model_name=model_name,
|
||||||
|
base_model=base_model,
|
||||||
|
model_type=model_type,
|
||||||
|
submodel=SubModelType.Vae,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# Text to image
|
# Text to image
|
||||||
class SDXLTextToLatentsInvocation(BaseInvocation):
|
class SDXLTextToLatentsInvocation(BaseInvocation):
|
||||||
"""Generates latents from conditionings."""
|
"""Generates latents from conditionings."""
|
||||||
|
Loading…
x
Reference in New Issue
Block a user