koho.cpp  1.1.0
decision_tree.h
Go to the documentation of this file.
1 
22 // Author: AI Werkstatt (TM)
23 // (C) Copyright 2019, AI Werkstatt (TM) www.aiwerkstatt.com. All rights reserved.
24 
25 #ifndef KOHO_DECISION_TREE_H
26 #define KOHO_DECISION_TREE_H
27 
28 #include <fstream>
29 #include <stack>
30 #include <algorithm>
31 
32 #include <cstring> // memset for C-style data from Cython
33 
35 
36 namespace koho {
37 
38  typedef double Features_t; // X provided by Python, needs to be double
39  typedef long Classes_t; // y provided by Python, needs to be long
40  typedef double ClassWeights_t; // class weights provided by Python, needs to be double
41 
42  typedef double Histogram_t; // weighted (number of classes) * number of samples
43 
44  typedef unsigned long SamplesIdx_t; // number of samples
45  typedef unsigned long FeaturesIdx_t; // number of features
46  typedef unsigned long ClassesIdx_t; // number of classes
47  typedef unsigned long OutputsIdx_t; // number of outputs
48  typedef unsigned long NodesIdx_t; // number of nodes
49  typedef unsigned long TreeDepthIdx_t; // maximum tree depth (2 ** 31) - 1
50 
51  const double PRECISION_EQUAL = 1e-7; // float == 0.0 as float <= PRECISION_ZERO
52 
53 // =============================================================================
54 // Tree
55 // =============================================================================
56 
58  class Node {
59  public:
60 
61  NodesIdx_t left_child;
62  NodesIdx_t right_child;
63  FeaturesIdx_t feature;
64  int NA; // NA as part of the split criterion, -1:NA, 0:left, 1:right
65  Features_t threshold;
66  std::vector<std::vector<Histogram_t>> histogram; // weighted number of samples per class per output
67  double impurity; // for inspection (e.g. graphviz visualization)
68  double improvement; // for feature importances
69 
71  Node(NodesIdx_t left_child,
72  NodesIdx_t right_child,
73  FeaturesIdx_t feature,
74  int NA,
75  Features_t threshold,
76  const std::vector<std::vector<Histogram_t>>& histogram,
77  double impurity,
78  double improvement);
79 
81  void serialize(std::ofstream& fout);
83  static Node deserialize(std::ifstream& fin);
84  };
85 
87  class Tree {
88  public:
89 
90  OutputsIdx_t n_outputs;
91  std::vector<ClassesIdx_t> n_classes;
92  ClassesIdx_t n_classes_max; // just for convenience
93  FeaturesIdx_t n_features;
94  TreeDepthIdx_t max_depth;
95  // Nodes
96  NodesIdx_t node_count;
97  std::vector<Node> nodes;
98 
100 
105  Tree(OutputsIdx_t n_outputs,
106  std::vector<ClassesIdx_t> n_classes,
107  FeaturesIdx_t n_features);
108 
110 
115  Tree(OutputsIdx_t n_outputs,
116  ClassesIdx_t* n_classes_ptr,
117  FeaturesIdx_t n_features);
118 
120  Tree() {}
121 
123 
126  NodesIdx_t add_node(TreeDepthIdx_t depth,
127  NodesIdx_t parent_id,
128  bool is_left,
129  FeaturesIdx_t feature,
130  int NA,
131  Features_t threshold,
132  const std::vector<std::vector<Histogram_t>>& histogram,
133  double impurity,
134  double improvement);
135 
137 
148  void predict(Features_t* X,
149  SamplesIdx_t n_samples,
150  double* y_prob);
151 
153 
156  void calculate_feature_importances(double* importances);
157 
159  void serialize(std::ofstream& fout);
161  void deserialize(std::ifstream& fin);
162  };
163 
164 // =============================================================================
165 // Impurity Criterion
166 // =============================================================================
167 
170 
171  protected:
172  OutputsIdx_t n_outputs;
173  ClassesIdx_t* n_classes;
174  ClassesIdx_t n_classes_max;
175  SamplesIdx_t n_samples;
176  ClassWeights_t* class_weight;
177  // Histograms
178  // vectors are created in initialization list
179  // - all samples
180  std::vector<std::vector<Histogram_t>> node_weighted_histogram;
181  std::vector<Histogram_t> node_weighted_n_samples;
182  std::vector<double> node_impurity;
183  // - samples with missing values
184  std::vector<std::vector<Histogram_t>> node_weighted_histogram_NA;
185  std::vector<Histogram_t> node_weighted_n_samples_NA;
186  std::vector<double> node_impurity_NA;
187  // - samples with values
188  std::vector<std::vector<Histogram_t>> node_weighted_histogram_values;
189  std::vector<Histogram_t> node_weighted_n_samples_values;
190  std::vector<double> node_impurity_values;
191  SamplesIdx_t node_pos_NA;
192  // - samples with values smaller than threshold (assigned to left child)
193  std::vector<std::vector<Histogram_t>> node_weighted_histogram_threshold_left;
194  std::vector<Histogram_t> node_weighted_n_samples_threshold_left;
195  std::vector<double> node_impurity_threshold_left;
196  // -- plus missing values (assigned to left child)
198  std::vector<double> node_impurity_threshold_left_NA;
199  // - samples with values greater than threshold (assigned to right child)
200  std::vector<std::vector<Histogram_t>> node_weighted_histogram_threshold_right;
201  std::vector<Histogram_t> node_weighted_n_samples_threshold_right;
202  std::vector<double> node_impurity_threshold_right;
203  // -- plus missing values (assigned to right child)
205  std::vector<double> node_impurity_threshold_right_NA;
206  SamplesIdx_t node_pos_threshold;
207 
208  public:
210 
213  GiniCriterion(OutputsIdx_t n_outputs,
214  ClassesIdx_t* n_classes, // required: 2 <= n_classes[o]
215  ClassesIdx_t n_classes_max,
216  SamplesIdx_t n_samples, // required: 2 <= n_samples
217  ClassWeights_t* class_weight);
218 
220  void calculate_node_histogram(Classes_t* y,
221  std::vector<SamplesIdx_t>& samples,
222  SamplesIdx_t start,
223  SamplesIdx_t end);
224 
226 
227  double calculate_impurity(std::vector<Histogram_t>& histogram);
228 
230 
233  void calculate_node_impurity();
234 
236 
239  void calculate_NA_histogram(Classes_t* y,
240  std::vector<SamplesIdx_t>& samples,
241  SamplesIdx_t pos);
242 
244 
248  void calculate_NA_impurity();
249 
252 
256  double calculate_NA_impurity_improvement();
257 
260 
263  void init_threshold_histograms();
264 
267 
271  void init_threshold_values_histograms();
272 
275 
279  void update_threshold_histograms(Classes_t* y,
280  std::vector<SamplesIdx_t>& samples,
281  SamplesIdx_t new_pos);
282 
284 
287  void calculate_threshold_impurity();
288 
291 
295  void calculate_threshold_NA_impurity();
296 
300 
303  double calculate_threshold_impurity_improvement();
304 
308 
311  double calculate_threshold_values_impurity_improvement();
312 
316 
319  double calculate_threshold_NA_left_impurity_improvement();
320 
324 
327  double calculate_threshold_NA_right_impurity_improvement();
328 
329  std::vector<std::vector<Histogram_t>> get_node_weighted_histogram() {
331  double get_node_impurity() {
332  return accumulate(GiniCriterion::node_impurity.begin(),
333  GiniCriterion::node_impurity.end(), 0.0) /
334  GiniCriterion::n_outputs; // average
335  }
337  return accumulate(GiniCriterion::node_impurity_NA.begin(),
338  GiniCriterion::node_impurity_NA.end(), 0.0) /
339  GiniCriterion::n_outputs; // average
340  }
342  return accumulate(GiniCriterion::node_impurity_values.begin(),
344  GiniCriterion::n_outputs; // average
345  }
347  return accumulate(GiniCriterion::node_impurity_threshold_left.begin(),
349  GiniCriterion::n_outputs; // average
350  }
352  return accumulate(GiniCriterion::node_impurity_threshold_right.begin(),
354  GiniCriterion::n_outputs; // average
355  }
356  };
357 
358 // =============================================================================
359 // Node Splitter
360 // =============================================================================
361 
363  class BestSplitter {
364 
365  protected:
366  FeaturesIdx_t n_features;
367  SamplesIdx_t n_samples;
368  FeaturesIdx_t max_features;
369  unsigned long max_thresholds;
371  // Samples
372  // samples[start, end] is a LUT to the training data X, y
373  // to handle the recursive partitioning and
374  // the sorting of the data efficiently.
375  std::vector<SamplesIdx_t> samples; // vector created in initialization list
376  SamplesIdx_t start;
377  SamplesIdx_t end;
378  public:
379  // Gini Criterion
380  GiniCriterion criterion; // nested object created in initialization list
381 
382  public:
384  BestSplitter(OutputsIdx_t n_outputs,
385  ClassesIdx_t* n_classes, // required: 2 <= n_classes
386  ClassesIdx_t n_classes_max,
387  FeaturesIdx_t n_features, // required: 1 <= n_features
388  SamplesIdx_t n_samples, // required: 2 <= n_samples
389  ClassWeights_t* class_weight,
390  FeaturesIdx_t max_features, // required: 0 < max_features <= n_features
391  unsigned long max_thresholds, // required: 0, 1
392  RandomState const& random_state);
393 
395  void init_node(Classes_t* y,
396  SamplesIdx_t start,
397  SamplesIdx_t end);
398 
400  void split_feature(Features_t* X,
401  Classes_t* y,
402  std::vector<SamplesIdx_t>& s,
403  FeaturesIdx_t feature,
404  int& NA,
405  Features_t& threshold,
406  SamplesIdx_t& pos,
407  double& improvement);
408 
411  void split_feature_extreme_random(Features_t* X,
412  Classes_t* y,
413  std::vector<SamplesIdx_t>& s,
414  FeaturesIdx_t feature,
415  int& NA,
416  Features_t& threshold,
417  SamplesIdx_t& pos,
418  double& improvement);
419 
421 
428  void split_node(Features_t* X,
429  Classes_t* y,
430  FeaturesIdx_t& feature,
431  int& NA,
432  Features_t& threshold,
433  SamplesIdx_t& pos,
434  double& improvement);
435  };
436 
437 // =============================================================================
438 // Tree Builder
439 // =============================================================================
440 
443 
444  protected:
445  TreeDepthIdx_t max_depth;
446  std::string missing_values;
447  // Best Splitter (and Gini Criterion)
448  BestSplitter splitter; // nested object created in initialization list
449 
450  public:
452 
494  DepthFirstTreeBuilder(OutputsIdx_t n_outputs,
495  ClassesIdx_t* n_classes,
496  ClassesIdx_t n_classes_max,
497  FeaturesIdx_t n_features,
498  SamplesIdx_t n_samples,
499  ClassWeights_t* class_weight,
500  TreeDepthIdx_t max_depth,
501  FeaturesIdx_t max_features,
502  unsigned long max_thresholds,
503  std::string missing_values,
504  RandomState const& random_state);
505 
507 
513  // Using 1d array addressing for X and y to support efficient Cython bindings to Python using memory views.
514  void build(Tree& tree,
515  Features_t* X,
516  Classes_t* y,
517  SamplesIdx_t n_samples);
518  };
519 
520 // =============================================================================
521 // Decision Tree Classifier
522 // =============================================================================
523 
526 
527  protected:
528  OutputsIdx_t n_outputs;
529  std::vector<std::vector<std::string>> classes;
530  std::vector<ClassesIdx_t> n_classes;
531  ClassesIdx_t n_classes_max; // just for convenience
532  std::vector<std::string> features;
533  FeaturesIdx_t n_features;
534 
535  // Hyperparameters
536  std::string class_balance;
537  TreeDepthIdx_t max_depth;
538  FeaturesIdx_t max_features;
539  unsigned long max_thresholds;
540  std::string missing_values;
541 
542  // Random Number Generator
544 
545  // Model
546  Tree tree_; // underlying estimator
547 
548  public:
550 
601  DecisionTreeClassifier(std::vector<std::vector<std::string>> const& classes,
602  std::vector<std::string> const& features,
603  std::string const& class_balance = "balanced",
604  TreeDepthIdx_t max_depth = 0,
605  FeaturesIdx_t max_features = 0,
606  unsigned long max_thresholds = 0,
607  std::string const& missing_values = "None",
608  long random_state_seed = 0);
609 
611 
615  void fit(std::vector<Features_t> & X,
616  std::vector<Classes_t> & y);
617 
619 
627  void predict_proba(Features_t* X,
628  SamplesIdx_t n_samples,
629  double* y_prob);
630 
632 
638  void predict(Features_t* X,
639  SamplesIdx_t n_samples,
640  Classes_t* y);
641 
643 
650  double score(Features_t* X,
651  Classes_t* y,
652  SamplesIdx_t n_samples);
653 
655 
661  void calculate_feature_importances(double* importances);
662 
664 
686  void export_graphviz(std::string const& filename, bool rotate=false);
687 
689  std::string export_text();
690 
692 
695  void export_serialize(std::string const& filename);
696 
698 
701  static DecisionTreeClassifier import_deserialize(std::string const& filename);
702 
704  void serialize(std::ofstream& fout);
706  static DecisionTreeClassifier deserialize(std::ifstream& fin);
707 
708  };
709 
710 } // namespace koho
711 
712 #endif
std::string missing_values
Definition: decision_tree.h:540
unsigned long NodesIdx_t
Definition: decision_tree.h:48
std::vector< double > node_impurity_threshold_right_NA
Definition: decision_tree.h:205
FeaturesIdx_t max_features
Definition: decision_tree.h:368
void serialize(std::ofstream &fout)
Serialize.
Definition: decision_tree.cpp:50
double ClassWeights_t
Definition: decision_tree.h:40
std::vector< Node > nodes
Definition: decision_tree.h:97
std::vector< std::vector< Histogram_t > > node_weighted_histogram_values
Definition: decision_tree.h:188
Definition: decision_forest.cpp:20
std::vector< Histogram_t > node_weighted_n_samples_threshold_right
Definition: decision_tree.h:201
unsigned long ClassesIdx_t
Definition: decision_tree.h:46
std::string class_balance
Definition: decision_tree.h:536
static Node deserialize(std::ifstream &fin)
Deserialize.
Definition: decision_tree.cpp:74
unsigned long TreeDepthIdx_t
Definition: decision_tree.h:49
std::vector< double > node_impurity_threshold_left_NA
Definition: decision_tree.h:198
FeaturesIdx_t n_features
Definition: decision_tree.h:93
std::vector< Histogram_t > node_weighted_n_samples_threshold_right_NA
Definition: decision_tree.h:204
double get_node_impurity_NA()
Definition: decision_tree.h:336
Tree()
Create a new tree without nodes for Python binding using pickle.
Definition: decision_tree.h:120
SamplesIdx_t node_pos_threshold
Definition: decision_tree.h:206
Splitter to find the best split for a node.
Definition: decision_tree.h:363
unsigned long FeaturesIdx_t
Definition: decision_tree.h:45
NodesIdx_t node_count
Definition: decision_tree.h:96
OutputsIdx_t n_outputs
Definition: decision_tree.h:90
std::vector< double > node_impurity_NA
Definition: decision_tree.h:186
Tree tree_
Definition: decision_tree.h:546
OutputsIdx_t n_outputs
Definition: decision_tree.h:172
std::vector< double > node_impurity_threshold_right
Definition: decision_tree.h:202
int NA
Definition: decision_tree.h:64
SamplesIdx_t n_samples
Definition: decision_tree.h:367
Random Number Generator module.
A random number generator.
Definition: random_number_generator.h:20
Node of a binary tree.
Definition: decision_tree.h:58
std::vector< std::vector< Histogram_t > > node_weighted_histogram_threshold_right
Definition: decision_tree.h:200
unsigned long OutputsIdx_t
Definition: decision_tree.h:47
ClassWeights_t * class_weight
Definition: decision_tree.h:176
RandomState random_state
Definition: decision_tree.h:543
SamplesIdx_t end
Definition: decision_tree.h:377
double impurity
Definition: decision_tree.h:67
std::vector< std::vector< Histogram_t > > node_weighted_histogram
Definition: decision_tree.h:180
double get_node_impurity_threshold_right()
Definition: decision_tree.h:351
double improvement
Definition: decision_tree.h:68
std::vector< std::vector< Histogram_t > > histogram
Definition: decision_tree.h:66
std::vector< SamplesIdx_t > samples
Definition: decision_tree.h:375
unsigned long max_thresholds
Definition: decision_tree.h:539
std::vector< std::vector< std::string > > classes
Definition: decision_tree.h:529
BestSplitter splitter
Definition: decision_tree.h:448
std::vector< double > node_impurity
Definition: decision_tree.h:182
OutputsIdx_t n_outputs
Definition: decision_tree.h:528
FeaturesIdx_t n_features
Definition: decision_tree.h:366
std::vector< std::string > features
Definition: decision_tree.h:532
double Histogram_t
Definition: decision_tree.h:42
double get_node_impurity_values()
Definition: decision_tree.h:341
TreeDepthIdx_t max_depth
Definition: decision_tree.h:445
double get_node_impurity_threshold_left()
Definition: decision_tree.h:346
FeaturesIdx_t n_features
Definition: decision_tree.h:533
long Classes_t
Definition: decision_tree.h:39
std::vector< Histogram_t > node_weighted_n_samples_threshold_left
Definition: decision_tree.h:194
std::vector< ClassesIdx_t > n_classes
Definition: decision_tree.h:530
ClassesIdx_t n_classes_max
Definition: decision_tree.h:174
Node(NodesIdx_t left_child, NodesIdx_t right_child, FeaturesIdx_t feature, int NA, Features_t threshold, const std::vector< std::vector< Histogram_t >> &histogram, double impurity, double improvement)
Create a new node.
Definition: decision_tree.cpp:31
Build a binary decision tree in depth-first order.
Definition: decision_tree.h:442
std::vector< Histogram_t > node_weighted_n_samples
Definition: decision_tree.h:181
SamplesIdx_t n_samples
Definition: decision_tree.h:175
std::vector< std::vector< Histogram_t > > node_weighted_histogram_threshold_left
Definition: decision_tree.h:193
NodesIdx_t right_child
Definition: decision_tree.h:62
double get_node_impurity()
Definition: decision_tree.h:331
std::vector< Histogram_t > node_weighted_n_samples_threshold_left_NA
Definition: decision_tree.h:197
SamplesIdx_t node_pos_NA
Definition: decision_tree.h:191
std::vector< double > node_impurity_threshold_left
Definition: decision_tree.h:195
std::vector< std::vector< Histogram_t > > get_node_weighted_histogram()
Definition: decision_tree.h:329
ClassesIdx_t * n_classes
Definition: decision_tree.h:173
ClassesIdx_t n_classes_max
Definition: decision_tree.h:531
RandomState random_state
Definition: decision_tree.h:370
FeaturesIdx_t feature
Definition: decision_tree.h:63
std::vector< Histogram_t > node_weighted_n_samples_values
Definition: decision_tree.h:189
SamplesIdx_t start
Definition: decision_tree.h:376
TreeDepthIdx_t max_depth
Definition: decision_tree.h:94
Gini Index impurity criterion.
Definition: decision_tree.h:169
GiniCriterion criterion
Definition: decision_tree.h:380
NodesIdx_t left_child
Definition: decision_tree.h:61
const double PRECISION_EQUAL
Definition: decision_tree.h:51
std::vector< Histogram_t > node_weighted_n_samples_NA
Definition: decision_tree.h:185
Features_t threshold
Definition: decision_tree.h:65
std::vector< ClassesIdx_t > n_classes
Definition: decision_tree.h:91
A decision tree classifier.
Definition: decision_tree.h:525
std::vector< std::vector< Histogram_t > > node_weighted_histogram_NA
Definition: decision_tree.h:184
ClassesIdx_t n_classes_max
Definition: decision_tree.h:92
FeaturesIdx_t max_features
Definition: decision_tree.h:538
unsigned long max_thresholds
Definition: decision_tree.h:369
TreeDepthIdx_t max_depth
Definition: decision_tree.h:537
double Features_t
Definition: decision_tree.h:38
std::vector< double > node_impurity_values
Definition: decision_tree.h:190
Binary tree structure build up of nodes.
Definition: decision_tree.h:87
std::string missing_values
Definition: decision_tree.h:446
unsigned long SamplesIdx_t
Definition: decision_tree.h:44