Skip to content

Commit fd74b5e

Browse files
committed
[SPARK-52470][ML][CONNECT] Support model summary offloading
### What changes were proposed in this pull request? This PR makes Spark Connect ML supporting model summary offloading. Model summary offloading is hard to support because it contains a Spark dataset which can't be easily serialized in Spark driver (NOTE: we can't java serializer to serialize the Spark dataset logical plan otherwise it is a RCE vulnerability), to address the issue, when saving Summary to disk, it only saves the necessary data fields, when loading Summary back, the client needs to send the dataset to Spark driver again, to achieve it, 2 new proto messages are introduced: 1. `CreateSummary` in `MlCommand` ``` // This is for re-creating the model summary when the model summary is lost // (model summary is lost when the model is offloaded and then loaded back) message CreateSummary { ObjectRef model_ref = 1; Relation dataset = 2; } ``` 2: `model_summary_dataset` in `MlRelation` ``` // (Optional) the dataset for restoring the model summary optional Relation model_summary_dataset = 3; ``` ### Why are the changes needed? Support model summary offloading. Without this, the model summary will be evicted from Spark driver memory after default 15min timeout, results in `model.summary` API unavailability. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #51187 from WeichenXu123/SPARK-52470. Authored-by: Weichen Xu <[email protected]> Signed-off-by: Weichen Xu <[email protected]>
1 parent 7d0b921 commit fd74b5e

File tree

31 files changed

+1194
-735
lines changed

31 files changed

+1194
-735
lines changed

common/utils/src/main/resources/error/error-conditions.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,11 @@
853853
"Please fit or load a model smaller than <modelMaxSize> bytes."
854854
]
855855
},
856+
"MODEL_SUMMARY_LOST" : {
857+
"message" : [
858+
"The model <objectName> summary is lost because the cached model is offloaded."
859+
]
860+
},
856861
"UNSUPPORTED_EXCEPTION" : {
857862
"message" : [
858863
"<message>"

mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -224,17 +224,8 @@ class FMClassifier @Since("3.0.0") (
224224
factors: Matrix,
225225
objectiveHistory: Array[Double]): FMClassificationModel = {
226226
val model = copyValues(new FMClassificationModel(uid, intercept, linear, factors))
227-
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
228-
229-
val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
230-
val summary = new FMClassificationTrainingSummaryImpl(
231-
summaryModel.transform(dataset),
232-
probabilityColName,
233-
predictionColName,
234-
$(labelCol),
235-
weightColName,
236-
objectiveHistory)
237-
model.setSummary(Some(summary))
227+
model.createSummary(dataset, objectiveHistory)
228+
model
238229
}
239230

240231
@Since("3.0.0")
@@ -343,6 +334,42 @@ class FMClassificationModel private[classification] (
343334
s"uid=${super.toString}, numClasses=$numClasses, numFeatures=$numFeatures, " +
344335
s"factorSize=${$(factorSize)}, fitLinear=${$(fitLinear)}, fitIntercept=${$(fitIntercept)}"
345336
}
337+
338+
private[spark] def createSummary(
339+
dataset: Dataset[_], objectiveHistory: Array[Double]
340+
): Unit = {
341+
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
342+
343+
val (summaryModel, probabilityColName, predictionColName) = findSummaryModel()
344+
val summary = new FMClassificationTrainingSummaryImpl(
345+
summaryModel.transform(dataset),
346+
probabilityColName,
347+
predictionColName,
348+
$(labelCol),
349+
weightColName,
350+
objectiveHistory)
351+
setSummary(Some(summary))
352+
}
353+
354+
override private[spark] def saveSummary(path: String): Unit = {
355+
ReadWriteUtils.saveObjectToLocal[Tuple1[Array[Double]]](
356+
path, Tuple1(summary.objectiveHistory),
357+
(data, dos) => {
358+
ReadWriteUtils.serializeDoubleArray(data._1, dos)
359+
}
360+
)
361+
}
362+
363+
override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = {
364+
val Tuple1(objectiveHistory: Array[Double])
365+
= ReadWriteUtils.loadObjectFromLocal[Tuple1[Array[Double]]](
366+
path,
367+
dis => {
368+
Tuple1(ReadWriteUtils.deserializeDoubleArray(dis))
369+
}
370+
)
371+
createSummary(dataset, objectiveHistory)
372+
}
346373
}
347374

348375
@Since("3.0.0")

mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -277,17 +277,8 @@ class LinearSVC @Since("2.2.0") (
277277
intercept: Double,
278278
objectiveHistory: Array[Double]): LinearSVCModel = {
279279
val model = copyValues(new LinearSVCModel(uid, coefficients, intercept))
280-
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
281-
282-
val (summaryModel, rawPredictionColName, predictionColName) = model.findSummaryModel()
283-
val summary = new LinearSVCTrainingSummaryImpl(
284-
summaryModel.transform(dataset),
285-
rawPredictionColName,
286-
predictionColName,
287-
$(labelCol),
288-
weightColName,
289-
objectiveHistory)
290-
model.setSummary(Some(summary))
280+
model.createSummary(dataset, objectiveHistory)
281+
model
291282
}
292283

293284
private def trainImpl(
@@ -445,6 +436,42 @@ class LinearSVCModel private[classification] (
445436
override def toString: String = {
446437
s"LinearSVCModel: uid=$uid, numClasses=$numClasses, numFeatures=$numFeatures"
447438
}
439+
440+
private[spark] def createSummary(
441+
dataset: Dataset[_], objectiveHistory: Array[Double]
442+
): Unit = {
443+
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
444+
445+
val (summaryModel, rawPredictionColName, predictionColName) = findSummaryModel()
446+
val summary = new LinearSVCTrainingSummaryImpl(
447+
summaryModel.transform(dataset),
448+
rawPredictionColName,
449+
predictionColName,
450+
$(labelCol),
451+
weightColName,
452+
objectiveHistory)
453+
setSummary(Some(summary))
454+
}
455+
456+
override private[spark] def saveSummary(path: String): Unit = {
457+
ReadWriteUtils.saveObjectToLocal[Tuple1[Array[Double]]](
458+
path, Tuple1(summary.objectiveHistory),
459+
(data, dos) => {
460+
ReadWriteUtils.serializeDoubleArray(data._1, dos)
461+
}
462+
)
463+
}
464+
465+
override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = {
466+
val Tuple1(objectiveHistory: Array[Double])
467+
= ReadWriteUtils.loadObjectFromLocal[Tuple1[Array[Double]]](
468+
path,
469+
dis => {
470+
Tuple1(ReadWriteUtils.deserializeDoubleArray(dis))
471+
}
472+
)
473+
createSummary(dataset, objectiveHistory)
474+
}
448475
}
449476

450477
@Since("2.2.0")

mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -718,29 +718,8 @@ class LogisticRegression @Since("1.2.0") (
718718
objectiveHistory: Array[Double]): LogisticRegressionModel = {
719719
val model = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector,
720720
numClasses, checkMultinomial(numClasses)))
721-
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
722-
723-
val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
724-
val logRegSummary = if (numClasses <= 2) {
725-
new BinaryLogisticRegressionTrainingSummaryImpl(
726-
summaryModel.transform(dataset),
727-
probabilityColName,
728-
predictionColName,
729-
$(labelCol),
730-
$(featuresCol),
731-
weightColName,
732-
objectiveHistory)
733-
} else {
734-
new LogisticRegressionTrainingSummaryImpl(
735-
summaryModel.transform(dataset),
736-
probabilityColName,
737-
predictionColName,
738-
$(labelCol),
739-
$(featuresCol),
740-
weightColName,
741-
objectiveHistory)
742-
}
743-
model.setSummary(Some(logRegSummary))
721+
model.createSummary(dataset, objectiveHistory)
722+
model
744723
}
745724

