Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from pyspark.sql import SparkSession
- from pyspark.sql.functions import col, broadcast
- from pyspark.storagelevel import StorageLevel
- from pyspark.sql.utils import AnalysisException, Py4JJavaError
- import logging
- spark = SparkSession.builder \
- .appName("EnhancedDataReconciliation") \
- .config("spark.sql.shuffle.partitions", str(spark.sparkContext.defaultParallelism * 3)) \
- .config("spark.dynamicAllocation.enabled", "true") \
- .getOrCreate()
- dfA = spark.read.parquet("/path/to/tableA")
- dfB = spark.read.parquet("/path/to/tableB")
- num_partitions = spark.sparkContext.defaultParallelism * 3
- dfA = dfA.repartition(num_partitions, "policy_name").persist(StorageLevel.MEMORY_AND_DISK)
- dfB = dfB.repartition(num_partitions, "policy_name").persist(StorageLevel.MEMORY_AND_DISK)
- def safe_broadcast(df, max_size=10*1024*1024):
- try:
- if df.count() * len(df.columns) * 8 < max_size:
- return broadcast(df)
- else:
- return df
- except AnalysisException:
- return df
- dfA_broadcasted = safe_broadcast(dfA)
- dfB_broadcasted = safe_broadcast(dfB)
- exact_matches = dfA_broadcasted.join(
- dfB_broadcasted,
- on=["policy_name", "source_ip", "destination_ip", "port", "protocol"],
- how="inner"
- )
- filtered_dfA = dfA.filter(col("policy_name").isNotNull())
- filtered_dfB = dfB.filter(col("policy_name").isNotNull())
- partial_matches = filtered_dfA.crossJoin(filtered_dfB).filter(
- (col("filtered_dfA.policy_name") != col("filtered_dfB.policy_name")) &
- (
- (col("filtered_dfA.source_ip") == col("filtered_dfB.source_ip")) |
- (col("filtered_dfA.destination_ip") == col("filtered_dfB.destination_ip")) |
- (col("filtered_dfA.port") == col("filtered_dfB.port")) |
- (col("filtered_dfA.protocol") == col("filtered_dfB.protocol"))
- )
- )
- non_matches_A = dfA.join(broadcast(dfB), on=["policy_name", "source_ip", "destination_ip", "port", "protocol"], how="left_anti")
- non_matches_B = dfB.join(broadcast(dfA), on=["policy_name", "source_ip", "destination_ip", "port", "protocol"], how="left_anti")
- exact_matches_result = exact_matches.collect()
- partial_matches_result = partial_matches.collect()
- non_matches_A_result = non_matches_A.collect()
- non_matches_B_result = non_matches_B.collect()
- exact_matches.write.parquet("/path/to/exact_matches")
- partial_matches.write.parquet("/path/to/partial_matches")
- non_matches_A.write.parquet("/path/to/non_matches_A")
- non_matches_B.write.parquet("/path/to/non_matches_B")
- dfA.unpersist()
- dfB.unpersist()
- logging.basicConfig(level=logging.INFO)
- logger = logging.getLogger(__name__)
- logger.info("Exact matches count: %d", len(exact_matches_result))
- logger.info("Partial matches count: %d", len(partial_matches_result))
- logger.info("Non-matches A count: %d", len(non_matches_A_result))
- logger.info("Non-matches B count: %d", len(non_matches_B_result))
- try:
- exact_matches.write.parquet("/path/to/exact_matches")
- except (AnalysisException, Py4JJavaError) as e:
- logger.error("Error saving exact matches: %s", str(e))
- def monitor_cache_hit_ratio(spark_session):
- storage_status = spark_session.sparkContext._jsc.getPersistentRDDs()
- for (rdd_id, rdd_info) in storage_status.items():
- if rdd_info.numCachedPartitions() > 0:
- logger.info(f"RDD ID {rdd_id}: {rdd_info.numCachedPartitions()} partitions cached out of {rdd_info.numPartitions()} total")
- monitor_cache_hit_ratio(spark)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement