Pyspark:多个数组的交集

我有以下测试数据,必须借助pyspark检查以下语句(数据实际上非常大:700000笔交易,每笔交易有10+个产品):


import pandas as pd

import datetime


data = {'date': ['2014-01-01', '2014-01-02', '2014-01-03', '2014-01-04', '2014-01-05', '2014-01-06'],

     'customerid': [1, 2, 2, 3, 4, 3], 'productids': ['A;B', 'D;E', 'H;X', 'P;Q;G', 'S;T;U', 'C;G']}

data = pd.DataFrame(data)

data['date'] = pd.to_datetime(data['date'])

“某个客户 ID 在 x 天内存在的交易的特征是购物车中至少有一件相同的产品。”


到目前为止,我有以下方法(例如 x = 2):


spark = SparkSession.builder \

    .master('local[*]') \

    .config("spark.driver.memory", "500g") \

    .appName('my-pandasToSparkDF-app') \

    .getOrCreate()

spark.conf.set("spark.sql.execution.arrow.enabled", "true")

spark.sparkContext.setLogLevel("OFF")


df=spark.createDataFrame(data)


x = 2


win = Window().partitionBy('customerid').orderBy(F.col("date").cast("long")).rangeBetween(-(86400*x), Window.currentRow)

test = df.withColumn("productids", F.array_distinct(F.split("productids", "\;")))\

    .withColumn("flat_col", F.array_distinct(F.flatten((F.collect_list("productids").over(win))))).orderBy(F.col("date"))


test = test.toPandas()

因此,从我们查看过去 2 天的每笔交易中,按 customerid 分组,相应的产品汇总在“flat_col”列中。

但我真正需要的是相同ID的购物篮的交集。只有这样我才能判断是否有常见的产品。

因此,“flat_col”的第五行中应该有 ['G'],而不是 ['P', 'Q', 'G', 'C']。同样,[] 应该出现在“flat_col”的所有其他行中。

太感谢了!


侃侃尔雅
浏览 114回答 2
2回答

互换的青春

您可以在不使用in 的情况下实现这一点self-join(因为连接shuffle在大数据中是昂贵的操作)。使用的功能。higher order functionsspark 2.4filter,transform,aggregatedf=spark.createDataFrame(data)x = 2win = Window().partitionBy('customerid').orderBy(F.col("date").cast("long")).rangeBetween(-(86400*x), Window.currentRow)test = df.withColumn("productids", F.array_distinct(F.split("productids", "\;")))\    .withColumn("flat_col", F.flatten(F.collect_list("productids").over(win)))\    .withColumn("occurances", F.expr("""filter(transform(productids, x->\     IF(aggregate(flat_col, 0,(acc,t)->acc+IF(t=x,1,0))>1,x,null)),y->y!='null')"""))\    .drop("flat_col").orderBy("date").show()+-------------------+----------+----------+----------+|               date|customerid|productids|occurances|+-------------------+----------+----------+----------+|2014-01-01 00:00:00|         1|    [A, B]|        []||2014-01-02 00:00:00|         2|    [D, E]|        []||2014-01-03 00:00:00|         2|    [H, X]|        []||2014-01-04 00:00:00|         3| [P, Q, G]|        []||2014-01-05 00:00:00|         4| [S, T, U]|        []||2014-01-06 00:00:00|         3|    [C, G]|       [G]|+-------------------+----------+----------+----------+

呼唤远方

自加入是有史以来最好的把戏from pyspark.sql.functions import concat_ws, collect_listspark.createDataFrame(data).registerTempTable("df")sql("SELECT date, customerid, explode(split(productids, ';')) productid FROM df").registerTempTable("altered")df = sql("SELECT al.date, al.customerid, al.productid productids, altr.productid flat_col FROM altered al left join altered altr on altr.customerid = al.customerid and al.productid = altr.productid and al.date != altr.date and datediff(al.date,altr.date) <=2 and datediff(al.date,altr.date) >=-2")df.groupBy("date", "customerid").agg(concat_ws(",", collect_list("productids")).alias('productids'), concat_ws(",", collect_list("flat_col")).alias('flat_col')).show()火花输出
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python