refiner only has clip2 not clip

This commit is contained in:
Lincoln Stein 2023-07-16 12:36:38 -04:00
parent 0a2964d8c0
commit 6534288b75

View File

@ -21,17 +21,20 @@ class SDXLModelLoaderOutput(BaseInvocationOutput):
unet: UNetField = Field(default=None, description="UNet submodel") unet: UNetField = Field(default=None, description="UNet submodel")
clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels") clip: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels (SDXL only)") clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae: VaeField = Field(default=None, description="Vae submodel") vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on # fmt: on
class SDXLRefinerModelLoaderOutput(SDXLModelLoaderOutput): class SDXLRefinerModelLoaderOutput(BaseInvocationOutput):
"""SDXL refiner model loader output""" """SDXL refiner model loader output"""
# fmt: off # fmt: off
type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output" type: Literal["sdxl_refiner_model_loader_output"] = "sdxl_refiner_model_loader_output"
unet: UNetField = Field(default=None, description="UNet submodel")
clip2: ClipField = Field(default=None, description="Tokenizer and text_encoder submodels")
vae: VaeField = Field(default=None, description="Vae submodel")
# fmt: on
#fmt: on #fmt: on
class SDXLModelLoaderInvocation(BaseInvocation): class SDXLModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl base model, outputting its submodels.""" """Loads an sdxl base model, outputting its submodels."""
@ -50,10 +53,6 @@ class SDXLModelLoaderInvocation(BaseInvocation):
}, },
} }
@classmethod
def _output_class(cls):
return SDXLModelLoaderOutput
def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput: def invoke(self, context: InvocationContext) -> SDXLModelLoaderOutput:
base_model = self.model.base_model base_model = self.model.base_model
model_name = self.model.model_name model_name = self.model.model_name
@ -67,7 +66,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
): ):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}") raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
return self._output_class( return SDXLModelLoaderOutput(
unet=UNetField( unet=UNetField(
unet=ModelInfo( unet=ModelInfo(
model_name=model_name, model_name=model_name,
@ -125,7 +124,7 @@ class SDXLModelLoaderInvocation(BaseInvocation):
), ),
) )
class SDXLRefinerModelLoaderInvocation(SDXLModelLoaderInvocation): class SDXLRefinerModelLoaderInvocation(BaseInvocation):
"""Loads an sdxl refiner model, outputting its submodels.""" """Loads an sdxl refiner model, outputting its submodels."""
type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader" type: Literal["sdxl_refiner_model_loader"] = "sdxl_refiner_model_loader"
@ -139,10 +138,61 @@ class SDXLRefinerModelLoaderInvocation(SDXLModelLoaderInvocation):
}, },
} }
@classmethod def invoke(self, context: InvocationContext) -> SDXLRefinerModelLoaderOutput:
def _output_class(cls): base_model = self.model.base_model
return SDXLRefinerModelLoaderOutput model_name = self.model.model_name
model_type = ModelType.Main
# TODO: not found exceptions
if not context.services.model_manager.model_exists(
model_name=model_name,
base_model=base_model,
model_type=model_type,
):
raise Exception(f"Unknown {base_model} {model_type} model: {model_name}")
return SDXLRefinerModelLoaderOutput(
unet=UNetField(
unet=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.UNet,
),
scheduler=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Scheduler,
),
loras=[],
),
clip2=ClipField(
tokenizer=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Tokenizer2,
),
text_encoder=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.TextEncoder2,
),
loras=[],
skipped_layers=0,
),
vae=VaeField(
vae=ModelInfo(
model_name=model_name,
base_model=base_model,
model_type=model_type,
submodel=SubModelType.Vae,
),
),
)
# Text to image # Text to image
class SDXLTextToLatentsInvocation(BaseInvocation): class SDXLTextToLatentsInvocation(BaseInvocation):
"""Generates latents from conditionings.""" """Generates latents from conditionings."""