猿问

如何在 Java Weka API 中使用类不平衡技术 (SMOTE)?

我正在尝试使用 Java Weka API 构建分类模型。我的训练数据集存在类别不平衡问题。出于这个原因,我想使用像 SMOTE 这样的类不平衡技术来减少类不平衡问题。


源代码如下:


package classification;

import java.util.Random;

import weka.classifiers.Classifier;

import weka.classifiers.bayes.NaiveBayesMultinomial;

import weka.core.Instance;

import weka.core.Instances;

import weka.core.converters.ConverterUtils.DataSource;

import weka.filters.Filter;

import weka.filters.unsupervised.attribute.StringToWordVector;

public class questStackoverflow {

public static void main(String agrs[]) throws Exception{

String fileRootPath = "../file.arff"; //Dataset

    Instances strdata = DataSource.read(fileRootPath); //Load Dataset

    StringToWordVector filter = new StringToWordVector(10000);

    filter.setInputFormat(strdata);

    String[] options = { "-W", "10000", "-L", "-M", "1",

            "-stemmer", "weka.core.stemmers.IteratedLovinsStemmer", 

            "-stopwords-handler", "weka.core.stopwords.Rainbow", 

            "-tokenizer", "weka.core.tokenizers.AlphabeticTokenizer" 

            };

    filter.setOptions(options);

    filter.setIDFTransform(true);

    Instances data = Filter.useFilter(strdata,filter); //Apply filter

    data.setClassIndex(0); //set class index        

    double recall=0.0;

    double precision=0.0;

    double fmeasure=0.0;

    double tp, fp, fn, tn;


    Classifier classifier = null;

    classifier = new NaiveBayesMultinomial(); //classifer

                }

            }   

        }


我的代码在没有类不平衡技术的情况下运行良好。但是,我需要使用类不平衡技术来缓解类不平衡问题。但是,我不知道如何在 Java Weka API 中使用它。


慕田峪7331174
浏览 284回答 1
1回答

幕布斯7119047

您可以在代码中添加以下代码行:weka.filters.supervised.instance.SMOTESMOTE smote=new SMOTE();smote.setInputFormat(trains);&nbsp; &nbsp; &nbsp; &nbsp;Instances Trains_smote= Filter.useFilter(trains, smote);您的代码如下。package classification;import java.util.Random;import weka.classifiers.Classifier;import weka.classifiers.bayes.NaiveBayesMultinomial;import weka.core.Instance;import weka.core.Instances;import weka.core.converters.ConverterUtils.DataSource;import weka.filters.Filter;import weka.filters.unsupervised.attribute.StringToWordVector;weka.filters.supervised.instance.SMOTEpublic class questStackoverflow {public static void main(String agrs[]) throws Exception{String fileRootPath = "../file.arff"; //DatasetInstances strdata = DataSource.read(fileRootPath); //Load DatasetStringToWordVector filter = new StringToWordVector(10000);filter.setInputFormat(strdata);String[] options = { "-W", "10000", "-L", "-M", "1",&nbsp; &nbsp; &nbsp; &nbsp; "-stemmer", "weka.core.stemmers.IteratedLovinsStemmer",&nbsp;&nbsp; &nbsp; &nbsp; &nbsp; "-stopwords-handler", "weka.core.stopwords.Rainbow",&nbsp;&nbsp; &nbsp; &nbsp; &nbsp; "-tokenizer", "weka.core.tokenizers.AlphabeticTokenizer"&nbsp;&nbsp; &nbsp; &nbsp; &nbsp; };filter.setOptions(options);filter.setIDFTransform(true);Instances data = Filter.useFilter(strdata,filter); //Apply filterdata.setClassIndex(0); //set class index&nbsp; &nbsp; &nbsp; &nbsp;&nbsp;double recall=0.0;double precision=0.0;double fmeasure=0.0;double tp, fp, fn, tn;Classifier classifier = null;classifier = new NaiveBayesMultinomial(); //classiferint folds = 10;&nbsp; &nbsp; &nbsp; &nbsp; &nbsp;Random random = new Random(1);data.randomize(random);data.stratify(folds);tp = fp = fn = tn = 0;for (int i = 0; i < folds; i++) {&nbsp; &nbsp;Instances trains = data.trainCV(folds, i,random); //training dataset&nbsp; &nbsp;Instances tests = data.testCV(folds, i); //testing dataset&nbsp; &nbsp;SMOTE smote=new SMOTE();&nbsp; &nbsp;smote.setInputFormat(trains);&nbsp; &nbsp; &nbsp; &nbsp;&nbsp;&nbsp; &nbsp;Instances Trains_smote = Filter.useFilter(trains, smote);&nbsp; &nbsp; classifier.buildClassifier(Trains_smote);&nbsp; &nbsp; //build classifier&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;&nbsp; &nbsp; for (int j = 0; j < tests.numInstances(); j++) {&nbsp; &nbsp;&nbsp;&nbsp; &nbsp; &nbsp; &nbsp;Instance instance = tests.instance(j);&nbsp; &nbsp; &nbsp; &nbsp;double classValue = instance.classValue();&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp;&nbsp; &nbsp; &nbsp; &nbsp;double result = classifier.classifyInstance(instance);&nbsp; &nbsp; &nbsp; &nbsp; if (result == 0.0 && classValue == 0.0) {&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; tp++;&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; } else if (result == 0.0 && classValue == 1.0) {&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; fp++;&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; } else if (result == 1.0 && classValue == 0.0) {&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; fn++;&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; } else if (result == 1.0 && classValue == 1.0) {&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; tn++;&nbsp; &nbsp; &nbsp; &nbsp; &nbsp; &nbsp; }&nbsp; &nbsp; &nbsp; &nbsp; }&nbsp; &nbsp;&nbsp; &nbsp; }&nbsp; &nbsp; if (tn + fn > 0)&nbsp; &nbsp; &nbsp; &nbsp; precision = tn / (tn + fn);&nbsp; &nbsp; if (tn + fp > 0)&nbsp; &nbsp; &nbsp; &nbsp; recall = tn / (tn + fp);&nbsp; &nbsp; if (precision + recall > 0)&nbsp; &nbsp; &nbsp; &nbsp; fmeasure = 2 * precision * recall / (precision + recall);&nbsp; &nbsp; System.out.println("Precision: " + precision);&nbsp; &nbsp; System.out.println("Recall: " + recall);&nbsp; &nbsp; System.out.println("Fmeasure: " + fmeasure);}}
随时随地看视频慕课网APP

相关分类

Java
我要回答