Skip to content

Commit 1222b02

Browse files
committed
[jvm-packages] Support spark connect
1 parent 5fbab40 commit 1222b02

File tree

13 files changed

+1088
-4
lines changed

13 files changed

+1088
-4
lines changed

jvm-packages/pom.xml

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@
4747
<maven.compiler.target>1.8</maven.compiler.target>
4848
<flink.version>1.20.0</flink.version>
4949
<junit.version>4.13.2</junit.version>
50-
<spark.version>3.5.3</spark.version>
51-
<spark.version.gpu>3.5.1</spark.version.gpu>
50+
<spark.version>4.0.0-SNAPSHOT</spark.version>
51+
<spark.version.gpu>4.0.0-SNAPSHOT</spark.version.gpu>
5252
<fasterxml.jackson.version>2.15.0</fasterxml.jackson.version>
5353
<scala.version>2.12.18</scala.version>
5454
<scala.binary.version>2.12</scala.binary.version>
@@ -89,6 +89,17 @@
8989
<name>central maven</name>
9090
<url>https://repo1.maven.org/maven2</url>
9191
</repository>
92+
<repository>
93+
<id>apache-snapshots</id>
94+
<url>https://repository.apache.org/content/repositories/snapshots/</url>
95+
<snapshots>
96+
<enabled>true</enabled>
97+
</snapshots>
98+
<releases>
99+
<enabled>false</enabled>
100+
</releases>
101+
</repository>
102+
92103
</repositories>
93104
<modules>
94105
</modules>
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
[project]
16+
name = "xgboost4j"
17+
version = "3.1.0"
18+
authors = [
19+
{ name = "Bobby Wang", email = "[email protected]" },
20+
]
21+
description = "XGBoost4j-Spark pyspark"
22+
readme = "README.md"
23+
requires-python = ">=3.10"
24+
classifiers = [
25+
"Programming Language :: Python :: 3",
26+
"Programming Language :: Python :: 3.10",
27+
"Programming Language :: Python :: 3.11",
28+
"Programming Language :: Python :: 3.12",
29+
"License :: OSI Approved :: Apache Software License",
30+
"Operating System :: OS Independent",
31+
"Environment :: GPU :: NVIDIA CUDA :: 11",
32+
"Environment :: GPU :: NVIDIA CUDA :: 11.4",
33+
"Environment :: GPU :: NVIDIA CUDA :: 11.5",
34+
"Environment :: GPU :: NVIDIA CUDA :: 11.6",
35+
"Environment :: GPU :: NVIDIA CUDA :: 11.7",
36+
"Environment :: GPU :: NVIDIA CUDA :: 11.8",
37+
"Environment :: GPU :: NVIDIA CUDA :: 12",
38+
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.0",
39+
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.1",
40+
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.2",
41+
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.3",
42+
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.4",
43+
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.5",
44+
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.6",
45+
"Environment :: GPU :: NVIDIA CUDA :: 12 :: 12.8",
46+
]
47+
48+
[build-system]
49+
requires = ["setuptools>=61.0"]
50+
build-backend = "setuptools.build_meta"

jvm-packages/xgboost4j-spark/python/src/ml/__init__.py

Whitespace-only changes.

jvm-packages/xgboost4j-spark/python/src/ml/dmlc/__init__.py

Whitespace-only changes.

jvm-packages/xgboost4j-spark/python/src/ml/dmlc/xgboost4j/__init__.py

Whitespace-only changes.

jvm-packages/xgboost4j-spark/python/src/ml/dmlc/xgboost4j/scala/__init__.py

