This note is a constant work in progress
This note will be useful to you if you are:
- A revising/updating their knowledge (me!)
- A junior data engineers building up a vocabulary of solutions to common problems
How to approach performance optimization?
Your operational loop is going to be:
- Determine the Service Level Agreement (SLA) (in this case - time budget within which the pipeline has to execute)
- Build pipeline
- Measure metrics
- Find out bottlenecks
- Reduce time lost to bottlenecks
- Loop steps 3 to 5 until SLA is met with a good margin
- End
Speed up reads
Slow JDBC query based ingestion
By default, spark will read all SQL query data through a single query (as a single partition). This will be very slow (as there is no parallelization), and also cause a gigantic skew (as it will form only one partition).
To avoid this issue, there are mechanisms to tell the JDBC reader how to parallelize the query.
Parallelize on a numeric column
If the primary key column is if type int, long, timestamp, then it is quite straightforward:
df = spark.read.jdbc(
url=jdbc_url,
table="transactions",
column="id", # numeric column used for partitioning
lowerBound=1,
upperBound=1000000,
numPartitions=10,
properties=props
)
This will end up firing 4 SQL queries in parallel
SELECT * FROM transactions WHERE id >= 1 AND id < 100001;
SELECT * FROM transactions WHERE id >= 100001 AND id < 200001;
...
SELECT * FROM transactions WHERE id >= 900001 AND id <= 1000000;
Parallelize on a string column
The predicates will have to be supplied manually
predicates = [
"region = 'NORTH'",
"region = 'SOUTH'",
"region = 'EAST'",
"region = 'WEST'"
]
df = spark.read.jdbc(
url=jdbc_url,
table="customers",
predicates=predicates,
properties=props
)
Each string in predicates will be used in a WHERE clause as follows:
SELECT * FROM customers WHERE region = 'NORTH';
SELECT * FROM customers WHERE region = 'SOUTH';
SELECT * FROM customers WHERE region = 'EAST';
SELECT * FROM customers WHERE region = 'WEST';
Parallelize using a hashing function
This method will put load on the db, but the data will be evenly distributed among the spark job's executors. But this is the only realistic method for dealing with columns containing high cardinality non-numeric datatypes like UUID strings etc.
Eg - Split the workload across 4 queries:
num_partitions = 4
predicates = [f"MOD(ABS(HASH(user_id)), {num_partitions}) = {i}" for i in range(num_partitions)]
df = spark.read.jdbc(
url=jdbc_url,
table="(SELECT * FROM users) AS users_subquery",
predicates=predicates,
properties=props
)
This will create the following queries:
SELECT * FROM users WHERE MOD(ABS(HASH(user_id)), 4) = 0;
SELECT * FROM users WHERE MOD(ABS(HASH(user_id)), 4) = 1;
SELECT * FROM users WHERE MOD(ABS(HASH(user_id)), 4) = 2;
SELECT * FROM users WHERE MOD(ABS(HASH(user_id)), 4) = 3;
Parallelize on timestamps
This is great for reconciliation runs etc
predicates = [
"created_at < '2024-01-01'",
"created_at >= '2024-01-01' AND created_at < '2024-04-01'",
"created_at >= '2024-04-01'"
]
df = spark.read.jdbc(
url=jdbc_url,
table="(SELECT * FROM users) AS users_subquery",
predicates=predicates,
properties=props
)
Slow transformations
Slow joins
(There are more details in the spark join strategies note)
One table is small enough to fit into memory
Perform a broadcast hash join (usually by broadcasting a dimension table)
# Force spark to send small_df to all executors
broadcast_join = big_df.join(broadcast(small_df), big_df.big_id == small_df.small_id, "inner")
Medium-sized datasets
When the dataset's hash can fit into memory, use the shuffle hash join.
(the hash table is used for comparison. So it must fit into the memory.)
hash_join = big_df.hint("shuffle_hash").join(
small_df.hint("shuffle_hash"),
big_df.big_id == small_df.small_id,
"inner"
)
Chained withColumn() induced slowness
Chaining withColumn will slow down spark code.
Reason? Spark has lazy evaluation. So, every withColumn will trigger a new transformation. For each withColumn(), a new intermediate DataFrame will be created. These intermediate dataframes will stay in memory until execution reaches an action (after which the dataframes will be eligible for garbage collection).
And, if you really care about it, there is also the processing delay on the query optimizer side. Every withColumn creates a new logical plan root node. In other words, every withColumn() creates a new projection node in the job's execution DAG. So, for every projection, spark will re-resolve every column reference, re-run type checking(if applicable), re-run predicate & projection pushdowns etc. (So, this leads to a quadratic time complexity on cost analyzer and optimizer's side).
Solution? Use a single select statement. The single select statement won't create intermediate dataframes - massively reducing memory as well as execution time.
# BEFORE OPTIMIZATION #########
df_new_cols = (
df
.withcolumn("new_col_a", col("existing_col_a") * 2)
.withcolumn("new_col_b", col("existing_col_b") * 2)
.withcolumn("new_col_c", col("existing_col_c") * 2)
)
# AFTER OPTMIZATION ##########
df_new_cols = df.select(
col("col1"),
col("col2"),
col("col3"),
(col("existing_col_a") * 2).alias("new_col_a"),
(col("existing_col_b") * 2).alias("new_col_b"),
(col("existing_col_c") * 2).alias("new_col_c"),
)
Random issues
Parquet specific issues
Parquet vectorized reader fails for decimal columns (until spark 3.1)
Databricks spark has vectorized reads enabled by default spark.sql.parquet.enableVectorizedReader = True. This leads to a huge boost in parquet parsing performance, but it also causes problems when reading decimal columns. This is because vectorized reader doesn't support decimal, and starts treating decimal values as binary.
Turn it off if this issue is encountered:
park.conf.set("spark.sql.parquet.enableVectorizedReader", "false")