I am trying to create a text classifier in JAVA with Weka. I have read several lessons and am trying to create my own classifier.
I have the following categories:
computer,sport,unknown
and the following already prepared data
cs belongs to computer
java -> computer
soccer -> sport
snowboard -> sport
So, for example, if the user wants to classify the word java, he must return the computer of the category (no doubt, Java exists only in this category!).
It compiles, but generates a strange output.
Conclusion:
But the first text to classify is java, and it only appears in the computer category, so it should be
[1.0 0.0 0.0]
and for another, it should not be found at all, so it should be classified as unknown
[0.0 0.0 1.0].
Here is the code:
import java.io.FileNotFoundException;
import java.io.Serializable;
import java.util.Arrays;
import weka.classifiers.Classifier;
import weka.classifiers.bayes.NaiveBayesMultinomialUpdateable;
import weka.core.Attribute;
import weka.core.FastVector;
import weka.core.Instance;
import weka.core.Instances;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.StringToWordVector;
public class TextClassifier implements Serializable {
private static final long serialVersionUID = -1397598966481635120L;
public static void main(String[] args) {
try {
TextClassifier cl = new TextClassifier(new NaiveBayesMultinomialUpdateable());
cl.addCategory("computer");
cl.addCategory("sport");
cl.addCategory("unknown");
cl.setupAfterCategorysAdded();
cl.addData("cs", "computer");
cl.addData("java", "computer");
cl.addData("soccer", "sport");
cl.addData("snowboard", "sport");
double[] result = cl.classifyMessage("java");
System.out.println("====== RESULT ====== \tCLASSIFIED AS:\t" + Arrays.toString(result));
result = cl.classifyMessage("asdasdasd");
System.out.println("====== RESULT ======\tCLASSIFIED AS:\t" + Arrays.toString(result));
} catch (Exception e) {
e.printStackTrace();
}
}
private Instances trainingData;
private StringToWordVector filter;
private Classifier classifier;
private boolean upToDate;
private FastVector classValues;
private FastVector attributes;
private boolean setup;
private Instances filteredData;
public TextClassifier(Classifier classifier) throws FileNotFoundException {
this(classifier, 10);
}
public TextClassifier(Classifier classifier, int startSize) throws FileNotFoundException {
this.filter = new StringToWordVector();
this.classifier = classifier;
this.attributes = new FastVector(2);
this.attributes.addElement(new Attribute("text", (FastVector) null));
this.classValues = new FastVector(startSize);
this.setup = false;
}
public void addCategory(String category) {
category = category.toLowerCase();
int capacity = classValues.capacity();
if (classValues.size() > (capacity - 5)) {
classValues.setCapacity(capacity * 2);
}
classValues.addElement(category);
}
public void addData(String message, String classValue) throws IllegalStateException {
if (!setup) {
throw new IllegalStateException("Must use setup first");
}
message = message.toLowerCase();
classValue = classValue.toLowerCase();
Instance instance = makeInstance(message, trainingData);
instance.setClassValue(classValue);
trainingData.add(instance);
upToDate = false;
}
private void buildIfNeeded() throws Exception {
if (!upToDate) {
filter.setInputFormat(trainingData);
filteredData = Filter.useFilter(trainingData, filter);
classifier.buildClassifier(filteredData);
upToDate = true;
}
}
public double[] classifyMessage(String message) throws Exception {
message = message.toLowerCase();
if (!setup) {
throw new Exception("Must use setup first");
}
if (trainingData.numInstances() == 0) {
throw new Exception("No classifier available.");
}
buildIfNeeded();
Instances testset = trainingData.stringFreeStructure();
Instance testInstance = makeInstance(message, testset);
filter.input(testInstance);
Instance filteredInstance = filter.output();
return classifier.distributionForInstance(filteredInstance);
}
private Instance makeInstance(String text, Instances data) {
Instance instance = new Instance(2);
Attribute messageAtt = data.attribute("text");
instance.setValue(messageAtt, messageAtt.addStringValue(text));
instance.setDataset(data);
return instance;
}
public void setupAfterCategorysAdded() {
attributes.addElement(new Attribute("class", classValues));
trainingData = new Instances("MessageClassificationProblem", attributes, 100);
trainingData.setClassIndex(trainingData.numAttributes() - 1);
setup = true;
}
}
Btw, found a nice page:
http://www.hakank.org/weka/TextClassifierApplet3.html