In the previous commit, the LLaVA model was updated to support partial loading.
In this commit, the SigLIP model is updated in the same way.
This model is used for FLUX Redux. It's <4GB and only ever run in isolation, so it won't benefit from partial loading for the vast majority of users. Regardless, I think it is best if we make _all_ models work with partial loading.
PS: I also fixed the initial load dtype issue, described in the prev commit. It's probably a non-issue for this model, but we may as well fix it.
The model manager has two types of model cache entries:
- `CachedModelOnlyFullLoad`: The model may only ever be loaded and unloaded as a single object.
- `CachedModelWithPartialLoad`: The model may be partially loaded and unloaded.
Partial loaded is enabled by overwriting certain torch layer classes, adding the ability to autocast the layer to a device on-the-fly. See `CustomLinear` for an example.
So, to take advantage of partial loading and be cached as a `CachedModelWithPartialLoad`, the model must inherit from `torch.nn.Module`.
The LLaVA classes provided by `transformers` do inherit from `torch.nn.Module`, but we wrap those classes in a separate class called `LlavaOnevisionModel`. The wrapper encapsulate both the LLaVA model and its "processor" - a lightweight class that prepares model inputs like text and images.
While it is more elegant to encapsulate both model and processor classes in a single entity, this prevents the model cache from enabling partial loading for the chunky vLLM model.
Fixing this involved a few changes.
- Update the `LlavaOnevisionModelLoader` class to operate on the vLLM model directly, instead the `LlavaOnevisionModel` wrapper class.
- Instantiate the processor directly in the node. The processor is lightweight and does its business on the CPU. We don't need to worry about caching in the model manager.
- Remove caching support code from the `LlavaOnevisionModel` wrapper class. It's not needed, because we do not cache this class. The class now only handles running the models provided to it.
- Rename `LlavaOnevisionModel` to `LlavaOnevisionPipeline` to better represent its purpose.
These changes have a bonus effect of fixing an OOM crash when initially loading the models. This was most apparent when loading LLaVA 7B, which is pretty chunky.
The initial load is onto CPU RAM. In the old version of the loaders, we ignored the loader's target dtype for the initial load. Instead, we loaded the model at `transformers`'s "default" dtype of fp32.
LLaVA 7B is fp16 and weighs ~17GB. Loading as fp32 means we need double that amount (~34GB) of CPU RAM. Many users only have 32GB RAM, so this causes a _CPU_ OOM - which is a hard crash of the whole process.
With the updated loaders, the initial load logic now uses the target dtype for the initial load. LLaVA now needs the expected ~17GB RAM for its initial load.
PS: If we didn't make the accompanying partial loading changes, we still could have solved this OOM. We'd just need to pass the initial load dtype to the wrapper class and have it load on that dtype. But we may as well fix both issues.
PPS: There are other models whose model classes are wrappers around a torch module class, and thus cannot be partially loaded. However, these models are typically fairly small and/or are run only on their own, so they don't benefit as much from partial loading. It's the really big models (like LLaVA 7B) that benefit most from the partial loading.
Before FLUX Fill was merged, we didn't do any checks for the model variant. We always returned "normal".
To determine if a model is a FLUX Fill model, we need to check the state dict for a specific key. Initially, this logic was too strict and rejected quantized FLUX models. This issue was resolved, but it turns out there is another failure mode - some fine-tunes use a different key.
This change further reduces the strictness, handling the alternate key and also falling back to "normal" if we don't see either key. This effectively restores the previous probing behaviour for all FLUX models.
Closes#7856Closes#7859
In #7780 we added FLUX Fill support, and needed the probe to be able to distinguish between "normal" FLUX models and FLUX Fill models.
Logic was added to the probe to check a particular state dict key (input channels), which should be 384 for FLUX Fill and 64 for other FLUX models.
The new logic was stricter and instead of falling back on the "normal" variant, it raised when an unexpected value for input channels was detected.
This caused failures to probe for BNB-NF4 quantized FLUX Dev/Schnell, which apparently only have 1 input channel.
After checking a variety of FLUX models, I loosened the strictness of the variant probing logic to only special-case the new FLUX Fill model, and otherwise fall back to returning the "normal" variant. This better matches the old behaviour and fixes the import errors.
Closes#7822