Whitespace-only changes.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import sys
2+
3+
import xgboost4j
4+
5+
sys.modules["ml.dmlc.xgboost4j.scala.spark"] = xgboost4j
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .estimator import XGBoostClassificationModel, XGBoostClassifier
2+
3+
__version__ = "3.0.0"
4+
5+
__all__ = ["XGBoostClassifier", "XGBoostClassificationModel"]
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from typing import Union, List, Any, Optional, Dict
2+
3+
from pyspark import keyword_only
4+
from pyspark.ml.classification import _JavaProbabilisticClassifier, _JavaProbabilisticClassificationModel
5+
6+
from .params import XGBoostParams
7+
8+
9+
class XGBoostClassifier(_JavaProbabilisticClassifier["XGBoostClassificationModel"], XGBoostParams):
10+
_input_kwargs: Dict[str, Any]
11+
12+
@keyword_only
13+
def __init__(
14+
self,
15+
*,
16+
featuresCol: Union[str, List[str]] = "features",
17+
labelCol: str = "label",
18+
predictionCol: str = "prediction",
19+
probabilityCol: str = "probability",
20+
rawPredictionCol: str = "rawPrediction",
21+
# SparkParams
22+
numWorkers: Optional[int] = None,
23+
numRound: Optional[int] = None,
24+
forceRepartition: Optional[bool] = None,
25+
numEarlyStoppingRounds: Optional[int] = None,
26+
inferBatchSize: Optional[int] = None,
27+
missing: Optional[float] = None,
28+
useExternalMemory: Optional[bool] = None,
29+
maxNumDevicePages: Optional[int] = None,
30+
maxQuantileBatches: Optional[int] = None,
31+
minCachePageBytes: Optional[int] = None,
32+
feature_names: Optional[List[str]] = None,
33+
feature_types: Optional[List[str]] = None,
34+
# RabitParams
35+
rabitTrackerTimeout: Optional[int] = None,
36+
rabitTrackerHostIp: Optional[str] = None,
37+
rabitTrackerPort: Optional[int] = None,
38+
# GeneralParams
39+
booster: Optional[str] = None,
40+
device: Optional[str] = None,
41+
verbosity: Optional[int] = None,
42+
validate_parameters: Optional[bool] = None,
43+
nthread: Optional[int] = None,
44+
# TreeBoosterParams
45+
eta: Optional[float] = None,
46+
gamma: Optional[float] = None,
47+
max_depth: Optional[int] = None,
48+
min_child_weight: Optional[float] = None,
49+
max_delta_step: Optional[float] = None,
50+
subsample: Optional[float] = None,
51+
sampling_method: Optional[str] = None,
52+
colsample_bytree: Optional[float] = None,
53+
colsample_bylevel: Optional[float] = None,
54+
colsample_bynode: Optional[float] = None,
55+
reg_lambda: Optional[float] = None,
56+
alpha: Optional[float] = None,
57+
tree_method: Optional[str] = None,
58+
scale_pos_weight: Optional[float] = None,
59+
updater: Optional[str] = None,
60+
refresh_leaf: Optional[bool] = None,
61+
process_type: Optional[str] = None,
62+
grow_policy: Optional[str] = None,
63+
max_leaves: Optional[int] = None,
64+
max_bin: Optional[int] = None,
65+
num_parallel_tree: Optional[int] = None,
66+
monotone_constraints: Optional[List[int]] = None,
67+
interaction_constraints: Optional[str] = None,
68+
max_cached_hist_node: Optional[int] = None,
69+
# LearningTaskParams
70+
objective: Optional[str] = None,
71+
num_class: Optional[int] = None,
72+
base_score: Optional[float] = None,
73+
eval_metric: Optional[str] = None,
74+
seed: Optional[int] = None,
75+
seed_per_iteration: Optional[bool] = None,
76+
tweedie_variance_power: Optional[float] = None,
77+
huber_slope: Optional[float] = None,
78+
aft_loss_distribution: Optional[str] = None,
79+
lambdarank_pair_method: Optional[str] = None,
80+
lambdarank_num_pair_per_sample: Optional[int] = None,
81+
lambdarank_unbiased: Optional[bool] = None,
82+
lambdarank_bias_norm: Optional[float] = None,
83+
ndcg_exp_gain: Optional[bool] = None,
84+
# DartBoosterParams
85+
sample_type: Optional[str] = None,
86+
normalize_type: Optional[str] = None,
87+
rate_drop: Optional[float] = None,
88+
one_drop: Optional[bool] = None,
89+
skip_drop: Optional[float] = None,
90+
**kwargs: Any,
91+
):
92+
super().__init__()
93+
self._java_obj = self._new_java_obj(
94+
"ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier", self.uid
95+
)
96+
self._set_params(**self._input_kwargs)
97+
98+
def _create_model(self, java_model: "JavaObject") -> "XGBoostClassificationModel":
99+
return XGBoostClassificationModel(java_model)
100+
101+
102+
class XGBoostClassificationModel(_JavaProbabilisticClassificationModel, XGBoostParams):
103+
pass

0 commit comments

Comments
 (0)