Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion mlx_vlm/convert.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright © 2023-2024 Apple Inc.

import argparse
import glob
import os
import shutil
from pathlib import Path
from typing import Callable, Optional, Union
Expand Down Expand Up @@ -113,11 +116,13 @@ def convert(
dequantize: bool = False,
trust_remote_code: bool = True,
quant_predicate: Optional[str] = None,
only_llm: bool = False,
skip_vision: bool = False,
):
print("[INFO] Loading")
model_path = get_model_path(hf_path, revision=revision)
model, config, processor = fetch_from_hub(
model_path, lazy=True, trust_remote_code=trust_remote_code
model_path, lazy=True, trust_remote_code=trust_remote_code, only_llm=only_llm
)

def base_quant_predicate(path, module):
Expand Down Expand Up @@ -178,6 +183,13 @@ def set_dtype(k, v):

save_config(config, config_path=mlx_path / "config.json")

# Copy over any coreml files if found
coreml_files = glob.glob(str(model_path / "*.mlpackage"))

for file in coreml_files:
des_path = os.path.join(mlx_path, file.split(os.path.sep)[-1])
shutil.copytree(file, des_path)

if upload_repo is not None:
upload_to_hub(mlx_path, upload_repo, hf_path)

Expand Down Expand Up @@ -233,6 +245,18 @@ def configure_parser() -> argparse.ArgumentParser:
action="store_true",
default=False,
)
parser.add_argument(
"--skip-vision",
help="Skip vision module quantization.",
action="store_true",
default=False,
)
parser.add_argument(
"--only-llm",
help="Convert only LLM.",
action="store_true",
default=False,
)
return parser


Expand Down
3 changes: 3 additions & 0 deletions mlx_vlm/hf_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .mlpackage_cache import resolve_coreml_mlpackage

__all__ = ["resolve_coreml_mlpackage"]
96 changes: 96 additions & 0 deletions mlx_vlm/hf_tools/mlpackage_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import glob
import os
from pathlib import Path
from typing import Optional

from filelock import FileLock
from huggingface_hub import HfApi, snapshot_download


def _repo_sha_and_files(repo_id: str, revision: str | None = None):
api = HfApi()
info = api.repo_info(repo_id, revision=revision, repo_type="model")
files = api.list_repo_files(repo_id, revision=info.sha, repo_type="model")
return info.sha, files


def _find_mlpackages(files: list[str]) -> list[str]:
return sorted(
{
p.split("/")[0]
for p in files
if p.endswith(".mlpackage") or ".mlpackage/" in p
}
)


def _stage_dir(repo_id: str, sha: str, root: str | None = None) -> Path:
root = root or os.path.expanduser("~/.cache/mlx-vlm/materialized")
return Path(root) / f"{repo_id.replace('/', '__')}-{sha}"


def _fetch_repo_path(repo_or_path: str, force_download: bool = False) -> Path:
p = Path(repo_or_path)
if p.exists():
return p
sha, files = _repo_sha_and_files(repo_or_path)
mlps = _find_mlpackages(files)
if mlps:
stage = _stage_dir(repo_or_path, sha)
lock = FileLock(str(stage) + ".lock")
with lock:
if not stage.exists():
patterns = [f"{m}/**" for m in mlps]
snapshot_download(
repo_or_path,
revision=sha,
allow_patterns=patterns,
local_dir=stage,
force_download=force_download,
)
return stage
return Path(snapshot_download(repo_or_path, revision=sha))


def resolve_coreml_mlpackage(
model_path: Path, path_or_hf_repo: Optional[str], force_download: bool = False
) -> Optional[str]:
"""
Resolve the Core ML .mlpackage path to load. This is required for Core ML models since model manifests are
incompatible with the HF snapshot cache.

Logic:
- If a local .mlpackage exists in model_path (and exactly one), use fetch_repo_path(path_or_hf_repo)
to locate and return the corresponding .mlpackage path from the repo cache.
This avoids loading from HF snapshot cache paths that are invalid for Core ML.
- If no local .mlpackage is present, return None.

Returns:
Optional[str]: The resolved .mlpackage path from the repo cache, or None if none should be loaded.

Raises:
ValueError: If multiple .mlpackage files are found locally or in the repo cache, or if
a local .mlpackage is found but path_or_hf_repo is not provided.
FileNotFoundError: If a local .mlpackage is detected but no .mlpackage exists in the
resolved repo cache path.
"""
local_candidates = glob.glob(str(model_path / "*.mlpackage"))
if len(local_candidates) == 0:
return None
if len(local_candidates) > 1:
raise ValueError("Found multiple vision model packages, aborting.")

