The basic summary is: A decision tree has each branch node represent a choice between a number of alternatives, and each leaf node represents a decision or (classification). This makes decision tree another supervised machine learning algorithm useful in classifying information.
The main problem it overcome in defining a decision tree is how to identify the best split of the data points. To find this you need to go through all the sets of data, and identify which will give you the best split (gain) and start from there.
For some more technical information about this split / gain:
http://en.wikipedia.org/wiki/Information_gain_in_decision_trees
The biggest advantages I see in using a decision tree are:
It's easy it is to interpret and visualise.
Data didn't need to be normalised or something between -1 and 1.
Decision trees however cant be effectively used on large datasets with a large number of results.
As with my previous Classifiers post, I ended up using SQLite in memory db as it's such a pleasure to use. I did venture into using LambdaJ, but it actually ended up being such an ugly line of code I left it and simply did it manually. I have not looked at the Java 8 implementation of lambdas yet, I just hope it doesn't end in code like (with a whole bunch of static imports):
falseList.add(filter(not(having(on(List.class).get(col).toString(), equalTo((String) value))), asList(rows)));
So my java implementation of the PCI decision tree ended up looking like (All code in Github) :
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package net.briandupreez.pci.chapter7; | |
import org.javatuples.Pair; | |
import java.sql.*; | |
import java.util.*; | |
/** | |
* TreePredict. | |
* User: bdupreez | |
* Date: 2013/08/09 | |
* Time: 7:15 AM | |
*/ | |
public class TreePredict { | |
protected Connection connection; | |
/** | |
* create some data, using in memory db | |
*/ | |
public TreePredict() { | |
try { | |
Class.forName("org.sqlite.JDBC"); | |
} catch (final ClassNotFoundException e) { | |
e.printStackTrace(); | |
} | |
try { | |
connection = DriverManager.getConnection("jdbc:sqlite::memory:"); | |
final Statement statement = connection.createStatement(); | |
statement.setQueryTimeout(30); // set timeout to 30 sec. | |
statement.executeUpdate("DROP TABLE IF EXISTS userData"); | |
statement.executeUpdate("CREATE TABLE IF NOT EXISTS userData(referrer,location,readFAQ,pagesViewed, serviceChosen)"); | |
statement.execute("INSERT INTO userData VALUES ('slashdot', 'USA', 'yes', 18, 'None')"); | |
statement.execute("INSERT INTO userData VALUES ('google', 'France', 'yes', 23, 'Premium')"); | |
statement.execute("INSERT INTO userData VALUES ('digg', 'USA', 'yes', 24, 'Basic')"); | |
statement.execute("INSERT INTO userData VALUES ('kiwitobes', 'France', 'yes', 23, 'Basic')"); | |
statement.execute("INSERT INTO userData VALUES ('google', 'UK', 'no', 21, 'Premium')"); | |
statement.execute("INSERT INTO userData VALUES ('(direct)', 'New Zealand', 'no', 12, 'None')"); | |
statement.execute("INSERT INTO userData VALUES ('(direct)', 'UK', 'no', 21, 'Basic')"); | |
statement.execute("INSERT INTO userData VALUES ('google', 'USA', 'no', 24, 'Premium')"); | |
statement.execute("INSERT INTO userData VALUES ('slashdot', 'France', 'yes', 19, 'None')"); | |
statement.execute("INSERT INTO userData VALUES ('digg', 'USA', 'no', 18, 'None')"); | |
statement.execute("INSERT INTO userData VALUES ('google', 'UK', 'no', 18, 'None')"); | |
statement.execute("INSERT INTO userData VALUES ('kiwitobes', 'UK', 'no', 19, 'None')"); | |
statement.execute("INSERT INTO userData VALUES ('digg', 'New Zealand', 'yes', 12, 'Basic')"); | |
statement.execute("INSERT INTO userData VALUES ('slashdot', 'UK', 'no', 21, 'None')"); | |
statement.execute("INSERT INTO userData VALUES ('google', 'UK', 'yes', 18, 'Basic')"); | |
statement.execute("INSERT INTO userData VALUES ('kiwitobes', 'France', 'yes', 19, 'Basic')"); | |
} catch (final SQLException e) { | |
e.printStackTrace(); | |
} | |
} | |
/** | |
* Divide into true and false | |
* | |
* @param rows the data | |
* @param col one which column | |
* @param value with value | |
* @return tuple, true and false | |
*/ | |
@SuppressWarnings("unchecked") | |
public Pair<Object[][], Object[][]> divideSet(final Object[][] rows, final int col, final Object value) { | |
final ArrayList<Object[]> trueList = new ArrayList<>(); | |
final ArrayList<Object[]> falseList = new ArrayList<>(); | |
//So wished we had lambdas in java | |
/* | |
when trying this with LambdaJ.... lol... ugliest line of java ever... and still didn't work | |
falseList.add(filter(not(having(on(List.class).get(col).toString(), equalTo((String) value))), asList(rows))); | |
*/ | |
for (final Object[] row : rows) { | |
if (value instanceof Integer) { | |
if (((Integer) row[col]) >= ((Integer) value)) { | |
trueList.add(row); | |
} else { | |
falseList.add(row); | |
} | |
} else { | |
if (row[col].equals(value)) { | |
trueList.add(row); | |
} else { | |
falseList.add(row); | |
} | |
} | |
} | |
Object[][] trueMatrix = new Object[trueList.size()][]; | |
trueMatrix = trueList.toArray(trueMatrix); | |
Object[][] falseMatrix = new Object[falseList.size()][]; | |
falseMatrix = falseList.toArray(falseMatrix); | |
return new Pair<>(trueMatrix, falseMatrix); | |
} | |
public Map<String, Double> uniqueCounts(final Object[][] rows) { | |
final Map<String, Double> counts = new HashMap<>(); | |
for (final Object[] row : rows) { | |
final String resultVal = (String) row[row.length - 1]; | |
if (counts.containsKey(resultVal)) { | |
counts.put(resultVal, counts.get(resultVal) + 1); | |
} else { | |
counts.put(resultVal, 1.0); | |
} | |
} | |
return counts; | |
} | |
/** | |
* Probability that a randomly placed item will | |
* be in the wrong category | |
* | |
* @return probability | |
*/ | |
public double calculateGiniImpurity(final Object[][] rows) { | |
final double total = total(); | |
final Map<String, Double> countsMap = uniqueCounts(rows); | |
double prob = 0.0; | |
for (Map.Entry<String, Double> entry : countsMap.entrySet()) { | |
double p1 = entry.getValue() / total; | |
for (Map.Entry<String, Double> entry2 : countsMap.entrySet()) { | |
if (entry.getKey().equals(entry2.getKey())) { | |
continue; | |
} | |
double p2 = (entry2.getValue()) / total; | |
prob += p1 * p2; | |
} | |
} | |
return prob; | |
} | |
/** | |
* Entropy - the amount of disorder in a set | |
* Harsher of mixed sets. | |
* Entropy is the sum of p(x)log(p(x)) across all | |
* the different possible results | |
* | |
* @return | |
*/ | |
public double calculateEntropy(final Object[][] rows) { | |
final Map<String, Double> countsMap = uniqueCounts(rows); | |
double ent = 0.0; | |
for (final Map.Entry<String, Double> entry : countsMap.entrySet()) { | |
double p = entry.getValue() / rows.length; | |
ent += -p * (Math.log(p) / Math.log(2)); | |
} | |
return ent; | |
} | |
/** | |
* Add it up. | |
* | |
* @return total | |
*/ | |
public double total() { | |
double total = 0; | |
try { | |
final Statement statement = connection.createStatement(); | |
statement.setFetchSize(1); | |
final ResultSet rs = statement.executeQuery("SELECT count(serviceChosen)FROM userData"); | |
total = (double) rs.getInt("count(serviceChosen)"); | |
rs.close(); | |
} catch (final SQLException e) { | |
e.printStackTrace(); | |
} | |
return total; | |
} | |
/** | |
* Recursive scoring and building of the tree | |
* | |
* @param rows the data | |
* @return top node | |
*/ | |
@SuppressWarnings("unchecked") | |
public DecisionNode buildTree(final Object[][] rows) { | |
double currentScore = calculateEntropy(rows); | |
double bestGain = 0.0; | |
Pair<Integer, Object> bestCriteria = null; | |
Pair<Object[][], Object[][]> bestSets = null; | |
for (int col = 0; col < rows[0].length - 1; col++) { | |
//generate a list of different values in the column | |
Map<String, Integer> columnValues = new TreeMap<>(); | |
for (final Object[] row : rows) { | |
columnValues.put(row[col].toString(), 1); | |
} | |
//then try divideSet for each value | |
for (final Object columnValue : columnValues.keySet()) { | |
final Pair<Object[][], Object[][]> pair = divideSet(rows, col, columnValue); | |
//check the gain | |
double p = ((double) pair.getValue0().length) / rows.length; | |
double gain = currentScore - p * calculateEntropy(pair.getValue0()) - ((1 - p) * calculateEntropy(pair.getValue1())); | |
if (gain > bestGain && pair.getValue0().length > 0 && pair.getValue1().length > 0) { | |
bestGain = gain; | |
bestCriteria = new Pair(col, columnValue); | |
bestSets = pair; | |
} | |
} | |
} | |
if (bestGain > 0 && bestSets != null) { | |
final DecisionNode trueBranch = buildTree(bestSets.getValue0()); | |
final DecisionNode falseBranch = buildTree(bestSets.getValue1()); | |
final DecisionNode returnNode = new DecisionNode(); | |
returnNode.setCol(bestCriteria.getValue0()); | |
returnNode.setValue(bestCriteria.getValue1()); | |
returnNode.setTrueBranch(trueBranch); | |
returnNode.setFalseBranch(falseBranch); | |
return returnNode; | |
} else { | |
final DecisionNode returnNode = new DecisionNode(); | |
returnNode.setResults(uniqueCounts(rows)); | |
return returnNode; | |
} | |
} | |
/** | |
* Print to console. recursive | |
* | |
* @param node the node | |
* @param indent the indent | |
*/ | |
public void printTree(final DecisionNode node, final String indent) { | |
if (node.getResults() != null) { | |
System.out.println(node.getResults()); | |
} else { | |
System.out.println(node.getCol() + ":" + node.getValue().toString() + "? "); | |
System.out.print(indent + "T->"); | |
printTree(node.getTrueBranch(), indent + "\t"); | |
System.out.print(indent + "F->"); | |
printTree(node.getFalseBranch(), indent + "\t"); | |
} | |
} | |
/** | |
* get teh data a matrix | |
* | |
* @return matrix | |
*/ | |
public Object[][] retrieveDataAsMatrix() { | |
final Object[][] arrays = new Object[16][5]; | |
try { | |
final Statement statement = connection.createStatement(); | |
final ResultSet rs = statement.executeQuery("SELECT * FROM userData"); | |
int i = 0; | |
while (rs.next()) { | |
arrays[i][0] = rs.getString(1); | |
arrays[i][1] = rs.getString(2); | |
arrays[i][2] = rs.getString(3); | |
arrays[i][3] = rs.getInt(4); | |
arrays[i][4] = rs.getString(5); | |
i++; | |
} | |
rs.close(); | |
} catch (final SQLException e) { | |
e.printStackTrace(); | |
} | |
return arrays; | |
} | |
/** | |
* Classify new data. | |
* | |
* @param row the data (minus the result) | |
* @param tree the trained tree | |
* @return the result based on teh tree | |
*/ | |
public Map<String, Double> classify(final Object[] row, final DecisionNode tree) { | |
if (tree.getResults() != null) { | |
return tree.getResults(); | |
} else { | |
final Object value = row[tree.getCol()]; | |
DecisionNode branch = null; | |
if (value instanceof Integer) { | |
if ((Integer) value >= (Integer) tree.getValue()) { | |
branch = tree.getTrueBranch(); | |
} else { | |
branch = tree.getFalseBranch(); | |
} | |
} else { | |
if (value.toString().equals(tree.getValue().toString())) { | |
branch = tree.getTrueBranch(); | |
} else { | |
branch = tree.getFalseBranch(); | |
} | |
} | |
return classify(row, branch); | |
} | |
} | |
/** | |
* Classify with missing data | |
* | |
* @param observation the input | |
* @param tree the tree | |
* @return a map with all probabilities from missing data | |
*/ | |
public Map<String, Double> missingDataClassify(final Object[] observation, final DecisionNode tree) { | |
if (tree.getResults() != null) { | |
return tree.getResults(); | |
} else { | |
final Object value = observation[tree.getCol()]; | |
if (value == null) { | |
final Map<String, Double> trueResult = missingDataClassify(observation, tree.getTrueBranch()); | |
final Map<String, Double> falseResult = missingDataClassify(observation, tree.getFalseBranch()); | |
double trueCount = 0.0; | |
for (final Double tv : trueResult.values()) { | |
trueCount += tv; | |
} | |
double falseCount = 0.0; | |
for (final Double fv : falseResult.values()) { | |
falseCount += fv; | |
} | |
double trueWeight = trueCount / (trueCount + falseCount); | |
double falseWeight = falseCount / (trueCount + falseCount); | |
final Map<String, Double> reslt = new HashMap<>(); | |
for (Map.Entry<String, Double> trueEntry : trueResult.entrySet()) { | |
reslt.put(trueEntry.getKey(), trueEntry.getValue() * trueWeight); | |
} | |
for (Map.Entry<String, Double> falseEntry : falseResult.entrySet()) { | |
reslt.put(falseEntry.getKey(), falseEntry.getValue() * falseWeight); | |
} | |
return reslt; | |
} else { | |
DecisionNode branch = null; | |
if (value instanceof Integer) { | |
if ((Integer) value >= (Integer) tree.getValue()) { | |
branch = tree.getTrueBranch(); | |
} else { | |
branch = tree.getFalseBranch(); | |
} | |
} else { | |
if (value.toString().equals(tree.getValue().toString())) { | |
branch = tree.getTrueBranch(); | |
} else { | |
branch = tree.getFalseBranch(); | |
} | |
} | |
return missingDataClassify(observation, branch); | |
} | |
} | |
} | |
/** | |
* Trim down nodes with delta less than min gain. | |
* recursive | |
* | |
* @param tree the node | |
* @param minGain minimum gain | |
*/ | |
public void prune(final DecisionNode tree, final double minGain) { | |
//if branches aren't leaves remove them. | |
if (tree.getTrueBranch().getResults() == null) { | |
prune(tree.getTrueBranch(), minGain); | |
} | |
if (tree.getFalseBranch().getResults() == null) { | |
prune(tree.getFalseBranch(), minGain); | |
} | |
//id both the branches are leaves, can we merge them | |
if (tree.getTrueBranch().getResults() != null && tree.getFalseBranch().getResults() != null) { | |
final ArrayList<Object[]> trueList = new ArrayList<>(); | |
for (final Map.Entry<String, Double> entry : tree.getTrueBranch().getResults().entrySet()) { | |
for (int i = 0; i < entry.getValue(); i++) { | |
trueList.add(new Object[]{entry.getKey()}); | |
} | |
} | |
Object[][] trueMatrix = new Object[trueList.size()][]; | |
trueMatrix = trueList.toArray(trueMatrix); | |
final ArrayList<Object[]> falseList = new ArrayList<>(); | |
for (final Map.Entry<String, Double> entry : tree.getFalseBranch().getResults().entrySet()) { | |
for (int i = 0; i < entry.getValue(); i++) { | |
falseList.add(new Object[]{entry.getKey()}); | |
} | |
} | |
Object[][] falseMatrix = new Object[falseList.size()][]; | |
falseMatrix = falseList.toArray(falseMatrix); | |
final Object[][] combined = new Object[falseList.size() + trueList.size()][]; | |
System.arraycopy(trueMatrix, 0, combined, 0, trueList.size()); | |
System.arraycopy(falseMatrix, 0, combined, trueList.size(), falseList.size()); | |
double delta = calculateEntropy(combined) - (calculateEntropy(trueMatrix) + calculateEntropy(falseMatrix)) / 2; | |
if (delta < minGain) { | |
tree.setFalseBranch(null); | |
tree.setTrueBranch(null); | |
tree.setResults(uniqueCounts(combined)); | |
} | |
} | |
} | |
} |
This story is very important.
ReplyDeleteĐặt vé máy bay tại Aivivu, tham khảo
ReplyDeleteVe may bay di My
vé máy bay từ mỹ về vn
thời gian bay từ nhật về vn
vé máy bay từ đức về việt nam
đăng ký bay từ canada về Việt Nam
dat ve may bay tu han quoc ve viet nam
chuyến bay chuyên gia về việt nam
i am glad to discover this page : i have to thank you for the time i spent on this especially great reading !! i really liked each part and also bookmarked you for new information on your site.
ReplyDeletemachine learning institute in delhi