我正在尝试在 Azure Synapse Pyspark 笔记本中重写一些 Oracle 代码,但是某些代码是作为递归 CTE 编写的。我不知道如何转换它。我已经做了好几天了,但没有运气。有谁知道重写这段代码的好方法吗? “recursiveCTE”在 CTE 本身内部被调用,这在 Synpase 中是不允许的。
recursiveCTE AS (
SELECT
ola.columnOne,
ola.columnTwo,
ola.columnThree,
nvl(ola.reference_line_id,nvl(ola.source_document_line_id,to_number(ola.return_attribute2) ) ) AS columnFour,
1 AS columnFive,
ola.line_id AS columnSix,
ola.columnSeven,
ola.columnEight
FROM
tableOne ola
WHERE
ola.columnOne IN (select header_ids.columnOne from header_ids)
AND ola.cancelled_flag = 'N'
AND ola.columnThree IN (
1006,
1011
)
UNION ALL
SELECT
ol2.columnOne,
ol2.columnTwo,
ol2.columnThree,
recursiveCTE.columnTwo,
lvl + 1 AS columnFive,
recursiveCTE.ref_orig_line_id AS columnSix,
ol2.columnSeven,
ol2.fulfilled_quantity * -1 as columnEight
FROM
tableOne ol2,
recursiveCTE
WHERE
nvl(ol2.reference_line_id,nvl(ol2.source_document_line_id,to_number(ol2.return_attribute2) ) ) = recursiveCTE.columnTwo
AND ol2.columnOne IN (select header_ids.columnOne from header_ids)
AND ol2.cancelled_flag = 'N'
AND ol2.columnThree IN (
1007
)
),
Azure Synapse Analytics 不支持 SQL Server 或 Oracle 中的递归 CTE。因此,最好的方法是在 PySpark 笔记本中使用迭代方法重构递归 CTE。
from pyspark.sql import functions as F
# Initial DataFrame corresponding to the base case of your recursive CTE
base_df = spark.sql("""
SELECT
ola.columnOne,
ola.columnTwo,
ola.columnThree,
nvl(ola.reference_line_id, nvl(ola.source_document_line_id, to_number(ola.return_attribute2))) AS columnFour,
1 AS columnFive,
ola.line_id AS columnSix,
ola.columnSeven,
ola.columnEight
FROM
tableOne ola
WHERE
ola.columnOne IN (select header_ids.columnOne from header_ids)
AND ola.cancelled_flag = 'N'
AND ola.columnThree IN (1006, 1011)
""")
# You can use the iterative approach to replicate the recursive CTE logic
previous_count = 0
limit = 1000 # Just as a safety precaution to avoid infinite loops
iteration = 0
while True and iteration < limit:
# Join base_df with tableOne on the recursive condition
recursive_join = base_df.alias("recursiveCTE").join(
spark.table("tableOne").alias("ol2"),
F.expr("nvl(ol2.reference_line_id, nvl(ol2.source_document_line_id, to_number(ol2.return_attribute2))) = recursiveCTE.columnTwo")
& (F.col("ol2.columnOne").isin(spark.sql("select header_ids.columnOne from header_ids")))
& (F.col("ol2.cancelled_flag") == 'N')
& (F.col("ol2.columnThree") == 1007),
"inner"
).select(
F.col("ol2.columnOne"),
F.col("ol2.columnTwo"),
F.col("ol2.columnThree"),
F.col("recursiveCTE.columnTwo"),
(F.col("recursiveCTE.columnFive") + 1).alias("columnFive"),
F.col("recursiveCTE.columnSix"),
F.col("ol2.columnSeven"),
(F.col("ol2.columnEight") * -1).alias("columnEight")
)
new_count = recursive_join.count()
# Check if we have any new rows
if new_count == previous_count:
break
else:
base_df = base_df.union(recursive_join)
previous_count = new_count
iteration += 1
# base_df now contains the full result
base_df.show()