Skip to content

[SPARK-23539][SS] Add support for Kafka headers in Structured Streaming #22282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions docs/structured-streaming-kafka-integration.md
Original file line number Diff line number Diff line change
@@ -27,6 +27,8 @@ For Scala/Java applications using SBT/Maven project definitions, link your appli
artifactId = spark-sql-kafka-0-10_{{site.SCALA_BINARY_VERSION}}
version = {{site.SPARK_VERSION_SHORT}}

Please note that to use the headers functionality, your Kafka client version should be version 0.11.0.0 or up.

For Python applications, you need to add this above library and its dependencies when deploying your
application. See the [Deploying](#deploying) subsection below.

@@ -50,6 +52,17 @@ val df = spark
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
.as[(String, String)]

// Subscribe to 1 topic, with headers
val df = spark
.readStream
.format("kafka")
.option("kafka.bootstrap.servers", "host1:port1,host2:port2")
.option("subscribe", "topic1")
.option("includeHeaders", "true")
.load()
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers")
.as[(String, String, Map)]

// Subscribe to multiple topics
val df = spark
.readStream
@@ -84,6 +97,16 @@ Dataset<Row> df = spark
.load();
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)");

// Subscribe to 1 topic, with headers
Dataset<Row> df = spark
.readStream()
.format("kafka")
.option("kafka.bootstrap.servers", "host1:port1,host2:port2")
.option("subscribe", "topic1")
.option("includeHeaders", "true")
.load()
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers");

// Subscribe to multiple topics
Dataset<Row> df = spark
.readStream()
@@ -116,6 +139,16 @@ df = spark \
.load()
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")

# Subscribe to 1 topic, with headers
val df = spark \
.readStream \
.format("kafka") \
.option("kafka.bootstrap.servers", "host1:port1,host2:port2") \
.option("subscribe", "topic1") \
.option("includeHeaders", "true") \
.load()
df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)", "headers")

# Subscribe to multiple topics
df = spark \
.readStream \
@@ -286,6 +319,10 @@ Each row in the source has the following schema:
<td>timestampType</td>
<td>int</td>
</tr>
<tr>
<td>headers (optional)</td>
<td>array</td>
</tr>
</table>

The following options must be set for the Kafka source
@@ -425,6 +462,13 @@ The following configurations are optional:
issues, set the Kafka consumer session timeout (by setting option "kafka.session.timeout.ms") to
be very small. When this is set, option "groupIdPrefix" will be ignored.</td>
</tr>
<tr>
<td>includeHeaders</td>
<td>boolean</td>
<td>false</td>
<td>streaming and batch</td>
<td>Whether to include the Kafka headers in the row.</td>
</tr>
</table>

### Consumer Caching
@@ -522,6 +566,10 @@ The Dataframe being written to Kafka should have the following columns in schema
<td>value (required)</td>
<td>string or binary</td>
</tr>
<tr>
<td>headers (optional)</td>
<td>array</td>
</tr>
<tr>
<td>topic (*optional)</td>
<td>string</td>
@@ -559,6 +607,13 @@ The following configurations are optional:
<td>Sets the topic that all rows will be written to in Kafka. This option overrides any
topic column that may exist in the data.</td>
</tr>
<tr>
<td>includeHeaders</td>
<td>boolean</td>
<td>false</td>
<td>streaming and batch</td>
<td>Whether to include the Kafka headers in the row.</td>
</tr>
</table>

### Creating a Kafka Sink for Streaming Queries
Original file line number Diff line number Diff line change
@@ -31,7 +31,8 @@ private[kafka010] class KafkaBatch(
specifiedKafkaParams: Map[String, String],
failOnDataLoss: Boolean,
startingOffsets: KafkaOffsetRangeLimit,
endingOffsets: KafkaOffsetRangeLimit)
endingOffsets: KafkaOffsetRangeLimit,
includeHeaders: Boolean)
extends Batch with Logging {
assert(startingOffsets != LatestOffsetRangeLimit,
"Starting offset not allowed to be set to latest offsets.")
@@ -90,7 +91,7 @@ private[kafka010] class KafkaBatch(
KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId)
offsetRanges.map { range =>
new KafkaBatchInputPartition(
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss)
range, executorKafkaParams, pollTimeoutMs, failOnDataLoss, includeHeaders)
}.toArray
}

Original file line number Diff line number Diff line change
@@ -29,13 +29,14 @@ private[kafka010] case class KafkaBatchInputPartition(
offsetRange: KafkaOffsetRange,
executorKafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends InputPartition
failOnDataLoss: Boolean,
includeHeaders: Boolean) extends InputPartition