if not path_or_hf_repo:
raise ValueError(
"Found a .mlpackage locally, but path_or_hf_repo is required to resolve the correct Core ML package path."
)

repo_path = _fetch_repo_path(path_or_hf_repo, force_download)
repo_candidates = glob.glob(str(repo_path / "*.mlpackage"))
if len(repo_candidates) == 0:
raise FileNotFoundError(
f"No Core ML .mlpackage found in resolved repo path: {repo_path}"
)
if len(repo_candidates) > 1:
raise ValueError("Found multiple vision model packages, aborting.")
return repo_candidates[0]
7 changes: 7 additions & 0 deletions mlx_vlm/models/fastvlm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .fastvlm import (
LanguageModel,
Model,
ModelConfig,
TextConfig,
VisionConfig,
)
149 changes: 149 additions & 0 deletions mlx_vlm/models/fastvlm/fastvlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import glob
import inspect
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

import coremltools
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from huggingface_hub import snapshot_download

from .language import LanguageModel, TextConfig


@dataclass
class VisionConfig:
mm_hidden_size: int
mm_vision_tower: str

@classmethod
def from_dict(cls, params):
return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)


@dataclass
class ModelConfig:
text_config: TextConfig
vision_config: VisionConfig
model_type: str
eos_token_id: int
ignore_index: int = -100
image_token_index: int = 32000
vision_feature_select_strategy: str = "default"
vision_feature_layer: int = -2
vocab_size: int = 151936

@classmethod
def from_dict(cls, params):
# Copy text config parameters from root level
params["text_config"] = dict(filter(lambda x: "mm" not in x[0], params.items()))
# Copy vision config parameters from root level
params["vision_config"] = dict(filter(lambda x: "mm" in x[0], params.items()))

return cls(
**{
k: v
for k, v in params.items()
if k in inspect.signature(cls).parameters
}
)


class FastVLMMultiModalProjector(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.linear_0 = nn.Linear(
config.vision_config.mm_hidden_size,
config.text_config.hidden_size,
bias=True,
)
self.gelu = nn.GELU()
self.linear_2 = nn.Linear(
config.text_config.hidden_size, config.text_config.hidden_size, bias=True
)

def __call__(self, x: mx.array) -> mx.array:
x = self.linear_0(x)
x = self.gelu(x)
x = self.linear_2(x)
return x


class Model(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.vision_tower = None
self.language_model = LanguageModel(config.text_config)
self.multi_modal_projector = FastVLMMultiModalProjector(config)
self.vision_feature_layer = config.vision_feature_layer
self.vision_feature_select_strategy = config.vision_feature_select_strategy

def get_input_embeddings(
self,
input_ids: Optional[mx.array] = None,
pixel_values: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model.model.embed_tokens(input_ids)

# Get the input embeddings from the language model
inputs_embeds = self.language_model.model.embed_tokens(input_ids)

# Get image features from CoreML model
coreml_out_dict = self.vision_tower.predict(
{"images": np.array(pixel_values, copy=False)}
)

# Pass image features through the multi-modal projector
image_features = self.multi_modal_projector(
mx.array(coreml_out_dict["image_features"])
)

# Insert special image tokens in the input_ids
final_inputs_embeds = self._merge_input_ids_with_image_features(
image_features, inputs_embeds, input_ids
)
return final_inputs_embeds

def _merge_input_ids_with_image_features(
self, image_features, inputs_embeds, input_ids
):
image_token_index = self.config.image_token_index
num_images, num_image_patches, embed_dim = image_features.shape

# Positions of <image> tokens in input_ids, assuming batch size is 1
image_positions = np.where(input_ids[0] == image_token_index)[0].tolist()
num_images, _, vision_hidden_size = image_features.shape

reshaped_image_hidden_states = image_features.reshape(-1, vision_hidden_size)

# cast to the dtype of the input_embeds to support quantized models
reshaped_image_hidden_states = reshaped_image_hidden_states.astype(
inputs_embeds.dtype
)
inputs_embeds[:, image_positions, :] = reshaped_image_hidden_states
return inputs_embeds

def __call__(
self,
input_ids: mx.array,
pixel_values: mx.array,
mask: mx.array,
cache=None,
**kwargs,
):
input_embeddings = self.get_input_embeddings(input_ids, pixel_values)
logits = self.language_model(
input_ids, cache=cache, inputs_embeds=input_embeddings
)
return logits
Loading