Skip to content

[SPARK-52495] Include partition columns in the single variant column #51206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ public static Variant parseJson(JsonParser parser, boolean allowDuplicateKeys)
return builder.result();
}

public static VariantBuilder parseJsonAndReturnBuilder(
JsonParser parser,
boolean allowDuplicateKeys) throws IOException {
VariantBuilder builder = new VariantBuilder(allowDuplicateKeys);
builder.buildJson(parser);
return builder;
}

// Build the variant metadata from `dictionaryKeys` and return the variant result.
public Variant result() {
int numKeys = dictionaryKeys.size();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ class UnivocityParser(
dataSchema: StructType,
requiredSchema: StructType,
val options: CSVOptions,
filters: Seq[Filter]) extends Logging {
filters: Seq[Filter],
partitionSchema: StructType = StructType(Seq.empty),
partitionValues: InternalRow = InternalRow.empty) extends Logging {
require(requiredSchema.toSet.subsetOf(dataSchema.toSet),
s"requiredSchema (${requiredSchema.catalogString}) should be the subset of " +
s"dataSchema (${dataSchema.catalogString}).")
Expand Down Expand Up @@ -369,6 +371,31 @@ class UnivocityParser(
fields.add(new VariantBuilder.FieldEntry(key, id, builder.getWritePos - start))
singleVariantFieldConverters(i).convertInput(builder, tokens(i))
}

// Add the partition columns to the variant object
if (partitionSchema.nonEmpty) {
partitionSchema.zipWithIndex.foreach { case (field, index) =>
val value = partitionValues.get(index, field.dataType)
if (value != null) {
val id = builder.addKey(field.name)
fields.add(new VariantBuilder.FieldEntry(field.name, id, builder.getWritePos - start))
field.dataType match {
case LongType => builder.appendLong(value.toString.toLong)
case _: DecimalType => builder.appendDecimal(
decimalParser(value.toString)
)
case DateType => builder.appendDate(dateFormatter.parse(value.toString))
case TimestampNTZType =>
builder.appendTimestampNtz(timestampNTZFormatter.parse(value.toString))
case TimestampType =>
builder.appendTimestamp(timestampFormatter.parse(value.toString))
case BooleanType => builder.appendBoolean(value.toString.toBoolean)
case StringType => builder.appendString(value.toString)
}
}
}
}

builder.finishWritingObject(start, fields)
val v = builder.result()
row(0) = new VariantVal(v.getValue, v.getMetadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ class JacksonParser(
schema: DataType,
val options: JSONOptions,
allowArrayAsStructs: Boolean,
filters: Seq[Filter] = Seq.empty) extends Logging {
filters: Seq[Filter] = Seq.empty,
partitionSchema: StructType = StructType(Seq.empty),
partitionValues: InternalRow = InternalRow.empty) extends Logging {

import JacksonUtils._
import com.fasterxml.jackson.core.JsonToken._
Expand Down Expand Up @@ -130,7 +132,30 @@ class JacksonParser(
parser.nextToken()
}
try {
val v = VariantBuilder.parseJson(parser, variantAllowDuplicateKeys)
val builder = VariantBuilder.parseJsonAndReturnBuilder(parser, variantAllowDuplicateKeys)
// Handle partition schema
if (partitionSchema.nonEmpty) {
partitionSchema.fields.zipWithIndex.foreach { case (field, i) =>
val value = partitionValues.get(i, field.dataType)
if (value != null) {
builder.addKey(field.name)
field.dataType match {
case LongType => builder.appendLong(value.toString.toLong)
case _: DecimalType => builder.appendDecimal(
decimalParser(value.toString)
)
case DateType => builder.appendDate(dateFormatter.parse(value.toString))
case TimestampNTZType =>
builder.appendTimestampNtz(timestampNTZFormatter.parse(value.toString))
case TimestampType =>
builder.appendTimestamp(timestampFormatter.parse(value.toString))
case BooleanType => builder.appendBoolean(value.toString.toBoolean)
case StringType => builder.appendString(value.toString)
}
}
}
}
val v = builder.result()
new VariantVal(v.getValue, v.getMetadata)
} catch {
case _: VariantSizeLimitException =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ import org.apache.spark.unsafe.types.{UTF8String, VariantVal}

class StaxXmlParser(
schema: StructType,
val options: XmlOptions) extends Logging {
val options: XmlOptions,
partitionSchema: StructType = StructType(Seq.empty),
partitionValues: InternalRow = InternalRow.empty) extends Logging {

private lazy val timestampFormatter = TimestampFormatter(
options.timestampFormatInRead,
Expand Down Expand Up @@ -146,7 +148,7 @@ class StaxXmlParser(
options.singleVariantColumn match {
case Some(_) =>
// If the singleVariantColumn is specified, parse the entire xml string as a Variant
val v = StaxXmlParser.parseVariant(xml, options)
val v = StaxXmlParser.parseVariant(xml, options, partitionSchema, partitionValues)
Some(InternalRow(v))
case _ =>
// Otherwise, parse the xml string as Structs
Expand Down Expand Up @@ -928,10 +930,14 @@ object StaxXmlParser {
/**
* Parse the input XML string as a Variant value
*/
def parseVariant(xml: String, options: XmlOptions): VariantVal = {
def parseVariant(
xml: String,
options: XmlOptions,
partitionSchema: StructType = StructType(Seq.empty),
partitionValues: InternalRow = InternalRow.empty): VariantVal = {
val parser = StaxXmlParserUtils.filteredReader(xml)
val rootAttributes = StaxXmlParserUtils.gatherRootAttributes(parser)
val v = convertVariant(parser, rootAttributes, options)
val v = convertVariant(parser, rootAttributes, options, partitionSchema, partitionValues)
parser.close()
v
}
Expand All @@ -944,20 +950,26 @@ object StaxXmlParser {
* @param parser The XML event stream reader positioned after the start element
* @param attributes The attributes of the current XML element to be included in the Variant
* @param options Configuration options that control how XML is parsed into Variants
* @param partitionSchema The schema of the partition columns, if any
* @param partitionValues The values of the partition columns, if any
* @return A Variant representing the XML element with its attributes and child content
*/
def convertVariant(
parser: XMLEventReader,
attributes: Array[Attribute],
options: XmlOptions): VariantVal = {
val v = convertVariantInternal(parser, attributes, options)
options: XmlOptions,
partitionSchema: StructType = StructType(Seq.empty),
partitionValues: InternalRow = InternalRow.empty): VariantVal = {
val v = convertVariantInternal(parser, attributes, options, partitionSchema, partitionValues)
new VariantVal(v.getValue, v.getMetadata)
}

private def convertVariantInternal(
parser: XMLEventReader,
attributes: Array[Attribute],
options: XmlOptions): Variant = {
options: XmlOptions,
partitionSchema: StructType = StructType(Seq.empty),
partitionValues: InternalRow = InternalRow.empty): Variant = {
// The variant builder for the root startElement
val rootBuilder = new VariantBuilder(false)
val start = rootBuilder.getWritePos
Expand Down Expand Up @@ -1006,6 +1018,24 @@ object StaxXmlParser {
variants.add(builder.result())

case _: EndElement =>
// In the end, add partition values if they exist
if (partitionSchema.nonEmpty) {
partitionSchema.fields.zipWithIndex.foreach { case (field, i) =>
val value = partitionValues.get(i, field.dataType)
if (value != null) {
val builder = new VariantBuilder(true)
appendXMLCharacterToVariant(builder, value.toString, options)
val variants = fieldToVariants.getOrElseUpdate(
field.name,
new java.util.ArrayList[Variant]()
)
// Override the data values if the partition column overlaps with any data column
variants.clear()
variants.add(builder.result())
}
}
}

if (fieldToVariants.nonEmpty) {
val onlyValueTagField = fieldToVariants.keySet.forall(_ == options.valueTag)
if (onlyValueTagField) {
Expand Down Expand Up @@ -1036,6 +1066,8 @@ object StaxXmlParser {
*
* @param builder The variant builder to write to
* @param fieldToVariants A map of field names to their corresponding variant values of the object
* The map is sorted by field names, and the ordering is based on the case
* sensitivity.
*/
private def writeVariantObject(
builder: VariantBuilder,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ object XmlOptions extends DataSourceOptions {
val INDENT = newOption("indent")
val PREFERS_DECIMAL = newOption("prefersDecimal")
val VALIDATE_NAME = newOption("validateName")
val SINGLE_VARIANT_COLUMN = newOption("singleVariantColumn")
val SINGLE_VARIANT_COLUMN = newOption(DataSourceOptions.SINGLE_VARIANT_COLUMN)
// Options with alternative
val ENCODING = "encoding"
val CHARSET = "charset"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.hadoop.mapreduce.Job

import org.apache.spark.paths.SparkPath
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.{DataSourceOptions, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
Expand Down Expand Up @@ -137,6 +137,12 @@ trait FileFormat {
val dataReader = buildReader(
sparkSession, dataSchema, partitionSchema, requiredSchema, filters, options, hadoopConf)

if (options.contains(DataSourceOptions.SINGLE_VARIANT_COLUMN)) {
// In singleVariantColumn mode, the partition values should be pushed down to the parser and
// included in variant column of the output rows
return (file: PartitionedFile) => dataReader(file)
}

new (PartitionedFile => Iterator[InternalRow]) with Serializable {
private val fullSchema = toAttributes(requiredSchema) ++ toAttributes(partitionSchema)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
package org.apache.spark.sql.execution.datasources

import java.util.Locale

import scala.collection.mutable

import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.{NUM_PRUNED, POST_SCAN_FILTERS, PUSHED_FILTERS, TOTAL}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.{DataSourceOptions, expressions}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.ScanOperation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
Expand Down Expand Up @@ -166,8 +164,14 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging {
filters.filter(_.deterministic), l.output)

val partitionColumns =
l.resolve(
fsRelation.partitionSchema, fsRelation.sparkSession.sessionState.analyzer.resolver)
if (fsRelation.options.contains(DataSourceOptions.SINGLE_VARIANT_COLUMN)) {
Seq.empty
} else {
l.resolve(
fsRelation.partitionSchema,
fsRelation.sparkSession.sessionState.analyzer.resolver
)
}
val partitionSet = AttributeSet(partitionColumns)

// this partitionKeyFilters should be the same with the ones being executed in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.DataSourceOptions
import org.apache.spark.sql.catalyst.catalog.BucketSpec
import org.apache.spark.sql.execution.FileRelation
import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister}
Expand Down Expand Up @@ -54,8 +55,15 @@ case class HadoopFsRelation(
// schema respects the order of the data schema for the overlapping columns, and it
// respects the data types of the partition schema.
val (schema: StructType, overlappedPartCols: Map[String, StructField]) =
PartitioningUtils.mergeDataAndPartitionSchema(dataSchema,
partitionSchema, sparkSession.sessionState.conf.caseSensitiveAnalysis)
if (options.contains(DataSourceOptions.SINGLE_VARIANT_COLUMN)) {
(dataSchema, Map.empty)
} else {
PartitioningUtils.mergeDataAndPartitionSchema(
dataSchema,
partitionSchema,
sparkSession.sessionState.conf.caseSensitiveAnalysis
)
}

override def toString: String = {
fileFormat match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ case class CSVFileFormat() extends TextBasedFileFormat with DataSourceRegister {
actualDataSchema,
actualRequiredSchema,
parsedOptions,
actualFilters)
actualFilters,
partitionSchema,
file.partitionValues)
// Use column pruning when specified by Catalyst, except when one or more columns have
// existence default value(s), since in that case we instruct the CSV parser to disable column
// pruning and instead read each entire row in order to correctly assign the default value(s).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ case class JsonFileFormat() extends TextBasedFileFormat with DataSourceRegister
actualSchema,
parsedOptions,
allowArrayAsStructs = true,
filters)
filters,
partitionSchema,
file.partitionValues)
JsonDataSource(parsedOptions).readFile(
broadcastedHadoopConf.value.value,
file,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,9 @@ case class XmlFileFormat() extends TextBasedFileFormat with DataSourceRegister {
(file: PartitionedFile) => {
val parser = new StaxXmlParser(
actualRequiredSchema,
xmlOptions)
xmlOptions,
partitionSchema,
file.partitionValues)
XmlDataSource(xmlOptions).readFile(
broadcastedHadoopConf.value.value,
file,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -933,4 +933,48 @@ class XmlVariantSuite extends QueryTest with SharedSparkSession with TestXmlData
.map(_.getString(0).replaceAll("\\s+", ""))
assert(xmlResult.head === xmlStr)
}

// ==========================================================
// ====== SingleVariantColumn with partition schema =========
// ==========================================================

test("SingleVariantColumn with partition columns in the source file path") {
withTempDir { dir =>
// Create partitioned directory structure and copy file to each partition
val path = s"${dir.getCanonicalPath}/year=2021/month=01"
val partitionDir = new java.io.File(path)
partitionDir.mkdirs()

// Copy cars.xml to each partition
val srcPath = getTestResourcePath(resDir + "cars.xml")
val destPath = s"${path}/data.xml"

// Use Spark's file utilities to copy file
val fs = org.apache.hadoop.fs.FileSystem.get(spark.sessionState.newHadoopConf())
fs.copyFromLocalFile(new org.apache.hadoop.fs.Path(srcPath),
new org.apache.hadoop.fs.Path(destPath))

// Read the files with partitioning info
val df = spark.read
.format("xml")
.option("singleVariantColumn", "var")
.options(baseOptions)
.load(dir.getCanonicalPath)

// Check that the DataFrame only contains the variant column
assert(df.columns.length == 1)

// Verify that the partition columns are present in the variant column
checkAnswer(
df.select(
variant_get(col("var"), "$.year", "int"),
variant_get(col("var"), "$.month", "int"),
variant_get(col("var"), "$.make", "string").as("make")
).orderBy("make"),
Seq(
Row(2021, 1, "Chevy"), Row(2021, 1, "Ford"), Row(2021, 1, "Tesla")
)
)
}
}
}