private[kafka010] object KafkaBatchReaderFactory extends PartitionReaderFactory {
override def createReader(partition: InputPartition): PartitionReader[InternalRow] = {
val p = partition.asInstanceOf[KafkaBatchInputPartition]
KafkaBatchPartitionReader(p.offsetRange, p.executorKafkaParams, p.pollTimeoutMs,
p.failOnDataLoss)
p.failOnDataLoss, p.includeHeaders)
}
}

@@ -44,12 +45,14 @@ private case class KafkaBatchPartitionReader(
offsetRange: KafkaOffsetRange,
executorKafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends PartitionReader[InternalRow] with Logging {
failOnDataLoss: Boolean,
includeHeaders: Boolean) extends PartitionReader[InternalRow] with Logging {

private val consumer = KafkaDataConsumer.acquire(offsetRange.topicPartition, executorKafkaParams)

private val rangeToRead = resolveRange(offsetRange)
private val converter = new KafkaRecordToUnsafeRowConverter
private val unsafeRowProjector = new KafkaRecordToRowConverter()
.toUnsafeRowProjector(includeHeaders)

private var nextOffset = rangeToRead.fromOffset
private var nextRow: UnsafeRow = _
@@ -58,7 +61,7 @@ private case class KafkaBatchPartitionReader(
if (nextOffset < rangeToRead.untilOffset) {
val record = consumer.get(nextOffset, rangeToRead.untilOffset, pollTimeoutMs, failOnDataLoss)
if (record != null) {
nextRow = converter.toUnsafeRow(record)
nextRow = unsafeRowProjector(record)
nextOffset = record.offset + 1
true
} else {
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.connector.read.streaming.{ContinuousPartitionReader, ContinuousPartitionReaderFactory, ContinuousStream, Offset, PartitionOffset}
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.kafka010.KafkaSourceProvider._
import org.apache.spark.sql.util.CaseInsensitiveStringMap

/**
@@ -56,6 +56,7 @@ class KafkaContinuousStream(

private[kafka010] val pollTimeoutMs =
options.getLong(KafkaSourceProvider.CONSUMER_POLL_TIMEOUT, 512)
private val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false)

// Initialized when creating reader factories. If this diverges from the partitions at the latest
// offsets, we need to reconfigure.
@@ -88,7 +89,7 @@ class KafkaContinuousStream(
if (deletedPartitions.nonEmpty) {
val message = if (
offsetReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) {
s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}"
s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}"
} else {
s"$deletedPartitions are gone. Some data may have been missed."
}
@@ -102,7 +103,7 @@ class KafkaContinuousStream(
startOffsets.toSeq.map {
case (topicPartition, start) =>
KafkaContinuousInputPartition(
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss)
topicPartition, start, kafkaParams, pollTimeoutMs, failOnDataLoss, includeHeaders)
}.toArray
}

@@ -153,19 +154,22 @@ class KafkaContinuousStream(
* @param pollTimeoutMs The timeout for Kafka consumer polling.
* @param failOnDataLoss Flag indicating whether data reader should fail if some offsets
* are skipped.
* @param includeHeaders Flag indicating whether to include Kafka records' headers.
*/
case class KafkaContinuousInputPartition(
topicPartition: TopicPartition,
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends InputPartition
topicPartition: TopicPartition,
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean,
includeHeaders: Boolean) extends InputPartition

object KafkaContinuousReaderFactory extends ContinuousPartitionReaderFactory {
override def createReader(partition: InputPartition): ContinuousPartitionReader[InternalRow] = {
val p = partition.asInstanceOf[KafkaContinuousInputPartition]
new KafkaContinuousPartitionReader(
p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs, p.failOnDataLoss)
p.topicPartition, p.startOffset, p.kafkaParams, p.pollTimeoutMs,
p.failOnDataLoss, p.includeHeaders)
}
}

@@ -184,9 +188,11 @@ class KafkaContinuousPartitionReader(
startOffset: Long,
kafkaParams: ju.Map[String, Object],
pollTimeoutMs: Long,
failOnDataLoss: Boolean) extends ContinuousPartitionReader[InternalRow] {
failOnDataLoss: Boolean,
includeHeaders: Boolean) extends ContinuousPartitionReader[InternalRow] {
private val consumer = KafkaDataConsumer.acquire(topicPartition, kafkaParams)
private val converter = new KafkaRecordToUnsafeRowConverter
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd have KafkaRecordToRowProjector (either class or object, but object would be fine) instead and move every projectors newly added in KafkaOffsetReader to there.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So +1 on your proposal. The proposed name is just 2 cents, and I'm not sure which name fits best.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HeartSaVioR Great. Let's continue on discussing which name would be the best.

private val unsafeRowProjector = new KafkaRecordToRowConverter()
.toUnsafeRowProjector(includeHeaders)

private var nextKafkaOffset = startOffset
private var currentRecord: ConsumerRecord[Array[Byte], Array[Byte]] = _
@@ -225,7 +231,7 @@ class KafkaContinuousPartitionReader(
}

override def get(): UnsafeRow = {
converter.toUnsafeRow(currentRecord)
unsafeRowProjector(currentRecord)
}

override def getOffset(): KafkaSourcePartitionOffset = {
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.read.{InputPartition, PartitionReaderFactory}
import org.apache.spark.sql.connector.read.streaming.{MicroBatchStream, Offset}
import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchStream
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.kafka010.KafkaSourceProvider._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.UninterruptibleThread

@@ -64,6 +64,8 @@ private[kafka010] class KafkaMicroBatchStream(
private[kafka010] val maxOffsetsPerTrigger = Option(options.get(
KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER)).map(_.toLong)

private val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false)

private val rangeCalculator = KafkaOffsetRangeCalculator(options)

private var endPartitionOffsets: KafkaSourceOffset = _
@@ -112,7 +114,7 @@ private[kafka010] class KafkaMicroBatchStream(
if (deletedPartitions.nonEmpty) {
val message =
if (kafkaOffsetReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) {
s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}"
s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}"
} else {
s"$deletedPartitions are gone. Some data may have been missed."
}
@@ -146,7 +148,8 @@ private[kafka010] class KafkaMicroBatchStream(

// Generate factories based on the offset ranges
offsetRanges.map { range =>
KafkaBatchInputPartition(range, executorKafkaParams, pollTimeoutMs, failOnDataLoss)
KafkaBatchInputPartition(range, executorKafkaParams, pollTimeoutMs,
failOnDataLoss, includeHeaders)
}.toArray
}

Original file line number Diff line number Diff line change
@@ -31,7 +31,6 @@ import org.apache.kafka.common.TopicPartition

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.types._
import org.apache.spark.util.{ThreadUtils, UninterruptibleThread}

/**
@@ -421,16 +420,3 @@ private[kafka010] class KafkaOffsetReader(
_consumer = null // will automatically get reinitialized again
}
}

private[kafka010] object KafkaOffsetReader {

def kafkaSchema: StructType = StructType(Seq(
StructField("key", BinaryType),
StructField("value", BinaryType),
StructField("topic", StringType),
StructField("partition", IntegerType),
StructField("offset", LongType),
StructField("timestamp", TimestampType),
StructField("timestampType", IntegerType)
))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.kafka010

import java.sql.Timestamp

import scala.collection.JavaConverters._

import org.apache.kafka.clients.consumer.ConsumerRecord

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/** A simple class for converting Kafka ConsumerRecord to InternalRow/UnsafeRow */
private[kafka010] class KafkaRecordToRowConverter {
import KafkaRecordToRowConverter._

private val toUnsafeRowWithoutHeaders = UnsafeProjection.create(schemaWithoutHeaders)
private val toUnsafeRowWithHeaders = UnsafeProjection.create(schemaWithHeaders)

val toInternalRowWithoutHeaders: Record => InternalRow =
(cr: Record) => InternalRow(
cr.key, cr.value, UTF8String.fromString(cr.topic), cr.partition, cr.offset,
DateTimeUtils.fromJavaTimestamp(new Timestamp(cr.timestamp)), cr.timestampType.id
)

val toInternalRowWithHeaders: Record => InternalRow =
(cr: Record) => InternalRow(
cr.key, cr.value, UTF8String.fromString(cr.topic), cr.partition, cr.offset,
DateTimeUtils.fromJavaTimestamp(new Timestamp(cr.timestamp)), cr.timestampType.id,
if (cr.headers.iterator().hasNext) {
new GenericArrayData(cr.headers.iterator().asScala
.map(header =>
InternalRow(UTF8String.fromString(header.key()), header.value())
).toArray)
} else {
null
}
)

def toUnsafeRowWithoutHeadersProjector: Record => UnsafeRow =
(cr: Record) => toUnsafeRowWithoutHeaders(toInternalRowWithoutHeaders(cr))

def toUnsafeRowWithHeadersProjector: Record => UnsafeRow =
(cr: Record) => toUnsafeRowWithHeaders(toInternalRowWithHeaders(cr))

def toUnsafeRowProjector(includeHeaders: Boolean): Record => UnsafeRow = {
if (includeHeaders) toUnsafeRowWithHeadersProjector else toUnsafeRowWithoutHeadersProjector
}
}

private[kafka010] object KafkaRecordToRowConverter {
type Record = ConsumerRecord[Array[Byte], Array[Byte]]

val headersType = ArrayType(StructType(Array(
StructField("key", StringType),
StructField("value", BinaryType))))

private val schemaWithoutHeaders = new StructType(Array(
StructField("key", BinaryType),
StructField("value", BinaryType),
StructField("topic", StringType),
StructField("partition", IntegerType),
StructField("offset", LongType),
StructField("timestamp", TimestampType),
StructField("timestampType", IntegerType)
))

private val schemaWithHeaders =
new StructType(schemaWithoutHeaders.fields :+ StructField("headers", headersType))

def kafkaSchema(includeHeaders: Boolean): StructType = {
if (includeHeaders) schemaWithHeaders else schemaWithoutHeaders
}
}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -24,10 +24,9 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.sources.{BaseRelation, TableScan}
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.UTF8String


private[kafka010] class KafkaRelation(
@@ -36,6 +35,7 @@ private[kafka010] class KafkaRelation(
sourceOptions: CaseInsensitiveMap[String],
specifiedKafkaParams: Map[String, String],
failOnDataLoss: Boolean,
includeHeaders: Boolean,
startingOffsets: KafkaOffsetRangeLimit,
endingOffsets: KafkaOffsetRangeLimit)
extends BaseRelation with TableScan with Logging {
@@ -49,7 +49,9 @@ private[kafka010] class KafkaRelation(
(sqlContext.sparkContext.conf.get(NETWORK_TIMEOUT) * 1000L).toString
).toLong

override def schema: StructType = KafkaOffsetReader.kafkaSchema
private val converter = new KafkaRecordToRowConverter()

override def schema: StructType = KafkaRecordToRowConverter.kafkaSchema(includeHeaders)

override def buildScan(): RDD[Row] = {
// Each running query should use its own group id. Otherwise, the query may be only assigned
@@ -100,18 +102,14 @@ private[kafka010] class KafkaRelation(
// Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays.
val executorKafkaParams =
KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId)
val toInternalRow = if (includeHeaders) {
converter.toInternalRowWithHeaders
} else {
converter.toInternalRowWithoutHeaders
}
val rdd = new KafkaSourceRDD(
sqlContext.sparkContext, executorKafkaParams, offsetRanges,
pollTimeoutMs, failOnDataLoss).map { cr =>
InternalRow(
cr.key,
cr.value,
UTF8String.fromString(cr.topic),
cr.partition,
cr.offset,
DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(cr.timestamp)),
cr.timestampType.id)
}
pollTimeoutMs, failOnDataLoss).map(toInternalRow)
sqlContext.internalCreateDataFrame(rdd.setName("kafka"), schema).rdd
}

Original file line number Diff line number Diff line change
@@ -31,12 +31,11 @@ import org.apache.spark.internal.config.Network.NETWORK_TIMEOUT
import org.apache.spark.scheduler.ExecutorCacheTaskLocation
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils}
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.kafka010.KafkaSource._
import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE}
import org.apache.spark.sql.kafka010.KafkaSourceProvider._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* A [[Source]] that reads data from Kafka using the following design.
@@ -84,13 +83,15 @@ private[kafka010] class KafkaSource(

private val sc = sqlContext.sparkContext

private val pollTimeoutMs = sourceOptions.getOrElse(
KafkaSourceProvider.CONSUMER_POLL_TIMEOUT,
(sc.conf.get(NETWORK_TIMEOUT) * 1000L).toString
).toLong
private val pollTimeoutMs =
sourceOptions.getOrElse(CONSUMER_POLL_TIMEOUT, (sc.conf.get(NETWORK_TIMEOUT) * 1000L).toString)
.toLong

private val maxOffsetsPerTrigger =
sourceOptions.get(KafkaSourceProvider.MAX_OFFSET_PER_TRIGGER).map(_.toLong)
sourceOptions.get(MAX_OFFSET_PER_TRIGGER).map(_.toLong)

private val includeHeaders =
sourceOptions.getOrElse(INCLUDE_HEADERS, "false").toBoolean

/**
* Lazily initialize `initialPartitionOffsets` to make sure that `KafkaConsumer.poll` is only
@@ -113,7 +114,9 @@ private[kafka010] class KafkaSource(

private var currentPartitionOffsets: Option[Map[TopicPartition, Long]] = None

override def schema: StructType = KafkaOffsetReader.kafkaSchema
private val converter = new KafkaRecordToRowConverter()

override def schema: StructType = KafkaRecordToRowConverter.kafkaSchema(includeHeaders)

/** Returns the maximum available offset for this source. */
override def getOffset: Option[Offset] = {
@@ -223,7 +226,7 @@ private[kafka010] class KafkaSource(
val deletedPartitions = fromPartitionOffsets.keySet.diff(untilPartitionOffsets.keySet)
if (deletedPartitions.nonEmpty) {
val message = if (kafkaReader.driverKafkaParams.containsKey(ConsumerConfig.GROUP_ID_CONFIG)) {
s"$deletedPartitions are gone. ${KafkaSourceProvider.CUSTOM_GROUP_ID_ERROR_MESSAGE}"
s"$deletedPartitions are gone. ${CUSTOM_GROUP_ID_ERROR_MESSAGE}"
} else {
s"$deletedPartitions are gone. Some data may have been missed."
}
@@ -267,16 +270,14 @@ private[kafka010] class KafkaSource(
}.toArray

// Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays.
val rdd = new KafkaSourceRDD(
sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss).map { cr =>
InternalRow(
cr.key,
cr.value,
UTF8String.fromString(cr.topic),
cr.partition,
cr.offset,
DateTimeUtils.fromJavaTimestamp(new java.sql.Timestamp(cr.timestamp)),
cr.timestampType.id)
val rdd = if (includeHeaders) {
new KafkaSourceRDD(
sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss)
.map(converter.toInternalRowWithHeaders)
} else {
new KafkaSourceRDD(
sc, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss)
.map(converter.toInternalRowWithoutHeaders)
}

logInfo("GetBatch generating RDD of offset range: " +
Original file line number Diff line number Diff line change
@@ -69,7 +69,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
val caseInsensitiveParameters = CaseInsensitiveMap(parameters)
validateStreamOptions(caseInsensitiveParameters)
require(schema.isEmpty, "Kafka source has a fixed schema and cannot be set with a custom one")
(shortName(), KafkaOffsetReader.kafkaSchema)
val includeHeaders = caseInsensitiveParameters.getOrElse(INCLUDE_HEADERS, "false").toBoolean
(shortName(), KafkaRecordToRowConverter.kafkaSchema(includeHeaders))
}

override def createSource(
@@ -107,7 +108,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
}

override def getTable(options: CaseInsensitiveStringMap): KafkaTable = {
new KafkaTable
val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false)
new KafkaTable(includeHeaders)
}

/**
@@ -131,12 +133,15 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
caseInsensitiveParameters, ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit)
assert(endingRelationOffsets != EarliestOffsetRangeLimit)

val includeHeaders = caseInsensitiveParameters.getOrElse(INCLUDE_HEADERS, "false").toBoolean

new KafkaRelation(
sqlContext,
strategy(caseInsensitiveParameters),
sourceOptions = caseInsensitiveParameters,
specifiedKafkaParams = specifiedKafkaParams,
failOnDataLoss = failOnDataLoss(caseInsensitiveParameters),
includeHeaders = includeHeaders,
startingOffsets = startingRelationOffsets,
endingOffsets = endingRelationOffsets)
}
@@ -359,11 +364,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
}
}

class KafkaTable extends Table with SupportsRead with SupportsWrite {
class KafkaTable(includeHeaders: Boolean) extends Table with SupportsRead with SupportsWrite {

override def name(): String = "KafkaTable"

override def schema(): StructType = KafkaOffsetReader.kafkaSchema
override def schema(): StructType = KafkaRecordToRowConverter.kafkaSchema(includeHeaders)

override def capabilities(): ju.Set[TableCapability] = {
import TableCapability._
@@ -403,8 +408,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
}

class KafkaScan(options: CaseInsensitiveStringMap) extends Scan {
val includeHeaders = options.getBoolean(INCLUDE_HEADERS, false)

override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema
override def readSchema(): StructType = {
KafkaRecordToRowConverter.kafkaSchema(includeHeaders)
}

override def toBatch(): Batch = {
val caseInsensitiveOptions = CaseInsensitiveMap(options.asScala.toMap)
@@ -423,7 +431,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister
specifiedKafkaParams,
failOnDataLoss(caseInsensitiveOptions),
startingRelationOffsets,
endingRelationOffsets)
endingRelationOffsets,
includeHeaders)
}

override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = {
@@ -498,6 +507,7 @@ private[kafka010] object KafkaSourceProvider extends Logging {
private[kafka010] val FETCH_OFFSET_RETRY_INTERVAL_MS = "fetchoffset.retryintervalms"
private[kafka010] val CONSUMER_POLL_TIMEOUT = "kafkaconsumer.polltimeoutms"
private val GROUP_ID_PREFIX = "groupidprefix"
private[kafka010] val INCLUDE_HEADERS = "includeheaders"

val TOPIC_OPTION_KEY = "topic"

Original file line number Diff line number Diff line change
@@ -19,9 +19,13 @@ package org.apache.spark.sql.kafka010

import java.{util => ju}

import scala.collection.JavaConverters._

import org.apache.kafka.clients.producer.{Callback, KafkaProducer, ProducerRecord, RecordMetadata}
import org.apache.kafka.common.header.Header
import org.apache.kafka.common.header.internals.RecordHeader

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection}
import org.apache.spark.sql.types.{BinaryType, StringType}

@@ -88,7 +92,17 @@ private[kafka010] abstract class KafkaRowWriter(
throw new NullPointerException(s"null topic present in the data. Use the " +
s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.")
}
val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value)
val record = if (projectedRow.isNullAt(3)) {
new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, null, key, value)
} else {
val headerArray = projectedRow.getArray(3)
val headers = (0 until headerArray.numElements()).map { i =>
val struct = headerArray.getStruct(i, 2)
new RecordHeader(struct.getUTF8String(0).toString, struct.getBinary(1))
.asInstanceOf[Header]
}
new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, null, key, value, headers.asJava)
}
producer.send(record, callback)
}

@@ -131,9 +145,26 @@ private[kafka010] abstract class KafkaRowWriter(
throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " +
s"attribute unsupported type ${t.catalogString}")
}
val headersExpression = inputSchema
.find(_.name == KafkaWriter.HEADERS_ATTRIBUTE_NAME).getOrElse(
Literal(CatalystTypeConverters.convertToCatalyst(null),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be indented further?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried, but the formatter reverts the indention to the current status.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The style checker or something else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code formatter of the IDE. Also, it passes the style checker of mvn.

KafkaRecordToRowConverter.headersType)
)
headersExpression.dataType match {
case KafkaRecordToRowConverter.headersType => // good
case t =>
throw new IllegalStateException(s"${KafkaWriter.HEADERS_ATTRIBUTE_NAME} " +
s"attribute unsupported type ${t.catalogString}")
}
UnsafeProjection.create(
Seq(topicExpression, Cast(keyExpression, BinaryType),
Cast(valueExpression, BinaryType)), inputSchema)
Seq(
topicExpression,
Cast(keyExpression, BinaryType),
Cast(valueExpression, BinaryType),
headersExpression
),
inputSchema
)
}
}

Original file line number Diff line number Diff line change
@@ -21,9 +21,10 @@ import java.{util => ju}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{QueryExecution, SQLExecution}
import org.apache.spark.sql.types.{BinaryType, StringType}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types.{BinaryType, MapType, StringType}
import org.apache.spark.util.Utils

/**
@@ -39,6 +40,7 @@ private[kafka010] object KafkaWriter extends Logging {
val TOPIC_ATTRIBUTE_NAME: String = "topic"
val KEY_ATTRIBUTE_NAME: String = "key"
val VALUE_ATTRIBUTE_NAME: String = "value"
val HEADERS_ATTRIBUTE_NAME: String = "headers"

override def toString: String = "KafkaWriter"

@@ -75,6 +77,15 @@ private[kafka010] object KafkaWriter extends Logging {
throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " +
s"must be a ${StringType.catalogString} or ${BinaryType.catalogString}")
}
schema.find(_.name == HEADERS_ATTRIBUTE_NAME).getOrElse(
Literal(CatalystTypeConverters.convertToCatalyst(null),
KafkaRecordToRowConverter.headersType)
).dataType match {
case KafkaRecordToRowConverter.headersType => // good
case _ =>
throw new AnalysisException(s"$HEADERS_ATTRIBUTE_NAME attribute type " +
s"must be a ${KafkaRecordToRowConverter.headersType.catalogString}")
}
}

def write(
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@

package org.apache.spark.sql.kafka010

import java.nio.charset.StandardCharsets
import java.util.concurrent.{Executors, TimeUnit}

import scala.collection.JavaConverters._
@@ -91,7 +92,7 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
test("new KafkaDataConsumer instance in case of Task retry") {
try {
val kafkaParams = getKafkaParams()
val key = new CacheKey(groupId, topicPartition)
val key = CacheKey(groupId, topicPartition)

val context1 = new TaskContextImpl(0, 0, 0, 0, 0, null, null, null)
TaskContext.setTaskContext(context1)
@@ -137,7 +138,8 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
}

test("SPARK-23623: concurrent use of KafkaDataConsumer") {
val data: immutable.IndexedSeq[String] = prepareTestTopicHavingTestMessages(topic)
val data: immutable.IndexedSeq[(String, Seq[(String, Array[Byte])])] =
prepareTestTopicHavingTestMessages(topic)

val topicPartition = new TopicPartition(topic, 0)
val kafkaParams = getKafkaParams()
@@ -157,10 +159,22 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
try {
val range = consumer.getAvailableOffsetRange()
val rcvd = range.earliest until range.latest map { offset =>
val bytes = consumer.get(offset, Long.MaxValue, 10000, failOnDataLoss = false).value()
new String(bytes)
val record = consumer.get(offset, Long.MaxValue, 10000, failOnDataLoss = false)
val value = new String(record.value(), StandardCharsets.UTF_8)
val headers = record.headers().toArray.map(header => (header.key(), header.value())).toSeq
(value, headers)
}
data.zip(rcvd).foreach { case (expected, actual) =>
// value
assert(expected._1 === actual._1)
// headers
expected._2.zip(actual._2).foreach { case (l, r) =>
// header key
assert(l._1 === r._1)
// header value
assert(l._2 === r._2)
}
}
assert(rcvd == data)
} catch {
case e: Throwable =>
error = e
@@ -307,9 +321,9 @@ class KafkaDataConsumerSuite extends SharedSparkSession with PrivateMethodTester
}

private def prepareTestTopicHavingTestMessages(topic: String) = {
val data = (1 to 1000).map(_.toString)
val data = (1 to 1000).map(i => (i.toString, Seq[(String, Array[Byte])]()))
testUtils.createTopic(topic, 1)
testUtils.sendMessages(topic, data.toArray)
testUtils.sendMessages(topic, data.toArray, None)
data
}

Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@

package org.apache.spark.sql.kafka010

import java.nio.charset.StandardCharsets.UTF_8
import java.util.Locale
import java.util.concurrent.atomic.AtomicInteger

@@ -70,7 +71,8 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession
protected def createDF(
topic: String,
withOptions: Map[String, String] = Map.empty[String, String],
brokerAddress: Option[String] = None) = {
brokerAddress: Option[String] = None,
includeHeaders: Boolean = false) = {
val df = spark
.read
.format("kafka")
@@ -80,7 +82,13 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession
withOptions.foreach {
case (key, value) => df.option(key, value)
}
df.load().selectExpr("CAST(value AS STRING)")
if (includeHeaders) {
df.option("includeHeaders", "true")
df.load()
.selectExpr("CAST(value AS STRING)", "headers")
} else {
df.load().selectExpr("CAST(value AS STRING)")
}
}

test("explicit earliest to latest offsets") {
@@ -147,6 +155,27 @@ abstract class KafkaRelationSuiteBase extends QueryTest with SharedSparkSession
checkAnswer(df, (0 to 30).map(_.toString).toDF)
}

test("default starting and ending offsets with headers") {
val topic = newTopic()
testUtils.createTopic(topic, partitions = 3)
testUtils.sendMessage(
topic, ("1", Seq()), Some(0)
)
testUtils.sendMessage(
topic, ("2", Seq(("a", "b".getBytes(UTF_8)), ("c", "d".getBytes(UTF_8)))), Some(1)
)
testUtils.sendMessage(
topic, ("3", Seq(("e", "f".getBytes(UTF_8)), ("e", "g".getBytes(UTF_8)))), Some(2)
)

// Implicit offset values, should default to earliest and latest
val df = createDF(topic, includeHeaders = true)
// Test that we default to "earliest" and "latest"
checkAnswer(df, Seq(("1", null),
("2", Seq(("a", "b".getBytes(UTF_8)), ("c", "d".getBytes(UTF_8)))),
("3", Seq(("e", "f".getBytes(UTF_8)), ("e", "g".getBytes(UTF_8))))).toDF)
}

test("reuse same dataframe in query") {
// This test ensures that we do not cache the Kafka Consumer in KafkaRelation
val topic = newTopic()
Original file line number Diff line number Diff line change
@@ -17,6 +17,7 @@

package org.apache.spark.sql.kafka010

import java.nio.charset.StandardCharsets.UTF_8
import java.util.Locale
import java.util.concurrent.atomic.AtomicInteger

@@ -32,7 +33,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming._
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{BinaryType, DataType}
import org.apache.spark.sql.types.{BinaryType, DataType, StringType, StructField, StructType}

abstract class KafkaSinkSuiteBase extends QueryTest with SharedSparkSession with KafkaTest {
protected var testUtils: KafkaTestUtils = _
@@ -59,13 +60,14 @@ abstract class KafkaSinkSuiteBase extends QueryTest with SharedSparkSession with

protected def newTopic(): String = s"topic-${topicId.getAndIncrement()}"

protected def createKafkaReader(topic: String): DataFrame = {
protected def createKafkaReader(topic: String, includeHeaders: Boolean = false): DataFrame = {
spark.read
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("startingOffsets", "earliest")
.option("endingOffsets", "latest")
.option("subscribe", topic)
.option("includeHeaders", includeHeaders.toString)
.load()
}
}
@@ -368,15 +370,51 @@ abstract class KafkaSinkBatchSuiteBase extends KafkaSinkSuiteBase {
test("batch - write to kafka") {
val topic = newTopic()
testUtils.createTopic(topic)
val df = Seq("1", "2", "3", "4", "5").map(v => (topic, v)).toDF("topic", "value")
val data = Seq(
Row(topic, "1", Seq(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto: StandardCharsets.UTF_8

Row("a", "b".getBytes(UTF_8))
)),
Row(topic, "2", Seq(
Row("c", "d".getBytes(UTF_8)),
Row("e", "f".getBytes(UTF_8))
)),
Row(topic, "3", Seq(
Row("g", "h".getBytes(UTF_8)),
Row("g", "i".getBytes(UTF_8))
)),
Row(topic, "4", null),
Row(topic, "5", Seq(
Row("j", "k".getBytes(UTF_8)),
Row("j", "l".getBytes(UTF_8)),
Row("m", "n".getBytes(UTF_8))
))
)

val df = spark.createDataFrame(
spark.sparkContext.parallelize(data),
StructType(Seq(StructField("topic", StringType), StructField("value", StringType),
StructField("headers", KafkaRecordToRowConverter.headersType)))
)

df.write
.format("kafka")
.option("kafka.bootstrap.servers", testUtils.brokerAddress)
.option("topic", topic)
.save()
checkAnswer(
createKafkaReader(topic).selectExpr("CAST(value as STRING) value"),
Row("1") :: Row("2") :: Row("3") :: Row("4") :: Row("5") :: Nil)
createKafkaReader(topic, includeHeaders = true).selectExpr(
"CAST(value as STRING) value", "headers"
),
Row("1", Seq(Row("a", "b".getBytes(UTF_8)))) ::
Row("2", Seq(Row("c", "d".getBytes(UTF_8)), Row("e", "f".getBytes(UTF_8)))) ::
Row("3", Seq(Row("g", "h".getBytes(UTF_8)), Row("g", "i".getBytes(UTF_8)))) ::
Row("4", null) ::
Row("5", Seq(
Row("j", "k".getBytes(UTF_8)),
Row("j", "l".getBytes(UTF_8)),
Row("m", "n".getBytes(UTF_8)))) ::
Nil
)
}

test("batch - null topic field value, and no topic option") {
Original file line number Diff line number Diff line change
@@ -41,6 +41,8 @@ import org.apache.kafka.clients.consumer.KafkaConsumer
import org.apache.kafka.clients.producer._
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.config.SaslConfigs
import org.apache.kafka.common.header.Header
import org.apache.kafka.common.header.internals.RecordHeader
import org.apache.kafka.common.network.ListenerName
import org.apache.kafka.common.security.auth.SecurityProtocol.{PLAINTEXT, SASL_PLAINTEXT}
import org.apache.kafka.common.serialization.{StringDeserializer, StringSerializer}
@@ -369,17 +371,36 @@ class KafkaTestUtils(
topic: String,
messages: Array[String],
partition: Option[Int]): Seq[(String, RecordMetadata)] = {
sendMessages(topic, messages.map(m => (m, Seq())), partition)
}

/** Send record to the Kafka broker with headers using specified partition */
def sendMessage(topic: String,
record: (String, Seq[(String, Array[Byte])]),
partition: Option[Int]): Seq[(String, RecordMetadata)] = {
sendMessages(topic, Array(record).toSeq, partition)
}

/** Send the array of records to the Kafka broker with headers using specified partition */
def sendMessages(topic: String,
records: Seq[(String, Seq[(String, Array[Byte])])],
partition: Option[Int]): Seq[(String, RecordMetadata)] = {
producer = new KafkaProducer[String, String](producerConfiguration)
val offsets = try {
messages.map { m =>
records.map { case (value, header) =>
val headers = header.map { case (k, v) =>
new RecordHeader(k, v).asInstanceOf[Header]
}
val record = partition match {
case Some(p) => new ProducerRecord[String, String](topic, p, null, m)
case None => new ProducerRecord[String, String](topic, m)
case Some(p) =>
new ProducerRecord[String, String](topic, p, null, value, headers.asJava)
case None =>
new ProducerRecord[String, String](topic, null, null, value, headers.asJava)
}
val metadata =
producer.send(record).get(10, TimeUnit.SECONDS)
logInfo(s"\tSent $m to partition ${metadata.partition}, offset ${metadata.offset}")
(m, metadata)
val metadata = producer.send(record).get(10, TimeUnit.SECONDS)
logInfo(s"\tSent ($value, $header) to partition ${metadata.partition}," +
" offset ${metadata.offset}")
(value, metadata)
}
} finally {
if (producer != null) {