746725
private def createBounds(
@@ -1323,6 +1302,54 @@ class LogisticRegressionModel private[spark] (
13231302
override def toString: String = {
13241303
s"LogisticRegressionModel: uid=$uid, numClasses=$numClasses, numFeatures=$numFeatures"
13251304
}
1305+
1306+
private[spark] def createSummary(
1307+
dataset: Dataset[_], objectiveHistory: Array[Double]
1308+
): Unit = {
1309+
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
1310+
1311+
val (summaryModel, probabilityColName, predictionColName) = findSummaryModel()
1312+
val logRegSummary = if (numClasses <= 2) {
1313+
new BinaryLogisticRegressionTrainingSummaryImpl(
1314+
summaryModel.transform(dataset),
1315+
probabilityColName,
1316+
predictionColName,
1317+
$(labelCol),
1318+
$(featuresCol),
1319+
weightColName,
1320+
objectiveHistory)
1321+
} else {
1322+
new LogisticRegressionTrainingSummaryImpl(
1323+
summaryModel.transform(dataset),
1324+
probabilityColName,
1325+
predictionColName,
1326+
$(labelCol),
1327+
$(featuresCol),
1328+
weightColName,
1329+
objectiveHistory)
1330+
}
1331+
setSummary(Some(logRegSummary))
1332+
}
1333+
1334+
override private[spark] def saveSummary(path: String): Unit = {
1335+
ReadWriteUtils.saveObjectToLocal[Tuple1[Array[Double]]](
1336+
path, Tuple1(summary.objectiveHistory),
1337+
(data, dos) => {
1338+
ReadWriteUtils.serializeDoubleArray(data._1, dos)
1339+
}
1340+
)
1341+
}
1342+
1343+
override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = {
1344+
val Tuple1(objectiveHistory: Array[Double])
1345+
= ReadWriteUtils.loadObjectFromLocal[Tuple1[Array[Double]]](
1346+
path,
1347+
dis => {
1348+
Tuple1(ReadWriteUtils.deserializeDoubleArray(dis))
1349+
}
1350+
)
1351+
createSummary(dataset, objectiveHistory)
1352+
}
13261353
}
13271354

13281355
@Since("1.6.0")

mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -251,14 +251,8 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
251251
objectiveHistory: Array[Double]): MultilayerPerceptronClassificationModel = {
252252
val model = copyValues(new MultilayerPerceptronClassificationModel(uid, weights))
253253

254-
val (summaryModel, _, predictionColName) = model.findSummaryModel()
255-
val summary = new MultilayerPerceptronClassificationTrainingSummaryImpl(
256-
summaryModel.transform(dataset),
257-
predictionColName,
258-
$(labelCol),
259-
"",
260-
objectiveHistory)
261-
model.setSummary(Some(summary))
254+
model.createSummary(dataset, objectiveHistory)
255+
model
262256
}
263257
}
264258

