猿问

如何获取 Spark DataFrame 中每行列表中最大值的索引?

我已经完成了 LDA 主题建模并将其存储在lda_model.


转换我的原始输入数据集后,我检索了一个 DataFrame。其中一列是 topicDistribution,其中该行属于 LDA 模型中每个主题的概率。因此,我想获取每行列表中最大值的索引。


df -- | 'list_of_words' | 'index ' | 'topicDistribution' | 

       ['product','...']     0       [0.08,0.2,0.4,0.0001]

          .....             ...         ........

我想转换 df 以便添加一个附加列,它是每行 topicDistribution 列表的 argmax。


df_transformed --  | 'list_of_words' | 'index' | 'topicDistribution' | 'topicID' |

                    ['product','...']     0     [0.08,0.2,0.4,0.0001]      2

                       ......            ....         .....              ....

我该怎么做?


开心每一天1111
浏览 352回答 1
1回答

qq_花开花谢_0

您可以创建一个用户定义的函数来获取最大值的索引from pyspark.sql import functions as ffrom pyspark.sql.types import IntegerTypemax_index = f.udf(lambda x: x.index(max(x)), IntegerType())df = df.withColumn("topicID", max_index("topicDistribution"))例子>>> from pyspark.sql import functions as f>>> from pyspark.sql.types import IntegerType >>> df = spark.createDataFrame([{"topicDistribution": [0.2, 0.3, 0.5]}])>>> df.show()+-----------------+|topicDistribution|+-----------------+|  [0.2, 0.3, 0.5]|+-----------------+>>> max_index = f.udf(lambda x: x.index(max(x)), IntegerType())>>> df.withColumn("topicID", max_index("topicDistribution")).show()+-----------------+-------+|topicDistribution|topicID|+-----------------+-------+|  [0.2, 0.3, 0.5]|      2|+-----------------+-------+编辑:由于您提到其中的列表topicDistribution是 numpy 数组,因此您可以更新max_index udf如下:max_index = f.udf(lambda x: x.tolist().index(max(x)), IntegerType())
随时随地看视频慕课网APP

相关分类

Python
我要回答