Skip to content

[SPARK-52482][SQL][CORE] ZStandard support for file data source reader #51182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ package org.apache.spark.input
import com.google.common.io.{ByteStreams, Closeables}
import org.apache.hadoop.conf.{Configurable => HConfigurable, Configuration}
import org.apache.hadoop.io.Text
import org.apache.hadoop.io.compress.CompressionCodecFactory
import org.apache.hadoop.mapreduce.InputSplit
import org.apache.hadoop.mapreduce.RecordReader
import org.apache.hadoop.mapreduce.TaskAttemptContext
import org.apache.hadoop.mapreduce.lib.input.{CombineFileRecordReader, CombineFileSplit}

import org.apache.spark.io.HadoopCodecStreams

/**
* A trait to implement [[org.apache.hadoop.conf.Configurable Configurable]] interface.
*/
Expand Down Expand Up @@ -69,15 +70,8 @@ private[spark] class WholeTextFileRecordReader(

override def nextKeyValue(): Boolean = {
if (!processed) {
val conf = getConf
val factory = new CompressionCodecFactory(conf)
val codec = factory.getCodec(path) // infers from file ext.
val fileIn = fs.open(path)
val innerBuffer = if (codec != null) {
ByteStreams.toByteArray(codec.createInputStream(fileIn))
} else {
ByteStreams.toByteArray(fileIn)
}
val fileIn = HadoopCodecStreams.createInputStream(getConf, path)
val innerBuffer = ByteStreams.toByteArray(fileIn)

value = new Text(innerBuffer)
Closeables.close(fileIn, false)
Expand Down
94 changes: 94 additions & 0 deletions core/src/main/scala/org/apache/spark/io/HadoopCodecStreams.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.io

import java.io.InputStream
import java.util.Locale

import scala.collection.Seq

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.compress._

import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.io.{CompressionCodec => SparkCompressionCodec}

/**
* An utility object to look up Hadoop compression codecs and create input streams.
* In addition to standard Hadoop codecs, it also supports Spark's Zstandard codec
* if Hadopp is not compiled with Zstandard support. Additionally, it supports
* non-standard file extensions like `.zstd` and `.gzip` for Zstandard and Gzip codecs.
*/
object HadoopCodecStreams {
private val ZSTD_EXTENSIONS = Seq(".zstd", ".zst")

// get codec based on file name extension
def getDecompressionCodec(
config: Configuration,
file: Path): Option[CompressionCodec] = {
val factory = new CompressionCodecFactory(config)
Option(factory.getCodec(file)).orElse {
// Try some non-standards extensions for Zstandard and Gzip
file.getName.toLowerCase() match {
case name if name.endsWith(".zstd") =>
Option(factory.getCodecByName(classOf[ZStandardCodec].getName))
case name if name.endsWith(".gzip") =>
Option(factory.getCodecByName(classOf[GzipCodec].getName))
case _ => None
}
}
}

def createZstdInputStream(
file: Path,
inputStream: InputStream): Option[InputStream] = {
val sparkConf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf)
val fileName = file.getName.toLowerCase(Locale.ROOT)

val isOpt = if (ZSTD_EXTENSIONS.exists(fileName.endsWith)) {
Some(
SparkCompressionCodec
.createCodec(sparkConf, SparkCompressionCodec.ZSTD)
.compressedInputStream(inputStream)
)
} else {
None
}
isOpt
}

def createInputStream(
config: Configuration,
file: Path): InputStream = {
val fs = file.getFileSystem(config)
val inputStream: InputStream = fs.open(file)

getDecompressionCodec(config, file)
.map { codec =>
try {
codec.createInputStream(inputStream)
} catch {
case e: RuntimeException =>
// createInputStream may fail for ZSTD if hadoop is not already compiled with ZSTD
// support. In that case, we try to use Spark's Zstandard codec.
createZstdInputStream(file, inputStream).getOrElse(throw e)
}
}.getOrElse(inputStream)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.hadoop.io.Text
import org.apache.hadoop.io.compress.{CompressionCodecFactory, GzipCodec}

import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.io.ZStdCompressionCodec

/**
* Tests the correctness of
Expand All @@ -36,6 +37,10 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite {
private var sc: SparkContext = _
private var factory: CompressionCodecFactory = _

def getSparkConf(): SparkConf = {
new SparkConf()
}

override def beforeAll(): Unit = {
// Hadoop's FileSystem caching does not use the Configuration as part of its cache key, which
// can cause Filesystem.get(Configuration) to return a cached instance created with a different
Expand All @@ -44,7 +49,7 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite {
// the new value of "fs.local.block.size" (see SPARK-5227 and SPARK-5679). To work around this,
// we disable FileSystem caching in this suite.
super.beforeAll()
val conf = new SparkConf().set("spark.hadoop.fs.file.impl.disable.cache", "true")
val conf = getSparkConf().set("spark.hadoop.fs.file.impl.disable.cache", "true")

sc = new SparkContext("local", "test", conf)

Expand All @@ -63,13 +68,25 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite {
}
}

private def createNativeFile(inputDir: File, fileName: String, contents: Array[Byte],
compress: Boolean) = {
val out = if (compress) {
import WholeTextFileRecordReaderSuite.CompressionType

def createNativeFile(inputDir: File, fileName: String, contents: Array[Byte],
compressionType: CompressionType.CompressionType = CompressionType.NONE): Unit = {
val out = if (compressionType == CompressionType.GZIP ||
compressionType == CompressionType.GZ) {
val codec = new GzipCodec
codec.setConf(new Configuration())
val path = s"${inputDir.toString}/$fileName${codec.getDefaultExtension}"
val extension = if (compressionType == CompressionType.GZIP) {
".gzip" // Try with non-standard extension
} else {
codec.getDefaultExtension
}
val path = s"${inputDir.toString}/$fileName${extension}"
codec.createOutputStream(new DataOutputStream(new FileOutputStream(path)))
} else if (compressionType == CompressionType.ZSTD || compressionType == CompressionType.ZST) {
val extension = if (compressionType == CompressionType.ZSTD) ".zstd" else ".zst"
val path = s"${inputDir.toString}/${fileName}${extension}"
new ZStdCompressionCodec(sc.conf).compressedOutputStream(new FileOutputStream(path))
} else {
val path = s"${inputDir.toString}/$fileName"
new DataOutputStream(new FileOutputStream(path))
Expand All @@ -86,48 +103,29 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite {
* 3) Does the contents be the same.
*/
test("Correctness of WholeTextFileRecordReader.") {
withTempDir { dir =>
logInfo(s"Local disk address is ${dir.toString}.")

WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) =>
createNativeFile(dir, filename, contents, false)
}

val res = sc.wholeTextFiles(dir.toString, 3).collect()

assert(res.length === WholeTextFileRecordReaderSuite.fileNames.length,
"Number of files read out does not fit with the actual value.")

for ((filename, contents) <- res) {
val shortName = filename.split('/').last
assert(WholeTextFileRecordReaderSuite.fileNames.contains(shortName),
s"Missing file name $filename.")
assert(contents === new Text(WholeTextFileRecordReaderSuite.files(shortName)).toString,
s"file $filename contents can not match.")
}
}
}

test("Correctness of WholeTextFileRecordReader with GzipCodec.") {
withTempDir { dir =>
logInfo(s"Local disk address is ${dir.toString}.")

WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) =>
createNativeFile(dir, filename, contents, true)
}

val res = sc.wholeTextFiles(dir.toString, 3).collect()

assert(res.length === WholeTextFileRecordReaderSuite.fileNames.length,
"Number of files read out does not fit with the actual value.")

for ((filename, contents) <- res) {
val shortName = filename.split('/').last.split('.')(0)

assert(WholeTextFileRecordReaderSuite.fileNames.contains(shortName),
s"Missing file name $filename.")
assert(contents === new Text(WholeTextFileRecordReaderSuite.files(shortName)).toString,
s"file $filename contents can not match.")
CompressionType.values.foreach { compressionType =>
withTempDir { dir =>
logInfo(s"Local disk address is ${dir.toString}.")

WholeTextFileRecordReaderSuite.files.foreach { case (filename, contents) =>
createNativeFile(dir, filename, contents, compressionType)
}

val res = sc.wholeTextFiles(dir.toString, 3).collect()

assert(res.length === WholeTextFileRecordReaderSuite.fileNames.length,
"Number of files read out does not fit with the actual value.")

for ((filename, contents) <- res) {
val shortName = compressionType match {
case CompressionType.NONE => filename.split('/').last
case _ => filename.split('/').last.split('.').head
}
assert(WholeTextFileRecordReaderSuite.fileNames.contains(shortName),
s"Missing file name $filename.")
assert(contents === new Text(WholeTextFileRecordReaderSuite.files(shortName)).toString,
s"file $filename contents can not match.")
}
}
}
}
Expand All @@ -145,4 +143,9 @@ object WholeTextFileRecordReaderSuite {
private val files = fileLengths.zip(fileNames).map { case (upperBound, filename) =>
filename -> LazyList.continually(testWords.toList.to(LazyList)).flatten.take(upperBound).toArray
}.toMap

object CompressionType extends Enumeration {
type CompressionType = Value
val NONE, GZ, GZIP, ZST, ZSTD = Value
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5949,6 +5949,17 @@ object SQLConf {
.createWithDefault(2)
}

val HADOOP_LINE_RECORD_READER_ENABLED =
buildConf("spark.sql.execution.datasources.hadoopLineRecordReader.enabled")
.internal()
.doc("Enable the imported Hadoop's LineRecordReader. This was imported and renamed to " +
"HadoopLineRecordReader to add support for compression option and other " +
"future codecs like ZSTD, etc. Setting the conf to false will use the LineRecordReader " +
"class from the hadoop jar instead of the imported one.")
.version("4.1.0")
.booleanConf
.createWithDefault(true)

/**
* Holds information about keys that have been deprecated.
*
Expand Down Expand Up @@ -7021,6 +7032,8 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def maxFlowRetryAttempts: Int = getConf(SQLConf.PIPELINES_MAX_FLOW_RETRY_ATTEMPTS)

def hadoopLineRecordReaderEnabled: Boolean = getConf(SQLConf.HADOOP_LINE_RECORD_READER_ENABLED)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
Loading