如何使用 pyspark 对多数类进行欠采样

我尝试像下面的代码一样解决数据,但是我还没有使用 groupy 和 udf 弄清楚它,并且还发现 udf 无法返回数据帧。


有什么办法可以通过spark或其他一些方法来实现这一点,可以处理不平衡的数据


ratio = 3

def balance_classes(grp):

    picked = grp.loc[grp.editorsSelection == True]

    n = round(picked.shape[0]*ratio)

    if n:        

        try:

            not_picked = grp.loc[grp.editorsSelection == False].sample(n)

        except: # In case, fewer than n comments with `editorsSelection == False`

            not_picked = grp.loc[grp.editorsSelection == False]

        balanced_grp = pd.concat([picked, not_picked])

        return balanced_grp

    else: # If no editor's pick for an article, dicard all comments from that article

        return None 


comments = comments.groupby('articleID').apply(balance_classes).reset_index(drop=True)


白猪掌柜的
浏览 373回答 1
1回答

德玛西亚99

我通常使用这个逻辑来欠采样:def resample(base_features,ratio,class_field,base_class):    pos = base_features.filter(col(class_field)==base_class)    neg = base_features.filter(col(class_field)!=base_class)    total_pos = pos.count()    total_neg = neg.count()    fraction=float(total_pos*ratio)/float(total_neg)    sampled = neg.sample(False,fraction)    return sampled.union(pos)base_feature 是具有特征的火花数据框。ratio 是正负之间的期望比率 class_field 是包含类的列的名称,base_class 是类的 id
打开App,查看更多内容
随时随地看视频慕课网APP

相关分类

Python