Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yuening Hu's Tree Topic Modeling #74

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ MALLET_DIR = $(shell pwd)

JAVAC = javac
JAVA_FLAGS = \
-classpath "$(MALLET_DIR)/class:$(MALLET_DIR)/lib/mallet-deps.jar:$(MALLET_DIR)/lib/jdom-1.0.jar:$(MALLET_DIR)/lib/grmm-deps.jar:$(MALLET_DIR)/lib/weka.jar " \
-classpath "$(MALLET_DIR)/lib/wordnet.jar:$(MALLET_DIR)/class:$(MALLET_DIR)/lib/mallet-deps.jar:$(MALLET_DIR)/lib/jdom-1.0.jar:$(MALLET_DIR)/lib/grmm-deps.jar:$(MALLET_DIR)/lib/weka.jar " \
-sourcepath "$(MALLET_DIR)/src" \
-g:lines,vars,source \
-d $(MALLET_DIR)/class \
Expand Down
Binary file added lib/.DS_Store
Binary file not shown.
Binary file added lib/wordnet.jar
Binary file not shown.
157 changes: 157 additions & 0 deletions src/cc/mallet/topics/tree/CorpusWriter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
package cc.mallet.topics.tree;

import java.io.BufferedReader;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.PrintStream;
import java.util.ArrayList;

import gnu.trove.TIntArrayList;
import gnu.trove.TIntIntHashMap;
import cc.mallet.types.FeatureSequence;
import cc.mallet.types.Instance;
import cc.mallet.types.InstanceList;

public class CorpusWriter {

public static void writeCorpus(InstanceList training, String outfilename, String vocabname) throws FileNotFoundException {

ArrayList<String> vocab = loadVocab(vocabname);

PrintStream out = new PrintStream (new File(outfilename));

int count = -1;
for (Instance instance : training) {
count++;
if (count % 1000 == 0) {
System.out.println("Processed " + count + " number of documents!");
}
FeatureSequence original_tokens = (FeatureSequence) instance.getData();
String name = instance.getName().toString();

TIntArrayList tokens = new TIntArrayList(original_tokens.getLength());
TIntIntHashMap topicCounts = new TIntIntHashMap ();
TIntArrayList topics = new TIntArrayList(original_tokens.getLength());
TIntArrayList paths = new TIntArrayList(original_tokens.getLength());

String doc = "";
for (int jj = 0; jj < original_tokens.getLength(); jj++) {
String word = (String) original_tokens.getObjectAtPosition(jj);
int token = vocab.indexOf(word);
doc += word + " ";
//if(token != -1) {
// doc += word + " ";
//}
}
System.out.println(name);
System.out.println(doc);

if (!doc.equals("")) {
out.println(doc);
}
}

out.close();
}

public static void writeCorpusMatrix(InstanceList training, String outfilename, String vocabname) throws FileNotFoundException {

// each document is represented in a vector (vocab size), and each entry is the frequency of a word.

ArrayList<String> vocab = loadVocab(vocabname);

PrintStream out = new PrintStream (new File(outfilename));

int count = -1;
for (Instance instance : training) {
count++;
if (count % 1000 == 0) {
System.out.println("Processed " + count + " number of documents!");
}
FeatureSequence original_tokens = (FeatureSequence) instance.getData();
String name = instance.getName().toString();

int[] tokens = new int[vocab.size()];
for (int jj = 0; jj < tokens.length; jj++) {
tokens[jj] = 0;
}

for (int jj = 0; jj < original_tokens.getLength(); jj++) {
String word = (String) original_tokens.getObjectAtPosition(jj);
int index = vocab.indexOf(word);
tokens[index] += 1;
}

String doc = "";
for (int jj = 0; jj < tokens.length; jj++) {
doc += tokens[jj] + "\t";
}

System.out.println(name);
System.out.println(doc);

if (!doc.equals("")) {
out.println(doc);
}
}

out.close();
}

public static ArrayList<String> loadVocab(String vocabFile) {

ArrayList<String> vocab = new ArrayList<String>();

try {
FileInputStream infstream = new FileInputStream(vocabFile);
DataInputStream in = new DataInputStream(infstream);
BufferedReader br = new BufferedReader(new InputStreamReader(in));

String strLine;
//Read File Line By Line
while ((strLine = br.readLine()) != null) {
strLine = strLine.trim();
String[] str = strLine.split("\t");
if (str.length > 1) {
vocab.add(str[1]);
} else {
System.out.println("Error! " + strLine);
}
}
in.close();

} catch (IOException e) {
System.out.println("No vocab file Found!");
}
return vocab;
}


public static void main(String[] args) {
//String input = "input/nyt/nyt-topic-input.mallet";
//String corpus = "../../pylda/variational/data/20_news/doc.dat";
//String vocab = "../../pylda/variational/data/20_news/voc.dat";

String input = "input/synthetic/synthetic-topic-input.mallet";
//String corpus = "../../spectral/input/synthetic-ordered.dat";
//String vocab = "../../spectral/input/synthetic-ordered.voc";
String corpus = "../../spectral/input/synthetic.dat";
String vocab = "../../spectral/input/synthetic.voc";

//String input = "../../itm-evaluation/results/govtrack-109/input/govtrack-109-topic-input.mallet";
//String corpus = "../../pylda/variational/data/20_news/doc.dat";
//String vocab = "../../itm-evaluation/results/govtrack-109/input/govtrack-109.voc";

try{
InstanceList data = InstanceList.load (new File(input));
writeCorpusMatrix(data, corpus, vocab);
} catch (Exception e) {
e.printStackTrace();
}
}

}
82 changes: 82 additions & 0 deletions src/cc/mallet/topics/tree/HIntIntDoubleHashMap.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package cc.mallet.topics.tree;

