Skip to content

Commit

Permalink
version 0.1
Browse files Browse the repository at this point in the history
  • Loading branch information
visze committed Aug 16, 2016
1 parent a2498a9 commit 21287a4
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<groupId>de.charite.compbio</groupId>
<artifactId>weka-GWAVA</artifactId>
<packaging>jar</packaging>
<version>0.1-SNAPSHOT</version>
<version>0.1</version>
<name>weka-GWAVA</name>
<url>https://charite.github.io/weka-GWAVA/</url>
<licenses>
Expand Down
18 changes: 12 additions & 6 deletions src/main/java/weka/classifiers/meta/GWAVABagging.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,31 @@
import weka.filters.Filter;
import weka.filters.supervised.instance.SpreadSubsample;

/**
* Extended class for bagging. Important that for every training set the majority class must be subsampled to the
* minority class.
*
* @author <a href="mailto:[email protected]">Max Schubach</a>
*
*/
public class GWAVABagging extends Bagging {

/**
* for serialization
*/
private static final long serialVersionUID = -3101726478201686871L;



@Override
protected synchronized Instances getTrainingSet(int iteration) throws Exception {
Instances bagData = super.getTrainingSet(iteration);
Instances bagData = super.getTrainingSet(iteration);

SpreadSubsample subsample = new SpreadSubsample();

subsample.setRandomSeed(m_Seed);
subsample.setDistributionSpread(1);
subsample.setInputFormat(bagData);
return Filter.useFilter(bagData, subsample);

}

}
8 changes: 7 additions & 1 deletion src/main/java/weka/classifiers/trees/GWAVARandomForest.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
package weka.classifiers.trees;

import weka.classifiers.trees.RandomForest;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.supervised.instance.SpreadSubsample;

/**
* The GWAVA random forest. Important that for every training set the majority class must be subsampled to the minority
* class.
*
* @author <a href="mailto:[email protected]">Max Schubach</a>
*
*/
public class GWAVARandomForest extends RandomForest {

/**
Expand Down
17 changes: 17 additions & 0 deletions src/test/java/weka/classifiers/meta/GWAVABaggingTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
import weka.classifiers.trees.J48;
import weka.core.Instances;

/**
*
* test class for {@link GWAVABagging}
*
* @author <a href="mailto:[email protected]">Max Schubach</a>
*
*/
public class GWAVABaggingTest {

private static Instances data;
Expand All @@ -25,6 +32,11 @@ public class GWAVABaggingTest {
private static int seed = 42;
private int folds = 10;

/**
* set up
*
* @throws Exception
*/
@Before
public void setUp() throws Exception {
File file = new File(Resources.getResource(diabetesFile).getPath());
Expand All @@ -38,6 +50,11 @@ public void setUp() throws Exception {
randData.randomize(rand);
}

/**
* test the tree
*
* @throws Exception
*/
@Test
public void classifyJ48Test() throws Exception {

Expand Down
38 changes: 33 additions & 5 deletions src/test/java/weka/classifiers/trees/GWAVARandomForestTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
import weka.classifiers.Evaluation;
import weka.core.Instances;

/**
*
* test class for {@link GWAVARandomForest}
*
* @author <a href="mailto:[email protected]">Max Schubach</a>
*
*/
public class GWAVARandomForestTest {

private Instances randDiabetesData;
Expand All @@ -30,6 +37,12 @@ public class GWAVARandomForestTest {
private int seed = 42;
private int folds = 10;

/**
*
* setup (load data etc.)
*
* @throws Exception
*/
@Before
public void setUp() throws Exception {
File file = new File(Resources.getResource(diabetesFile).getPath());
Expand Down Expand Up @@ -60,6 +73,11 @@ public void setUp() throws Exception {
randGeneratedImbalancedData.randomize(rand);
}

/**
* test GWAVA vs tree
*
* @throws Exception
*/
@Test
public void classifyJ48Test() throws Exception {

Expand All @@ -69,7 +87,7 @@ public void classifyJ48Test() throws Exception {
eval.crossValidateModel(gwava, randDiabetesData, folds, new Random(seed));

double prcGWAVA = eval.areaUnderPRC(1);
double rocGWAVA= eval.areaUnderROC(1);
double rocGWAVA = eval.areaUnderROC(1);

eval = new Evaluation(randDiabetesData);
eval.crossValidateModel(new J48(), randDiabetesData, folds, new Random(seed));
Expand All @@ -80,6 +98,11 @@ public void classifyJ48Test() throws Exception {
assertThat(rocGWAVA, Matchers.greaterThan(rocJ48));
}

/**
* test gwava vs tree bin data
*
* @throws Exception
*/
@Test
public void classifyRFRandomBinDataTest() throws Exception {

Expand All @@ -89,17 +112,22 @@ public void classifyRFRandomBinDataTest() throws Exception {

Evaluation eval = new Evaluation(randGeneratedImbalancedBinData);
eval.crossValidateModel(gwava, randGeneratedImbalancedBinData, folds, new Random(seed));
double prcHyperSMURF = eval.areaUnderPRC(1);
double rocHyperSMURF = eval.areaUnderROC(1);
double prcGWAVA = eval.areaUnderPRC(1);
double rocGWAVA = eval.areaUnderROC(1);
eval = new Evaluation(randGeneratedImbalancedBinData);
eval.crossValidateModel(new J48(), randGeneratedImbalancedBinData, folds, new Random(seed));
double prcJ48 = eval.areaUnderPRC(1);
double rocJ48 = eval.areaUnderROC(1);

assertThat(prcHyperSMURF, Matchers.greaterThan(prcJ48));
assertThat(rocHyperSMURF, Matchers.greaterThan(rocJ48));
assertThat(prcGWAVA, Matchers.greaterThan(prcJ48));
assertThat(rocGWAVA, Matchers.greaterThan(rocJ48));
}

/**
* test gwava vst tree using random data
*
* @throws Exception
*/
@Test
public void classifyRFRandomDataTest() throws Exception {

Expand Down

0 comments on commit 21287a4

Please sign in to comment.