This commit is contained in:
Brandon Rising 2024-08-14 11:53:07 -04:00 committed by Brandon
parent 56fda669fd
commit 9ed53af520
6 changed files with 8 additions and 11 deletions

View File

@ -101,10 +101,7 @@ class FluxTextToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
# if the cache is not empty. # if the cache is not empty.
# context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30) # context.models._services.model_manager.load.ram_cache.make_room(24 * 2**30)
with ( with transformer_info as transformer, scheduler_info as scheduler:
transformer_info as transformer,
scheduler_info as scheduler
):
assert isinstance(transformer, FluxTransformer2DModel) assert isinstance(transformer, FluxTransformer2DModel)
assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler) assert isinstance(scheduler, FlowMatchEulerDiscreteScheduler)

View File

@ -60,11 +60,11 @@ class CLIPField(BaseModel):
loras: List[LoRAField] = Field(description="LoRAs to apply on model loading") loras: List[LoRAField] = Field(description="LoRAs to apply on model loading")
class TransformerField(BaseModel): class TransformerField(BaseModel):
transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel") transformer: ModelIdentifierField = Field(description="Info to load Transformer submodel")
scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel") scheduler: ModelIdentifierField = Field(description="Info to load scheduler submodel")
class T5EncoderField(BaseModel): class T5EncoderField(BaseModel):
tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel") tokenizer: ModelIdentifierField = Field(description="Info to load tokenizer submodel")
text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel") text_encoder: ModelIdentifierField = Field(description="Info to load text_encoder submodel")

View File

@ -52,9 +52,7 @@ def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
return model.calc_size() return model.calc_size()
elif isinstance( elif isinstance(
model, model,
( (T5TokenizerFast,),
T5TokenizerFast,
),
): ):
return len(model) return len(model)
else: else:

View File

@ -54,7 +54,7 @@ def filter_files(
"lora_weights.safetensors", "lora_weights.safetensors",
"weights.pb", "weights.pb",
"onnx_data", "onnx_data",
"spiece.model", # Added for `black-forest-labs/FLUX.1-schnell`. "spiece.model", # Added for `black-forest-labs/FLUX.1-schnell`.
) )
): ):
paths.append(file) paths.append(file)

View File

@ -19,7 +19,7 @@ from invokeai.backend.requantize import requantize
class FastQuantizedDiffusersModel(QuantizedDiffusersModel): class FastQuantizedDiffusersModel(QuantizedDiffusersModel):
@classmethod @classmethod
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], base_class = FluxTransformer2DModel, **kwargs): def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], base_class=FluxTransformer2DModel, **kwargs):
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation.""" """We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
base_class = base_class or cls.base_class base_class = base_class or cls.base_class
if base_class is None: if base_class is None:

View File

@ -15,7 +15,9 @@ from invokeai.backend.requantize import requantize
class FastQuantizedTransformersModel(QuantizedTransformersModel): class FastQuantizedTransformersModel(QuantizedTransformersModel):
@classmethod @classmethod
def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], auto_class = AutoModelForTextEncoding, **kwargs): def from_pretrained(
cls, model_name_or_path: Union[str, os.PathLike], auto_class=AutoModelForTextEncoding, **kwargs
):
"""We override the `from_pretrained()` method in order to use our custom `requantize()` implementation.""" """We override the `from_pretrained()` method in order to use our custom `requantize()` implementation."""
auto_class = auto_class or cls.auto_class auto_class = auto_class or cls.auto_class
if auto_class is None: if auto_class is None: