final tidying before marking PR as ready for review

- Replace AnyModelLoader with ModelLoaderRegistry
- Fix type check errors in multiple files
- Remove apparently unneeded `get_model_config_enum()` method from model manager
- Remove last vestiges of old model manager
- Updated tests and documentation

resolve conflict with seamless.py
This commit is contained in:
psychedelicious
2024-02-18 17:27:42 +11:00
parent 2ad0752582
commit be8b99eed5
74 changed files with 672 additions and 10362 deletions

View File

@ -35,7 +35,7 @@ class Struct_mallinfo2(ctypes.Structure):
("keepcost", ctypes.c_size_t),
]
def __str__(self):
def __str__(self) -> str:
s = ""
s += f"{'arena': <10}= {(self.arena/2**30):15.5f} # Non-mmapped space allocated (GB) (uordblks + fordblks)\n"
s += f"{'ordblks': <10}= {(self.ordblks): >15} # Number of free chunks\n"
@ -62,7 +62,7 @@ class LibcUtil:
TODO: Improve cross-OS compatibility of this class.
"""
def __init__(self):
def __init__(self) -> None:
self._libc = ctypes.cdll.LoadLibrary("libc.so.6")
def mallinfo2(self) -> Struct_mallinfo2:
@ -72,4 +72,5 @@ class LibcUtil:
"""
mallinfo2 = self._libc.mallinfo2
mallinfo2.restype = Struct_mallinfo2
return mallinfo2()
result: Struct_mallinfo2 = mallinfo2()
return result

View File

@ -1,12 +1,15 @@
"""Utilities for parsing model files, used mostly by probe.py"""
import json
import torch
from typing import Union
from pathlib import Path
from typing import Dict, Optional, Union
import safetensors
import torch
from picklescan.scanner import scan_file_path
def _fast_safetensors_reader(path: str):
def _fast_safetensors_reader(path: str) -> Dict[str, torch.Tensor]:
checkpoint = {}
device = torch.device("meta")
with open(path, "rb") as f:
@ -37,10 +40,12 @@ def _fast_safetensors_reader(path: str):
return checkpoint
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
def read_checkpoint_meta(path: Union[str, Path], scan: bool = False) -> Dict[str, torch.Tensor]:
if str(path).endswith(".safetensors"):
try:
checkpoint = _fast_safetensors_reader(path)
path_str = path.as_posix() if isinstance(path, Path) else path
checkpoint = _fast_safetensors_reader(path_str)
except Exception:
# TODO: create issue for support "meta"?
checkpoint = safetensors.torch.load_file(path, device="cpu")
@ -52,14 +57,15 @@ def read_checkpoint_meta(path: Union[str, Path], scan: bool = False):
checkpoint = torch.load(path, map_location=torch.device("meta"))
return checkpoint
def lora_token_vector_length(checkpoint: dict) -> int:
def lora_token_vector_length(checkpoint: Dict[str, torch.Tensor]) -> Optional[int]:
"""
Given a checkpoint in memory, return the lora token vector length
:param checkpoint: The checkpoint
"""
def _get_shape_1(key: str, tensor, checkpoint) -> int:
def _get_shape_1(key: str, tensor: torch.Tensor, checkpoint: Dict[str, torch.Tensor]) -> Optional[int]:
lora_token_vector_length = None
if "." not in key: