Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ import org.apache.spark.sql.execution.auron.plan.NativeWindowBase
import org.apache.spark.sql.execution.auron.plan.NativeWindowExec
import org.apache.spark.sql.execution.auron.shuffle.{AuronBlockStoreShuffleReaderBase, AuronRssShuffleManagerBase, RssPartitionWriterBase}
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNestedLoopJoinExec, ShuffledHashJoinExec}
import org.apache.spark.sql.execution.joins.auron.plan.NativeBroadcastJoinExec
Expand Down Expand Up @@ -301,6 +302,45 @@ class ShimsImpl extends Shims with Logging {
child: SparkPlan): NativeGenerateBase =
NativeGenerateExec(generator, requiredChildOutput, outer, generatorOutput, child)

@sparkver("3.0 / 3.1")
override def copyBatchScanExecWithRuntimeFilters(
exec: BatchScanExec,
runtimeFilters: Seq[Expression]): BatchScanExec =
exec.copy(exec.output, exec.scan)

@sparkver("3.2")
override def copyBatchScanExecWithRuntimeFilters(
exec: BatchScanExec,
runtimeFilters: Seq[Expression]): BatchScanExec =
exec.copy(exec.output, exec.scan, runtimeFilters)

@sparkver("3.3")
override def copyBatchScanExecWithRuntimeFilters(
exec: BatchScanExec,
runtimeFilters: Seq[Expression]): BatchScanExec =
exec.copy(exec.output, exec.scan, runtimeFilters, exec.keyGroupedPartitioning)

@sparkver("3.4")
override def copyBatchScanExecWithRuntimeFilters(
exec: BatchScanExec,
runtimeFilters: Seq[Expression]): BatchScanExec =
exec.copy(
exec.output,
exec.scan,
runtimeFilters,
exec.keyGroupedPartitioning,
exec.ordering,
exec.table,
exec.commonPartitionValues,
exec.applyPartialClustering,
exec.replicatePartitions)

@sparkver("3.5 / 4.0 / 4.1")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This groups 4.1 with 3.5/4.0 on the assumption that Spark 4.1's BatchScanExec constructor is still (output, scan, runtimeFilters, ordering, table, spjParams). The shims module compiles for 4.1 even though iceberg doesn't build there, so if 4.1 changed that constructor the 4.1 profile would fail to compile rather than fail a test. Was the 4.1 branch actually built against a 4.1 profile, or is this optimistic grouping ahead of 4.1 GA? If it hasn't been compiled against 4.1 yet, would it be safer to split 4.1 into its own branch (or drop it from the group) until the signature is confirmed?

override def copyBatchScanExecWithRuntimeFilters(
exec: BatchScanExec,
runtimeFilters: Seq[Expression]): BatchScanExec =
exec.copy(exec.output, exec.scan, runtimeFilters, exec.ordering, exec.table, exec.spjParams)

@sparkver("3.4 / 3.5 / 4.0 / 4.1")
private def effectiveLimit(rawLimit: Int): Int =
if (rawLimit == -1) Int.MaxValue else rawLimit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import org.apache.spark.sql.execution.auron.plan.NativeBroadcastJoinBase
import org.apache.spark.sql.execution.auron.plan.NativeSortMergeJoinBase
import org.apache.spark.sql.execution.auron.shuffle.RssPartitionWriterBase
import org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.execution.datasources.v2.BatchScanExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeLike, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, ShuffledHashJoinExec}
import org.apache.spark.sql.execution.metric.SQLMetric
Expand Down Expand Up @@ -125,6 +126,10 @@ abstract class Shims {
generatorOutput: Seq[Attribute],
child: SparkPlan): NativeGenerateBase

def copyBatchScanExecWithRuntimeFilters(
exec: BatchScanExec,
runtimeFilters: Seq[Expression]): BatchScanExec

def getLimitAndOffset(plan: GlobalLimitExec): (Int, Int) = (plan.limit, 0)