import java.io.Serializable;

import gnu.trove.TIntDoubleHashMap;
import gnu.trove.TIntIntHashMap;
import gnu.trove.TIntObjectHashMap;

/**
* This class defines a two level hashmap, so a value will be indexed by two keys.
* The value is double, and two keys are both int.
*
* @author Yuening Hu
*/

public class HIntIntDoubleHashMap implements Serializable{
TIntObjectHashMap<TIntDoubleHashMap> data;

public HIntIntDoubleHashMap() {
this.data = new TIntObjectHashMap<TIntDoubleHashMap> ();
}

/**
* If keys do not exist, insert value.
* Else update with the new value.
*/
public void put(int key1, int key2, double value) {
if(! this.data.contains(key1)) {
this.data.put(key1, new TIntDoubleHashMap());
}
TIntDoubleHashMap tmp = this.data.get(key1);
tmp.put(key2, value);
}

/**
* Return the HashMap indexed by the first key.
*/
public TIntDoubleHashMap get(int key1) {
return this.data.get(key1);
}

/**
* Return the value indexed by key1 and key2.
*/
public double get(int key1, int key2) {
if (this.data.contains(key1)) {
TIntDoubleHashMap tmp1 = this.data.get(key1);
if (tmp1.contains(key2)) {
return tmp1.get(key2);
}
}
System.out.println("HIntIntDoubleHashMap: key does not exist!");
return -1;
}

/**
* Return the first key set.
*/
public int[] getKey1Set() {
return this.data.keys();
}

/**
* Check whether key1 is contained in the first key set or not.
*/
public boolean contains(int key1) {
return this.data.contains(key1);
}

/**
* Check whether the key pair (key1, key2) is contained or not.
*/
public boolean contains(int key1, int key2) {
if (this.data.contains(key1)) {
return this.data.get(key1).contains(key2);
} else {
return false;
}
}

}

121 changes: 121 additions & 0 deletions src/cc/mallet/topics/tree/HIntIntIntHashMap.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package cc.mallet.topics.tree;

import java.io.Serializable;

import gnu.trove.TIntDoubleHashMap;
import gnu.trove.TIntIntHashMap;
import gnu.trove.TIntObjectHashMap;

/**
* This class defines a two level hashmap, so a value will be indexed by two keys.
* The value is int, and two keys are both int.
*
* @author Yuening Hu
*/

public class HIntIntIntHashMap implements Serializable{

TIntObjectHashMap<TIntIntHashMap> data;

public HIntIntIntHashMap() {
this.data = new TIntObjectHashMap<TIntIntHashMap> ();
}

/**
* If keys do not exist, insert value.
* Else update with the new value.
*/
public void put(int key1, int key2, int value) {
if(! this.data.contains(key1)) {
this.data.put(key1, new TIntIntHashMap());
}
TIntIntHashMap tmp = this.data.get(key1);
tmp.put(key2, value);
}

/**
* Return the HashMap indexed by the first key.
*/
public TIntIntHashMap get(int key1) {
if(this.contains(key1)) {
return this.data.get(key1);
}
return null;
}

/**
* Return the value indexed by key1 and key2.
*/
public int get(int key1, int key2) {
if (this.contains(key1, key2)) {
return this.data.get(key1).get(key2);
} else {
System.out.println("HIntIntIntHashMap: key does not exist!");
return 0;
}
}

/**
* Return the first key set.
*/
public int[] getKey1Set() {
return this.data.keys();
}

/**
* Check whether key1 is contained in the first key set or not.
*/
public boolean contains(int key1) {
return this.data.contains(key1);
}

/**
* Check whether the key pair (key1, key2) is contained or not.
*/
public boolean contains(int key1, int key2) {
if (this.data.contains(key1)) {
return this.data.get(key1).contains(key2);
} else {
return false;
}
}

/**
* Adjust the value indexed by the key pair (key1, key2) by the specified amount.
*/
public void adjustValue(int key1, int key2, int increment) {
int old = this.get(key1, key2);
this.put(key1, key2, old+increment);
}


/**
* If the key pair (key1, key2) exists, adjust the value by the specified amount,
* Or insert the new value.
*/
public void adjustOrPutValue(int key1, int key2, int increment, int newvalue) {
if (this.contains(key1, key2)) {
int old = this.get(key1, key2);
this.put(key1, key2, old+increment);
} else {
this.put(key1, key2, newvalue);
}
}

/**
* Remove the first key
*/
public void removeKey1(int key1) {
this.data.remove(key1);
}

/**
* Remove the second key
*/
public void removeKey2(int key1, int key2) {
if (this.data.contains(key1)) {
this.data.get(key1).remove(key2);
}
}

}
Loading