@@ -365,6 +359,39 @@ class MultilayerPerceptronClassificationModel private[ml] (
365359
s"MultilayerPerceptronClassificationModel: uid=$uid, numLayers=${$(layers).length}, " +
366360
s"numClasses=$numClasses, numFeatures=$numFeatures"
367361
}
362+
363+
private[spark] def createSummary(
364+
dataset: Dataset[_], objectiveHistory: Array[Double]
365+
): Unit = {
366+
val (summaryModel, _, predictionColName) = findSummaryModel()
367+
val summary = new MultilayerPerceptronClassificationTrainingSummaryImpl(
368+
summaryModel.transform(dataset),
369+
predictionColName,
370+
$(labelCol),
371+
"",
372+
objectiveHistory)
373+
setSummary(Some(summary))
374+
}
375+
376+
override private[spark] def saveSummary(path: String): Unit = {
377+
ReadWriteUtils.saveObjectToLocal[Tuple1[Array[Double]]](
378+
path, Tuple1(summary.objectiveHistory),
379+
(data, dos) => {
380+
ReadWriteUtils.serializeDoubleArray(data._1, dos)
381+
}
382+
)
383+
}
384+
385+
override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = {
386+
val Tuple1(objectiveHistory: Array[Double])
387+
= ReadWriteUtils.loadObjectFromLocal[Tuple1[Array[Double]]](
388+
path,
389+
dis => {
390+
Tuple1(ReadWriteUtils.deserializeDoubleArray(dis))
391+
}
392+
)
393+
createSummary(dataset, objectiveHistory)
394+
}
368395
}
369396

370397
@Since("2.0.0")

mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -182,26 +182,8 @@ class RandomForestClassifier @Since("1.4.0") (
182182
numFeatures: Int,
183183
numClasses: Int): RandomForestClassificationModel = {
184184
val model = copyValues(new RandomForestClassificationModel(uid, trees, numFeatures, numClasses))
185-
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
186-
187-
val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel()
188-
val rfSummary = if (numClasses <= 2) {
189-
new BinaryRandomForestClassificationTrainingSummaryImpl(
190-
summaryModel.transform(dataset),
191-
probabilityColName,
192-
predictionColName,
193-
$(labelCol),
194-
weightColName,
195-
Array(0.0))
196-
} else {
197-
new RandomForestClassificationTrainingSummaryImpl(
198-
summaryModel.transform(dataset),
199-
predictionColName,
200-
$(labelCol),
201-
weightColName,
202-
Array(0.0))
203-
}
204-
model.setSummary(Some(rfSummary))
185+
model.createSummary(dataset)
186+
model
205187
}
206188

207189
@Since("1.4.1")
@@ -393,6 +375,35 @@ class RandomForestClassificationModel private[ml] (
393375
@Since("2.0.0")
394376
override def write: MLWriter =
395377
new RandomForestClassificationModel.RandomForestClassificationModelWriter(this)
378+
379+
private[spark] def createSummary(dataset: Dataset[_]): Unit = {
380+
val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol)
381+
382+
val (summaryModel, probabilityColName, predictionColName) = findSummaryModel()
383+
val rfSummary = if (numClasses <= 2) {
384+
new BinaryRandomForestClassificationTrainingSummaryImpl(
385+
summaryModel.transform(dataset),
386+
probabilityColName,
387+
predictionColName,
388+
$(labelCol),
389+
weightColName,
390+
Array(0.0))
391+
} else {
392+
new RandomForestClassificationTrainingSummaryImpl(
393+
summaryModel.transform(dataset),
394+
predictionColName,
395+
$(labelCol),
396+
weightColName,
397+
Array(0.0))
398+
}
399+
setSummary(Some(rfSummary))
400+
}
401+
402+
override private[spark] def saveSummary(path: String): Unit = {}
403+
404+
override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = {
405+
createSummary(dataset)
406+
}
396407
}
397408

398409
@Since("2.0.0")

mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,9 @@ class BisectingKMeansModel private[ml] (
180180
override def summary: BisectingKMeansSummary = super.summary
181181

182182
override def estimatedSize: Long = SizeEstimator.estimate(parentModel)
183+
184+
// BisectingKMeans model hasn't supported offloading, so put an empty `saveSummary` here for now
185+
override private[spark] def saveSummary(path: String): Unit = {}
183186
}
184187

185188
object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {

0 commit comments

Comments
 (0)