Skip to content

Commit 6a0ed19

Browse files
authored
feat: Add EmbeddingGemma300M (#184)
* feat: adding EmbeddingGemma300M model * add output_key field to ModelInfo and update models to include it. * chore: Misc. updates Signed-off-by: Anush008 <[email protected]> --------- Signed-off-by: Anush008 <[email protected]>
1 parent 99f93e9 commit 6a0ed19

File tree

8 files changed

+84
-10
lines changed

8 files changed

+84
-10
lines changed

src/models/image_embedding.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pub fn models_list() -> Vec<ModelInfo<ImageEmbeddingModel>> {
2626
model_code: String::from("Qdrant/clip-ViT-B-32-vision"),
2727
model_file: String::from("model.onnx"),
2828
additional_files: Vec::new(),
29+
output_key: None,
2930
},
3031
ModelInfo {
3132
model: ImageEmbeddingModel::Resnet50,
@@ -34,6 +35,7 @@ pub fn models_list() -> Vec<ModelInfo<ImageEmbeddingModel>> {
3435
model_code: String::from("Qdrant/resnet50-onnx"),
3536
model_file: String::from("model.onnx"),
3637
additional_files: Vec::new(),
38+
output_key: None,
3739
},
3840
ModelInfo {
3941
model: ImageEmbeddingModel::UnicomVitB16,
@@ -42,6 +44,7 @@ pub fn models_list() -> Vec<ModelInfo<ImageEmbeddingModel>> {
4244
model_code: String::from("Qdrant/Unicom-ViT-B-16"),
4345
model_file: String::from("model.onnx"),
4446
additional_files: Vec::new(),
47+
output_key: None,
4548
},
4649
ModelInfo {
4750
model: ImageEmbeddingModel::UnicomVitB32,
@@ -50,6 +53,7 @@ pub fn models_list() -> Vec<ModelInfo<ImageEmbeddingModel>> {
5053
model_code: String::from("Qdrant/Unicom-ViT-B-32"),
5154
model_file: String::from("model.onnx"),
5255
additional_files: Vec::new(),
56+
output_key: None,
5357
},
5458
ModelInfo {
5559
model: ImageEmbeddingModel::NomicEmbedVisionV15,
@@ -58,6 +62,7 @@ pub fn models_list() -> Vec<ModelInfo<ImageEmbeddingModel>> {
5862
model_code: String::from("nomic-ai/nomic-embed-vision-v1.5"),
5963
model_file: String::from("onnx/model.onnx"),
6064
additional_files: Vec::new(),
65+
output_key: None,
6166
},
6267
];
6368

src/models/model_info.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
1-
use crate::RerankerModel;
1+
use crate::{OutputKey, RerankerModel};
22

33
/// Data struct about the available models
44
#[derive(Debug, Clone)]
5+
#[non_exhaustive]
56
pub struct ModelInfo<T> {
67
pub model: T,
78
pub dim: usize,
89
pub description: String,
910
pub model_code: String,
1011
pub model_file: String,
1112
pub additional_files: Vec<String>,
13+
pub output_key: Option<OutputKey>,
1214
}
1315

1416
/// Data struct about the available reranker models
1517
#[derive(Debug, Clone)]
18+
#[non_exhaustive]
1619
pub struct RerankerModelInfo {
1720
pub model: RerankerModel,
1821
pub description: String,

src/models/sparse.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ pub fn models_list() -> Vec<ModelInfo<SparseModel>> {
1717
model_code: String::from("Qdrant/Splade_PP_en_v1"),
1818
model_file: String::from("model.onnx"),
1919
additional_files: Vec::new(),
20+
output_key: None,
2021
}]
2122
}
2223

src/models/text_embedding.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ pub enum EmbeddingModel {
6868
ClipVitB32,
6969
/// jinaai/jina-embeddings-v2-base-code
7070
JinaEmbeddingsV2BaseCode,
71+
/// onnx-community/embeddinggemma-300m-ONNX
72+
EmbeddingGemma300M,
7173
}
7274

7375
/// Centralized function to initialize the models map.
@@ -80,6 +82,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
8082
model_code: String::from("Qdrant/all-MiniLM-L6-v2-onnx"),
8183
model_file: String::from("model.onnx"),
8284
additional_files: Vec::new(),
85+
output_key: None,
8386
},
8487
ModelInfo {
8588
model: EmbeddingModel::AllMiniLML6V2Q,
@@ -88,6 +91,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
8891
model_code: String::from("Xenova/all-MiniLM-L6-v2"),
8992
model_file: String::from("onnx/model_quantized.onnx"),
9093
additional_files: Vec::new(),
94+
output_key: None,
9195
},
9296
ModelInfo {
9397
model: EmbeddingModel::AllMiniLML12V2,
@@ -96,6 +100,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
96100
model_code: String::from("Xenova/all-MiniLM-L12-v2"),
97101
model_file: String::from("onnx/model.onnx"),
98102
additional_files: Vec::new(),
103+
output_key: None,
99104
},
100105
ModelInfo {
101106
model: EmbeddingModel::AllMiniLML12V2Q,
@@ -104,6 +109,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
104109
model_code: String::from("Xenova/all-MiniLM-L12-v2"),
105110
model_file: String::from("onnx/model_quantized.onnx"),
106111
additional_files: Vec::new(),
112+
output_key: None,
107113
},
108114
ModelInfo {
109115
model: EmbeddingModel::BGEBaseENV15,
@@ -112,6 +118,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
112118
model_code: String::from("Xenova/bge-base-en-v1.5"),
113119
model_file: String::from("onnx/model.onnx"),
114120
additional_files: Vec::new(),
121+
output_key: None,
115122
},
116123
ModelInfo {
117124
model: EmbeddingModel::BGEBaseENV15Q,
@@ -120,6 +127,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
120127
model_code: String::from("Qdrant/bge-base-en-v1.5-onnx-Q"),
121128
model_file: String::from("model_optimized.onnx"),
122129
additional_files: Vec::new(),
130+
output_key: None,
123131
},
124132
ModelInfo {
125133
model: EmbeddingModel::BGELargeENV15,
@@ -128,6 +136,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
128136
model_code: String::from("Xenova/bge-large-en-v1.5"),
129137
model_file: String::from("onnx/model.onnx"),
130138
additional_files: Vec::new(),
139+
output_key: None,
131140
},
132141
ModelInfo {
133142
model: EmbeddingModel::BGELargeENV15Q,
@@ -136,6 +145,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
136145
model_code: String::from("Qdrant/bge-large-en-v1.5-onnx-Q"),
137146
model_file: String::from("model_optimized.onnx"),
138147
additional_files: Vec::new(),
148+
output_key: None,
139149
},
140150
ModelInfo {
141151
model: EmbeddingModel::BGESmallENV15,
@@ -144,6 +154,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
144154
model_code: String::from("Xenova/bge-small-en-v1.5"),
145155
model_file: String::from("onnx/model.onnx"),
146156
additional_files: Vec::new(),
157+
output_key: None,
147158
},
148159
ModelInfo {
149160
model: EmbeddingModel::BGESmallENV15Q,
@@ -154,6 +165,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
154165
model_code: String::from("Qdrant/bge-small-en-v1.5-onnx-Q"),
155166
model_file: String::from("model_optimized.onnx"),
156167
additional_files: Vec::new(),
168+
output_key: None,
157169
},
158170
ModelInfo {
159171
model: EmbeddingModel::NomicEmbedTextV1,
@@ -162,6 +174,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
162174
model_code: String::from("nomic-ai/nomic-embed-text-v1"),
163175
model_file: String::from("onnx/model.onnx"),
164176
additional_files: Vec::new(),
177+
output_key: None,
165178
},
166179
ModelInfo {
167180
model: EmbeddingModel::NomicEmbedTextV15,
@@ -170,6 +183,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
170183
model_code: String::from("nomic-ai/nomic-embed-text-v1.5"),
171184
model_file: String::from("onnx/model.onnx"),
172185
additional_files: Vec::new(),
186+
output_key: None,
173187
},
174188
ModelInfo {
175189
model: EmbeddingModel::NomicEmbedTextV15Q,
@@ -180,6 +194,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
180194
model_code: String::from("nomic-ai/nomic-embed-text-v1.5"),
181195
model_file: String::from("onnx/model_quantized.onnx"),
182196
additional_files: Vec::new(),
197+
output_key: None,
183198
},
184199
ModelInfo {
185200
model: EmbeddingModel::ParaphraseMLMiniLML12V2Q,
@@ -188,6 +203,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
188203
model_code: String::from("Qdrant/paraphrase-multilingual-MiniLM-L12-v2-onnx-Q"),
189204
model_file: String::from("model_optimized.onnx"),
190205
additional_files: Vec::new(),
206+
output_key: None,
191207
},
192208
ModelInfo {
193209
model: EmbeddingModel::ParaphraseMLMiniLML12V2,
@@ -196,6 +212,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
196212
model_code: String::from("Xenova/paraphrase-multilingual-MiniLM-L12-v2"),
197213
model_file: String::from("onnx/model.onnx"),
198214
additional_files: Vec::new(),
215+
output_key: None,
199216
},
200217
ModelInfo {
201218
model: EmbeddingModel::ParaphraseMLMpnetBaseV2,
@@ -206,6 +223,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
206223
model_code: String::from("Xenova/paraphrase-multilingual-mpnet-base-v2"),
207224
model_file: String::from("onnx/model.onnx"),
208225
additional_files: Vec::new(),
226+
output_key: None,
209227
},
210228
ModelInfo {
211229
model: EmbeddingModel::BGESmallZHV15,
@@ -214,6 +232,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
214232
model_code: String::from("Xenova/bge-small-zh-v1.5"),
215233
model_file: String::from("onnx/model.onnx"),
216234
additional_files: Vec::new(),
235+
output_key: None,
217236
},
218237
ModelInfo {
219238
model: EmbeddingModel::BGELargeZHV15,
@@ -222,6 +241,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
222241
model_code: String::from("Xenova/bge-large-zh-v1.5"),
223242
model_file: String::from("onnx/model.onnx"),
224243
additional_files: Vec::new(),
244+
output_key: None,
225245
},
226246
ModelInfo {
227247
model: EmbeddingModel::ModernBertEmbedLarge,
@@ -230,6 +250,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
230250
model_code: String::from("lightonai/modernbert-embed-large"),
231251
model_file: String::from("onnx/model.onnx"),
232252
additional_files: Vec::new(),
253+
output_key: None,
233254
},
234255
ModelInfo {
235256
model: EmbeddingModel::MultilingualE5Small,
@@ -238,6 +259,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
238259
model_code: String::from("intfloat/multilingual-e5-small"),
239260
model_file: String::from("onnx/model.onnx"),
240261
additional_files: Vec::new(),
262+
output_key: None,
241263
},
242264
ModelInfo {
243265
model: EmbeddingModel::MultilingualE5Base,
@@ -246,6 +268,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
246268
model_code: String::from("intfloat/multilingual-e5-base"),
247269
model_file: String::from("onnx/model.onnx"),
248270
additional_files: Vec::new(),
271+
output_key: None,
249272
},
250273
ModelInfo {
251274
model: EmbeddingModel::MultilingualE5Large,
@@ -254,6 +277,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
254277
model_code: String::from("Qdrant/multilingual-e5-large-onnx"),
255278
model_file: String::from("model.onnx"),
256279
additional_files: vec!["model.onnx_data".to_string()],
280+
output_key: None,
257281
},
258282
ModelInfo {
259283
model: EmbeddingModel::MxbaiEmbedLargeV1,
@@ -262,6 +286,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
262286
model_code: String::from("mixedbread-ai/mxbai-embed-large-v1"),
263287
model_file: String::from("onnx/model.onnx"),
264288
additional_files: Vec::new(),
289+
output_key: None,
265290
},
266291
ModelInfo {
267292
model: EmbeddingModel::MxbaiEmbedLargeV1Q,
@@ -270,6 +295,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
270295
model_code: String::from("mixedbread-ai/mxbai-embed-large-v1"),
271296
model_file: String::from("onnx/model_quantized.onnx"),
272297
additional_files: Vec::new(),
298+
output_key: None,
273299
},
274300
ModelInfo {
275301
model: EmbeddingModel::GTEBaseENV15,
@@ -278,6 +304,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
278304
model_code: String::from("Alibaba-NLP/gte-base-en-v1.5"),
279305
model_file: String::from("onnx/model.onnx"),
280306
additional_files: Vec::new(),
307+
output_key: None,
281308
},
282309
ModelInfo {
283310
model: EmbeddingModel::GTEBaseENV15Q,
@@ -286,6 +313,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
286313
model_code: String::from("Alibaba-NLP/gte-base-en-v1.5"),
287314
model_file: String::from("onnx/model_quantized.onnx"),
288315
additional_files: Vec::new(),
316+
output_key: None,
289317
},
290318
ModelInfo {
291319
model: EmbeddingModel::GTELargeENV15,
@@ -294,6 +322,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
294322
model_code: String::from("Alibaba-NLP/gte-large-en-v1.5"),
295323
model_file: String::from("onnx/model.onnx"),
296324
additional_files: Vec::new(),
325+
output_key: None,
297326
},
298327
ModelInfo {
299328
model: EmbeddingModel::GTELargeENV15Q,
@@ -302,6 +331,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
302331
model_code: String::from("Alibaba-NLP/gte-large-en-v1.5"),
303332
model_file: String::from("onnx/model_quantized.onnx"),
304333
additional_files: Vec::new(),
334+
output_key: None,
305335
},
306336
ModelInfo {
307337
model: EmbeddingModel::ClipVitB32,
@@ -310,6 +340,7 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
310340
model_code: String::from("Qdrant/clip-ViT-B-32-text"),
311341
model_file: String::from("model.onnx"),
312342
additional_files: Vec::new(),
343+
output_key: None,
313344
},
314345
ModelInfo {
315346
model: EmbeddingModel::JinaEmbeddingsV2BaseCode,
@@ -318,6 +349,16 @@ fn init_models_map() -> HashMap<EmbeddingModel, ModelInfo<EmbeddingModel>> {
318349
model_code: String::from("jinaai/jina-embeddings-v2-base-code"),
319350
model_file: String::from("onnx/model.onnx"),
320351
additional_files: Vec::new(),
352+
output_key: None,
353+
},
354+
ModelInfo {
355+
model: EmbeddingModel::EmbeddingGemma300M,
356+
dim: 768,
357+
description: String::from("EmbeddingGemma is a 300M parameter from Google"),
358+
model_code: String::from("onnx-community/embeddinggemma-300m-ONNX"),
359+
model_file: String::from("onnx/model.onnx"),
360+
additional_files: vec!["onnx/model.onnx_data".to_string()],
361+
output_key: Some(crate::OutputKey::ByName("sentence_embedding")),
321362
},
322363
];
323364

src/output/output_precedence.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
//! e.g. reading the output keys from the model file.
88
99
/// Enum for defining the key of the output.
10-
#[derive(Debug, Clone)]
10+
#[derive(Debug, Clone, PartialEq, Eq)]
1111
pub enum OutputKey {
1212
OnlyOne,
1313
ByOrder(usize),
@@ -41,3 +41,9 @@ impl OutputPrecedence for &[OutputKey] {
4141
self.iter()
4242
}
4343
}
44+
45+
impl OutputPrecedence for &OutputKey {
46+
fn key_precedence(&self) -> impl Iterator<Item = &OutputKey> {
47+
std::iter::once(*self)
48+
}
49+
}

0 commit comments

Comments
 (0)