diff --git a/README.md b/README.md index 0f0bf039550d7..42f3a1280df91 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ # Apache Spark + Spark is a unified analytics engine for large-scale data processing. It provides high-level APIs in Scala, Java, Python, and R (Deprecated), and an optimized engine that supports general computation graphs for data analysis. It also supports a diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index a280887da845e..5def48196cf30 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -853,11 +853,6 @@ "Please fit or load a model smaller than <modelMaxSize> bytes." ] }, - "MODEL_SUMMARY_LOST" : { - "message" : [ - "The model <objectName> summary is lost because the cached model is offloaded." - ] - }, "UNSUPPORTED_EXCEPTION" : { "message" : [ "<message>" diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala index b653383161e74..cefa13b2bbe71 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/FMClassifier.scala @@ -224,8 +224,17 @@ class FMClassifier @Since("3.0.0") ( factors: Matrix, objectiveHistory: Array[Double]): FMClassificationModel = { val model = copyValues(new FMClassificationModel(uid, intercept, linear, factors)) - model.createSummary(dataset, objectiveHistory) - model + val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) + + val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() + val summary = new FMClassificationTrainingSummaryImpl( + summaryModel.transform(dataset), + probabilityColName, + predictionColName, + $(labelCol), + weightColName, + objectiveHistory) + model.setSummary(Some(summary)) } @Since("3.0.0") @@ -334,42 +343,6 @@ class FMClassificationModel private[classification] ( s"uid=${super.toString}, numClasses=$numClasses, numFeatures=$numFeatures, " + s"factorSize=${$(factorSize)}, fitLinear=${$(fitLinear)}, fitIntercept=${$(fitIntercept)}" } - - private[spark] def createSummary( - dataset: Dataset[_], objectiveHistory: Array[Double] - ): Unit = { - val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) - - val (summaryModel, probabilityColName, predictionColName) = findSummaryModel() - val summary = new FMClassificationTrainingSummaryImpl( - summaryModel.transform(dataset), - probabilityColName, - predictionColName, - $(labelCol), - weightColName, - objectiveHistory) - setSummary(Some(summary)) - } - - override private[spark] def saveSummary(path: String): Unit = { - ReadWriteUtils.saveObjectToLocal[Tuple1[Array[Double]]]( - path, Tuple1(summary.objectiveHistory), - (data, dos) => { - ReadWriteUtils.serializeDoubleArray(data._1, dos) - } - ) - } - - override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = { - val Tuple1(objectiveHistory: Array[Double]) - = ReadWriteUtils.loadObjectFromLocal[Tuple1[Array[Double]]]( - path, - dis => { - Tuple1(ReadWriteUtils.deserializeDoubleArray(dis)) - } - ) - createSummary(dataset, objectiveHistory) - } } @Since("3.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index 0d163b761686d..a50346ae88f4c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -277,8 +277,17 @@ class LinearSVC @Since("2.2.0") ( intercept: Double, objectiveHistory: Array[Double]): LinearSVCModel = { val model = copyValues(new LinearSVCModel(uid, coefficients, intercept)) - model.createSummary(dataset, objectiveHistory) - model + val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) + + val (summaryModel, rawPredictionColName, predictionColName) = model.findSummaryModel() + val summary = new LinearSVCTrainingSummaryImpl( + summaryModel.transform(dataset), + rawPredictionColName, + predictionColName, + $(labelCol), + weightColName, + objectiveHistory) + model.setSummary(Some(summary)) } private def trainImpl( @@ -436,42 +445,6 @@ class LinearSVCModel private[classification] ( override def toString: String = { s"LinearSVCModel: uid=$uid, numClasses=$numClasses, numFeatures=$numFeatures" } - - private[spark] def createSummary( - dataset: Dataset[_], objectiveHistory: Array[Double] - ): Unit = { - val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) - - val (summaryModel, rawPredictionColName, predictionColName) = findSummaryModel() - val summary = new LinearSVCTrainingSummaryImpl( - summaryModel.transform(dataset), - rawPredictionColName, - predictionColName, - $(labelCol), - weightColName, - objectiveHistory) - setSummary(Some(summary)) - } - - override private[spark] def saveSummary(path: String): Unit = { - ReadWriteUtils.saveObjectToLocal[Tuple1[Array[Double]]]( - path, Tuple1(summary.objectiveHistory), - (data, dos) => { - ReadWriteUtils.serializeDoubleArray(data._1, dos) - } - ) - } - - override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = { - val Tuple1(objectiveHistory: Array[Double]) - = ReadWriteUtils.loadObjectFromLocal[Tuple1[Array[Double]]]( - path, - dis => { - Tuple1(ReadWriteUtils.deserializeDoubleArray(dis)) - } - ) - createSummary(dataset, objectiveHistory) - } } @Since("2.2.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8c010f67f5e0a..58a2652d0eab9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -718,8 +718,29 @@ class LogisticRegression @Since("1.2.0") ( objectiveHistory: Array[Double]): LogisticRegressionModel = { val model = copyValues(new LogisticRegressionModel(uid, coefficientMatrix, interceptVector, numClasses, checkMultinomial(numClasses))) - model.createSummary(dataset, objectiveHistory) - model + val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) + + val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() + val logRegSummary = if (numClasses <= 2) { + new BinaryLogisticRegressionTrainingSummaryImpl( + summaryModel.transform(dataset), + probabilityColName, + predictionColName, + $(labelCol), + $(featuresCol), + weightColName, + objectiveHistory) + } else { + new LogisticRegressionTrainingSummaryImpl( + summaryModel.transform(dataset), + probabilityColName, + predictionColName, + $(labelCol), + $(featuresCol), + weightColName, + objectiveHistory) + } + model.setSummary(Some(logRegSummary)) } private def createBounds( @@ -1302,54 +1323,6 @@ class LogisticRegressionModel private[spark] ( override def toString: String = { s"LogisticRegressionModel: uid=$uid, numClasses=$numClasses, numFeatures=$numFeatures" } - - private[spark] def createSummary( - dataset: Dataset[_], objectiveHistory: Array[Double] - ): Unit = { - val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) - - val (summaryModel, probabilityColName, predictionColName) = findSummaryModel() - val logRegSummary = if (numClasses <= 2) { - new BinaryLogisticRegressionTrainingSummaryImpl( - summaryModel.transform(dataset), - probabilityColName, - predictionColName, - $(labelCol), - $(featuresCol), - weightColName, - objectiveHistory) - } else { - new LogisticRegressionTrainingSummaryImpl( - summaryModel.transform(dataset), - probabilityColName, - predictionColName, - $(labelCol), - $(featuresCol), - weightColName, - objectiveHistory) - } - setSummary(Some(logRegSummary)) - } - - override private[spark] def saveSummary(path: String): Unit = { - ReadWriteUtils.saveObjectToLocal[Tuple1[Array[Double]]]( - path, Tuple1(summary.objectiveHistory), - (data, dos) => { - ReadWriteUtils.serializeDoubleArray(data._1, dos) - } - ) - } - - override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = { - val Tuple1(objectiveHistory: Array[Double]) - = ReadWriteUtils.loadObjectFromLocal[Tuple1[Array[Double]]]( - path, - dis => { - Tuple1(ReadWriteUtils.deserializeDoubleArray(dis)) - } - ) - createSummary(dataset, objectiveHistory) - } } @Since("1.6.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index 5e52d62fb83cb..6bd46cff815d7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -251,8 +251,14 @@ class MultilayerPerceptronClassifier @Since("1.5.0") ( objectiveHistory: Array[Double]): MultilayerPerceptronClassificationModel = { val model = copyValues(new MultilayerPerceptronClassificationModel(uid, weights)) - model.createSummary(dataset, objectiveHistory) - model + val (summaryModel, _, predictionColName) = model.findSummaryModel() + val summary = new MultilayerPerceptronClassificationTrainingSummaryImpl( + summaryModel.transform(dataset), + predictionColName, + $(labelCol), + "", + objectiveHistory) + model.setSummary(Some(summary)) } } @@ -359,39 +365,6 @@ class MultilayerPerceptronClassificationModel private[ml] ( s"MultilayerPerceptronClassificationModel: uid=$uid, numLayers=${$(layers).length}, " + s"numClasses=$numClasses, numFeatures=$numFeatures" } - - private[spark] def createSummary( - dataset: Dataset[_], objectiveHistory: Array[Double] - ): Unit = { - val (summaryModel, _, predictionColName) = findSummaryModel() - val summary = new MultilayerPerceptronClassificationTrainingSummaryImpl( - summaryModel.transform(dataset), - predictionColName, - $(labelCol), - "", - objectiveHistory) - setSummary(Some(summary)) - } - - override private[spark] def saveSummary(path: String): Unit = { - ReadWriteUtils.saveObjectToLocal[Tuple1[Array[Double]]]( - path, Tuple1(summary.objectiveHistory), - (data, dos) => { - ReadWriteUtils.serializeDoubleArray(data._1, dos) - } - ) - } - - override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = { - val Tuple1(objectiveHistory: Array[Double]) - = ReadWriteUtils.loadObjectFromLocal[Tuple1[Array[Double]]]( - path, - dis => { - Tuple1(ReadWriteUtils.deserializeDoubleArray(dis)) - } - ) - createSummary(dataset, objectiveHistory) - } } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 8b580b1e075c5..f64e2a6d4efc3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -182,8 +182,26 @@ class RandomForestClassifier @Since("1.4.0") ( numFeatures: Int, numClasses: Int): RandomForestClassificationModel = { val model = copyValues(new RandomForestClassificationModel(uid, trees, numFeatures, numClasses)) - model.createSummary(dataset) - model + val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) + + val (summaryModel, probabilityColName, predictionColName) = model.findSummaryModel() + val rfSummary = if (numClasses <= 2) { + new BinaryRandomForestClassificationTrainingSummaryImpl( + summaryModel.transform(dataset), + probabilityColName, + predictionColName, + $(labelCol), + weightColName, + Array(0.0)) + } else { + new RandomForestClassificationTrainingSummaryImpl( + summaryModel.transform(dataset), + predictionColName, + $(labelCol), + weightColName, + Array(0.0)) + } + model.setSummary(Some(rfSummary)) } @Since("1.4.1") @@ -375,35 +393,6 @@ class RandomForestClassificationModel private[ml] ( @Since("2.0.0") override def write: MLWriter = new RandomForestClassificationModel.RandomForestClassificationModelWriter(this) - - private[spark] def createSummary(dataset: Dataset[_]): Unit = { - val weightColName = if (!isDefined(weightCol)) "weightCol" else $(weightCol) - - val (summaryModel, probabilityColName, predictionColName) = findSummaryModel() - val rfSummary = if (numClasses <= 2) { - new BinaryRandomForestClassificationTrainingSummaryImpl( - summaryModel.transform(dataset), - probabilityColName, - predictionColName, - $(labelCol), - weightColName, - Array(0.0)) - } else { - new RandomForestClassificationTrainingSummaryImpl( - summaryModel.transform(dataset), - predictionColName, - $(labelCol), - weightColName, - Array(0.0)) - } - setSummary(Some(rfSummary)) - } - - override private[spark] def saveSummary(path: String): Unit = {} - - override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = { - createSummary(dataset) - } } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 9e09ee00c3e30..3248b4b391d0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -180,9 +180,6 @@ class BisectingKMeansModel private[ml] ( override def summary: BisectingKMeansSummary = super.summary override def estimatedSize: Long = SizeEstimator.estimate(parentModel) - - // BisectingKMeans model hasn't supported offloading, so put an empty `saveSummary` here for now - override private[spark] def saveSummary(path: String): Unit = {} } object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] { diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index e7f930065486b..a94b8a87d8fc7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -223,36 +223,6 @@ class GaussianMixtureModel private[ml] ( override def summary: GaussianMixtureSummary = super.summary override def estimatedSize: Long = SizeEstimator.estimate((weights, gaussians)) - - private[spark] def createSummary( - predictions: DataFrame, logLikelihood: Double, iteration: Int - ): Unit = { - val summary = new GaussianMixtureSummary(predictions, - $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iteration) - setSummary(Some(summary)) - } - - override private[spark] def saveSummary(path: String): Unit = { - ReadWriteUtils.saveObjectToLocal[(Double, Int)]( - path, (summary.logLikelihood, summary.numIter), - (data, dos) => { - dos.writeDouble(data._1) - dos.writeInt(data._2) - } - ) - } - - override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = { - val (logLikelihood: Double, numIter: Int) = ReadWriteUtils.loadObjectFromLocal[(Double, Int)]( - path, - dis => { - val logLikelihood = dis.readDouble() - val numIter = dis.readInt() - (logLikelihood, numIter) - } - ) - createSummary(dataset, logLikelihood, numIter) - } } @Since("2.0.0") @@ -483,10 +453,11 @@ class GaussianMixture @Since("2.0.0") ( val model = copyValues(new GaussianMixtureModel(uid, weights, gaussianDists)) .setParent(this) - model.createSummary(model.transform(dataset), logLikelihood, iteration) + val summary = new GaussianMixtureSummary(model.transform(dataset), + $(predictionCol), $(probabilityCol), $(featuresCol), $(k), logLikelihood, iteration) instr.logNamedValue("logLikelihood", logLikelihood) - instr.logNamedValue("clusterSizes", model.summary.clusterSizes) - model + instr.logNamedValue("clusterSizes", summary.clusterSizes) + model.setSummary(Some(summary)) } private def trainImpl( diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index ccae39cedd20f..f3ac58e670e5a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -215,42 +215,6 @@ class KMeansModel private[ml] ( override def summary: KMeansSummary = super.summary override def estimatedSize: Long = SizeEstimator.estimate(parentModel.clusterCenters) - - private[spark] def createSummary( - predictions: DataFrame, numIter: Int, trainingCost: Double - ): Unit = { - val summary = new KMeansSummary( - predictions, - $(predictionCol), - $(featuresCol), - $(k), - numIter, - trainingCost) - - setSummary(Some(summary)) - } - - override private[spark] def saveSummary(path: String): Unit = { - ReadWriteUtils.saveObjectToLocal[(Int, Double)]( - path, (summary.numIter, summary.trainingCost), - (data, dos) => { - dos.writeInt(data._1) - dos.writeDouble(data._2) - } - ) - } - - override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = { - val (numIter: Int, trainingCost: Double) = ReadWriteUtils.loadObjectFromLocal[(Int, Double)]( - path, - dis => { - val numIter = dis.readInt() - val trainingCost = dis.readDouble() - (numIter, trainingCost) - } - ) - createSummary(dataset, numIter, trainingCost) - } } /** Helper class for storing model data */ @@ -450,9 +414,16 @@ class KMeans @Since("1.5.0") ( } val model = copyValues(new KMeansModel(uid, oldModel).setParent(this)) + val summary = new KMeansSummary( + model.transform(dataset), + $(predictionCol), + $(featuresCol), + $(k), + oldModel.numIter, + oldModel.trainingCost) - model.createSummary(model.transform(dataset), oldModel.numIter, oldModel.trainingCost) - instr.logNamedValue("clusterSizes", model.summary.clusterSizes) + model.setSummary(Some(summary)) + instr.logNamedValue("clusterSizes", summary.clusterSizes) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index cf62c2bf41b6d..14467c761b216 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -419,8 +419,9 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val model = copyValues( new GeneralizedLinearRegressionModel(uid, wlsModel.coefficients, wlsModel.intercept) .setParent(this)) - model.createSummary(dataset, wlsModel.diagInvAtWA.toArray, 1) - model + val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, + wlsModel.diagInvAtWA.toArray, 1, getSolver) + model.setSummary(Some(trainingSummary)) } else { val instances = validated.rdd.map { case Row(label: Double, weight: Double, offset: Double, features: Vector) => @@ -435,8 +436,9 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val val model = copyValues( new GeneralizedLinearRegressionModel(uid, irlsModel.coefficients, irlsModel.intercept) .setParent(this)) - model.createSummary(dataset, irlsModel.diagInvAtWA.toArray, irlsModel.numIterations) - model + val trainingSummary = new GeneralizedLinearRegressionTrainingSummary(dataset, model, + irlsModel.diagInvAtWA.toArray, irlsModel.numIterations, getSolver) + model.setSummary(Some(trainingSummary)) } model @@ -1138,39 +1140,6 @@ class GeneralizedLinearRegressionModel private[ml] ( s"GeneralizedLinearRegressionModel: uid=$uid, family=${$(family)}, link=${$(link)}, " + s"numFeatures=$numFeatures" } - - private[spark] def createSummary( - dataset: Dataset[_], diagInvAtWA: Array[Double], numIter: Int - ): Unit = { - val summary = new GeneralizedLinearRegressionTrainingSummary( - dataset, this, diagInvAtWA, numIter, $(solver) - ) - - setSummary(Some(summary)) - } - - override private[spark] def saveSummary(path: String): Unit = { - ReadWriteUtils.saveObjectToLocal[(Array[Double], Int)]( - path, (summary.diagInvAtWA, summary.numIterations), - (data, dos) => { - ReadWriteUtils.serializeDoubleArray(data._1, dos) - dos.writeInt(data._2) - } - ) - } - - override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = { - val (diagInvAtWA: Array[Double], numIterations: Int) = - ReadWriteUtils.loadObjectFromLocal[(Array[Double], Int)]( - path, - dis => { - val diagInvAtWA = ReadWriteUtils.deserializeDoubleArray(dis) - val numIterations = dis.readInt() - (diagInvAtWA, numIterations) - } - ) - createSummary(dataset, diagInvAtWA, numIterations) - } } @Since("2.0.0") @@ -1498,7 +1467,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( class GeneralizedLinearRegressionTrainingSummary private[regression] ( dataset: Dataset[_], origModel: GeneralizedLinearRegressionModel, - private[spark] val diagInvAtWA: Array[Double], + private val diagInvAtWA: Array[Double], @Since("2.0.0") val numIterations: Int, @Since("2.0.0") val solver: String) extends GeneralizedLinearRegressionSummary(dataset, origModel) with Serializable { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 822df270c0bf7..b06140e48338c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -433,8 +433,15 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String } val model = createModel(parameters, yMean, yStd, featuresMean, featuresStd) - model.createSummary(dataset, Array(0.0), objectiveHistory, Array.emptyDoubleArray) - model + + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() + val trainingSummary = new LinearRegressionTrainingSummary( + summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol), + summaryModel.get(summaryModel.weightCol).getOrElse(""), + summaryModel.numFeatures, summaryModel.getFitIntercept, + Array(0.0), objectiveHistory) + model.setSummary(Some(trainingSummary)) } private def trainWithNormal( @@ -452,16 +459,20 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String // attach returned model. val lrModel = copyValues(new LinearRegressionModel( uid, model.coefficients.compressed, model.intercept)) + val (summaryModel, predictionColName) = lrModel.findSummaryModelAndPredictionCol() - val coefficientArray = if (lrModel.getFitIntercept) { - lrModel.coefficients.toArray ++ Array(lrModel.intercept) + val coefficientArray = if (summaryModel.getFitIntercept) { + summaryModel.coefficients.toArray ++ Array(summaryModel.intercept) } else { - lrModel.coefficients.toArray + summaryModel.coefficients.toArray } - lrModel.createSummary( - dataset, model.diagInvAtWA.toArray, model.objectiveHistory, coefficientArray - ) - lrModel + val trainingSummary = new LinearRegressionTrainingSummary( + summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol), + summaryModel.get(summaryModel.weightCol).getOrElse(""), + summaryModel.numFeatures, summaryModel.getFitIntercept, + model.diagInvAtWA.toArray, model.objectiveHistory, coefficientArray) + + lrModel.setSummary(Some(trainingSummary)) } private def trainWithConstantLabel( @@ -486,9 +497,16 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val intercept = yMean val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept)) + // Handle possible missing or invalid prediction columns + val (summaryModel, predictionColName) = model.findSummaryModelAndPredictionCol() - model.createSummary(dataset, Array(0.0), Array(0.0), Array.emptyDoubleArray) - model + val trainingSummary = new LinearRegressionTrainingSummary( + summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol), + summaryModel.get(summaryModel.weightCol).getOrElse(""), + summaryModel.numFeatures, summaryModel.getFitIntercept, + Array(0.0), Array(0.0)) + + model.setSummary(Some(trainingSummary)) } private def createOptimizer( @@ -782,53 +800,6 @@ class LinearRegressionModel private[ml] ( override def toString: String = { s"LinearRegressionModel: uid=$uid, numFeatures=$numFeatures" } - - private[spark] def createSummary( - dataset: Dataset[_], - diagInvAtWA: Array[Double], - objectiveHistory: Array[Double], - coefficientArray: Array[Double] - ): Unit = { - // Handle possible missing or invalid prediction columns - val (summaryModel, predictionColName) = findSummaryModelAndPredictionCol() - - val trainingSummary = new LinearRegressionTrainingSummary( - summaryModel.transform(dataset), predictionColName, $(labelCol), $(featuresCol), - summaryModel.get(summaryModel.weightCol).getOrElse(""), - summaryModel.numFeatures, summaryModel.getFitIntercept, - diagInvAtWA, objectiveHistory, coefficientArray) - - setSummary(Some(trainingSummary)) - } - - override private[spark] def saveSummary(path: String): Unit = { - ReadWriteUtils.saveObjectToLocal[(Array[Double], Array[Double], Array[Double])]( - path, (summary.diagInvAtWA, summary.objectiveHistory, summary.coefficientArray), - (data, dos) => { - ReadWriteUtils.serializeDoubleArray(data._1, dos) - ReadWriteUtils.serializeDoubleArray(data._2, dos) - ReadWriteUtils.serializeDoubleArray(data._3, dos) - } - ) - } - - override private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = { - val ( - diagInvAtWA: Array[Double], - objectiveHistory: Array[Double], - coefficientArray: Array[Double] - ) - = ReadWriteUtils.loadObjectFromLocal[(Array[Double], Array[Double], Array[Double])]( - path, - dis => { - val diagInvAtWA = ReadWriteUtils.deserializeDoubleArray(dis) - val objectiveHistory = ReadWriteUtils.deserializeDoubleArray(dis) - val coefficientArray = ReadWriteUtils.deserializeDoubleArray(dis) - (diagInvAtWA, objectiveHistory, coefficientArray) - } - ) - createSummary(dataset, diagInvAtWA, objectiveHistory, coefficientArray) - } } private[ml] case class LinearModelData(intercept: Double, coefficients: Vector, scale: Double) @@ -955,7 +926,7 @@ class LinearRegressionTrainingSummary private[regression] ( private val fitIntercept: Boolean, diagInvAtWA: Array[Double], val objectiveHistory: Array[Double], - override private[regression] val coefficientArray: Array[Double] = Array.emptyDoubleArray) + private val coefficientArray: Array[Double] = Array.emptyDoubleArray) extends LinearRegressionSummary( predictions, predictionCol, @@ -1001,8 +972,8 @@ class LinearRegressionSummary private[regression] ( private val weightCol: String, private val numFeatures: Int, private val fitIntercept: Boolean, - private[regression] val diagInvAtWA: Array[Double], - private[regression] val coefficientArray: Array[Double] = Array.emptyDoubleArray) + private val diagInvAtWA: Array[Double], + private val coefficientArray: Array[Double] = Array.emptyDoubleArray) extends Summary with Serializable { @transient private val metrics = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala b/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala index c6f6babf71a2b..0ba8ce072ab4a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/HasTrainingSummary.scala @@ -19,7 +19,6 @@ package org.apache.spark.ml.util import org.apache.spark.SparkException import org.apache.spark.annotation.Since -import org.apache.spark.sql.DataFrame /** @@ -50,14 +49,4 @@ private[spark] trait HasTrainingSummary[T] { this.trainingSummary = summary this } - - private[spark] def loadSummary(path: String, dataset: DataFrame): Unit = { - throw new SparkException( - s"No loadSummary implementation for this ${this.getClass.getSimpleName}") - } - - private[spark] def saveSummary(path: String): Unit = { - throw new SparkException( - s"No saveSummary implementation for this ${this.getClass.getSimpleName}") - } } diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index f66fc762971b5..a5fdaed0db2c4 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -889,14 +889,15 @@ def summary(self) -> "LinearSVCTrainingSummary": # type: ignore[override] Gets summary (accuracy/precision/recall, objective history, total iterations) of model trained on the training set. An exception is thrown if `trainingSummary is None`. """ - return super().summary - - @property - def _summaryCls(self) -> type: - return LinearSVCTrainingSummary - - def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame: - return train_dataset + if self.hasSummary: + s = LinearSVCTrainingSummary(super(LinearSVCModel, self).summary) + if is_remote(): + s.__source_transformer__ = self # type: ignore[attr-defined] + return s + else: + raise RuntimeError( + "No training summary available for this %s" % self.__class__.__name__ + ) def evaluate(self, dataset: DataFrame) -> "LinearSVCSummary": """ @@ -1576,6 +1577,29 @@ def interceptVector(self) -> Vector: """ return self._call_java("interceptVector") + @property + @since("2.0.0") + def summary(self) -> "LogisticRegressionTrainingSummary": + """ + Gets summary (accuracy/precision/recall, objective history, total iterations) of model + trained on the training set. An exception is thrown if `trainingSummary is None`. + """ + if self.hasSummary: + s: LogisticRegressionTrainingSummary + if self.numClasses <= 2: + s = BinaryLogisticRegressionTrainingSummary( + super(LogisticRegressionModel, self).summary + ) + else: + s = LogisticRegressionTrainingSummary(super(LogisticRegressionModel, self).summary) + if is_remote(): + s.__source_transformer__ = self # type: ignore[attr-defined] + return s + else: + raise RuntimeError( + "No training summary available for this %s" % self.__class__.__name__ + ) + def evaluate(self, dataset: DataFrame) -> "LogisticRegressionSummary": """ Evaluates the model on a test dataset. @@ -1599,15 +1623,6 @@ def evaluate(self, dataset: DataFrame) -> "LogisticRegressionSummary": s.__source_transformer__ = self # type: ignore[attr-defined] return s - @property - def _summaryCls(self) -> type: - if self.numClasses <= 2: - return BinaryLogisticRegressionTrainingSummary - return LogisticRegressionTrainingSummary - - def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame: - return train_dataset - class LogisticRegressionSummary(_ClassificationSummary): """ @@ -2300,13 +2315,29 @@ def trees(self) -> List[DecisionTreeClassificationModel]: return [DecisionTreeClassificationModel(m) for m in list(self._call_java("trees"))] @property - def _summaryCls(self) -> type: - if self.numClasses <= 2: - return BinaryRandomForestClassificationTrainingSummary - return RandomForestClassificationTrainingSummary - - def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame: - return train_dataset + @since("3.1.0") + def summary(self) -> "RandomForestClassificationTrainingSummary": + """ + Gets summary (accuracy/precision/recall, objective history, total iterations) of model + trained on the training set. An exception is thrown if `trainingSummary is None`. + """ + if self.hasSummary: + s: RandomForestClassificationTrainingSummary + if self.numClasses <= 2: + s = BinaryRandomForestClassificationTrainingSummary( + super(RandomForestClassificationModel, self).summary + ) + else: + s = RandomForestClassificationTrainingSummary( + super(RandomForestClassificationModel, self).summary + ) + if is_remote(): + s.__source_transformer__ = self # type: ignore[attr-defined] + return s + else: + raise RuntimeError( + "No training summary available for this %s" % self.__class__.__name__ + ) def evaluate(self, dataset: DataFrame) -> "RandomForestClassificationSummary": """ @@ -3341,14 +3372,17 @@ def summary( # type: ignore[override] Gets summary (accuracy/precision/recall, objective history, total iterations) of model trained on the training set. An exception is thrown if `trainingSummary is None`. """ - return super().summary - - @property - def _summaryCls(self) -> type: - return MultilayerPerceptronClassificationTrainingSummary - - def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame: - return train_dataset + if self.hasSummary: + s = MultilayerPerceptronClassificationTrainingSummary( + super(MultilayerPerceptronClassificationModel, self).summary + ) + if is_remote(): + s.__source_transformer__ = self # type: ignore[attr-defined] + return s + else: + raise RuntimeError( + "No training summary available for this %s" % self.__class__.__name__ + ) def evaluate(self, dataset: DataFrame) -> "MultilayerPerceptronClassificationSummary": """ @@ -4287,6 +4321,22 @@ def factors(self) -> Matrix: """ return self._call_java("factors") + @since("3.1.0") + def summary(self) -> "FMClassificationTrainingSummary": + """ + Gets summary (accuracy/precision/recall, objective history, total iterations) of model + trained on the training set. An exception is thrown if `trainingSummary is None`. + """ + if self.hasSummary: + s = FMClassificationTrainingSummary(super(FMClassificationModel, self).summary) + if is_remote(): + s.__source_transformer__ = self # type: ignore[attr-defined] + return s + else: + raise RuntimeError( + "No training summary available for this %s" % self.__class__.__name__ + ) + def evaluate(self, dataset: DataFrame) -> "FMClassificationSummary": """ Evaluates the model on a test dataset. @@ -4306,21 +4356,6 @@ def evaluate(self, dataset: DataFrame) -> "FMClassificationSummary": s.__source_transformer__ = self # type: ignore[attr-defined] return s - @since("3.1.0") - def summary(self) -> "FMClassificationTrainingSummary": - """ - Gets summary (accuracy/precision/recall, objective history, total iterations) of model - trained on the training set. An exception is thrown if `trainingSummary is None`. - """ - return super().summary - - @property - def _summaryCls(self) -> type: - return FMClassificationTrainingSummary - - def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame: - return train_dataset - class FMClassificationSummary(_BinaryClassificationSummary): """ diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 0e26398de3c45..7267ee2805987 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -255,6 +255,23 @@ def gaussiansDF(self) -> DataFrame: """ return self._call_java("gaussiansDF") + @property + @since("2.1.0") + def summary(self) -> "GaussianMixtureSummary": + """ + Gets summary (cluster assignments, cluster sizes) of the model trained on the + training set. An exception is thrown if no summary exists. + """ + if self.hasSummary: + s = GaussianMixtureSummary(super(GaussianMixtureModel, self).summary) + if is_remote(): + s.__source_transformer__ = self # type: ignore[attr-defined] + return s + else: + raise RuntimeError( + "No training summary available for this %s" % self.__class__.__name__ + ) + @since("3.0.0") def predict(self, value: Vector) -> int: """ @@ -269,10 +286,6 @@ def predictProbability(self, value: Vector) -> Vector: """ return self._call_java("predictProbability", value) - @property - def _summaryCls(self) -> type: - return GaussianMixtureSummary - @inherit_doc class GaussianMixture( @@ -692,6 +705,23 @@ def numFeatures(self) -> int: """ return self._call_java("numFeatures") + @property + @since("2.1.0") + def summary(self) -> KMeansSummary: + """ + Gets summary (cluster assignments, cluster sizes) of the model trained on the + training set. An exception is thrown if no summary exists. + """ + if self.hasSummary: + s = KMeansSummary(super(KMeansModel, self).summary) + if is_remote(): + s.__source_transformer__ = self # type: ignore[attr-defined] + return s + else: + raise RuntimeError( + "No training summary available for this %s" % self.__class__.__name__ + ) + @since("3.0.0") def predict(self, value: Vector) -> int: """ @@ -699,10 +729,6 @@ def predict(self, value: Vector) -> int: """ return self._call_java("predict", value) - @property - def _summaryCls(self) -> type: - return KMeansSummary - @inherit_doc class KMeans(JavaEstimator[KMeansModel], _KMeansParams, JavaMLWritable, JavaMLReadable["KMeans"]): @@ -1029,6 +1055,23 @@ def numFeatures(self) -> int: """ return self._call_java("numFeatures") + @property + @since("2.1.0") + def summary(self) -> "BisectingKMeansSummary": + """ + Gets summary (cluster assignments, cluster sizes) of the model trained on the + training set. An exception is thrown if no summary exists. + """ + if self.hasSummary: + s = BisectingKMeansSummary(super(BisectingKMeansModel, self).summary) + if is_remote(): + s.__source_transformer__ = self # type: ignore[attr-defined] + return s + else: + raise RuntimeError( + "No training summary available for this %s" % self.__class__.__name__ + ) + @since("3.0.0") def predict(self, value: Vector) -> int: """ @@ -1036,10 +1079,6 @@ def predict(self, value: Vector) -> int: """ return self._call_java("predict", value) - @property - def _summaryCls(self) -> type: - return BisectingKMeansSummary - @inherit_doc class BisectingKMeans( diff --git a/python/pyspark/ml/connect/proto.py b/python/pyspark/ml/connect/proto.py index 7cffd32631ba5..31f100859281a 100644 --- a/python/pyspark/ml/connect/proto.py +++ b/python/pyspark/ml/connect/proto.py @@ -70,13 +70,8 @@ class AttributeRelation(LogicalPlan): could be a model or a summary. This attribute returns a DataFrame. """ - def __init__( - self, - ref_id: str, - methods: List[pb2.Fetch.Method], - child: Optional["LogicalPlan"] = None, - ) -> None: - super().__init__(child) + def __init__(self, ref_id: str, methods: List[pb2.Fetch.Method]) -> None: + super().__init__(None) self._ref_id = ref_id self._methods = methods @@ -84,6 +79,4 @@ def plan(self, session: "SparkConnectClient") -> pb2.Relation: plan = self._create_proto_relation() plan.ml_relation.fetch.obj_ref.CopyFrom(pb2.ObjectRef(id=self._ref_id)) plan.ml_relation.fetch.methods.extend(self._methods) - if self._child is not None: - plan.ml_relation.model_summary_dataset.CopyFrom(self._child.plan(session)) return plan diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index ce97b98f6665c..66d6dbd6a2678 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -479,11 +479,22 @@ def scale(self) -> float: return self._call_java("scale") @property - def _summaryCls(self) -> type: - return LinearRegressionTrainingSummary - - def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame: - return train_dataset + @since("2.0.0") + def summary(self) -> "LinearRegressionTrainingSummary": + """ + Gets summary (residuals, MSE, r-squared ) of model on + training set. An exception is thrown if + `trainingSummary is None`. + """ + if self.hasSummary: + s = LinearRegressionTrainingSummary(super(LinearRegressionModel, self).summary) + if is_remote(): + s.__source_transformer__ = self # type: ignore[attr-defined] + return s + else: + raise RuntimeError( + "No training summary available for this %s" % self.__class__.__name__ + ) def evaluate(self, dataset: DataFrame) -> "LinearRegressionSummary": """ @@ -2763,11 +2774,24 @@ def intercept(self) -> float: return self._call_java("intercept") @property - def _summaryCls(self) -> type: - return GeneralizedLinearRegressionTrainingSummary - - def _summary_dataset(self, train_dataset: DataFrame) -> DataFrame: - return train_dataset + @since("2.0.0") + def summary(self) -> "GeneralizedLinearRegressionTrainingSummary": + """ + Gets summary (residuals, deviance, p-values) of model on + training set. An exception is thrown if + `trainingSummary is None`. + """ + if self.hasSummary: + s = GeneralizedLinearRegressionTrainingSummary( + super(GeneralizedLinearRegressionModel, self).summary + ) + if is_remote(): + s.__source_transformer__ = self # type: ignore[attr-defined] + return s + else: + raise RuntimeError( + "No training summary available for this %s" % self.__class__.__name__ + ) def evaluate(self, dataset: DataFrame) -> "GeneralizedLinearRegressionSummary": """ diff --git a/python/pyspark/ml/tests/connect/test_connect_cache.py b/python/pyspark/ml/tests/connect/test_connect_cache.py index f911ab22286c0..8d156a0f11e1d 100644 --- a/python/pyspark/ml/tests/connect/test_connect_cache.py +++ b/python/pyspark/ml/tests/connect/test_connect_cache.py @@ -48,24 +48,20 @@ def test_delete_model(self): "obj: class org.apache.spark.ml.classification.LinearSVCModel" in cache_info[0], cache_info, ) - # the `model._summary` holds another ref to the remote model. - assert model._java_obj._ref_count == 2 + assert model._java_obj._ref_count == 1 model2 = model.copy() cache_info = spark.client._get_ml_cache_info() self.assertEqual(len(cache_info), 1) - assert model._java_obj._ref_count == 3 - assert model2._java_obj._ref_count == 3 + assert model._java_obj._ref_count == 2 + assert model2._java_obj._ref_count == 2 # explicitly delete the model del model cache_info = spark.client._get_ml_cache_info() self.assertEqual(len(cache_info), 1) - # Note the copied model 'model2' also holds the `_summary` object, - # and the `_summary` object holds another ref to the remote model. - # so the ref count is 2. - assert model2._java_obj._ref_count == 2 + assert model2._java_obj._ref_count == 1 del model2 cache_info = spark.client._get_ml_cache_info() @@ -103,6 +99,7 @@ def test_cleanup_ml_cache(self): cache_info, ) + # explicitly delete the model1 del model1 cache_info = spark.client._get_ml_cache_info() diff --git a/python/pyspark/ml/tests/test_classification.py b/python/pyspark/ml/tests/test_classification.py index 21bce70e8735b..57e4c0ef86dc0 100644 --- a/python/pyspark/ml/tests/test_classification.py +++ b/python/pyspark/ml/tests/test_classification.py @@ -55,7 +55,6 @@ MultilayerPerceptronClassificationTrainingSummary, ) from pyspark.ml.regression import DecisionTreeRegressionModel -from pyspark.sql import is_remote from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -242,45 +241,37 @@ def test_binary_logistic_regression_summary(self): model = lr.fit(df) self.assertEqual(lr.uid, model.uid) self.assertTrue(model.hasSummary) - - def check_summary(): - s = model.summary - # test that api is callable and returns expected types - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.probabilityCol, "probability") - self.assertEqual(s.labelCol, "label") - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - objHist = s.objectiveHistory - self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) - self.assertGreater(s.totalIterations, 0) - self.assertTrue(isinstance(s.labels, list)) - self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.precisionByLabel, list)) - self.assertTrue(isinstance(s.recallByLabel, list)) - self.assertTrue(isinstance(s.fMeasureByLabel(), list)) - self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) - self.assertTrue(isinstance(s.roc, DataFrame)) - self.assertAlmostEqual(s.areaUnderROC, 1.0, 2) - self.assertTrue(isinstance(s.pr, DataFrame)) - self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame)) - self.assertTrue(isinstance(s.precisionByThreshold, DataFrame)) - self.assertTrue(isinstance(s.recallByThreshold, DataFrame)) - self.assertAlmostEqual(s.accuracy, 1.0, 2) - self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2) - self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2) - self.assertAlmostEqual(s.weightedRecall, 1.0, 2) - self.assertAlmostEqual(s.weightedPrecision, 1.0, 2) - self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2) - self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2) - - check_summary() - if is_remote(): - self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) - check_summary() - s = model.summary + # test that api is callable and returns expected types + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.labels, list)) + self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.precisionByLabel, list)) + self.assertTrue(isinstance(s.recallByLabel, list)) + self.assertTrue(isinstance(s.fMeasureByLabel(), list)) + self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) + self.assertTrue(isinstance(s.roc, DataFrame)) + self.assertAlmostEqual(s.areaUnderROC, 1.0, 2) + self.assertTrue(isinstance(s.pr, DataFrame)) + self.assertTrue(isinstance(s.fMeasureByThreshold, DataFrame)) + self.assertTrue(isinstance(s.precisionByThreshold, DataFrame)) + self.assertTrue(isinstance(s.recallByThreshold, DataFrame)) + self.assertAlmostEqual(s.accuracy, 1.0, 2) + self.assertAlmostEqual(s.weightedTruePositiveRate, 1.0, 2) + self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.0, 2) + self.assertAlmostEqual(s.weightedRecall, 1.0, 2) + self.assertAlmostEqual(s.weightedPrecision, 1.0, 2) + self.assertAlmostEqual(s.weightedFMeasure(), 1.0, 2) + self.assertAlmostEqual(s.weightedFMeasure(1.0), 1.0, 2) + # test evaluation (with training dataset) produces a summary with same values # one check is enough to verify a summary is returned, Scala version runs full test sameSummary = model.evaluate(df) @@ -301,39 +292,31 @@ def test_multiclass_logistic_regression_summary(self): lr = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", fitIntercept=False) model = lr.fit(df) self.assertTrue(model.hasSummary) - - def check_summary(): - s = model.summary - # test that api is callable and returns expected types - self.assertTrue(isinstance(s.predictions, DataFrame)) - self.assertEqual(s.probabilityCol, "probability") - self.assertEqual(s.labelCol, "label") - self.assertEqual(s.featuresCol, "features") - self.assertEqual(s.predictionCol, "prediction") - objHist = s.objectiveHistory - self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) - self.assertGreater(s.totalIterations, 0) - self.assertTrue(isinstance(s.labels, list)) - self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) - self.assertTrue(isinstance(s.precisionByLabel, list)) - self.assertTrue(isinstance(s.recallByLabel, list)) - self.assertTrue(isinstance(s.fMeasureByLabel(), list)) - self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) - self.assertAlmostEqual(s.accuracy, 0.75, 2) - self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2) - self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2) - self.assertAlmostEqual(s.weightedRecall, 0.75, 2) - self.assertAlmostEqual(s.weightedPrecision, 0.583, 2) - self.assertAlmostEqual(s.weightedFMeasure(), 0.65, 2) - self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.65, 2) - - check_summary() - if is_remote(): - self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) - check_summary() - s = model.summary + # test that api is callable and returns expected types + self.assertTrue(isinstance(s.predictions, DataFrame)) + self.assertEqual(s.probabilityCol, "probability") + self.assertEqual(s.labelCol, "label") + self.assertEqual(s.featuresCol, "features") + self.assertEqual(s.predictionCol, "prediction") + objHist = s.objectiveHistory + self.assertTrue(isinstance(objHist, list) and isinstance(objHist[0], float)) + self.assertGreater(s.totalIterations, 0) + self.assertTrue(isinstance(s.labels, list)) + self.assertTrue(isinstance(s.truePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.falsePositiveRateByLabel, list)) + self.assertTrue(isinstance(s.precisionByLabel, list)) + self.assertTrue(isinstance(s.recallByLabel, list)) + self.assertTrue(isinstance(s.fMeasureByLabel(), list)) + self.assertTrue(isinstance(s.fMeasureByLabel(1.0), list)) + self.assertAlmostEqual(s.accuracy, 0.75, 2) + self.assertAlmostEqual(s.weightedTruePositiveRate, 0.75, 2) + self.assertAlmostEqual(s.weightedFalsePositiveRate, 0.25, 2) + self.assertAlmostEqual(s.weightedRecall, 0.75, 2) + self.assertAlmostEqual(s.weightedPrecision, 0.583, 2) + self.assertAlmostEqual(s.weightedFMeasure(), 0.65, 2) + self.assertAlmostEqual(s.weightedFMeasure(1.0), 0.65, 2) + # test evaluation (with training dataset) produces a summary with same values # one check is enough to verify a summary is returned, Scala version runs full test sameSummary = model.evaluate(df) @@ -443,21 +426,15 @@ def test_linear_svc(self): self.assertEqual(output.columns, expected_cols) self.assertEqual(output.count(), 4) - def check_summary(): - # model summary - self.assertTrue(model.hasSummary) - summary = model.summary() - self.assertIsInstance(summary, LinearSVCSummary) - self.assertIsInstance(summary, LinearSVCTrainingSummary) - self.assertEqual(summary.labels, [0.0, 1.0]) - self.assertEqual(summary.accuracy, 0.5) - self.assertEqual(summary.areaUnderROC, 0.75) - self.assertEqual(summary.predictions.columns, expected_cols) - - check_summary() - if is_remote(): - self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) - check_summary() + # model summary + self.assertTrue(model.hasSummary) + summary = model.summary() + self.assertIsInstance(summary, LinearSVCSummary) + self.assertIsInstance(summary, LinearSVCTrainingSummary) + self.assertEqual(summary.labels, [0.0, 1.0]) + self.assertEqual(summary.accuracy, 0.5) + self.assertEqual(summary.areaUnderROC, 0.75) + self.assertEqual(summary.predictions.columns, expected_cols) summary2 = model.evaluate(df) self.assertIsInstance(summary2, LinearSVCSummary) @@ -549,20 +526,13 @@ def test_factorization_machine(self): # model summary self.assertTrue(model.hasSummary) - - def check_summary(): - summary = model.summary() - self.assertIsInstance(summary, FMClassificationSummary) - self.assertIsInstance(summary, FMClassificationTrainingSummary) - self.assertEqual(summary.labels, [0.0, 1.0]) - self.assertEqual(summary.accuracy, 0.25) - self.assertEqual(summary.areaUnderROC, 0.5) - self.assertEqual(summary.predictions.columns, expected_cols) - - check_summary() - if is_remote(): - self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) - check_summary() + summary = model.summary() + self.assertIsInstance(summary, FMClassificationSummary) + self.assertIsInstance(summary, FMClassificationTrainingSummary) + self.assertEqual(summary.labels, [0.0, 1.0]) + self.assertEqual(summary.accuracy, 0.25) + self.assertEqual(summary.areaUnderROC, 0.5) + self.assertEqual(summary.predictions.columns, expected_cols) summary2 = model.evaluate(df) self.assertIsInstance(summary2, FMClassificationSummary) @@ -803,27 +773,21 @@ def test_binary_random_forest_classifier(self): self.assertEqual(tree.transform(df).count(), 4) self.assertEqual(tree.transform(df).columns, expected_cols) - def check_summary(): - # model summary - summary = model.summary - self.assertTrue(isinstance(summary, BinaryRandomForestClassificationSummary)) - self.assertTrue(isinstance(summary, BinaryRandomForestClassificationTrainingSummary)) - self.assertEqual(summary.labels, [0.0, 1.0]) - self.assertEqual(summary.accuracy, 0.75) - self.assertEqual(summary.areaUnderROC, 0.875) - self.assertEqual(summary.predictions.columns, expected_cols) - - check_summary() - if is_remote(): - self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) - check_summary() + # model summary + summary = model.summary + self.assertTrue(isinstance(summary, BinaryRandomForestClassificationSummary)) + self.assertTrue(isinstance(summary, BinaryRandomForestClassificationTrainingSummary)) + self.assertEqual(summary.labels, [0.0, 1.0]) + self.assertEqual(summary.accuracy, 0.75) + self.assertEqual(summary.areaUnderROC, 0.875) + self.assertEqual(summary.predictions.columns, expected_cols) summary2 = model.evaluate(df) self.assertTrue(isinstance(summary2, BinaryRandomForestClassificationSummary)) self.assertFalse(isinstance(summary2, BinaryRandomForestClassificationTrainingSummary)) self.assertEqual(summary2.labels, [0.0, 1.0]) self.assertEqual(summary2.accuracy, 0.75) - self.assertEqual(summary2.areaUnderROC, 0.875) + self.assertEqual(summary.areaUnderROC, 0.875) self.assertEqual(summary2.predictions.columns, expected_cols) # Model save & load @@ -895,19 +859,13 @@ def test_multiclass_random_forest_classifier(self): self.assertEqual(output.columns, expected_cols) self.assertEqual(output.count(), 4) - def check_summary(): - # model summary - summary = model.summary - self.assertTrue(isinstance(summary, RandomForestClassificationSummary)) - self.assertTrue(isinstance(summary, RandomForestClassificationTrainingSummary)) - self.assertEqual(summary.labels, [0.0, 1.0, 2.0]) - self.assertEqual(summary.accuracy, 0.5) - self.assertEqual(summary.predictions.columns, expected_cols) - - check_summary() - if is_remote(): - self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) - check_summary() + # model summary + summary = model.summary + self.assertTrue(isinstance(summary, RandomForestClassificationSummary)) + self.assertTrue(isinstance(summary, RandomForestClassificationTrainingSummary)) + self.assertEqual(summary.labels, [0.0, 1.0, 2.0]) + self.assertEqual(summary.accuracy, 0.5) + self.assertEqual(summary.predictions.columns, expected_cols) summary2 = model.evaluate(df) self.assertTrue(isinstance(summary2, RandomForestClassificationSummary)) @@ -995,20 +953,14 @@ def test_mlp(self): self.assertEqual(output.columns, expected_cols) self.assertEqual(output.count(), 4) - def check_summary(): - # model summary - self.assertTrue(model.hasSummary) - summary = model.summary() - self.assertIsInstance(summary, MultilayerPerceptronClassificationSummary) - self.assertIsInstance(summary, MultilayerPerceptronClassificationTrainingSummary) - self.assertEqual(summary.labels, [0.0, 1.0]) - self.assertEqual(summary.accuracy, 0.75) - self.assertEqual(summary.predictions.columns, expected_cols) - - check_summary() - if is_remote(): - self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) - check_summary() + # model summary + self.assertTrue(model.hasSummary) + summary = model.summary() + self.assertIsInstance(summary, MultilayerPerceptronClassificationSummary) + self.assertIsInstance(summary, MultilayerPerceptronClassificationTrainingSummary) + self.assertEqual(summary.labels, [0.0, 1.0]) + self.assertEqual(summary.accuracy, 0.75) + self.assertEqual(summary.predictions.columns, expected_cols) summary2 = model.evaluate(df) self.assertIsInstance(summary2, MultilayerPerceptronClassificationSummary) diff --git a/python/pyspark/ml/tests/test_clustering.py b/python/pyspark/ml/tests/test_clustering.py index fbf012babcc3d..1b8eb73135a96 100644 --- a/python/pyspark/ml/tests/test_clustering.py +++ b/python/pyspark/ml/tests/test_clustering.py @@ -85,39 +85,23 @@ def test_kmeans(self): self.assertTrue(np.allclose(model.predict(Vectors.dense(0.0, 5.0)), 1, atol=1e-4)) - def check_summary(): - # Model summary - self.assertTrue(model.hasSummary) - summary = model.summary - self.assertTrue(isinstance(summary, KMeansSummary)) - self.assertEqual(summary.k, 2) - self.assertEqual(summary.numIter, 2) - self.assertEqual(summary.clusterSizes, [4, 2]) - self.assertTrue(np.allclose(summary.trainingCost, 1.35710375, atol=1e-4)) - - self.assertEqual(summary.featuresCol, "features") - self.assertEqual(summary.predictionCol, "prediction") - - self.assertEqual(summary.cluster.columns, ["prediction"]) - self.assertEqual(summary.cluster.count(), 6) - - self.assertEqual(summary.predictions.columns, expected_cols) - self.assertEqual(summary.predictions.count(), 6) - - # check summary before model offloading occurs - check_summary() + # Model summary + self.assertTrue(model.hasSummary) + summary = model.summary + self.assertTrue(isinstance(summary, KMeansSummary)) + self.assertEqual(summary.k, 2) + self.assertEqual(summary.numIter, 2) + self.assertEqual(summary.clusterSizes, [4, 2]) + self.assertTrue(np.allclose(summary.trainingCost, 1.35710375, atol=1e-4)) - if is_remote(): - self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) - # check summary "try_remote_call" path after model offloading occurs - self.assertEqual(model.summary.numIter, 2) + self.assertEqual(summary.featuresCol, "features") + self.assertEqual(summary.predictionCol, "prediction") - self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) - # check summary "invoke_remote_attribute_relation" path after model offloading occurs - self.assertEqual(model.summary.cluster.count(), 6) + self.assertEqual(summary.cluster.columns, ["prediction"]) + self.assertEqual(summary.cluster.count(), 6) - self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) - check_summary() + self.assertEqual(summary.predictions.columns, expected_cols) + self.assertEqual(summary.predictions.count(), 6) # save & load with tempfile.TemporaryDirectory(prefix="kmeans_model") as d: @@ -128,9 +112,6 @@ def check_summary(): model.write().overwrite().save(d) model2 = KMeansModel.load(d) self.assertEqual(str(model), str(model2)) - self.assertFalse(model2.hasSummary) - with self.assertRaisesRegex(Exception, "No training summary available"): - model2.summary def test_bisecting_kmeans(self): df = ( @@ -297,36 +278,30 @@ def test_gaussian_mixture(self): self.assertEqual(output.columns, expected_cols) self.assertEqual(output.count(), 6) - def check_summary(): - # Model summary - self.assertTrue(model.hasSummary) - summary = model.summary - self.assertTrue(isinstance(summary, GaussianMixtureSummary)) - self.assertEqual(summary.k, 2) - self.assertEqual(summary.numIter, 2) - self.assertEqual(len(summary.clusterSizes), 2) - self.assertEqual(summary.clusterSizes, [3, 3]) - ll = summary.logLikelihood - self.assertTrue(ll < 0, ll) - self.assertTrue(np.allclose(ll, -1.311264553744033, atol=1e-4), ll) - - self.assertEqual(summary.featuresCol, "features") - self.assertEqual(summary.predictionCol, "prediction") - self.assertEqual(summary.probabilityCol, "probability") - - self.assertEqual(summary.cluster.columns, ["prediction"]) - self.assertEqual(summary.cluster.count(), 6) - - self.assertEqual(summary.predictions.columns, expected_cols) - self.assertEqual(summary.predictions.count(), 6) - - self.assertEqual(summary.probability.columns, ["probability"]) - self.assertEqual(summary.predictions.count(), 6) - - check_summary() - if is_remote(): - self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) - check_summary() + # Model summary + self.assertTrue(model.hasSummary) + summary = model.summary + self.assertTrue(isinstance(summary, GaussianMixtureSummary)) + self.assertEqual(summary.k, 2) + self.assertEqual(summary.numIter, 2) + self.assertEqual(len(summary.clusterSizes), 2) + self.assertEqual(summary.clusterSizes, [3, 3]) + ll = summary.logLikelihood + self.assertTrue(ll < 0, ll) + self.assertTrue(np.allclose(ll, -1.311264553744033, atol=1e-4), ll) + + self.assertEqual(summary.featuresCol, "features") + self.assertEqual(summary.predictionCol, "prediction") + self.assertEqual(summary.probabilityCol, "probability") + + self.assertEqual(summary.cluster.columns, ["prediction"]) + self.assertEqual(summary.cluster.count(), 6) + + self.assertEqual(summary.predictions.columns, expected_cols) + self.assertEqual(summary.predictions.count(), 6) + + self.assertEqual(summary.probability.columns, ["probability"]) + self.assertEqual(summary.predictions.count(), 6) # save & load with tempfile.TemporaryDirectory(prefix="gaussian_mixture") as d: diff --git a/python/pyspark/ml/tests/test_regression.py b/python/pyspark/ml/tests/test_regression.py index 52688fdd63cf2..8638fb4d6078e 100644 --- a/python/pyspark/ml/tests/test_regression.py +++ b/python/pyspark/ml/tests/test_regression.py @@ -43,7 +43,6 @@ GBTRegressor, GBTRegressionModel, ) -from pyspark.sql import is_remote from pyspark.testing.sqlutils import ReusedSQLTestCase @@ -194,58 +193,50 @@ def test_linear_regression(self): np.allclose(model.predict(Vectors.dense(0.0, 5.0)), 0.21249999999999963, atol=1e-4) ) - def check_summary(): - # Model summary - summary = model.summary - self.assertTrue(isinstance(summary, LinearRegressionSummary)) - self.assertTrue(isinstance(summary, LinearRegressionTrainingSummary)) - self.assertEqual(summary.predictions.columns, expected_cols) - self.assertEqual(summary.predictions.count(), 4) - self.assertEqual(summary.residuals.columns, ["residuals"]) - self.assertEqual(summary.residuals.count(), 4) - - self.assertEqual(summary.degreesOfFreedom, 1) - self.assertEqual(summary.numInstances, 4) - self.assertEqual(summary.objectiveHistory, [0.0]) - self.assertTrue( - np.allclose( - summary.coefficientStandardErrors, - [1.2859821149611763, 0.6248749874975031, 3.1645497310044184], - atol=1e-4, - ) - ) - self.assertTrue( - np.allclose( - summary.devianceResiduals, [-0.7424621202458727, 0.7875000000000003], atol=1e-4 - ) + # Model summary + summary = model.summary + self.assertTrue(isinstance(summary, LinearRegressionSummary)) + self.assertTrue(isinstance(summary, LinearRegressionTrainingSummary)) + self.assertEqual(summary.predictions.columns, expected_cols) + self.assertEqual(summary.predictions.count(), 4) + self.assertEqual(summary.residuals.columns, ["residuals"]) + self.assertEqual(summary.residuals.count(), 4) + + self.assertEqual(summary.degreesOfFreedom, 1) + self.assertEqual(summary.numInstances, 4) + self.assertEqual(summary.objectiveHistory, [0.0]) + self.assertTrue( + np.allclose( + summary.coefficientStandardErrors, + [1.2859821149611763, 0.6248749874975031, 3.1645497310044184], + atol=1e-4, ) - self.assertTrue( - np.allclose( - summary.pValues, - [0.7020630236843428, 0.8866003086182783, 0.9298746994547682], - atol=1e-4, - ) + ) + self.assertTrue( + np.allclose( + summary.devianceResiduals, [-0.7424621202458727, 0.7875000000000003], atol=1e-4 ) - self.assertTrue( - np.allclose( - summary.tValues, - [0.5054502643838291, 0.1800360108036021, -0.11060025272186746], - atol=1e-4, - ) + ) + self.assertTrue( + np.allclose( + summary.pValues, + [0.7020630236843428, 0.8866003086182783, 0.9298746994547682], + atol=1e-4, ) - self.assertTrue(np.allclose(summary.explainedVariance, 0.07997500000000031, atol=1e-4)) - self.assertTrue(np.allclose(summary.meanAbsoluteError, 0.4200000000000002, atol=1e-4)) - self.assertTrue(np.allclose(summary.meanSquaredError, 0.20212500000000005, atol=1e-4)) - self.assertTrue( - np.allclose(summary.rootMeanSquaredError, 0.44958314025327956, atol=1e-4) + ) + self.assertTrue( + np.allclose( + summary.tValues, + [0.5054502643838291, 0.1800360108036021, -0.11060025272186746], + atol=1e-4, ) - self.assertTrue(np.allclose(summary.r2, 0.4427212572373862, atol=1e-4)) - self.assertTrue(np.allclose(summary.r2adj, -0.6718362282878414, atol=1e-4)) - - check_summary() - if is_remote(): - self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) - check_summary() + ) + self.assertTrue(np.allclose(summary.explainedVariance, 0.07997500000000031, atol=1e-4)) + self.assertTrue(np.allclose(summary.meanAbsoluteError, 0.4200000000000002, atol=1e-4)) + self.assertTrue(np.allclose(summary.meanSquaredError, 0.20212500000000005, atol=1e-4)) + self.assertTrue(np.allclose(summary.rootMeanSquaredError, 0.44958314025327956, atol=1e-4)) + self.assertTrue(np.allclose(summary.r2, 0.4427212572373862, atol=1e-4)) + self.assertTrue(np.allclose(summary.r2adj, -0.6718362282878414, atol=1e-4)) summary2 = model.evaluate(df) self.assertTrue(isinstance(summary2, LinearRegressionSummary)) @@ -327,43 +318,36 @@ def test_generalized_linear_regression(self): self.assertEqual(output.columns, expected_cols) self.assertEqual(output.count(), 4) - def check_summary(): - # Model summary - self.assertTrue(model.hasSummary) - - summary = model.summary - self.assertIsInstance(summary, GeneralizedLinearRegressionSummary) - self.assertIsInstance(summary, GeneralizedLinearRegressionTrainingSummary) - self.assertEqual(summary.numIterations, 1) - self.assertEqual(summary.numInstances, 4) - self.assertEqual(summary.rank, 3) - self.assertTrue( - np.allclose( - summary.tValues, - [0.3725037662281711, -0.49418209022924164, 2.6589353685797654], - atol=1e-4, - ), + # Model summary + self.assertTrue(model.hasSummary) + + summary = model.summary + self.assertIsInstance(summary, GeneralizedLinearRegressionSummary) + self.assertIsInstance(summary, GeneralizedLinearRegressionTrainingSummary) + self.assertEqual(summary.numIterations, 1) + self.assertEqual(summary.numInstances, 4) + self.assertEqual(summary.rank, 3) + self.assertTrue( + np.allclose( summary.tValues, - ) - self.assertTrue( - np.allclose( - summary.pValues, - [0.7729938686180984, 0.707802691825973, 0.22900885781807023], - atol=1e-4, - ), + [0.3725037662281711, -0.49418209022924164, 2.6589353685797654], + atol=1e-4, + ), + summary.tValues, + ) + self.assertTrue( + np.allclose( summary.pValues, - ) - self.assertEqual(summary.predictions.columns, expected_cols) - self.assertEqual(summary.predictions.count(), 4) - self.assertEqual(summary.residuals().columns, ["devianceResiduals"]) - self.assertEqual(summary.residuals().count(), 4) - - check_summary() - if is_remote(): - self.spark.client._delete_ml_cache([model._java_obj._ref_id], evict_only=True) - check_summary() + [0.7729938686180984, 0.707802691825973, 0.22900885781807023], + atol=1e-4, + ), + summary.pValues, + ) + self.assertEqual(summary.predictions.columns, expected_cols) + self.assertEqual(summary.predictions.count(), 4) + self.assertEqual(summary.residuals().columns, ["devianceResiduals"]) + self.assertEqual(summary.residuals().count(), 4) - summary = model.summary summary2 = model.evaluate(df) self.assertIsInstance(summary2, GeneralizedLinearRegressionSummary) self.assertNotIsInstance(summary2, GeneralizedLinearRegressionTrainingSummary) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index f9a532de10f93..b86178a97c382 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -72,6 +72,20 @@ _logger = logging.getLogger("pyspark.ml.util") +def try_remote_intermediate_result(f: FuncT) -> FuncT: + """Mark the function/property that returns the intermediate result of the remote call. + Eg, model.summary""" + + @functools.wraps(f) + def wrapped(self: "JavaWrapper") -> Any: + if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: + return f"{str(self._java_obj)}.{f.__name__}" + else: + return f(self) + + return cast(FuncT, wrapped) + + def invoke_helper_attr(method: str, *args: Any) -> Any: from pyspark.ml.wrapper import JavaWrapper @@ -111,12 +125,7 @@ def invoke_remote_attribute_relation( object_id = instance._java_obj # type: ignore methods, obj_ref = _extract_id_methods(object_id) methods.append(pb2.Fetch.Method(method=method, args=serialize(session.client, *args))) - - if methods[0].method == "summary": - child = instance._summary_dataset._plan # type: ignore - else: - child = None - plan = AttributeRelation(obj_ref, methods, child=child) + plan = AttributeRelation(obj_ref, methods) # To delay the GC of the model, keep a reference to the source instance, # might be a model or a summary. @@ -195,15 +204,6 @@ def wrapped(self: "JavaEstimator", dataset: "ConnectDataFrame") -> Any: _logger.warning(warning_msg) remote_model_ref = RemoteModelRef(model_info.obj_ref.id) model = self._create_model(remote_model_ref) - if isinstance(model, HasTrainingSummary): - summary_dataset = model._summary_dataset(dataset) - - summary = model._summaryCls(f"{str(model._java_obj)}.summary") # type: ignore - summary._summary_dataset = summary_dataset - summary._remote_model_obj = model._java_obj # type: ignore - summary._remote_model_obj.add_ref() - - model._summary = summary # type: ignore if model.__class__.__name__ not in ["Bucketizer"]: model._resetUid(self.uid) return self._copyValues(model) @@ -278,16 +278,15 @@ def try_remote_call(f: FuncT) -> FuncT: @functools.wraps(f) def wrapped(self: "JavaWrapper", name: str, *args: Any) -> Any: - import pyspark.sql.connect.proto as pb2 - from pyspark.sql.connect.session import SparkSession - - session = SparkSession.getActiveSession() - - def remote_call() -> Any: + if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: + # Launch a remote call if possible + import pyspark.sql.connect.proto as pb2 + from pyspark.sql.connect.session import SparkSession from pyspark.ml.connect.util import _extract_id_methods from pyspark.ml.connect.serialize import serialize, deserialize from pyspark.ml.wrapper import JavaModel + session = SparkSession.getActiveSession() assert session is not None if self._java_obj == ML_CONNECT_HELPER_ID: obj_id = ML_CONNECT_HELPER_ID @@ -316,30 +315,6 @@ def remote_call() -> Any: return model_info.obj_ref.id else: return deserialize(properties) - - if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: - from pyspark.errors.exceptions.connect import SparkException - - try: - return remote_call() - except SparkException as e: - if e.getErrorClass() == "CONNECT_ML.MODEL_SUMMARY_LOST": - # the model summary is lost because the remote model was offloaded, - # send request to restore model.summary - create_summary_command = pb2.Command() - create_summary_command.ml_command.create_summary.CopyFrom( - pb2.MlCommand.CreateSummary( - model_ref=pb2.ObjectRef( - id=self._remote_model_obj.ref_id # type: ignore - ), - dataset=self._summary_dataset._plan.plan( # type: ignore - session.client # type: ignore - ), - ) - ) - session.client.execute_command(create_summary_command) # type: ignore - - return remote_call() else: return f(self, name, *args) @@ -371,11 +346,8 @@ def wrapped(self: "JavaWrapper") -> Any: except Exception: return - if in_remote: - if isinstance(self._java_obj, RemoteModelRef): - self._java_obj.release_ref() - if hasattr(self, "_remote_model_obj"): - self._remote_model_obj.release_ref() + if in_remote and isinstance(self._java_obj, RemoteModelRef): + self._java_obj.release_ref() return else: return f(self) @@ -1104,32 +1076,17 @@ def hasSummary(self) -> bool: Indicates whether a training summary exists for this model instance. """ - if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: - return hasattr(self, "_summary") return cast("JavaWrapper", self)._call_java("hasSummary") @property @since("2.1.0") + @try_remote_intermediate_result def summary(self) -> T: """ Gets summary of the model trained on the training set. An exception is thrown if no summary exists. """ - if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: - if hasattr(self, "_summary"): - return self._summary - else: - raise RuntimeError( - "No training summary available for this %s" % self.__class__.__name__ - ) - return self._summaryCls(cast("JavaWrapper", self)._call_java("summary")) - - @property - def _summaryCls(self) -> type: - raise NotImplementedError() - - def _summary_dataset(self, train_dataset: "DataFrame") -> "DataFrame": - return self.transform(train_dataset) # type: ignore + return cast("JavaWrapper", self)._call_java("summary") class MetaAlgorithmReadWrite: diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 3cfb38fdfa7da..34719f2b0ba6e 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1985,7 +1985,7 @@ def _create_profile(self, profile: pb2.ResourceProfile) -> int: profile_id = properties["create_resource_profile_command_result"] return profile_id - def _delete_ml_cache(self, cache_ids: List[str], evict_only: bool = False) -> List[str]: + def _delete_ml_cache(self, cache_ids: List[str]) -> List[str]: # try best to delete the cache try: if len(cache_ids) > 0: @@ -1993,7 +1993,6 @@ def _delete_ml_cache(self, cache_ids: List[str], evict_only: bool = False) -> Li command.ml_command.delete.obj_refs.extend( [pb2.ObjectRef(id=cache_id) for cache_id in cache_ids] ) - command.ml_command.delete.evict_only = evict_only (_, properties, _) = self.execute_command(command) assert properties is not None diff --git a/python/pyspark/sql/connect/proto/ml_pb2.py b/python/pyspark/sql/connect/proto/ml_pb2.py index 1ede558b94140..46fc82131a9e7 100644 --- a/python/pyspark/sql/connect/proto/ml_pb2.py +++ b/python/pyspark/sql/connect/proto/ml_pb2.py @@ -40,7 +40,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xb1\r\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01 \x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03 \x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x18\x04 \x01(\x0b\x32\x1e.spark.connect.MlCommand.WriteH\x00R\x05write\x12\x33\n\x04read\x18\x05 \x01(\x0b\x32\x1d.spark.connect.MlCommand.ReadH\x00R\x04read\x12?\n\x08\x65valuate\x18\x06 \x01(\x0b\x32!.spark.connect.MlCommand.EvaluateH\x00R\x08\x65valuate\x12\x46\n\x0b\x63lean_cache\x18\x07 \x01(\x0b\x32#.spark.connect.MlCommand.CleanCacheH\x00R\ncleanCache\x12M\n\x0eget_cache_info\x18\x08 \x01(\x0b\x32%.spark.connect.MlCommand.GetCacheInfoH\x00R\x0cgetCacheInfo\x12O\n\x0e\x63reate_summary\x18\t \x01(\x0b\x32&.spark.connect.MlCommand.CreateSummaryH\x00R\rcreateSummary\x1a\xb2\x01\n\x03\x46it\x12\x37\n\testimator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\testimator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1ap\n\x06\x44\x65lete\x12\x33\n\x08obj_refs\x18\x01 \x03(\x0b\x32\x18.spark.connect.ObjectRefR\x07objRefs\x12"\n\nevict_only\x18\x02 \x01(\x08H\x00R\tevictOnly\x88\x01\x01\x42\r\n\x0b_evict_only\x1a\x0c\n\nCleanCache\x1a\x0e\n\x0cGetCacheInfo\x1a\x9a\x03\n\x05Write\x12\x37\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorH\x00R\x08operator\x12\x33\n\x07obj_ref\x18\x02 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x34\n\x06params\x18\x03 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x01R\x06params\x88\x01\x01\x12\x12\n\x04path\x18\x04 \x01(\tR\x04path\x12.\n\x10should_overwrite\x18\x05 \x01(\x08H\x02R\x0fshouldOverwrite\x88\x01\x01\x12\x45\n\x07options\x18\x06 \x03(\x0b\x32+.spark.connect.MlCommand.Write.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x06\n\x04typeB\t\n\x07_paramsB\x13\n\x11_should_overwrite\x1aQ\n\x04Read\x12\x35\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\x08operator\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x1a\xb7\x01\n\x08\x45valuate\x12\x37\n\tevaluator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\tevaluator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1ay\n\rCreateSummary\x12\x35\n\tmodel_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefR\x08modelRef\x12\x31\n\x07\x64\x61taset\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07\x63ommand"\xd5\x03\n\x0fMlCommandResult\x12\x39\n\x05param\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x05param\x12\x1a\n\x07summary\x18\x02 \x01(\tH\x00R\x07summary\x12T\n\roperator_info\x18\x03 \x01(\x0b\x32-.spark.connect.MlCommandResult.MlOperatorInfoH\x00R\x0coperatorInfo\x1a\x85\x02\n\x0eMlOperatorInfo\x12\x33\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x14\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x12\x15\n\x03uid\x18\x03 \x01(\tH\x01R\x03uid\x88\x01\x01\x12\x34\n\x06params\x18\x04 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x02R\x06params\x88\x01\x01\x12,\n\x0fwarning_message\x18\x05 \x01(\tH\x03R\x0ewarningMessage\x88\x01\x01\x42\x06\n\x04typeB\x06\n\x04_uidB\t\n\x07_paramsB\x12\n\x10_warning_messageB\r\n\x0bresult_typeB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x16spark/connect/ml.proto\x12\rspark.connect\x1a\x1dspark/connect/relations.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/ml_common.proto"\xb2\x0b\n\tMlCommand\x12\x30\n\x03\x66it\x18\x01 \x01(\x0b\x32\x1c.spark.connect.MlCommand.FitH\x00R\x03\x66it\x12,\n\x05\x66\x65tch\x18\x02 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12\x39\n\x06\x64\x65lete\x18\x03 \x01(\x0b\x32\x1f.spark.connect.MlCommand.DeleteH\x00R\x06\x64\x65lete\x12\x36\n\x05write\x18\x04 \x01(\x0b\x32\x1e.spark.connect.MlCommand.WriteH\x00R\x05write\x12\x33\n\x04read\x18\x05 \x01(\x0b\x32\x1d.spark.connect.MlCommand.ReadH\x00R\x04read\x12?\n\x08\x65valuate\x18\x06 \x01(\x0b\x32!.spark.connect.MlCommand.EvaluateH\x00R\x08\x65valuate\x12\x46\n\x0b\x63lean_cache\x18\x07 \x01(\x0b\x32#.spark.connect.MlCommand.CleanCacheH\x00R\ncleanCache\x12M\n\x0eget_cache_info\x18\x08 \x01(\x0b\x32%.spark.connect.MlCommand.GetCacheInfoH\x00R\x0cgetCacheInfo\x1a\xb2\x01\n\x03\x46it\x12\x37\n\testimator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\testimator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_params\x1a=\n\x06\x44\x65lete\x12\x33\n\x08obj_refs\x18\x01 \x03(\x0b\x32\x18.spark.connect.ObjectRefR\x07objRefs\x1a\x0c\n\nCleanCache\x1a\x0e\n\x0cGetCacheInfo\x1a\x9a\x03\n\x05Write\x12\x37\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorH\x00R\x08operator\x12\x33\n\x07obj_ref\x18\x02 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x34\n\x06params\x18\x03 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x01R\x06params\x88\x01\x01\x12\x12\n\x04path\x18\x04 \x01(\tR\x04path\x12.\n\x10should_overwrite\x18\x05 \x01(\x08H\x02R\x0fshouldOverwrite\x88\x01\x01\x12\x45\n\x07options\x18\x06 \x03(\x0b\x32+.spark.connect.MlCommand.Write.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x06\n\x04typeB\t\n\x07_paramsB\x13\n\x11_should_overwrite\x1aQ\n\x04Read\x12\x35\n\x08operator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\x08operator\x12\x12\n\x04path\x18\x02 \x01(\tR\x04path\x1a\xb7\x01\n\x08\x45valuate\x12\x37\n\tevaluator\x18\x01 \x01(\x0b\x32\x19.spark.connect.MlOperatorR\tevaluator\x12\x34\n\x06params\x18\x02 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x00R\x06params\x88\x01\x01\x12\x31\n\x07\x64\x61taset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x07\x64\x61tasetB\t\n\x07_paramsB\t\n\x07\x63ommand"\xd5\x03\n\x0fMlCommandResult\x12\x39\n\x05param\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x05param\x12\x1a\n\x07summary\x18\x02 \x01(\tH\x00R\x07summary\x12T\n\roperator_info\x18\x03 \x01(\x0b\x32-.spark.connect.MlCommandResult.MlOperatorInfoH\x00R\x0coperatorInfo\x1a\x85\x02\n\x0eMlOperatorInfo\x12\x33\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12\x14\n\x04name\x18\x02 \x01(\tH\x00R\x04name\x12\x15\n\x03uid\x18\x03 \x01(\tH\x01R\x03uid\x88\x01\x01\x12\x34\n\x06params\x18\x04 \x01(\x0b\x32\x17.spark.connect.MlParamsH\x02R\x06params\x88\x01\x01\x12,\n\x0fwarning_message\x18\x05 \x01(\tH\x03R\x0ewarningMessage\x88\x01\x01\x42\x06\n\x04typeB\x06\n\x04_uidB\t\n\x07_paramsB\x12\n\x10_warning_messageB\r\n\x0bresult_typeB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _globals = globals() @@ -54,27 +54,25 @@ _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._loaded_options = None _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_options = b"8\001" _globals["_MLCOMMAND"]._serialized_start = 137 - _globals["_MLCOMMAND"]._serialized_end = 1850 - _globals["_MLCOMMAND_FIT"]._serialized_start = 712 - _globals["_MLCOMMAND_FIT"]._serialized_end = 890 - _globals["_MLCOMMAND_DELETE"]._serialized_start = 892 - _globals["_MLCOMMAND_DELETE"]._serialized_end = 1004 - _globals["_MLCOMMAND_CLEANCACHE"]._serialized_start = 1006 - _globals["_MLCOMMAND_CLEANCACHE"]._serialized_end = 1018 - _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_start = 1020 - _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_end = 1034 - _globals["_MLCOMMAND_WRITE"]._serialized_start = 1037 - _globals["_MLCOMMAND_WRITE"]._serialized_end = 1447 - _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1349 - _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1407 - _globals["_MLCOMMAND_READ"]._serialized_start = 1449 - _globals["_MLCOMMAND_READ"]._serialized_end = 1530 - _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1533 - _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1716 - _globals["_MLCOMMAND_CREATESUMMARY"]._serialized_start = 1718 - _globals["_MLCOMMAND_CREATESUMMARY"]._serialized_end = 1839 - _globals["_MLCOMMANDRESULT"]._serialized_start = 1853 - _globals["_MLCOMMANDRESULT"]._serialized_end = 2322 - _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 2046 - _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 2307 + _globals["_MLCOMMAND"]._serialized_end = 1595 + _globals["_MLCOMMAND_FIT"]._serialized_start = 631 + _globals["_MLCOMMAND_FIT"]._serialized_end = 809 + _globals["_MLCOMMAND_DELETE"]._serialized_start = 811 + _globals["_MLCOMMAND_DELETE"]._serialized_end = 872 + _globals["_MLCOMMAND_CLEANCACHE"]._serialized_start = 874 + _globals["_MLCOMMAND_CLEANCACHE"]._serialized_end = 886 + _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_start = 888 + _globals["_MLCOMMAND_GETCACHEINFO"]._serialized_end = 902 + _globals["_MLCOMMAND_WRITE"]._serialized_start = 905 + _globals["_MLCOMMAND_WRITE"]._serialized_end = 1315 + _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_start = 1217 + _globals["_MLCOMMAND_WRITE_OPTIONSENTRY"]._serialized_end = 1275 + _globals["_MLCOMMAND_READ"]._serialized_start = 1317 + _globals["_MLCOMMAND_READ"]._serialized_end = 1398 + _globals["_MLCOMMAND_EVALUATE"]._serialized_start = 1401 + _globals["_MLCOMMAND_EVALUATE"]._serialized_end = 1584 + _globals["_MLCOMMANDRESULT"]._serialized_start = 1598 + _globals["_MLCOMMANDRESULT"]._serialized_end = 2067 + _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_start = 1791 + _globals["_MLCOMMANDRESULT_MLOPERATORINFO"]._serialized_end = 2052 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/ml_pb2.pyi b/python/pyspark/sql/connect/proto/ml_pb2.pyi index 0a72c207b5264..88cc6cb625ded 100644 --- a/python/pyspark/sql/connect/proto/ml_pb2.pyi +++ b/python/pyspark/sql/connect/proto/ml_pb2.pyi @@ -118,39 +118,21 @@ class MlCommand(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor OBJ_REFS_FIELD_NUMBER: builtins.int - EVICT_ONLY_FIELD_NUMBER: builtins.int @property def obj_refs( self, ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ pyspark.sql.connect.proto.ml_common_pb2.ObjectRef ]: ... - evict_only: builtins.bool - """if set `evict_only` to true, only evict the cached model from memory, - but keep the offloaded model in Spark driver local disk. - """ def __init__( self, *, obj_refs: collections.abc.Iterable[pyspark.sql.connect.proto.ml_common_pb2.ObjectRef] | None = ..., - evict_only: builtins.bool | None = ..., ) -> None: ... - def HasField( - self, - field_name: typing_extensions.Literal[ - "_evict_only", b"_evict_only", "evict_only", b"evict_only" - ], - ) -> builtins.bool: ... def ClearField( - self, - field_name: typing_extensions.Literal[ - "_evict_only", b"_evict_only", "evict_only", b"evict_only", "obj_refs", b"obj_refs" - ], + self, field_name: typing_extensions.Literal["obj_refs", b"obj_refs"] ) -> None: ... - def WhichOneof( - self, oneof_group: typing_extensions.Literal["_evict_only", b"_evict_only"] - ) -> typing_extensions.Literal["evict_only"] | None: ... class CleanCache(google.protobuf.message.Message): """Force to clean up all the ML cached objects""" @@ -360,34 +342,6 @@ class MlCommand(google.protobuf.message.Message): self, oneof_group: typing_extensions.Literal["_params", b"_params"] ) -> typing_extensions.Literal["params"] | None: ... - class CreateSummary(google.protobuf.message.Message): - """This is for re-creating the model summary when the model summary is lost - (model summary is lost when the model is offloaded and then loaded back) - """ - - DESCRIPTOR: google.protobuf.descriptor.Descriptor - - MODEL_REF_FIELD_NUMBER: builtins.int - DATASET_FIELD_NUMBER: builtins.int - @property - def model_ref(self) -> pyspark.sql.connect.proto.ml_common_pb2.ObjectRef: ... - @property - def dataset(self) -> pyspark.sql.connect.proto.relations_pb2.Relation: ... - def __init__( - self, - *, - model_ref: pyspark.sql.connect.proto.ml_common_pb2.ObjectRef | None = ..., - dataset: pyspark.sql.connect.proto.relations_pb2.Relation | None = ..., - ) -> None: ... - def HasField( - self, - field_name: typing_extensions.Literal["dataset", b"dataset", "model_ref", b"model_ref"], - ) -> builtins.bool: ... - def ClearField( - self, - field_name: typing_extensions.Literal["dataset", b"dataset", "model_ref", b"model_ref"], - ) -> None: ... - FIT_FIELD_NUMBER: builtins.int FETCH_FIELD_NUMBER: builtins.int DELETE_FIELD_NUMBER: builtins.int @@ -396,7 +350,6 @@ class MlCommand(google.protobuf.message.Message): EVALUATE_FIELD_NUMBER: builtins.int CLEAN_CACHE_FIELD_NUMBER: builtins.int GET_CACHE_INFO_FIELD_NUMBER: builtins.int - CREATE_SUMMARY_FIELD_NUMBER: builtins.int @property def fit(self) -> global___MlCommand.Fit: ... @property @@ -413,8 +366,6 @@ class MlCommand(google.protobuf.message.Message): def clean_cache(self) -> global___MlCommand.CleanCache: ... @property def get_cache_info(self) -> global___MlCommand.GetCacheInfo: ... - @property - def create_summary(self) -> global___MlCommand.CreateSummary: ... def __init__( self, *, @@ -426,7 +377,6 @@ class MlCommand(google.protobuf.message.Message): evaluate: global___MlCommand.Evaluate | None = ..., clean_cache: global___MlCommand.CleanCache | None = ..., get_cache_info: global___MlCommand.GetCacheInfo | None = ..., - create_summary: global___MlCommand.CreateSummary | None = ..., ) -> None: ... def HasField( self, @@ -435,8 +385,6 @@ class MlCommand(google.protobuf.message.Message): b"clean_cache", "command", b"command", - "create_summary", - b"create_summary", "delete", b"delete", "evaluate", @@ -460,8 +408,6 @@ class MlCommand(google.protobuf.message.Message): b"clean_cache", "command", b"command", - "create_summary", - b"create_summary", "delete", b"delete", "evaluate", @@ -482,15 +428,7 @@ class MlCommand(google.protobuf.message.Message): self, oneof_group: typing_extensions.Literal["command", b"command"] ) -> ( typing_extensions.Literal[ - "fit", - "fetch", - "delete", - "write", - "read", - "evaluate", - "clean_cache", - "get_cache_info", - "create_summary", + "fit", "fetch", "delete", "write", "read", "evaluate", "clean_cache", "get_cache_info" ] | None ): ... diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py b/python/pyspark/sql/connect/proto/relations_pb2.py index 3774bcbdbfb0e..525ba88ff67c6 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.py +++ b/python/pyspark/sql/connect/proto/relations_pb2.py @@ -43,7 +43,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/ml_common.proto"\x9c\x1d\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12*\n\x05to_df\x18\x12 \x01(\x0b\x32\x13.spark.connect.ToDFH\x00R\x04toDf\x12U\n\x14with_columns_renamed\x18\x13 \x01(\x0b\x32!.spark.connect.WithColumnsRenamedH\x00R\x12withColumnsRenamed\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12)\n\x04\x64rop\x18\x15 \x01(\x0b\x32\x13.spark.connect.DropH\x00R\x04\x64rop\x12)\n\x04tail\x18\x16 \x01(\x0b\x32\x13.spark.connect.TailH\x00R\x04tail\x12?\n\x0cwith_columns\x18\x17 \x01(\x0b\x32\x1a.spark.connect.WithColumnsH\x00R\x0bwithColumns\x12)\n\x04hint\x18\x18 \x01(\x0b\x32\x13.spark.connect.HintH\x00R\x04hint\x12\x32\n\x07unpivot\x18\x19 \x01(\x0b\x32\x16.spark.connect.UnpivotH\x00R\x07unpivot\x12\x36\n\tto_schema\x18\x1a \x01(\x0b\x32\x17.spark.connect.ToSchemaH\x00R\x08toSchema\x12\x64\n\x19repartition_by_expression\x18\x1b \x01(\x0b\x32&.spark.connect.RepartitionByExpressionH\x00R\x17repartitionByExpression\x12\x45\n\x0emap_partitions\x18\x1c \x01(\x0b\x32\x1c.spark.connect.MapPartitionsH\x00R\rmapPartitions\x12H\n\x0f\x63ollect_metrics\x18\x1d \x01(\x0b\x32\x1d.spark.connect.CollectMetricsH\x00R\x0e\x63ollectMetrics\x12,\n\x05parse\x18\x1e \x01(\x0b\x32\x14.spark.connect.ParseH\x00R\x05parse\x12\x36\n\tgroup_map\x18\x1f \x01(\x0b\x32\x17.spark.connect.GroupMapH\x00R\x08groupMap\x12=\n\x0c\x63o_group_map\x18 \x01(\x0b\x32\x19.spark.connect.CoGroupMapH\x00R\ncoGroupMap\x12\x45\n\x0ewith_watermark\x18! \x01(\x0b\x32\x1c.spark.connect.WithWatermarkH\x00R\rwithWatermark\x12\x63\n\x1a\x61pply_in_pandas_with_state\x18" \x01(\x0b\x32%.spark.connect.ApplyInPandasWithStateH\x00R\x16\x61pplyInPandasWithState\x12<\n\x0bhtml_string\x18# \x01(\x0b\x32\x19.spark.connect.HtmlStringH\x00R\nhtmlString\x12X\n\x15\x63\x61\x63hed_local_relation\x18$ \x01(\x0b\x32".spark.connect.CachedLocalRelationH\x00R\x13\x63\x61\x63hedLocalRelation\x12[\n\x16\x63\x61\x63hed_remote_relation\x18% \x01(\x0b\x32#.spark.connect.CachedRemoteRelationH\x00R\x14\x63\x61\x63hedRemoteRelation\x12\x8e\x01\n)common_inline_user_defined_table_function\x18& \x01(\x0b\x32\x33.spark.connect.CommonInlineUserDefinedTableFunctionH\x00R$commonInlineUserDefinedTableFunction\x12\x37\n\nas_of_join\x18\' \x01(\x0b\x32\x17.spark.connect.AsOfJoinH\x00R\x08\x61sOfJoin\x12\x85\x01\n&common_inline_user_defined_data_source\x18( \x01(\x0b\x32\x30.spark.connect.CommonInlineUserDefinedDataSourceH\x00R!commonInlineUserDefinedDataSource\x12\x45\n\x0ewith_relations\x18) \x01(\x0b\x32\x1c.spark.connect.WithRelationsH\x00R\rwithRelations\x12\x38\n\ttranspose\x18* \x01(\x0b\x32\x18.spark.connect.TransposeH\x00R\ttranspose\x12w\n unresolved_table_valued_function\x18+ \x01(\x0b\x32,.spark.connect.UnresolvedTableValuedFunctionH\x00R\x1dunresolvedTableValuedFunction\x12?\n\x0clateral_join\x18, \x01(\x0b\x32\x1a.spark.connect.LateralJoinH\x00R\x0blateralJoin\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x30\n\x07\x64rop_na\x18[ \x01(\x0b\x32\x15.spark.connect.NADropH\x00R\x06\x64ropNa\x12\x34\n\x07replace\x18\\ \x01(\x0b\x32\x18.spark.connect.NAReplaceH\x00R\x07replace\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x39\n\x08\x64\x65scribe\x18\x66 \x01(\x0b\x32\x1b.spark.connect.StatDescribeH\x00R\x08\x64\x65scribe\x12*\n\x03\x63ov\x18g \x01(\x0b\x32\x16.spark.connect.StatCovH\x00R\x03\x63ov\x12-\n\x04\x63orr\x18h \x01(\x0b\x32\x17.spark.connect.StatCorrH\x00R\x04\x63orr\x12L\n\x0f\x61pprox_quantile\x18i \x01(\x0b\x32!.spark.connect.StatApproxQuantileH\x00R\x0e\x61pproxQuantile\x12=\n\nfreq_items\x18j \x01(\x0b\x32\x1c.spark.connect.StatFreqItemsH\x00R\tfreqItems\x12:\n\tsample_by\x18k \x01(\x0b\x32\x1b.spark.connect.StatSampleByH\x00R\x08sampleBy\x12\x33\n\x07\x63\x61talog\x18\xc8\x01 \x01(\x0b\x32\x16.spark.connect.CatalogH\x00R\x07\x63\x61talog\x12=\n\x0bml_relation\x18\xac\x02 \x01(\x0b\x32\x19.spark.connect.MlRelationH\x00R\nmlRelation\x12\x35\n\textension\x18\xe6\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\xe4\x03\n\nMlRelation\x12\x43\n\ttransform\x18\x01 \x01(\x0b\x32#.spark.connect.MlRelation.TransformH\x00R\ttransform\x12,\n\x05\x66\x65tch\x18\x02 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x12P\n\x15model_summary_dataset\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationH\x01R\x13modelSummaryDataset\x88\x01\x01\x1a\xeb\x01\n\tTransform\x12\x33\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12=\n\x0btransformer\x18\x02 \x01(\x0b\x32\x19.spark.connect.MlOperatorH\x00R\x0btransformer\x12-\n\x05input\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12/\n\x06params\x18\x04 \x01(\x0b\x32\x17.spark.connect.MlParamsR\x06paramsB\n\n\x08operatorB\t\n\x07ml_typeB\x18\n\x16_model_summary_dataset"\xcb\x02\n\x05\x46\x65tch\x12\x31\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefR\x06objRef\x12\x35\n\x07methods\x18\x02 \x03(\x0b\x32\x1b.spark.connect.Fetch.MethodR\x07methods\x1a\xd7\x01\n\x06Method\x12\x16\n\x06method\x18\x01 \x01(\tR\x06method\x12\x34\n\x04\x61rgs\x18\x02 \x03(\x0b\x32 .spark.connect.Fetch.Method.ArgsR\x04\x61rgs\x1a\x7f\n\x04\x41rgs\x12\x39\n\x05param\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x05param\x12/\n\x05input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x05inputB\x0b\n\targs_type"\t\n\x07Unknown"\x8e\x01\n\x0eRelationCommon\x12#\n\x0bsource_info\x18\x01 \x01(\tB\x02\x18\x01R\nsourceInfo\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x12-\n\x06origin\x18\x03 \x01(\x0b\x32\x15.spark.connect.OriginR\x06originB\n\n\x08_plan_id"\xde\x03\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query\x12\x34\n\x04\x61rgs\x18\x02 \x03(\x0b\x32\x1c.spark.connect.SQL.ArgsEntryB\x02\x18\x01R\x04\x61rgs\x12@\n\x08pos_args\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralB\x02\x18\x01R\x07posArgs\x12O\n\x0fnamed_arguments\x18\x04 \x03(\x0b\x32&.spark.connect.SQL.NamedArgumentsEntryR\x0enamedArguments\x12>\n\rpos_arguments\x18\x05 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0cposArguments\x1aZ\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01\x1a\\\n\x13NamedArgumentsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value:\x02\x38\x01"u\n\rWithRelations\x12+\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04root\x12\x37\n\nreferences\x18\x02 \x03(\x0b\x32\x17.spark.connect.RelationR\nreferences"\x97\x05\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x12!\n\x0cis_streaming\x18\x03 \x01(\x08R\x0bisStreaming\x1a\xc0\x01\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x12\x45\n\x07options\x18\x02 \x03(\x0b\x32+.spark.connect.Read.NamedTable.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x95\x02\n\nDataSource\x12\x1b\n\x06\x66ormat\x18\x01 \x01(\tH\x00R\x06\x66ormat\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x12\x14\n\x05paths\x18\x04 \x03(\tR\x05paths\x12\x1e\n\npredicates\x18\x05 \x03(\tR\npredicates\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_formatB\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\x95\x05\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns\x12K\n\x0ejoin_data_type\x18\x06 \x01(\x0b\x32 .spark.connect.Join.JoinDataTypeH\x00R\x0cjoinDataType\x88\x01\x01\x1a\\\n\x0cJoinDataType\x12$\n\x0eis_left_struct\x18\x01 \x01(\x08R\x0cisLeftStruct\x12&\n\x0fis_right_struct\x18\x02 \x01(\x08R\risRightStruct"\xd0\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06\x12\x13\n\x0fJOIN_TYPE_CROSS\x10\x07\x42\x11\n\x0f_join_data_type"\xdf\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01\x12\x37\n\x15\x61llow_missing_columns\x18\x06 \x01(\x08H\x02R\x13\x61llowMissingColumns\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_nameB\x18\n\x16_allow_missing_columns"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"K\n\x04Tail\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"\xfe\x05\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x41\n\ngroup_type\x18\x02 \x01(\x0e\x32".spark.connect.Aggregate.GroupTypeR\tgroupType\x12L\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12N\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x14\x61ggregateExpressions\x12\x34\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.spark.connect.Aggregate.PivotR\x05pivot\x12J\n\rgrouping_sets\x18\x06 \x03(\x0b\x32%.spark.connect.Aggregate.GroupingSetsR\x0cgroupingSets\x1ao\n\x05Pivot\x12+\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1aL\n\x0cGroupingSets\x12<\n\x0cgrouping_set\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0bgroupingSet"\x9f\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04\x12\x1c\n\x18GROUP_TYPE_GROUPING_SETS\x10\x05"\xa0\x01\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x05order\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\x05order\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x42\x0c\n\n_is_global"\x8d\x01\n\x04\x44rop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x33\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07\x63olumns\x12!\n\x0c\x63olumn_names\x18\x03 \x03(\tR\x0b\x63olumnNames"\xf0\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x12.\n\x10within_watermark\x18\x04 \x01(\x08H\x01R\x0fwithinWatermark\x88\x01\x01\x42\x16\n\x14_all_columns_as_keysB\x13\n\x11_within_watermark"Y\n\rLocalRelation\x12\x17\n\x04\x64\x61ta\x18\x01 \x01(\x0cH\x00R\x04\x64\x61ta\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x42\x07\n\x05_dataB\t\n\x07_schema"H\n\x13\x43\x61\x63hedLocalRelation\x12\x12\n\x04hash\x18\x03 \x01(\tR\x04hashJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03R\x06userIdR\tsessionId"7\n\x14\x43\x61\x63hedRemoteRelation\x12\x1f\n\x0brelation_id\x18\x01 \x01(\tR\nrelationId"\x91\x02\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x12/\n\x13\x64\x65terministic_order\x18\x06 \x01(\x08R\x12\x64\x65terministicOrderB\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8e\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"r\n\nHtmlString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"Q\n\x0cStatDescribe\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"`\n\x07StatCov\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x89\x01\n\x08StatCorr\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2\x12\x1b\n\x06method\x18\x04 \x01(\tH\x00R\x06method\x88\x01\x01\x42\t\n\x07_method"\xa4\x01\n\x12StatApproxQuantile\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12$\n\rprobabilities\x18\x03 \x03(\x01R\rprobabilities\x12%\n\x0erelative_error\x18\x04 \x01(\x01R\rrelativeError"}\n\rStatFreqItems\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x1d\n\x07support\x18\x03 \x01(\x01H\x00R\x07support\x88\x01\x01\x42\n\n\x08_support"\xb5\x02\n\x0cStatSampleBy\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03\x63ol\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x42\n\tfractions\x18\x03 \x03(\x0b\x32$.spark.connect.StatSampleBy.FractionR\tfractions\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x00R\x04seed\x88\x01\x01\x1a\x63\n\x08\x46raction\x12;\n\x07stratum\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x07stratum\x12\x1a\n\x08\x66raction\x18\x02 \x01(\x01R\x08\x66ractionB\x07\n\x05_seed"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"\x86\x01\n\x06NADrop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\'\n\rmin_non_nulls\x18\x03 \x01(\x05H\x00R\x0bminNonNulls\x88\x01\x01\x42\x10\n\x0e_min_non_nulls"\xa8\x02\n\tNAReplace\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12H\n\x0creplacements\x18\x03 \x03(\x0b\x32$.spark.connect.NAReplace.ReplacementR\x0creplacements\x1a\x8d\x01\n\x0bReplacement\x12>\n\told_value\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08oldValue\x12>\n\tnew_value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08newValue"X\n\x04ToDF\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\xfe\x02\n\x12WithColumnsRenamed\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12i\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x37.spark.connect.WithColumnsRenamed.RenameColumnsMapEntryB\x02\x18\x01R\x10renameColumnsMap\x12\x42\n\x07renames\x18\x03 \x03(\x0b\x32(.spark.connect.WithColumnsRenamed.RenameR\x07renames\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x45\n\x06Rename\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x12 \n\x0cnew_col_name\x18\x02 \x01(\tR\nnewColName"w\n\x0bWithColumns\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x07\x61liases\x18\x02 \x03(\x0b\x32\x1f.spark.connect.Expression.AliasR\x07\x61liases"\x86\x01\n\rWithWatermark\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\nevent_time\x18\x02 \x01(\tR\teventTime\x12\'\n\x0f\x64\x65lay_threshold\x18\x03 \x01(\tR\x0e\x64\x65layThreshold"\x84\x01\n\x04Hint\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x39\n\nparameters\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\nparameters"\xc7\x02\n\x07Unpivot\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03ids\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x03ids\x12:\n\x06values\x18\x03 \x01(\x0b\x32\x1d.spark.connect.Unpivot.ValuesH\x00R\x06values\x88\x01\x01\x12\x30\n\x14variable_column_name\x18\x04 \x01(\tR\x12variableColumnName\x12*\n\x11value_column_name\x18\x05 \x01(\tR\x0fvalueColumnName\x1a;\n\x06Values\x12\x31\n\x06values\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x06valuesB\t\n\x07_values"z\n\tTranspose\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\rindex_columns\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0cindexColumns"}\n\x1dUnresolvedTableValuedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"j\n\x08ToSchema\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema"\xcb\x01\n\x17RepartitionByExpression\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x0fpartition_exprs\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0epartitionExprs\x12*\n\x0enum_partitions\x18\x03 \x01(\x05H\x00R\rnumPartitions\x88\x01\x01\x42\x11\n\x0f_num_partitions"\xe8\x01\n\rMapPartitions\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x04\x66unc\x18\x02 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12"\n\nis_barrier\x18\x03 \x01(\x08H\x00R\tisBarrier\x88\x01\x01\x12"\n\nprofile_id\x18\x04 \x01(\x05H\x01R\tprofileId\x88\x01\x01\x42\r\n\x0b_is_barrierB\r\n\x0b_profile_id"\xd2\x06\n\x08GroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12J\n\x13sorting_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x12sortingExpressions\x12<\n\rinitial_input\x18\x05 \x01(\x0b\x32\x17.spark.connect.RelationR\x0cinitialInput\x12[\n\x1cinitial_grouping_expressions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x1ainitialGroupingExpressions\x12;\n\x18is_map_groups_with_state\x18\x07 \x01(\x08H\x00R\x14isMapGroupsWithState\x88\x01\x01\x12$\n\x0boutput_mode\x18\x08 \x01(\tH\x01R\noutputMode\x88\x01\x01\x12&\n\x0ctimeout_conf\x18\t \x01(\tH\x02R\x0btimeoutConf\x88\x01\x01\x12?\n\x0cstate_schema\x18\n \x01(\x0b\x32\x17.spark.connect.DataTypeH\x03R\x0bstateSchema\x88\x01\x01\x12\x65\n\x19transform_with_state_info\x18\x0b \x01(\x0b\x32%.spark.connect.TransformWithStateInfoH\x04R\x16transformWithStateInfo\x88\x01\x01\x42\x1b\n\x19_is_map_groups_with_stateB\x0e\n\x0c_output_modeB\x0f\n\r_timeout_confB\x0f\n\r_state_schemaB\x1c\n\x1a_transform_with_state_info"\xdf\x01\n\x16TransformWithStateInfo\x12\x1b\n\ttime_mode\x18\x01 \x01(\tR\x08timeMode\x12\x38\n\x16\x65vent_time_column_name\x18\x02 \x01(\tH\x00R\x13\x65ventTimeColumnName\x88\x01\x01\x12\x41\n\routput_schema\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x01R\x0coutputSchema\x88\x01\x01\x42\x19\n\x17_event_time_column_nameB\x10\n\x0e_output_schema"\x8e\x04\n\nCoGroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12W\n\x1ainput_grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x18inputGroupingExpressions\x12-\n\x05other\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x05other\x12W\n\x1aother_grouping_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x18otherGroupingExpressions\x12\x42\n\x04\x66unc\x18\x05 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12U\n\x19input_sorting_expressions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x17inputSortingExpressions\x12U\n\x19other_sorting_expressions\x18\x07 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x17otherSortingExpressions"\xe5\x02\n\x16\x41pplyInPandasWithState\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12#\n\routput_schema\x18\x04 \x01(\tR\x0coutputSchema\x12!\n\x0cstate_schema\x18\x05 \x01(\tR\x0bstateSchema\x12\x1f\n\x0boutput_mode\x18\x06 \x01(\tR\noutputMode\x12!\n\x0ctimeout_conf\x18\x07 \x01(\tR\x0btimeoutConf"\xf4\x01\n$CommonInlineUserDefinedTableFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12$\n\rdeterministic\x18\x02 \x01(\x08R\rdeterministic\x12\x37\n\targuments\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12<\n\x0bpython_udtf\x18\x04 \x01(\x0b\x32\x19.spark.connect.PythonUDTFH\x00R\npythonUdtfB\n\n\x08\x66unction"\xb1\x01\n\nPythonUDTF\x12=\n\x0breturn_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\nreturnType\x88\x01\x01\x12\x1b\n\teval_type\x18\x02 \x01(\x05R\x08\x65valType\x12\x18\n\x07\x63ommand\x18\x03 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x04 \x01(\tR\tpythonVerB\x0e\n\x0c_return_type"\x97\x01\n!CommonInlineUserDefinedDataSource\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12O\n\x12python_data_source\x18\x02 \x01(\x0b\x32\x1f.spark.connect.PythonDataSourceH\x00R\x10pythonDataSourceB\r\n\x0b\x64\x61ta_source"K\n\x10PythonDataSource\x12\x18\n\x07\x63ommand\x18\x01 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x02 \x01(\tR\tpythonVer"\x88\x01\n\x0e\x43ollectMetrics\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x33\n\x07metrics\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07metrics"\x84\x03\n\x05Parse\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x38\n\x06\x66ormat\x18\x02 \x01(\x0e\x32 .spark.connect.Parse.ParseFormatR\x06\x66ormat\x12\x34\n\x06schema\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x06schema\x88\x01\x01\x12;\n\x07options\x18\x04 \x03(\x0b\x32!.spark.connect.Parse.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"X\n\x0bParseFormat\x12\x1c\n\x18PARSE_FORMAT_UNSPECIFIED\x10\x00\x12\x14\n\x10PARSE_FORMAT_CSV\x10\x01\x12\x15\n\x11PARSE_FORMAT_JSON\x10\x02\x42\t\n\x07_schema"\xdb\x03\n\x08\x41sOfJoin\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12\x37\n\nleft_as_of\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08leftAsOf\x12\x39\n\x0bright_as_of\x18\x04 \x01(\x0b\x32\x19.spark.connect.ExpressionR\trightAsOf\x12\x36\n\tjoin_expr\x18\x05 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08joinExpr\x12#\n\rusing_columns\x18\x06 \x03(\tR\x0cusingColumns\x12\x1b\n\tjoin_type\x18\x07 \x01(\tR\x08joinType\x12\x37\n\ttolerance\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\ttolerance\x12.\n\x13\x61llow_exact_matches\x18\t \x01(\x08R\x11\x61llowExactMatches\x12\x1c\n\tdirection\x18\n \x01(\tR\tdirection"\xe6\x01\n\x0bLateralJoin\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinTypeB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/ml_common.proto"\x9c\x1d\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01 \x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02 \x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03 \x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66ilter\x18\x04 \x01(\x0b\x32\x15.spark.connect.FilterH\x00R\x06\x66ilter\x12)\n\x04join\x18\x05 \x01(\x0b\x32\x13.spark.connect.JoinH\x00R\x04join\x12\x34\n\x06set_op\x18\x06 \x01(\x0b\x32\x1b.spark.connect.SetOperationH\x00R\x05setOp\x12)\n\x04sort\x18\x07 \x01(\x0b\x32\x13.spark.connect.SortH\x00R\x04sort\x12,\n\x05limit\x18\x08 \x01(\x0b\x32\x14.spark.connect.LimitH\x00R\x05limit\x12\x38\n\taggregate\x18\t \x01(\x0b\x32\x18.spark.connect.AggregateH\x00R\taggregate\x12&\n\x03sql\x18\n \x01(\x0b\x32\x12.spark.connect.SQLH\x00R\x03sql\x12\x45\n\x0elocal_relation\x18\x0b \x01(\x0b\x32\x1c.spark.connect.LocalRelationH\x00R\rlocalRelation\x12/\n\x06sample\x18\x0c \x01(\x0b\x32\x15.spark.connect.SampleH\x00R\x06sample\x12/\n\x06offset\x18\r \x01(\x0b\x32\x15.spark.connect.OffsetH\x00R\x06offset\x12>\n\x0b\x64\x65\x64uplicate\x18\x0e \x01(\x0b\x32\x1a.spark.connect.DeduplicateH\x00R\x0b\x64\x65\x64uplicate\x12,\n\x05range\x18\x0f \x01(\x0b\x32\x14.spark.connect.RangeH\x00R\x05range\x12\x45\n\x0esubquery_alias\x18\x10 \x01(\x0b\x32\x1c.spark.connect.SubqueryAliasH\x00R\rsubqueryAlias\x12>\n\x0brepartition\x18\x11 \x01(\x0b\x32\x1a.spark.connect.RepartitionH\x00R\x0brepartition\x12*\n\x05to_df\x18\x12 \x01(\x0b\x32\x13.spark.connect.ToDFH\x00R\x04toDf\x12U\n\x14with_columns_renamed\x18\x13 \x01(\x0b\x32!.spark.connect.WithColumnsRenamedH\x00R\x12withColumnsRenamed\x12<\n\x0bshow_string\x18\x14 \x01(\x0b\x32\x19.spark.connect.ShowStringH\x00R\nshowString\x12)\n\x04\x64rop\x18\x15 \x01(\x0b\x32\x13.spark.connect.DropH\x00R\x04\x64rop\x12)\n\x04tail\x18\x16 \x01(\x0b\x32\x13.spark.connect.TailH\x00R\x04tail\x12?\n\x0cwith_columns\x18\x17 \x01(\x0b\x32\x1a.spark.connect.WithColumnsH\x00R\x0bwithColumns\x12)\n\x04hint\x18\x18 \x01(\x0b\x32\x13.spark.connect.HintH\x00R\x04hint\x12\x32\n\x07unpivot\x18\x19 \x01(\x0b\x32\x16.spark.connect.UnpivotH\x00R\x07unpivot\x12\x36\n\tto_schema\x18\x1a \x01(\x0b\x32\x17.spark.connect.ToSchemaH\x00R\x08toSchema\x12\x64\n\x19repartition_by_expression\x18\x1b \x01(\x0b\x32&.spark.connect.RepartitionByExpressionH\x00R\x17repartitionByExpression\x12\x45\n\x0emap_partitions\x18\x1c \x01(\x0b\x32\x1c.spark.connect.MapPartitionsH\x00R\rmapPartitions\x12H\n\x0f\x63ollect_metrics\x18\x1d \x01(\x0b\x32\x1d.spark.connect.CollectMetricsH\x00R\x0e\x63ollectMetrics\x12,\n\x05parse\x18\x1e \x01(\x0b\x32\x14.spark.connect.ParseH\x00R\x05parse\x12\x36\n\tgroup_map\x18\x1f \x01(\x0b\x32\x17.spark.connect.GroupMapH\x00R\x08groupMap\x12=\n\x0c\x63o_group_map\x18 \x01(\x0b\x32\x19.spark.connect.CoGroupMapH\x00R\ncoGroupMap\x12\x45\n\x0ewith_watermark\x18! \x01(\x0b\x32\x1c.spark.connect.WithWatermarkH\x00R\rwithWatermark\x12\x63\n\x1a\x61pply_in_pandas_with_state\x18" \x01(\x0b\x32%.spark.connect.ApplyInPandasWithStateH\x00R\x16\x61pplyInPandasWithState\x12<\n\x0bhtml_string\x18# \x01(\x0b\x32\x19.spark.connect.HtmlStringH\x00R\nhtmlString\x12X\n\x15\x63\x61\x63hed_local_relation\x18$ \x01(\x0b\x32".spark.connect.CachedLocalRelationH\x00R\x13\x63\x61\x63hedLocalRelation\x12[\n\x16\x63\x61\x63hed_remote_relation\x18% \x01(\x0b\x32#.spark.connect.CachedRemoteRelationH\x00R\x14\x63\x61\x63hedRemoteRelation\x12\x8e\x01\n)common_inline_user_defined_table_function\x18& \x01(\x0b\x32\x33.spark.connect.CommonInlineUserDefinedTableFunctionH\x00R$commonInlineUserDefinedTableFunction\x12\x37\n\nas_of_join\x18\' \x01(\x0b\x32\x17.spark.connect.AsOfJoinH\x00R\x08\x61sOfJoin\x12\x85\x01\n&common_inline_user_defined_data_source\x18( \x01(\x0b\x32\x30.spark.connect.CommonInlineUserDefinedDataSourceH\x00R!commonInlineUserDefinedDataSource\x12\x45\n\x0ewith_relations\x18) \x01(\x0b\x32\x1c.spark.connect.WithRelationsH\x00R\rwithRelations\x12\x38\n\ttranspose\x18* \x01(\x0b\x32\x18.spark.connect.TransposeH\x00R\ttranspose\x12w\n unresolved_table_valued_function\x18+ \x01(\x0b\x32,.spark.connect.UnresolvedTableValuedFunctionH\x00R\x1dunresolvedTableValuedFunction\x12?\n\x0clateral_join\x18, \x01(\x0b\x32\x1a.spark.connect.LateralJoinH\x00R\x0blateralJoin\x12\x30\n\x07\x66ill_na\x18Z \x01(\x0b\x32\x15.spark.connect.NAFillH\x00R\x06\x66illNa\x12\x30\n\x07\x64rop_na\x18[ \x01(\x0b\x32\x15.spark.connect.NADropH\x00R\x06\x64ropNa\x12\x34\n\x07replace\x18\\ \x01(\x0b\x32\x18.spark.connect.NAReplaceH\x00R\x07replace\x12\x36\n\x07summary\x18\x64 \x01(\x0b\x32\x1a.spark.connect.StatSummaryH\x00R\x07summary\x12\x39\n\x08\x63rosstab\x18\x65 \x01(\x0b\x32\x1b.spark.connect.StatCrosstabH\x00R\x08\x63rosstab\x12\x39\n\x08\x64\x65scribe\x18\x66 \x01(\x0b\x32\x1b.spark.connect.StatDescribeH\x00R\x08\x64\x65scribe\x12*\n\x03\x63ov\x18g \x01(\x0b\x32\x16.spark.connect.StatCovH\x00R\x03\x63ov\x12-\n\x04\x63orr\x18h \x01(\x0b\x32\x17.spark.connect.StatCorrH\x00R\x04\x63orr\x12L\n\x0f\x61pprox_quantile\x18i \x01(\x0b\x32!.spark.connect.StatApproxQuantileH\x00R\x0e\x61pproxQuantile\x12=\n\nfreq_items\x18j \x01(\x0b\x32\x1c.spark.connect.StatFreqItemsH\x00R\tfreqItems\x12:\n\tsample_by\x18k \x01(\x0b\x32\x1b.spark.connect.StatSampleByH\x00R\x08sampleBy\x12\x33\n\x07\x63\x61talog\x18\xc8\x01 \x01(\x0b\x32\x16.spark.connect.CatalogH\x00R\x07\x63\x61talog\x12=\n\x0bml_relation\x18\xac\x02 \x01(\x0b\x32\x19.spark.connect.MlRelationH\x00R\nmlRelation\x12\x35\n\textension\x18\xe6\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x33\n\x07unknown\x18\xe7\x07 \x01(\x0b\x32\x16.spark.connect.UnknownH\x00R\x07unknownB\n\n\x08rel_type"\xf8\x02\n\nMlRelation\x12\x43\n\ttransform\x18\x01 \x01(\x0b\x32#.spark.connect.MlRelation.TransformH\x00R\ttransform\x12,\n\x05\x66\x65tch\x18\x02 \x01(\x0b\x32\x14.spark.connect.FetchH\x00R\x05\x66\x65tch\x1a\xeb\x01\n\tTransform\x12\x33\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefH\x00R\x06objRef\x12=\n\x0btransformer\x18\x02 \x01(\x0b\x32\x19.spark.connect.MlOperatorH\x00R\x0btransformer\x12-\n\x05input\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12/\n\x06params\x18\x04 \x01(\x0b\x32\x17.spark.connect.MlParamsR\x06paramsB\n\n\x08operatorB\t\n\x07ml_type"\xcb\x02\n\x05\x46\x65tch\x12\x31\n\x07obj_ref\x18\x01 \x01(\x0b\x32\x18.spark.connect.ObjectRefR\x06objRef\x12\x35\n\x07methods\x18\x02 \x03(\x0b\x32\x1b.spark.connect.Fetch.MethodR\x07methods\x1a\xd7\x01\n\x06Method\x12\x16\n\x06method\x18\x01 \x01(\tR\x06method\x12\x34\n\x04\x61rgs\x18\x02 \x03(\x0b\x32 .spark.connect.Fetch.Method.ArgsR\x04\x61rgs\x1a\x7f\n\x04\x41rgs\x12\x39\n\x05param\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralH\x00R\x05param\x12/\n\x05input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x05inputB\x0b\n\targs_type"\t\n\x07Unknown"\x8e\x01\n\x0eRelationCommon\x12#\n\x0bsource_info\x18\x01 \x01(\tB\x02\x18\x01R\nsourceInfo\x12\x1c\n\x07plan_id\x18\x02 \x01(\x03H\x00R\x06planId\x88\x01\x01\x12-\n\x06origin\x18\x03 \x01(\x0b\x32\x15.spark.connect.OriginR\x06originB\n\n\x08_plan_id"\xde\x03\n\x03SQL\x12\x14\n\x05query\x18\x01 \x01(\tR\x05query\x12\x34\n\x04\x61rgs\x18\x02 \x03(\x0b\x32\x1c.spark.connect.SQL.ArgsEntryB\x02\x18\x01R\x04\x61rgs\x12@\n\x08pos_args\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralB\x02\x18\x01R\x07posArgs\x12O\n\x0fnamed_arguments\x18\x04 \x03(\x0b\x32&.spark.connect.SQL.NamedArgumentsEntryR\x0enamedArguments\x12>\n\rpos_arguments\x18\x05 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0cposArguments\x1aZ\n\tArgsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x37\n\x05value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x05value:\x02\x38\x01\x1a\\\n\x13NamedArgumentsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12/\n\x05value\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x05value:\x02\x38\x01"u\n\rWithRelations\x12+\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04root\x12\x37\n\nreferences\x18\x02 \x03(\x0b\x32\x17.spark.connect.RelationR\nreferences"\x97\x05\n\x04Read\x12\x41\n\x0bnamed_table\x18\x01 \x01(\x0b\x32\x1e.spark.connect.Read.NamedTableH\x00R\nnamedTable\x12\x41\n\x0b\x64\x61ta_source\x18\x02 \x01(\x0b\x32\x1e.spark.connect.Read.DataSourceH\x00R\ndataSource\x12!\n\x0cis_streaming\x18\x03 \x01(\x08R\x0bisStreaming\x1a\xc0\x01\n\nNamedTable\x12/\n\x13unparsed_identifier\x18\x01 \x01(\tR\x12unparsedIdentifier\x12\x45\n\x07options\x18\x02 \x03(\x0b\x32+.spark.connect.Read.NamedTable.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x95\x02\n\nDataSource\x12\x1b\n\x06\x66ormat\x18\x01 \x01(\tH\x00R\x06\x66ormat\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x12\x45\n\x07options\x18\x03 \x03(\x0b\x32+.spark.connect.Read.DataSource.OptionsEntryR\x07options\x12\x14\n\x05paths\x18\x04 \x03(\tR\x05paths\x12\x1e\n\npredicates\x18\x05 \x03(\tR\npredicates\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_formatB\t\n\x07_schemaB\x0b\n\tread_type"u\n\x07Project\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12;\n\x0b\x65xpressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0b\x65xpressions"p\n\x06\x46ilter\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x37\n\tcondition\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\tcondition"\x95\x05\n\x04Join\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinType\x12#\n\rusing_columns\x18\x05 \x03(\tR\x0cusingColumns\x12K\n\x0ejoin_data_type\x18\x06 \x01(\x0b\x32 .spark.connect.Join.JoinDataTypeH\x00R\x0cjoinDataType\x88\x01\x01\x1a\\\n\x0cJoinDataType\x12$\n\x0eis_left_struct\x18\x01 \x01(\x08R\x0cisLeftStruct\x12&\n\x0fis_right_struct\x18\x02 \x01(\x08R\risRightStruct"\xd0\x01\n\x08JoinType\x12\x19\n\x15JOIN_TYPE_UNSPECIFIED\x10\x00\x12\x13\n\x0fJOIN_TYPE_INNER\x10\x01\x12\x18\n\x14JOIN_TYPE_FULL_OUTER\x10\x02\x12\x18\n\x14JOIN_TYPE_LEFT_OUTER\x10\x03\x12\x19\n\x15JOIN_TYPE_RIGHT_OUTER\x10\x04\x12\x17\n\x13JOIN_TYPE_LEFT_ANTI\x10\x05\x12\x17\n\x13JOIN_TYPE_LEFT_SEMI\x10\x06\x12\x13\n\x0fJOIN_TYPE_CROSS\x10\x07\x42\x11\n\x0f_join_data_type"\xdf\x03\n\x0cSetOperation\x12\x36\n\nleft_input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\tleftInput\x12\x38\n\x0bright_input\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\nrightInput\x12\x45\n\x0bset_op_type\x18\x03 \x01(\x0e\x32%.spark.connect.SetOperation.SetOpTypeR\tsetOpType\x12\x1a\n\x06is_all\x18\x04 \x01(\x08H\x00R\x05isAll\x88\x01\x01\x12\x1c\n\x07\x62y_name\x18\x05 \x01(\x08H\x01R\x06\x62yName\x88\x01\x01\x12\x37\n\x15\x61llow_missing_columns\x18\x06 \x01(\x08H\x02R\x13\x61llowMissingColumns\x88\x01\x01"r\n\tSetOpType\x12\x1b\n\x17SET_OP_TYPE_UNSPECIFIED\x10\x00\x12\x19\n\x15SET_OP_TYPE_INTERSECT\x10\x01\x12\x15\n\x11SET_OP_TYPE_UNION\x10\x02\x12\x16\n\x12SET_OP_TYPE_EXCEPT\x10\x03\x42\t\n\x07_is_allB\n\n\x08_by_nameB\x18\n\x16_allow_missing_columns"L\n\x05Limit\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"O\n\x06Offset\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x16\n\x06offset\x18\x02 \x01(\x05R\x06offset"K\n\x04Tail\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05limit\x18\x02 \x01(\x05R\x05limit"\xfe\x05\n\tAggregate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x41\n\ngroup_type\x18\x02 \x01(\x0e\x32".spark.connect.Aggregate.GroupTypeR\tgroupType\x12L\n\x14grouping_expressions\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12N\n\x15\x61ggregate_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x14\x61ggregateExpressions\x12\x34\n\x05pivot\x18\x05 \x01(\x0b\x32\x1e.spark.connect.Aggregate.PivotR\x05pivot\x12J\n\rgrouping_sets\x18\x06 \x03(\x0b\x32%.spark.connect.Aggregate.GroupingSetsR\x0cgroupingSets\x1ao\n\x05Pivot\x12+\n\x03\x63ol\x18\x01 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x39\n\x06values\x18\x02 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values\x1aL\n\x0cGroupingSets\x12<\n\x0cgrouping_set\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0bgroupingSet"\x9f\x01\n\tGroupType\x12\x1a\n\x16GROUP_TYPE_UNSPECIFIED\x10\x00\x12\x16\n\x12GROUP_TYPE_GROUPBY\x10\x01\x12\x15\n\x11GROUP_TYPE_ROLLUP\x10\x02\x12\x13\n\x0fGROUP_TYPE_CUBE\x10\x03\x12\x14\n\x10GROUP_TYPE_PIVOT\x10\x04\x12\x1c\n\x18GROUP_TYPE_GROUPING_SETS\x10\x05"\xa0\x01\n\x04Sort\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x05order\x18\x02 \x03(\x0b\x32#.spark.connect.Expression.SortOrderR\x05order\x12 \n\tis_global\x18\x03 \x01(\x08H\x00R\x08isGlobal\x88\x01\x01\x42\x0c\n\n_is_global"\x8d\x01\n\x04\x44rop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x33\n\x07\x63olumns\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07\x63olumns\x12!\n\x0c\x63olumn_names\x18\x03 \x03(\tR\x0b\x63olumnNames"\xf0\x01\n\x0b\x44\x65\x64uplicate\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames\x12\x32\n\x13\x61ll_columns_as_keys\x18\x03 \x01(\x08H\x00R\x10\x61llColumnsAsKeys\x88\x01\x01\x12.\n\x10within_watermark\x18\x04 \x01(\x08H\x01R\x0fwithinWatermark\x88\x01\x01\x42\x16\n\x14_all_columns_as_keysB\x13\n\x11_within_watermark"Y\n\rLocalRelation\x12\x17\n\x04\x64\x61ta\x18\x01 \x01(\x0cH\x00R\x04\x64\x61ta\x88\x01\x01\x12\x1b\n\x06schema\x18\x02 \x01(\tH\x01R\x06schema\x88\x01\x01\x42\x07\n\x05_dataB\t\n\x07_schema"H\n\x13\x43\x61\x63hedLocalRelation\x12\x12\n\x04hash\x18\x03 \x01(\tR\x04hashJ\x04\x08\x01\x10\x02J\x04\x08\x02\x10\x03R\x06userIdR\tsessionId"7\n\x14\x43\x61\x63hedRemoteRelation\x12\x1f\n\x0brelation_id\x18\x01 \x01(\tR\nrelationId"\x91\x02\n\x06Sample\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1f\n\x0blower_bound\x18\x02 \x01(\x01R\nlowerBound\x12\x1f\n\x0bupper_bound\x18\x03 \x01(\x01R\nupperBound\x12.\n\x10with_replacement\x18\x04 \x01(\x08H\x00R\x0fwithReplacement\x88\x01\x01\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x01R\x04seed\x88\x01\x01\x12/\n\x13\x64\x65terministic_order\x18\x06 \x01(\x08R\x12\x64\x65terministicOrderB\x13\n\x11_with_replacementB\x07\n\x05_seed"\x91\x01\n\x05Range\x12\x19\n\x05start\x18\x01 \x01(\x03H\x00R\x05start\x88\x01\x01\x12\x10\n\x03\x65nd\x18\x02 \x01(\x03R\x03\x65nd\x12\x12\n\x04step\x18\x03 \x01(\x03R\x04step\x12*\n\x0enum_partitions\x18\x04 \x01(\x05H\x01R\rnumPartitions\x88\x01\x01\x42\x08\n\x06_startB\x11\n\x0f_num_partitions"r\n\rSubqueryAlias\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x14\n\x05\x61lias\x18\x02 \x01(\tR\x05\x61lias\x12\x1c\n\tqualifier\x18\x03 \x03(\tR\tqualifier"\x8e\x01\n\x0bRepartition\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12%\n\x0enum_partitions\x18\x02 \x01(\x05R\rnumPartitions\x12\x1d\n\x07shuffle\x18\x03 \x01(\x08H\x00R\x07shuffle\x88\x01\x01\x42\n\n\x08_shuffle"\x8e\x01\n\nShowString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate\x12\x1a\n\x08vertical\x18\x04 \x01(\x08R\x08vertical"r\n\nHtmlString\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x19\n\x08num_rows\x18\x02 \x01(\x05R\x07numRows\x12\x1a\n\x08truncate\x18\x03 \x01(\x05R\x08truncate"\\\n\x0bStatSummary\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1e\n\nstatistics\x18\x02 \x03(\tR\nstatistics"Q\n\x0cStatDescribe\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols"e\n\x0cStatCrosstab\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"`\n\x07StatCov\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2"\x89\x01\n\x08StatCorr\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ol1\x18\x02 \x01(\tR\x04\x63ol1\x12\x12\n\x04\x63ol2\x18\x03 \x01(\tR\x04\x63ol2\x12\x1b\n\x06method\x18\x04 \x01(\tH\x00R\x06method\x88\x01\x01\x42\t\n\x07_method"\xa4\x01\n\x12StatApproxQuantile\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12$\n\rprobabilities\x18\x03 \x03(\x01R\rprobabilities\x12%\n\x0erelative_error\x18\x04 \x01(\x01R\rrelativeError"}\n\rStatFreqItems\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x1d\n\x07support\x18\x03 \x01(\x01H\x00R\x07support\x88\x01\x01\x42\n\n\x08_support"\xb5\x02\n\x0cStatSampleBy\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03\x63ol\x18\x02 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x03\x63ol\x12\x42\n\tfractions\x18\x03 \x03(\x0b\x32$.spark.connect.StatSampleBy.FractionR\tfractions\x12\x17\n\x04seed\x18\x05 \x01(\x03H\x00R\x04seed\x88\x01\x01\x1a\x63\n\x08\x46raction\x12;\n\x07stratum\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x07stratum\x12\x1a\n\x08\x66raction\x18\x02 \x01(\x01R\x08\x66ractionB\x07\n\x05_seed"\x86\x01\n\x06NAFill\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\x39\n\x06values\x18\x03 \x03(\x0b\x32!.spark.connect.Expression.LiteralR\x06values"\x86\x01\n\x06NADrop\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12\'\n\rmin_non_nulls\x18\x03 \x01(\x05H\x00R\x0bminNonNulls\x88\x01\x01\x42\x10\n\x0e_min_non_nulls"\xa8\x02\n\tNAReplace\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04\x63ols\x18\x02 \x03(\tR\x04\x63ols\x12H\n\x0creplacements\x18\x03 \x03(\x0b\x32$.spark.connect.NAReplace.ReplacementR\x0creplacements\x1a\x8d\x01\n\x0bReplacement\x12>\n\told_value\x18\x01 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08oldValue\x12>\n\tnew_value\x18\x02 \x01(\x0b\x32!.spark.connect.Expression.LiteralR\x08newValue"X\n\x04ToDF\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12!\n\x0c\x63olumn_names\x18\x02 \x03(\tR\x0b\x63olumnNames"\xfe\x02\n\x12WithColumnsRenamed\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12i\n\x12rename_columns_map\x18\x02 \x03(\x0b\x32\x37.spark.connect.WithColumnsRenamed.RenameColumnsMapEntryB\x02\x18\x01R\x10renameColumnsMap\x12\x42\n\x07renames\x18\x03 \x03(\x0b\x32(.spark.connect.WithColumnsRenamed.RenameR\x07renames\x1a\x43\n\x15RenameColumnsMapEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x45\n\x06Rename\x12\x19\n\x08\x63ol_name\x18\x01 \x01(\tR\x07\x63olName\x12 \n\x0cnew_col_name\x18\x02 \x01(\tR\nnewColName"w\n\x0bWithColumns\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x39\n\x07\x61liases\x18\x02 \x03(\x0b\x32\x1f.spark.connect.Expression.AliasR\x07\x61liases"\x86\x01\n\rWithWatermark\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x1d\n\nevent_time\x18\x02 \x01(\tR\teventTime\x12\'\n\x0f\x64\x65lay_threshold\x18\x03 \x01(\tR\x0e\x64\x65layThreshold"\x84\x01\n\x04Hint\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x39\n\nparameters\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\nparameters"\xc7\x02\n\x07Unpivot\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12+\n\x03ids\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x03ids\x12:\n\x06values\x18\x03 \x01(\x0b\x32\x1d.spark.connect.Unpivot.ValuesH\x00R\x06values\x88\x01\x01\x12\x30\n\x14variable_column_name\x18\x04 \x01(\tR\x12variableColumnName\x12*\n\x11value_column_name\x18\x05 \x01(\tR\x0fvalueColumnName\x1a;\n\x06Values\x12\x31\n\x06values\x18\x01 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x06valuesB\t\n\x07_values"z\n\tTranspose\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12>\n\rindex_columns\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0cindexColumns"}\n\x1dUnresolvedTableValuedFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12\x37\n\targuments\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments"j\n\x08ToSchema\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12/\n\x06schema\x18\x02 \x01(\x0b\x32\x17.spark.connect.DataTypeR\x06schema"\xcb\x01\n\x17RepartitionByExpression\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x0fpartition_exprs\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x0epartitionExprs\x12*\n\x0enum_partitions\x18\x03 \x01(\x05H\x00R\rnumPartitions\x88\x01\x01\x42\x11\n\x0f_num_partitions"\xe8\x01\n\rMapPartitions\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x42\n\x04\x66unc\x18\x02 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12"\n\nis_barrier\x18\x03 \x01(\x08H\x00R\tisBarrier\x88\x01\x01\x12"\n\nprofile_id\x18\x04 \x01(\x05H\x01R\tprofileId\x88\x01\x01\x42\r\n\x0b_is_barrierB\r\n\x0b_profile_id"\xd2\x06\n\x08GroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12J\n\x13sorting_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x12sortingExpressions\x12<\n\rinitial_input\x18\x05 \x01(\x0b\x32\x17.spark.connect.RelationR\x0cinitialInput\x12[\n\x1cinitial_grouping_expressions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x1ainitialGroupingExpressions\x12;\n\x18is_map_groups_with_state\x18\x07 \x01(\x08H\x00R\x14isMapGroupsWithState\x88\x01\x01\x12$\n\x0boutput_mode\x18\x08 \x01(\tH\x01R\noutputMode\x88\x01\x01\x12&\n\x0ctimeout_conf\x18\t \x01(\tH\x02R\x0btimeoutConf\x88\x01\x01\x12?\n\x0cstate_schema\x18\n \x01(\x0b\x32\x17.spark.connect.DataTypeH\x03R\x0bstateSchema\x88\x01\x01\x12\x65\n\x19transform_with_state_info\x18\x0b \x01(\x0b\x32%.spark.connect.TransformWithStateInfoH\x04R\x16transformWithStateInfo\x88\x01\x01\x42\x1b\n\x19_is_map_groups_with_stateB\x0e\n\x0c_output_modeB\x0f\n\r_timeout_confB\x0f\n\r_state_schemaB\x1c\n\x1a_transform_with_state_info"\xdf\x01\n\x16TransformWithStateInfo\x12\x1b\n\ttime_mode\x18\x01 \x01(\tR\x08timeMode\x12\x38\n\x16\x65vent_time_column_name\x18\x02 \x01(\tH\x00R\x13\x65ventTimeColumnName\x88\x01\x01\x12\x41\n\routput_schema\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x01R\x0coutputSchema\x88\x01\x01\x42\x19\n\x17_event_time_column_nameB\x10\n\x0e_output_schema"\x8e\x04\n\nCoGroupMap\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12W\n\x1ainput_grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x18inputGroupingExpressions\x12-\n\x05other\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationR\x05other\x12W\n\x1aother_grouping_expressions\x18\x04 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x18otherGroupingExpressions\x12\x42\n\x04\x66unc\x18\x05 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12U\n\x19input_sorting_expressions\x18\x06 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x17inputSortingExpressions\x12U\n\x19other_sorting_expressions\x18\x07 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x17otherSortingExpressions"\xe5\x02\n\x16\x41pplyInPandasWithState\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12L\n\x14grouping_expressions\x18\x02 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x13groupingExpressions\x12\x42\n\x04\x66unc\x18\x03 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionR\x04\x66unc\x12#\n\routput_schema\x18\x04 \x01(\tR\x0coutputSchema\x12!\n\x0cstate_schema\x18\x05 \x01(\tR\x0bstateSchema\x12\x1f\n\x0boutput_mode\x18\x06 \x01(\tR\noutputMode\x12!\n\x0ctimeout_conf\x18\x07 \x01(\tR\x0btimeoutConf"\xf4\x01\n$CommonInlineUserDefinedTableFunction\x12#\n\rfunction_name\x18\x01 \x01(\tR\x0c\x66unctionName\x12$\n\rdeterministic\x18\x02 \x01(\x08R\rdeterministic\x12\x37\n\targuments\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\targuments\x12<\n\x0bpython_udtf\x18\x04 \x01(\x0b\x32\x19.spark.connect.PythonUDTFH\x00R\npythonUdtfB\n\n\x08\x66unction"\xb1\x01\n\nPythonUDTF\x12=\n\x0breturn_type\x18\x01 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\nreturnType\x88\x01\x01\x12\x1b\n\teval_type\x18\x02 \x01(\x05R\x08\x65valType\x12\x18\n\x07\x63ommand\x18\x03 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x04 \x01(\tR\tpythonVerB\x0e\n\x0c_return_type"\x97\x01\n!CommonInlineUserDefinedDataSource\x12\x12\n\x04name\x18\x01 \x01(\tR\x04name\x12O\n\x12python_data_source\x18\x02 \x01(\x0b\x32\x1f.spark.connect.PythonDataSourceH\x00R\x10pythonDataSourceB\r\n\x0b\x64\x61ta_source"K\n\x10PythonDataSource\x12\x18\n\x07\x63ommand\x18\x01 \x01(\x0cR\x07\x63ommand\x12\x1d\n\npython_ver\x18\x02 \x01(\tR\tpythonVer"\x88\x01\n\x0e\x43ollectMetrics\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x12\n\x04name\x18\x02 \x01(\tR\x04name\x12\x33\n\x07metrics\x18\x03 \x03(\x0b\x32\x19.spark.connect.ExpressionR\x07metrics"\x84\x03\n\x05Parse\x12-\n\x05input\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x05input\x12\x38\n\x06\x66ormat\x18\x02 \x01(\x0e\x32 .spark.connect.Parse.ParseFormatR\x06\x66ormat\x12\x34\n\x06schema\x18\x03 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x06schema\x88\x01\x01\x12;\n\x07options\x18\x04 \x03(\x0b\x32!.spark.connect.Parse.OptionsEntryR\x07options\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01"X\n\x0bParseFormat\x12\x1c\n\x18PARSE_FORMAT_UNSPECIFIED\x10\x00\x12\x14\n\x10PARSE_FORMAT_CSV\x10\x01\x12\x15\n\x11PARSE_FORMAT_JSON\x10\x02\x42\t\n\x07_schema"\xdb\x03\n\x08\x41sOfJoin\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12\x37\n\nleft_as_of\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08leftAsOf\x12\x39\n\x0bright_as_of\x18\x04 \x01(\x0b\x32\x19.spark.connect.ExpressionR\trightAsOf\x12\x36\n\tjoin_expr\x18\x05 \x01(\x0b\x32\x19.spark.connect.ExpressionR\x08joinExpr\x12#\n\rusing_columns\x18\x06 \x03(\tR\x0cusingColumns\x12\x1b\n\tjoin_type\x18\x07 \x01(\tR\x08joinType\x12\x37\n\ttolerance\x18\x08 \x01(\x0b\x32\x19.spark.connect.ExpressionR\ttolerance\x12.\n\x13\x61llow_exact_matches\x18\t \x01(\x08R\x11\x61llowExactMatches\x12\x1c\n\tdirection\x18\n \x01(\tR\tdirection"\xe6\x01\n\x0bLateralJoin\x12+\n\x04left\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationR\x04left\x12-\n\x05right\x18\x02 \x01(\x0b\x32\x17.spark.connect.RelationR\x05right\x12@\n\x0ejoin_condition\x18\x03 \x01(\x0b\x32\x19.spark.connect.ExpressionR\rjoinCondition\x12\x39\n\tjoin_type\x18\x04 \x01(\x0e\x32\x1c.spark.connect.Join.JoinTypeR\x08joinTypeB6\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _globals = globals() @@ -81,169 +81,169 @@ _globals["_RELATION"]._serialized_start = 224 _globals["_RELATION"]._serialized_end = 3964 _globals["_MLRELATION"]._serialized_start = 3967 - _globals["_MLRELATION"]._serialized_end = 4451 - _globals["_MLRELATION_TRANSFORM"]._serialized_start = 4179 - _globals["_MLRELATION_TRANSFORM"]._serialized_end = 4414 - _globals["_FETCH"]._serialized_start = 4454 - _globals["_FETCH"]._serialized_end = 4785 - _globals["_FETCH_METHOD"]._serialized_start = 4570 - _globals["_FETCH_METHOD"]._serialized_end = 4785 - _globals["_FETCH_METHOD_ARGS"]._serialized_start = 4658 - _globals["_FETCH_METHOD_ARGS"]._serialized_end = 4785 - _globals["_UNKNOWN"]._serialized_start = 4787 - _globals["_UNKNOWN"]._serialized_end = 4796 - _globals["_RELATIONCOMMON"]._serialized_start = 4799 - _globals["_RELATIONCOMMON"]._serialized_end = 4941 - _globals["_SQL"]._serialized_start = 4944 - _globals["_SQL"]._serialized_end = 5422 - _globals["_SQL_ARGSENTRY"]._serialized_start = 5238 - _globals["_SQL_ARGSENTRY"]._serialized_end = 5328 - _globals["_SQL_NAMEDARGUMENTSENTRY"]._serialized_start = 5330 - _globals["_SQL_NAMEDARGUMENTSENTRY"]._serialized_end = 5422 - _globals["_WITHRELATIONS"]._serialized_start = 5424 - _globals["_WITHRELATIONS"]._serialized_end = 5541 - _globals["_READ"]._serialized_start = 5544 - _globals["_READ"]._serialized_end = 6207 - _globals["_READ_NAMEDTABLE"]._serialized_start = 5722 - _globals["_READ_NAMEDTABLE"]._serialized_end = 5914 - _globals["_READ_NAMEDTABLE_OPTIONSENTRY"]._serialized_start = 5856 - _globals["_READ_NAMEDTABLE_OPTIONSENTRY"]._serialized_end = 5914 - _globals["_READ_DATASOURCE"]._serialized_start = 5917 - _globals["_READ_DATASOURCE"]._serialized_end = 6194 - _globals["_READ_DATASOURCE_OPTIONSENTRY"]._serialized_start = 5856 - _globals["_READ_DATASOURCE_OPTIONSENTRY"]._serialized_end = 5914 - _globals["_PROJECT"]._serialized_start = 6209 - _globals["_PROJECT"]._serialized_end = 6326 - _globals["_FILTER"]._serialized_start = 6328 - _globals["_FILTER"]._serialized_end = 6440 - _globals["_JOIN"]._serialized_start = 6443 - _globals["_JOIN"]._serialized_end = 7104 - _globals["_JOIN_JOINDATATYPE"]._serialized_start = 6782 - _globals["_JOIN_JOINDATATYPE"]._serialized_end = 6874 - _globals["_JOIN_JOINTYPE"]._serialized_start = 6877 - _globals["_JOIN_JOINTYPE"]._serialized_end = 7085 - _globals["_SETOPERATION"]._serialized_start = 7107 - _globals["_SETOPERATION"]._serialized_end = 7586 - _globals["_SETOPERATION_SETOPTYPE"]._serialized_start = 7423 - _globals["_SETOPERATION_SETOPTYPE"]._serialized_end = 7537 - _globals["_LIMIT"]._serialized_start = 7588 - _globals["_LIMIT"]._serialized_end = 7664 - _globals["_OFFSET"]._serialized_start = 7666 - _globals["_OFFSET"]._serialized_end = 7745 - _globals["_TAIL"]._serialized_start = 7747 - _globals["_TAIL"]._serialized_end = 7822 - _globals["_AGGREGATE"]._serialized_start = 7825 - _globals["_AGGREGATE"]._serialized_end = 8591 - _globals["_AGGREGATE_PIVOT"]._serialized_start = 8240 - _globals["_AGGREGATE_PIVOT"]._serialized_end = 8351 - _globals["_AGGREGATE_GROUPINGSETS"]._serialized_start = 8353 - _globals["_AGGREGATE_GROUPINGSETS"]._serialized_end = 8429 - _globals["_AGGREGATE_GROUPTYPE"]._serialized_start = 8432 - _globals["_AGGREGATE_GROUPTYPE"]._serialized_end = 8591 - _globals["_SORT"]._serialized_start = 8594 - _globals["_SORT"]._serialized_end = 8754 - _globals["_DROP"]._serialized_start = 8757 - _globals["_DROP"]._serialized_end = 8898 - _globals["_DEDUPLICATE"]._serialized_start = 8901 - _globals["_DEDUPLICATE"]._serialized_end = 9141 - _globals["_LOCALRELATION"]._serialized_start = 9143 - _globals["_LOCALRELATION"]._serialized_end = 9232 - _globals["_CACHEDLOCALRELATION"]._serialized_start = 9234 - _globals["_CACHEDLOCALRELATION"]._serialized_end = 9306 - _globals["_CACHEDREMOTERELATION"]._serialized_start = 9308 - _globals["_CACHEDREMOTERELATION"]._serialized_end = 9363 - _globals["_SAMPLE"]._serialized_start = 9366 - _globals["_SAMPLE"]._serialized_end = 9639 - _globals["_RANGE"]._serialized_start = 9642 - _globals["_RANGE"]._serialized_end = 9787 - _globals["_SUBQUERYALIAS"]._serialized_start = 9789 - _globals["_SUBQUERYALIAS"]._serialized_end = 9903 - _globals["_REPARTITION"]._serialized_start = 9906 - _globals["_REPARTITION"]._serialized_end = 10048 - _globals["_SHOWSTRING"]._serialized_start = 10051 - _globals["_SHOWSTRING"]._serialized_end = 10193 - _globals["_HTMLSTRING"]._serialized_start = 10195 - _globals["_HTMLSTRING"]._serialized_end = 10309 - _globals["_STATSUMMARY"]._serialized_start = 10311 - _globals["_STATSUMMARY"]._serialized_end = 10403 - _globals["_STATDESCRIBE"]._serialized_start = 10405 - _globals["_STATDESCRIBE"]._serialized_end = 10486 - _globals["_STATCROSSTAB"]._serialized_start = 10488 - _globals["_STATCROSSTAB"]._serialized_end = 10589 - _globals["_STATCOV"]._serialized_start = 10591 - _globals["_STATCOV"]._serialized_end = 10687 - _globals["_STATCORR"]._serialized_start = 10690 - _globals["_STATCORR"]._serialized_end = 10827 - _globals["_STATAPPROXQUANTILE"]._serialized_start = 10830 - _globals["_STATAPPROXQUANTILE"]._serialized_end = 10994 - _globals["_STATFREQITEMS"]._serialized_start = 10996 - _globals["_STATFREQITEMS"]._serialized_end = 11121 - _globals["_STATSAMPLEBY"]._serialized_start = 11124 - _globals["_STATSAMPLEBY"]._serialized_end = 11433 - _globals["_STATSAMPLEBY_FRACTION"]._serialized_start = 11325 - _globals["_STATSAMPLEBY_FRACTION"]._serialized_end = 11424 - _globals["_NAFILL"]._serialized_start = 11436 - _globals["_NAFILL"]._serialized_end = 11570 - _globals["_NADROP"]._serialized_start = 11573 - _globals["_NADROP"]._serialized_end = 11707 - _globals["_NAREPLACE"]._serialized_start = 11710 - _globals["_NAREPLACE"]._serialized_end = 12006 - _globals["_NAREPLACE_REPLACEMENT"]._serialized_start = 11865 - _globals["_NAREPLACE_REPLACEMENT"]._serialized_end = 12006 - _globals["_TODF"]._serialized_start = 12008 - _globals["_TODF"]._serialized_end = 12096 - _globals["_WITHCOLUMNSRENAMED"]._serialized_start = 12099 - _globals["_WITHCOLUMNSRENAMED"]._serialized_end = 12481 - _globals["_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY"]._serialized_start = 12343 - _globals["_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY"]._serialized_end = 12410 - _globals["_WITHCOLUMNSRENAMED_RENAME"]._serialized_start = 12412 - _globals["_WITHCOLUMNSRENAMED_RENAME"]._serialized_end = 12481 - _globals["_WITHCOLUMNS"]._serialized_start = 12483 - _globals["_WITHCOLUMNS"]._serialized_end = 12602 - _globals["_WITHWATERMARK"]._serialized_start = 12605 - _globals["_WITHWATERMARK"]._serialized_end = 12739 - _globals["_HINT"]._serialized_start = 12742 - _globals["_HINT"]._serialized_end = 12874 - _globals["_UNPIVOT"]._serialized_start = 12877 - _globals["_UNPIVOT"]._serialized_end = 13204 - _globals["_UNPIVOT_VALUES"]._serialized_start = 13134 - _globals["_UNPIVOT_VALUES"]._serialized_end = 13193 - _globals["_TRANSPOSE"]._serialized_start = 13206 - _globals["_TRANSPOSE"]._serialized_end = 13328 - _globals["_UNRESOLVEDTABLEVALUEDFUNCTION"]._serialized_start = 13330 - _globals["_UNRESOLVEDTABLEVALUEDFUNCTION"]._serialized_end = 13455 - _globals["_TOSCHEMA"]._serialized_start = 13457 - _globals["_TOSCHEMA"]._serialized_end = 13563 - _globals["_REPARTITIONBYEXPRESSION"]._serialized_start = 13566 - _globals["_REPARTITIONBYEXPRESSION"]._serialized_end = 13769 - _globals["_MAPPARTITIONS"]._serialized_start = 13772 - _globals["_MAPPARTITIONS"]._serialized_end = 14004 - _globals["_GROUPMAP"]._serialized_start = 14007 - _globals["_GROUPMAP"]._serialized_end = 14857 - _globals["_TRANSFORMWITHSTATEINFO"]._serialized_start = 14860 - _globals["_TRANSFORMWITHSTATEINFO"]._serialized_end = 15083 - _globals["_COGROUPMAP"]._serialized_start = 15086 - _globals["_COGROUPMAP"]._serialized_end = 15612 - _globals["_APPLYINPANDASWITHSTATE"]._serialized_start = 15615 - _globals["_APPLYINPANDASWITHSTATE"]._serialized_end = 15972 - _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_start = 15975 - _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_end = 16219 - _globals["_PYTHONUDTF"]._serialized_start = 16222 - _globals["_PYTHONUDTF"]._serialized_end = 16399 - _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_start = 16402 - _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_end = 16553 - _globals["_PYTHONDATASOURCE"]._serialized_start = 16555 - _globals["_PYTHONDATASOURCE"]._serialized_end = 16630 - _globals["_COLLECTMETRICS"]._serialized_start = 16633 - _globals["_COLLECTMETRICS"]._serialized_end = 16769 - _globals["_PARSE"]._serialized_start = 16772 - _globals["_PARSE"]._serialized_end = 17160 - _globals["_PARSE_OPTIONSENTRY"]._serialized_start = 5856 - _globals["_PARSE_OPTIONSENTRY"]._serialized_end = 5914 - _globals["_PARSE_PARSEFORMAT"]._serialized_start = 17061 - _globals["_PARSE_PARSEFORMAT"]._serialized_end = 17149 - _globals["_ASOFJOIN"]._serialized_start = 17163 - _globals["_ASOFJOIN"]._serialized_end = 17638 - _globals["_LATERALJOIN"]._serialized_start = 17641 - _globals["_LATERALJOIN"]._serialized_end = 17871 + _globals["_MLRELATION"]._serialized_end = 4343 + _globals["_MLRELATION_TRANSFORM"]._serialized_start = 4097 + _globals["_MLRELATION_TRANSFORM"]._serialized_end = 4332 + _globals["_FETCH"]._serialized_start = 4346 + _globals["_FETCH"]._serialized_end = 4677 + _globals["_FETCH_METHOD"]._serialized_start = 4462 + _globals["_FETCH_METHOD"]._serialized_end = 4677 + _globals["_FETCH_METHOD_ARGS"]._serialized_start = 4550 + _globals["_FETCH_METHOD_ARGS"]._serialized_end = 4677 + _globals["_UNKNOWN"]._serialized_start = 4679 + _globals["_UNKNOWN"]._serialized_end = 4688 + _globals["_RELATIONCOMMON"]._serialized_start = 4691 + _globals["_RELATIONCOMMON"]._serialized_end = 4833 + _globals["_SQL"]._serialized_start = 4836 + _globals["_SQL"]._serialized_end = 5314 + _globals["_SQL_ARGSENTRY"]._serialized_start = 5130 + _globals["_SQL_ARGSENTRY"]._serialized_end = 5220 + _globals["_SQL_NAMEDARGUMENTSENTRY"]._serialized_start = 5222 + _globals["_SQL_NAMEDARGUMENTSENTRY"]._serialized_end = 5314 + _globals["_WITHRELATIONS"]._serialized_start = 5316 + _globals["_WITHRELATIONS"]._serialized_end = 5433 + _globals["_READ"]._serialized_start = 5436 + _globals["_READ"]._serialized_end = 6099 + _globals["_READ_NAMEDTABLE"]._serialized_start = 5614 + _globals["_READ_NAMEDTABLE"]._serialized_end = 5806 + _globals["_READ_NAMEDTABLE_OPTIONSENTRY"]._serialized_start = 5748 + _globals["_READ_NAMEDTABLE_OPTIONSENTRY"]._serialized_end = 5806 + _globals["_READ_DATASOURCE"]._serialized_start = 5809 + _globals["_READ_DATASOURCE"]._serialized_end = 6086 + _globals["_READ_DATASOURCE_OPTIONSENTRY"]._serialized_start = 5748 + _globals["_READ_DATASOURCE_OPTIONSENTRY"]._serialized_end = 5806 + _globals["_PROJECT"]._serialized_start = 6101 + _globals["_PROJECT"]._serialized_end = 6218 + _globals["_FILTER"]._serialized_start = 6220 + _globals["_FILTER"]._serialized_end = 6332 + _globals["_JOIN"]._serialized_start = 6335 + _globals["_JOIN"]._serialized_end = 6996 + _globals["_JOIN_JOINDATATYPE"]._serialized_start = 6674 + _globals["_JOIN_JOINDATATYPE"]._serialized_end = 6766 + _globals["_JOIN_JOINTYPE"]._serialized_start = 6769 + _globals["_JOIN_JOINTYPE"]._serialized_end = 6977 + _globals["_SETOPERATION"]._serialized_start = 6999 + _globals["_SETOPERATION"]._serialized_end = 7478 + _globals["_SETOPERATION_SETOPTYPE"]._serialized_start = 7315 + _globals["_SETOPERATION_SETOPTYPE"]._serialized_end = 7429 + _globals["_LIMIT"]._serialized_start = 7480 + _globals["_LIMIT"]._serialized_end = 7556 + _globals["_OFFSET"]._serialized_start = 7558 + _globals["_OFFSET"]._serialized_end = 7637 + _globals["_TAIL"]._serialized_start = 7639 + _globals["_TAIL"]._serialized_end = 7714 + _globals["_AGGREGATE"]._serialized_start = 7717 + _globals["_AGGREGATE"]._serialized_end = 8483 + _globals["_AGGREGATE_PIVOT"]._serialized_start = 8132 + _globals["_AGGREGATE_PIVOT"]._serialized_end = 8243 + _globals["_AGGREGATE_GROUPINGSETS"]._serialized_start = 8245 + _globals["_AGGREGATE_GROUPINGSETS"]._serialized_end = 8321 + _globals["_AGGREGATE_GROUPTYPE"]._serialized_start = 8324 + _globals["_AGGREGATE_GROUPTYPE"]._serialized_end = 8483 + _globals["_SORT"]._serialized_start = 8486 + _globals["_SORT"]._serialized_end = 8646 + _globals["_DROP"]._serialized_start = 8649 + _globals["_DROP"]._serialized_end = 8790 + _globals["_DEDUPLICATE"]._serialized_start = 8793 + _globals["_DEDUPLICATE"]._serialized_end = 9033 + _globals["_LOCALRELATION"]._serialized_start = 9035 + _globals["_LOCALRELATION"]._serialized_end = 9124 + _globals["_CACHEDLOCALRELATION"]._serialized_start = 9126 + _globals["_CACHEDLOCALRELATION"]._serialized_end = 9198 + _globals["_CACHEDREMOTERELATION"]._serialized_start = 9200 + _globals["_CACHEDREMOTERELATION"]._serialized_end = 9255 + _globals["_SAMPLE"]._serialized_start = 9258 + _globals["_SAMPLE"]._serialized_end = 9531 + _globals["_RANGE"]._serialized_start = 9534 + _globals["_RANGE"]._serialized_end = 9679 + _globals["_SUBQUERYALIAS"]._serialized_start = 9681 + _globals["_SUBQUERYALIAS"]._serialized_end = 9795 + _globals["_REPARTITION"]._serialized_start = 9798 + _globals["_REPARTITION"]._serialized_end = 9940 + _globals["_SHOWSTRING"]._serialized_start = 9943 + _globals["_SHOWSTRING"]._serialized_end = 10085 + _globals["_HTMLSTRING"]._serialized_start = 10087 + _globals["_HTMLSTRING"]._serialized_end = 10201 + _globals["_STATSUMMARY"]._serialized_start = 10203 + _globals["_STATSUMMARY"]._serialized_end = 10295 + _globals["_STATDESCRIBE"]._serialized_start = 10297 + _globals["_STATDESCRIBE"]._serialized_end = 10378 + _globals["_STATCROSSTAB"]._serialized_start = 10380 + _globals["_STATCROSSTAB"]._serialized_end = 10481 + _globals["_STATCOV"]._serialized_start = 10483 + _globals["_STATCOV"]._serialized_end = 10579 + _globals["_STATCORR"]._serialized_start = 10582 + _globals["_STATCORR"]._serialized_end = 10719 + _globals["_STATAPPROXQUANTILE"]._serialized_start = 10722 + _globals["_STATAPPROXQUANTILE"]._serialized_end = 10886 + _globals["_STATFREQITEMS"]._serialized_start = 10888 + _globals["_STATFREQITEMS"]._serialized_end = 11013 + _globals["_STATSAMPLEBY"]._serialized_start = 11016 + _globals["_STATSAMPLEBY"]._serialized_end = 11325 + _globals["_STATSAMPLEBY_FRACTION"]._serialized_start = 11217 + _globals["_STATSAMPLEBY_FRACTION"]._serialized_end = 11316 + _globals["_NAFILL"]._serialized_start = 11328 + _globals["_NAFILL"]._serialized_end = 11462 + _globals["_NADROP"]._serialized_start = 11465 + _globals["_NADROP"]._serialized_end = 11599 + _globals["_NAREPLACE"]._serialized_start = 11602 + _globals["_NAREPLACE"]._serialized_end = 11898 + _globals["_NAREPLACE_REPLACEMENT"]._serialized_start = 11757 + _globals["_NAREPLACE_REPLACEMENT"]._serialized_end = 11898 + _globals["_TODF"]._serialized_start = 11900 + _globals["_TODF"]._serialized_end = 11988 + _globals["_WITHCOLUMNSRENAMED"]._serialized_start = 11991 + _globals["_WITHCOLUMNSRENAMED"]._serialized_end = 12373 + _globals["_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY"]._serialized_start = 12235 + _globals["_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY"]._serialized_end = 12302 + _globals["_WITHCOLUMNSRENAMED_RENAME"]._serialized_start = 12304 + _globals["_WITHCOLUMNSRENAMED_RENAME"]._serialized_end = 12373 + _globals["_WITHCOLUMNS"]._serialized_start = 12375 + _globals["_WITHCOLUMNS"]._serialized_end = 12494 + _globals["_WITHWATERMARK"]._serialized_start = 12497 + _globals["_WITHWATERMARK"]._serialized_end = 12631 + _globals["_HINT"]._serialized_start = 12634 + _globals["_HINT"]._serialized_end = 12766 + _globals["_UNPIVOT"]._serialized_start = 12769 + _globals["_UNPIVOT"]._serialized_end = 13096 + _globals["_UNPIVOT_VALUES"]._serialized_start = 13026 + _globals["_UNPIVOT_VALUES"]._serialized_end = 13085 + _globals["_TRANSPOSE"]._serialized_start = 13098 + _globals["_TRANSPOSE"]._serialized_end = 13220 + _globals["_UNRESOLVEDTABLEVALUEDFUNCTION"]._serialized_start = 13222 + _globals["_UNRESOLVEDTABLEVALUEDFUNCTION"]._serialized_end = 13347 + _globals["_TOSCHEMA"]._serialized_start = 13349 + _globals["_TOSCHEMA"]._serialized_end = 13455 + _globals["_REPARTITIONBYEXPRESSION"]._serialized_start = 13458 + _globals["_REPARTITIONBYEXPRESSION"]._serialized_end = 13661 + _globals["_MAPPARTITIONS"]._serialized_start = 13664 + _globals["_MAPPARTITIONS"]._serialized_end = 13896 + _globals["_GROUPMAP"]._serialized_start = 13899 + _globals["_GROUPMAP"]._serialized_end = 14749 + _globals["_TRANSFORMWITHSTATEINFO"]._serialized_start = 14752 + _globals["_TRANSFORMWITHSTATEINFO"]._serialized_end = 14975 + _globals["_COGROUPMAP"]._serialized_start = 14978 + _globals["_COGROUPMAP"]._serialized_end = 15504 + _globals["_APPLYINPANDASWITHSTATE"]._serialized_start = 15507 + _globals["_APPLYINPANDASWITHSTATE"]._serialized_end = 15864 + _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_start = 15867 + _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_end = 16111 + _globals["_PYTHONUDTF"]._serialized_start = 16114 + _globals["_PYTHONUDTF"]._serialized_end = 16291 + _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_start = 16294 + _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_end = 16445 + _globals["_PYTHONDATASOURCE"]._serialized_start = 16447 + _globals["_PYTHONDATASOURCE"]._serialized_end = 16522 + _globals["_COLLECTMETRICS"]._serialized_start = 16525 + _globals["_COLLECTMETRICS"]._serialized_end = 16661 + _globals["_PARSE"]._serialized_start = 16664 + _globals["_PARSE"]._serialized_end = 17052 + _globals["_PARSE_OPTIONSENTRY"]._serialized_start = 5748 + _globals["_PARSE_OPTIONSENTRY"]._serialized_end = 5806 + _globals["_PARSE_PARSEFORMAT"]._serialized_start = 16953 + _globals["_PARSE_PARSEFORMAT"]._serialized_end = 17041 + _globals["_ASOFJOIN"]._serialized_start = 17055 + _globals["_ASOFJOIN"]._serialized_end = 17530 + _globals["_LATERALJOIN"]._serialized_start = 17533 + _globals["_LATERALJOIN"]._serialized_end = 17763 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi b/python/pyspark/sql/connect/proto/relations_pb2.pyi index e1eb7945c19f0..beeeb712da762 100644 --- a/python/pyspark/sql/connect/proto/relations_pb2.pyi +++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi @@ -707,57 +707,28 @@ class MlRelation(google.protobuf.message.Message): TRANSFORM_FIELD_NUMBER: builtins.int FETCH_FIELD_NUMBER: builtins.int - MODEL_SUMMARY_DATASET_FIELD_NUMBER: builtins.int @property def transform(self) -> global___MlRelation.Transform: ... @property def fetch(self) -> global___Fetch: ... - @property - def model_summary_dataset(self) -> global___Relation: - """(Optional) the dataset for restoring the model summary""" def __init__( self, *, transform: global___MlRelation.Transform | None = ..., fetch: global___Fetch | None = ..., - model_summary_dataset: global___Relation | None = ..., ) -> None: ... def HasField( self, field_name: typing_extensions.Literal[ - "_model_summary_dataset", - b"_model_summary_dataset", - "fetch", - b"fetch", - "ml_type", - b"ml_type", - "model_summary_dataset", - b"model_summary_dataset", - "transform", - b"transform", + "fetch", b"fetch", "ml_type", b"ml_type", "transform", b"transform" ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ - "_model_summary_dataset", - b"_model_summary_dataset", - "fetch", - b"fetch", - "ml_type", - b"ml_type", - "model_summary_dataset", - b"model_summary_dataset", - "transform", - b"transform", + "fetch", b"fetch", "ml_type", b"ml_type", "transform", b"transform" ], ) -> None: ... - @typing.overload - def WhichOneof( - self, - oneof_group: typing_extensions.Literal["_model_summary_dataset", b"_model_summary_dataset"], - ) -> typing_extensions.Literal["model_summary_dataset"] | None: ... - @typing.overload def WhichOneof( self, oneof_group: typing_extensions.Literal["ml_type", b"ml_type"] ) -> typing_extensions.Literal["transform", "fetch"] | None: ... diff --git a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto index 3497284af4ab8..b66c0a186df39 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/ml.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/ml.proto @@ -38,7 +38,6 @@ message MlCommand { Evaluate evaluate = 6; CleanCache clean_cache = 7; GetCacheInfo get_cache_info = 8; - CreateSummary create_summary = 9; } // Command for estimator.fit(dataset) @@ -55,9 +54,6 @@ message MlCommand { // or summary evaluated by a model message Delete { repeated ObjectRef obj_refs = 1; - // if set `evict_only` to true, only evict the cached model from memory, - // but keep the offloaded model in Spark driver local disk. - optional bool evict_only = 2; } // Force to clean up all the ML cached objects @@ -102,13 +98,6 @@ message MlCommand { // (Required) the evaluating dataset Relation dataset = 3; } - - // This is for re-creating the model summary when the model summary is lost - // (model summary is lost when the model is offloaded and then loaded back) - message CreateSummary { - ObjectRef model_ref = 1; - Relation dataset = 2; - } } // The result of MlCommand diff --git a/sql/connect/common/src/main/protobuf/spark/connect/relations.proto b/sql/connect/common/src/main/protobuf/spark/connect/relations.proto index ccb674e812dc0..70a52a2111494 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/relations.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/relations.proto @@ -115,9 +115,6 @@ message MlRelation { Transform transform = 1; Fetch fetch = 2; } - // (Optional) the dataset for restoring the model summary - optional Relation model_summary_dataset = 3; - // Relation to represent transform(input) of the operator // which could be a cached model or a new transformer message Transform { diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala index b075187b7002f..ef1b17dc2221e 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala @@ -30,7 +30,7 @@ import org.apache.commons.io.FileUtils import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.ml.Model -import org.apache.spark.ml.util.{ConnectHelper, HasTrainingSummary, MLWritable, Summary} +import org.apache.spark.ml.util.{ConnectHelper, MLWritable, Summary} import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.service.SessionHolder @@ -115,12 +115,6 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { } } - private[spark] def getModelOffloadingPath(refId: String): Path = { - val path = offloadedModelsDir.resolve(refId) - require(path.startsWith(offloadedModelsDir)) - path - } - /** * Cache an object into a map of MLCache, and return its key * @param obj @@ -143,14 +137,9 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { } cachedModel.put(objectId, CacheItem(obj, sizeBytes)) if (getMemoryControlEnabled) { - val savePath = getModelOffloadingPath(objectId) + val savePath = offloadedModelsDir.resolve(objectId) + require(savePath.startsWith(offloadedModelsDir)) obj.asInstanceOf[MLWritable].write.saveToLocal(savePath.toString) - if (obj.isInstanceOf[HasTrainingSummary[_]] - && obj.asInstanceOf[HasTrainingSummary[_]].hasSummary) { - obj - .asInstanceOf[HasTrainingSummary[_]] - .saveSummary(savePath.resolve("summary").toString) - } Files.writeString(savePath.resolve(modelClassNameFile), obj.getClass.getName) totalMLCacheInMemorySizeBytes.addAndGet(sizeBytes) totalMLCacheSizeBytes.addAndGet(sizeBytes) @@ -187,7 +176,8 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { verifyObjectId(refId) var obj: Object = Option(cachedModel.get(refId)).map(_.obj).getOrElse(null) if (obj == null && getMemoryControlEnabled) { - val loadPath = getModelOffloadingPath(refId) + val loadPath = offloadedModelsDir.resolve(refId) + require(loadPath.startsWith(offloadedModelsDir)) if (Files.isDirectory(loadPath)) { val className = Files.readString(loadPath.resolve(modelClassNameFile)) obj = MLUtils.loadTransformer( @@ -204,13 +194,14 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { } } - def _removeModel(refId: String, evictOnly: Boolean): Boolean = { + def _removeModel(refId: String): Boolean = { verifyObjectId(refId) val removedModel = cachedModel.remove(refId) val removedFromMem = removedModel != null - val removedFromDisk = if (!evictOnly && removedModel != null && getMemoryControlEnabled) { + val removedFromDisk = if (removedModel != null && getMemoryControlEnabled) { totalMLCacheSizeBytes.addAndGet(-removedModel.sizeBytes) - val removePath = getModelOffloadingPath(refId) + val removePath = offloadedModelsDir.resolve(refId) + require(removePath.startsWith(offloadedModelsDir)) val offloadingPath = new File(removePath.toString) if (offloadingPath.exists()) { FileUtils.deleteDirectory(offloadingPath) @@ -229,8 +220,8 @@ private[connect] class MLCache(sessionHolder: SessionHolder) extends Logging { * @param refId * the key used to look up the corresponding object */ - def remove(refId: String, evictOnly: Boolean = false): Boolean = { - val modelIsRemoved = _removeModel(refId, evictOnly) + def remove(refId: String): Boolean = { + val modelIsRemoved = _removeModel(refId) modelIsRemoved } diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala index 847052be98a98..a017c719ed16e 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLException.scala @@ -51,9 +51,3 @@ private[spark] case class MLCacheSizeOverflowException(mlCacheMaxSize: Long) errorClass = "CONNECT_ML.ML_CACHE_SIZE_OVERFLOW_EXCEPTION", messageParameters = Map("mlCacheMaxSize" -> mlCacheMaxSize.toString), cause = null) - -private[spark] case class MLModelSummaryLostException(objectName: String) - extends SparkException( - errorClass = "CONNECT_ML.MODEL_SUMMARY_LOST", - messageParameters = Map("objectName" -> objectName), - cause = null) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala index 7220acb8feaca..2c6ccec75b667 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLHandler.scala @@ -42,6 +42,7 @@ private case class Method( /** * Helper function to get the attribute from an object by reflection + * c5 */ private class AttributeHelper( val sessionHolder: SessionHolder, @@ -115,6 +116,7 @@ private object ModelAttributeHelper { } // MLHandler is a utility to group all ML operations +// 1 private[connect] object MLHandler extends Logging { val currentSessionHolder = new ThreadLocal[SessionHolder] { @@ -229,8 +231,11 @@ private[connect] object MLHandler extends Logging { if (obj != null && obj.isInstanceOf[HasTrainingSummary[_]] && methods(0).getMethod == "summary" && !obj.asInstanceOf[HasTrainingSummary[_]].hasSummary) { - throw MLModelSummaryLostException(objRefId) + throw MLCacheInvalidException( + objRefId, + sessionHolder.mlCache.getOffloadingTimeoutMinute) } + val helper = AttributeHelper(sessionHolder, objRefId, methods) val attrResult = helper.getAttribute attrResult match { @@ -262,13 +267,9 @@ private[connect] object MLHandler extends Logging { case proto.MlCommand.CommandCase.DELETE => val ids = mutable.ArrayBuilder.make[String] - val deleteCmd = mlCommand.getDelete - val evictOnly = if (deleteCmd.hasEvictOnly) { - deleteCmd.getEvictOnly - } else { false } - deleteCmd.getObjRefsList.asScala.toArray.foreach { objId => + mlCommand.getDelete.getObjRefsList.asScala.toArray.foreach { objId => if (!objId.getId.contains(".")) { - if (mlCache.remove(objId.getId, evictOnly)) { + if (mlCache.remove(objId.getId)) { ids += objId.getId } } @@ -402,29 +403,10 @@ private[connect] object MLHandler extends Logging { .setParam(LiteralValueProtoConverter.toLiteralProto(metric)) .build() - case proto.MlCommand.CommandCase.CREATE_SUMMARY => - val createSummaryCmd = mlCommand.getCreateSummary - createModelSummary(sessionHolder, createSummaryCmd) - case other => throw MlUnsupportedException(s"$other not supported") } } - private def createModelSummary( - sessionHolder: SessionHolder, - createSummaryCmd: proto.MlCommand.CreateSummary): proto.MlCommandResult = { - val refId = createSummaryCmd.getModelRef.getId - val model = sessionHolder.mlCache.get(refId).asInstanceOf[HasTrainingSummary[_]] - val dataset = MLUtils.parseRelationProto(createSummaryCmd.getDataset, sessionHolder) - val modelPath = sessionHolder.mlCache.getModelOffloadingPath(refId) - val summaryPath = modelPath.resolve("summary").toString - model.loadSummary(summaryPath, dataset) - proto.MlCommandResult - .newBuilder() - .setParam(LiteralValueProtoConverter.toLiteralProto(true)) - .build() - } - def transformMLRelation(relation: proto.MlRelation, sessionHolder: SessionHolder): DataFrame = { relation.getMlTypeCase match { // Ml transform @@ -454,26 +436,10 @@ private[connect] object MLHandler extends Logging { // Get the attribute from a cached object which could be a model or summary case proto.MlRelation.MlTypeCase.FETCH => - val objRefId = relation.getFetch.getObjRef.getId - val methods = relation.getFetch.getMethodsList.asScala.toArray - val obj = sessionHolder.mlCache.get(objRefId) - if (obj != null && obj.isInstanceOf[HasTrainingSummary[_]] - && methods(0).getMethod == "summary" - && !obj.asInstanceOf[HasTrainingSummary[_]].hasSummary) { - - if (relation.hasModelSummaryDataset) { - val dataset = - MLUtils.parseRelationProto(relation.getModelSummaryDataset, sessionHolder) - val modelPath = sessionHolder.mlCache.getModelOffloadingPath(objRefId) - val summaryPath = modelPath.resolve("summary").toString - obj.asInstanceOf[HasTrainingSummary[_]].loadSummary(summaryPath, dataset) - } else { - // For old Spark client backward compatibility. - throw MLModelSummaryLostException(objRefId) - } - } - - val helper = AttributeHelper(sessionHolder, objRefId, methods) + val helper = AttributeHelper( + sessionHolder, + relation.getFetch.getObjRef.getId, + relation.getFetch.getMethodsList.asScala.toArray) helper.getAttribute.asInstanceOf[DataFrame] case other => throw MlUnsupportedException(s"$other not supported")