def createNativeGlobalLimitExec(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ class IcebergConvertProvider extends AuronConvertProvider with Logging {
case e: BatchScanExec =>
IcebergScanSupport.plan(e) match {
case Some(plan) =>
AuronConverters.addRenameColumnsExec(NativeIcebergTableScanExec(e, plan))
AuronConverters.addRenameColumnsExec(
NativeIcebergTableScanExec(e, plan, e.runtimeFilters))
case None =>
IcebergScanSupport.fallbackReason(e) match {
case Some(reason) => throw new AssertionError(reason)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ package org.apache.spark.sql.auron.iceberg
import scala.collection.JavaConverters._
import scala.util.control.NonFatal

import org.apache.commons.lang3.reflect.MethodUtils
import org.apache.iceberg.{AddedRowsScanTask, ChangelogOperation, ChangelogScanTask, FileFormat, FileScanTask, MetadataColumns, ScanTask}
import org.apache.iceberg.expressions.{And => IcebergAnd, BoundPredicate, Expression => IcebergExpression, Not => IcebergNot, Or => IcebergOr, UnboundPredicate}
import org.apache.iceberg.spark.source.AuronIcebergSourceUtil
import org.apache.spark.internal.Logging
import org.apache.spark.sql.auron.NativeConverters
import org.apache.spark.sql.auron.{NativeConverters, Shims}
import org.apache.spark.sql.catalyst.expressions.{And => SparkAnd, AttributeReference, EqualTo, Expression => SparkExpression, GreaterThan, GreaterThanOrEqual, In, IsNaN, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not => SparkNot, Or => SparkOr}
import org.apache.spark.sql.catalyst.trees.TreeNodeTag
import org.apache.spark.sql.connector.read.{InputPartition, Scan}
Expand Down Expand Up @@ -55,6 +56,8 @@ final case class IcebergScanPlan(
object IcebergScanSupport extends Logging {
private val scanPlanTag: TreeNodeTag[Option[IcebergScanPlan]] = TreeNodeTag(
"auron.iceberg.scan.plan")
private val runtimeFilteredScanPlanTag: TreeNodeTag[Option[IcebergScanPlan]] = TreeNodeTag(
"auron.iceberg.runtime.filtered.scan.plan")

private val SparkChangelogScanClassName =
"org.apache.iceberg.spark.source.SparkChangelogScan"
Expand Down Expand Up @@ -82,35 +85,54 @@ object IcebergScanSupport extends Logging {
}
}

def plan(exec: BatchScanExec): Option[IcebergScanPlan] = {
exec.getTagValue(scanPlanTag) match {
def plan(exec: BatchScanExec, useRuntimeFilters: Boolean = false): Option[IcebergScanPlan] = {
val tag =
if (useRuntimeFilters && exec.runtimeFilters.nonEmpty) {
runtimeFilteredScanPlanTag
} else {
scanPlanTag
}
exec.getTagValue(tag) match {
case Some(cached) => cached
case None =>
val planned = planUncached(exec)
exec.setTagValue(scanPlanTag, planned)
val planned = planUncached(exec, useRuntimeFilters)
exec.setTagValue(tag, planned)
planned
}
}

private def planUncached(exec: BatchScanExec): Option[IcebergScanPlan] = {
def withRuntimeFilters(
exec: BatchScanExec,
runtimeFilters: Seq[SparkExpression]): BatchScanExec = {
if (exec.runtimeFilters == runtimeFilters) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This guard exec.runtimeFilters == runtimeFilters looks like it's always true, so the Shims.get.copyBatchScanExecWithRuntimeFilters(...) else-branch is never taken. NativeIcebergTableScanExec is only ever constructed at IcebergConvertProvider.scala:59 as NativeIcebergTableScanExec(e, plan, e.runtimeFilters), so its runtimeFilters field is always the same object as basedScan.runtimeFilters (same e). Both call sites of withRuntimeFiltersNativeIcebergTableScanExec.scala:68 and :250 — pass withRuntimeFilters(basedScan, runtimeFilters) with runtimeFilters eq basedScan.runtimeFilters, and the node (a LeafExecNode) is never rebuilt with different filters anywhere in the tree.

Two things follow from that. First, all five @sparkver overloads of the new shim in ShimsImpl.scala (plus the abstract method in Shims.scala) are never invoked at runtime, so the new integration tests can't exercise them — a wrong version-specific copy(...) argument list would surface only as a compile error on that profile, never as a test failure. Second, doCanonicalize at NativeIcebergTableScanExec.scala:249-250 reduces to the previous basedScan.canonicalized (the wrapper returns basedScan unchanged), so the new comment there — "first make sure it sees the top-level runtime filters" — describes a transformation that doesn't currently happen.

Is this intentional groundwork for a future path that builds the node with filters different from basedScan (in which case a comment saying so, plus a test that takes the copy branch, would make the shim's ~40 version-specific lines defensible), or could withRuntimeFilters and the shim be dropped in favor of using basedScan directly? Since the field is always basedScan.runtimeFilters, I'm curious which direction you had in mind.

exec
} else {
Shims.get.copyBatchScanExecWithRuntimeFilters(exec, runtimeFilters)
}
}

private def planUncached(
exec: BatchScanExec,
useRuntimeFilters: Boolean): Option[IcebergScanPlan] = {
val scan = exec.scan
val scanClassName = scan.getClass.getName
// Only handle Iceberg scans; other sources must stay on Spark's path.
if (scanClassName == SparkChangelogScanClassName) {
return planChangelogScan(exec, scan)
return planChangelogScan(exec, scan, useRuntimeFilters)
}

if (!AuronIcebergSourceUtil.getClassOfSparkBatchQueryScan.isInstance(scan)) {
return None
}

planFileScan(exec, scan, scanClassName)
planFileScan(exec, scan, scanClassName, useRuntimeFilters)
}

private def planFileScan(
exec: BatchScanExec,
scan: Scan,
scanClassName: String): Option[IcebergScanPlan] = {
scanClassName: String,
useRuntimeFilters: Boolean): Option[IcebergScanPlan] = {
val readSchema = scan.readSchema
val schemas = supportedSchemas(readSchema, isChangelogScan = false)
if (schemas.isEmpty) {
Expand Down Expand Up @@ -143,7 +165,7 @@ object IcebergScanSupport extends Logging {
missingFieldIds.isEmpty,
s"Missing Iceberg field ids for columns: ${missingFieldIds.mkString(", ")}")

val partitions = inputPartitions(exec)
val partitions = inputPartitions(exec, useRuntimeFilters)
// Empty scan (e.g. empty table) should still build a plan to return no rows.
if (partitions.isEmpty) {
logWarning(s"Native Iceberg scan planned with empty partitions for $scanClassName.")
Expand Down Expand Up @@ -203,15 +225,18 @@ object IcebergScanSupport extends Logging {
fieldIdsByName))
}

private def planChangelogScan(exec: BatchScanExec, scan: Scan): Option[IcebergScanPlan] = {
private def planChangelogScan(
exec: BatchScanExec,
scan: Scan,
useRuntimeFilters: Boolean): Option[IcebergScanPlan] = {
val readSchema = scan.readSchema
val schemas = supportedSchemas(readSchema, isChangelogScan = true)
if (schemas.isEmpty) {
return None
}
val (fileSchema, partitionSchema) = schemas.get

val partitions = inputPartitions(exec)
val partitions = inputPartitions(exec, useRuntimeFilters)
if (partitions.isEmpty) {
return Some(
IcebergScanPlan(
Expand Down Expand Up @@ -326,7 +351,16 @@ object IcebergScanSupport extends Logging {
private def deletesEmpty(deletes: java.util.List[_]): Boolean =
deletes == null || deletes.isEmpty

private def inputPartitions(exec: BatchScanExec): Seq[InputPartition] = {
private def inputPartitions(
exec: BatchScanExec,
useRuntimeFilters: Boolean): Seq[InputPartition] = {
if (useRuntimeFilters) {
runtimeFilteredPartitions(exec) match {
case Some(partitions) => return partitions
case None =>
}
}

// Prefer DataSource V2 batch API; if not available, fallback to exec methods via reflection.
val fromBatch =
try {
Expand Down Expand Up @@ -382,6 +416,40 @@ object IcebergScanSupport extends Logging {
}
}

private def runtimeFilteredPartitions(exec: BatchScanExec): Option[Seq[InputPartition]] = {
if (exec.runtimeFilters.isEmpty) {
return None
}

try {
MethodUtils.invokeMethod(exec, true, "prepare")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prepare() is a public SparkPlan method and exec is a typed BatchScanExec, so exec.prepare() compiles directly — only waitForSubqueries (protected) and filteredPartitions genuinely need reflective access. Calling exec.prepare() here would drop one reflection call and leave a single reflection helper (invokeDeclaredMethod) for the two methods that actually need it. Minor.

MethodUtils.invokeMethod(exec, true, "waitForSubqueries")
invokeDeclaredMethod(exec, "filteredPartitions") match {
case Some(seq: scala.collection.Seq[_]) =>
Some(flattenPartitions(seq))
case _ =>
None
}
} catch {
case NonFatal(t) =>

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This NonFatal catch wraps prepare / waitForSubqueries / filteredPartitions and returns None, which flows up to inputPartitions and full-scans all partitions. That also swallows a genuine DPP subquery or broadcast execution failure and turns it into a silent full-table scan, where vanilla Spark's BatchScanExec would have surfaced the failure. Partitioning errors here isn't wrong for correctness (the join re-filters), but it changes error semantics — a real failure in waitForSubqueries gets masked. Should a waitForSubqueries failure be swallowed and fall through to a full scan at all, or would you rather let it propagate so a real subquery/broadcast failure stays visible? (It does log a warning, so this is about whether full-scanning past a genuine failure is the right default, not about missing logs.)

logWarning(
s"Failed to obtain runtime-filtered input partitions for ${exec.getClass.getName}.",
t)
None
}
}

private def flattenPartitions(seq: scala.collection.Seq[_]): Seq[InputPartition] = {
seq.flatMap {
case partition: InputPartition =>
Seq(partition)
case nested: scala.collection.Seq[_] =>
flattenPartitions(nested)
case _ =>
Seq.empty
}.toSeq
}

private case class IcebergPartitionView(tasks: Seq[ScanTask])

private def icebergPartition(partition: InputPartition): Option[IcebergPartitionView] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.auron.{EmptyNativeRDD, NativeConverters, NativeHelper, NativeRDD, NativeSupports, Shims}
import org.apache.spark.sql.auron.iceberg.{IcebergNativeScanTask, IcebergScanPlan}
import org.apache.spark.sql.auron.iceberg.{IcebergNativeScanTask, IcebergScanPlan, IcebergScanSupport}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Literal}
import org.apache.spark.sql.catalyst.expressions.{Expression, GenericInternalRow, Literal}
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan, SQLExecution}
import org.apache.spark.sql.execution.datasources.{FilePartition, PartitionedFile}
Expand All @@ -47,7 +47,10 @@ import org.apache.auron.{protobuf => pb}
import org.apache.auron.jni.JniBridge
import org.apache.auron.metric.SparkMetricNode

case class NativeIcebergTableScanExec(basedScan: BatchScanExec, plan: IcebergScanPlan)
case class NativeIcebergTableScanExec(
basedScan: BatchScanExec,
staticPlan: IcebergScanPlan,
runtimeFilters: Seq[Expression])
extends LeafExecNode
with NativeSupports
with Logging {
Expand All @@ -60,6 +63,15 @@ case class NativeIcebergTableScanExec(basedScan: BatchScanExec, plan: IcebergSca
override val output = basedScan.output
override val outputPartitioning = basedScan.outputPartitioning

private lazy val plan: IcebergScanPlan = {
if (runtimeFilters.nonEmpty) {
val filteredScan = IcebergScanSupport.withRuntimeFilters(basedScan, runtimeFilters)
IcebergScanSupport.plan(filteredScan, useRuntimeFilters = true).getOrElse(staticPlan)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When runtime-filtered planning returns None, this falls back to staticPlan — the unfiltered plan over all partitions — with no log line. Correctness is safe here (the enclosing join re-applies the predicate, and prune-to-empty returns Some(Seq.empty) rather than None, so it doesn't hit this branch), so this is really an observability question: someone debugging "did DPP actually apply?" gets no signal that the query quietly scanned everything. Would a logWarning on the getOrElse(staticPlan) branch help — something noting that runtime-filtered planning was unavailable and all partitions are being scanned?

} else {
staticPlan
}
}

private lazy val fileSchema: StructType = plan.fileSchema
private lazy val partitionSchema: StructType = plan.partitionSchema
private lazy val projectableSchema: StructType =
Expand Down Expand Up @@ -213,8 +225,29 @@ case class NativeIcebergTableScanExec(basedScan: BatchScanExec, plan: IcebergSca

override val nodeName: String = "NativeIcebergTableScan"

// Delegate canonicalization to the original scan to keep plan equivalence checks consistent.
override protected def doCanonicalize(): SparkPlan = basedScan.canonicalized
override def simpleString(maxFields: Int): String = {
val runtimeFiltersString =
if (runtimeFilters.nonEmpty) {
s", runtimeFilters=${runtimeFilters.mkString("[", ", ", "]")}"
} else {
""
}
s"$nodeName (${basedScan.simpleString(maxFields)}$runtimeFiltersString)"
}

override def verboseStringWithOperatorId(): String = {
s"""
|$formattedNodeName
|Output: ${output.mkString("[", ", ", "]")}
|${basedScan.scan.description()}
|RuntimeFilters: ${runtimeFilters.mkString("[", ", ", "]")}
|""".stripMargin
}

// Keep canonicalization aligned with Spark's BatchScanExec, but first make sure it sees
// the top-level runtime filters carried by this native scan.
override protected def doCanonicalize(): SparkPlan =
IcebergScanSupport.withRuntimeFilters(basedScan, runtimeFilters).canonicalized

private def buildFileSizes(): Map[String, Long] = {
// Map file path to full file size; tasks may split a file into multiple ranges.
Expand Down
Loading
Loading