koho.cpp  1.0.0
decision_tree.h
Go to the documentation of this file.
1 
22 // Author: AI Werkstatt (TM)
23 // (C) Copyright 2019, AI Werkstatt (TM) www.aiwerkstatt.com. All rights reserved.
24 
25 #ifndef KOHO_DECISION_TREE_H
26 #define KOHO_DECISION_TREE_H
27 
28 #include <fstream>
29 #include <stack>
30 #include <algorithm>
31 
32 #include <cstring> // memset for C-style data from Cython
33 
35 
36 namespace koho {
37 
38  typedef double Features_t; // X provided by Python, needs to be double
39  typedef long Classes_t; // y provided by Python, needs to be long
40  typedef double ClassWeights_t; // class weights provided by Python, needs to be double
41 
42  typedef double Histogram_t; // weighted (number of classes) * number of samples
43 
44  typedef unsigned long SamplesIdx_t; // number of samples
45  typedef unsigned long FeaturesIdx_t; // number of features
46  typedef unsigned long ClassesIdx_t; // number of classes
47  typedef unsigned long NodesIdx_t; // number of nodes
48  typedef unsigned long TreeDepthIdx_t; // maximum tree depth (2 ** 31) - 1
49 
50  const double PRECISION_EQUAL = 1e-7; // float == 0.0 as float <= PRECISION_ZERO
51 
52 // =============================================================================
53 // Tree
54 // =============================================================================
55 
57  class Node {
58  public:
59 
60  NodesIdx_t left_child;
61  NodesIdx_t right_child;
62  FeaturesIdx_t feature;
63  int NA; // NA as part of the split criterion, -1:NA, 0:left, 1:right
64  Features_t threshold;
65  std::vector<Histogram_t> histogram; // weighted number of samples per class
66  double impurity; // for inspection (e.g. graphviz visualization)
67  double improvement; // for feature importances
68 
70  Node(NodesIdx_t left_child,
71  NodesIdx_t right_child,
72  FeaturesIdx_t feature,
73  int NA,
74  Features_t threshold,
75  const std::vector<Histogram_t>& histogram,
76  double impurity,
77  double improvement);
78 
80  void serialize(std::ofstream& fout);
82  static Node deserialize(std::ifstream& fin);
83  };
84 
86  class Tree {
87  public:
88 
89  ClassesIdx_t n_classes;
90  FeaturesIdx_t n_features;
91  TreeDepthIdx_t max_depth;
92  // Nodes
93  NodesIdx_t node_count;
94  std::vector<Node> nodes;
95 
97 
101  Tree(ClassesIdx_t n_classes,
102  FeaturesIdx_t n_features);
103 
105 
108  NodesIdx_t add_node(TreeDepthIdx_t depth,
109  NodesIdx_t parent_id,
110  bool is_left,
111  FeaturesIdx_t feature,
112  int NA,
113  Features_t threshold,
114  const std::vector<Histogram_t>& histogram,
115  double impurity,
116  double improvement);
117 
119 
124  // Using 1d array addressing for X and y_prob
125  // to support efficient Cython bindings to Python using memory views.
126  void predict(Features_t* X,
127  SamplesIdx_t n_samples,
128  double* y_prob);
129 
131 
134  void calculate_feature_importances(double* importances);
135 
137  void serialize(std::ofstream& fout);
139  void deserialize(std::ifstream& fin);
140  };
141 
142 // =============================================================================
143 // Impurity Criterion
144 // =============================================================================
145 
148 
149  protected:
150  ClassesIdx_t n_classes;
151  SamplesIdx_t n_samples;
152  ClassWeights_t* class_weight;
153  // Histograms
154  // vectors are created in initialization list
155  // - all samples
156  std::vector<Histogram_t> node_weighted_histogram;
159  // - samples with missing values
160  std::vector<Histogram_t> node_weighted_histogram_NA;
163  // - samples with values
164  std::vector<Histogram_t> node_weighted_histogram_values;
167  SamplesIdx_t node_pos_NA;
168  // - samples with values smaller than threshold (assigned to left child)
169  std::vector<Histogram_t> node_weighted_histogram_threshold_left;
172  // -- plus missing values (assigned to left child)
175  // - samples with values greater than threshold (assigned to right child)
176  std::vector<Histogram_t> node_weighted_histogram_threshold_right;
179  // -- plus missing values (assigned to right child)
182  SamplesIdx_t node_pos_threshold;
183 
184  public:
186 
189  GiniCriterion(ClassesIdx_t n_classes, // required: 2 <= n_classes
190  SamplesIdx_t n_samples, // required: 2 <= n_samples
191  ClassWeights_t* class_weight);
192 
194  void calculate_node_histogram(Classes_t* y,
195  std::vector<SamplesIdx_t>& samples,
196  SamplesIdx_t start,
197  SamplesIdx_t end);
198 
200 
201  double calculate_impurity(std::vector<Histogram_t>& histogram);
202 
204 
207  void calculate_node_impurity();
208 
210 
213  void calculate_NA_histogram(Classes_t* y,
214  std::vector<SamplesIdx_t>& samples,
215  SamplesIdx_t pos);
216 
218 
222  void calculate_NA_impurity();
223 
226 
230  double calculate_NA_impurity_improvement();
231 
234 
237  void init_threshold_histograms();
238 
241 
245  void init_threshold_values_histograms();
246 
249 
253  void update_threshold_histograms(Classes_t* y,
254  std::vector<SamplesIdx_t>& samples,
255  SamplesIdx_t new_pos);
256 
258 
261  void calculate_threshold_impurity();
262 
265 
269  void calculate_threshold_NA_impurity();
270 
274 
277  double calculate_threshold_impurity_improvement();
278 
282 
285  double calculate_threshold_values_impurity_improvement();
286 
290 
293  double calculate_threshold_NA_left_impurity_improvement();
294 
298 
301  double calculate_threshold_NA_right_impurity_improvement();
302 
303  std::vector<Histogram_t> get_node_weighted_histogram() {
305  double get_node_impurity() {
315  };
316 
317 // =============================================================================
318 // Node Splitter
319 // =============================================================================
320 
322  class BestSplitter {
323 
324  protected:
325  FeaturesIdx_t n_features;
326  SamplesIdx_t n_samples;
327  FeaturesIdx_t max_features;
328  unsigned long max_thresholds;
330  // Samples
331  // samples[start, end] is a LUT to the training data X, y
332  // to handle the recursive partitioning and
333  // the sorting of the data efficiently.
334  std::vector<SamplesIdx_t> samples; // vector created in initialization list
335  SamplesIdx_t start;
336  SamplesIdx_t end;
337  public:
338  // Gini Criterion
339  GiniCriterion criterion; // nested object created in initialization list
340 
341  public:
343  BestSplitter(ClassesIdx_t n_classes, // required: 2 <= n_classes
344  FeaturesIdx_t n_features, // required: 1 <= n_features
345  SamplesIdx_t n_samples, // required: 2 <= n_samples
346  ClassWeights_t* class_weight,
347  FeaturesIdx_t max_features, // required: 0 < max_features <= n_features
348  unsigned long max_thresholds, // required: 0, 1
349  RandomState const& random_state);
350 
352  void init_node(Classes_t* y,
353  SamplesIdx_t start,
354  SamplesIdx_t end);
355 
357  void split_feature(Features_t* X,
358  Classes_t* y,
359  std::vector<SamplesIdx_t>& s,
360  FeaturesIdx_t feature,
361  int& NA,
362  Features_t& threshold,
363  SamplesIdx_t& pos,
364  double& improvement);
365 
368  void split_feature_extreme_random(Features_t* X,
369  Classes_t* y,
370  std::vector<SamplesIdx_t>& s,
371  FeaturesIdx_t feature,
372  int& NA,
373  Features_t& threshold,
374  SamplesIdx_t& pos,
375  double& improvement);
376 
378 
385  void split_node(Features_t* X,
386  Classes_t* y,
387  FeaturesIdx_t& feature,
388  int& NA,
389  Features_t& threshold,
390  SamplesIdx_t& pos,
391  double& improvement);
392  };
393 
394 // =============================================================================
395 // Tree Builder
396 // =============================================================================
397 
400 
401  protected:
402  TreeDepthIdx_t max_depth;
403  std::string missing_values;
404  // Best Splitter (and Gini Criterion)
405  BestSplitter splitter; // nested object created in initialization list
406 
407  public:
409 
449  DepthFirstTreeBuilder(ClassesIdx_t n_classes,
450  FeaturesIdx_t n_features,
451  SamplesIdx_t n_samples,
452  ClassWeights_t* class_weight,
453  TreeDepthIdx_t max_depth,
454  FeaturesIdx_t max_features,
455  unsigned long max_thresholds,
456  std::string missing_values,
457  RandomState const& random_state);
458 
460 
466  // Using 1d array addressing for X and y to support efficient Cython bindings to Python using memory views.
467  void build(Tree& tree,
468  Features_t* X,
469  Classes_t* y,
470  SamplesIdx_t n_samples);
471  };
472 
473 // =============================================================================
474 // Decision Tree Classifier
475 // =============================================================================
476 
479 
480  protected:
481  std::vector<std::string> classes;
482  ClassesIdx_t n_classes;
483  std::vector<std::string> features;
484  FeaturesIdx_t n_features;
485 
486  // Hyperparameters
487  std::string class_balance;
488  TreeDepthIdx_t max_depth;
489  FeaturesIdx_t max_features;
490  unsigned long max_thresholds;
491  std::string missing_values;
492 
493  // Random Number Generator
495 
496  // Model
497  Tree tree_; // underlying estimator
498 
499  public:
501 
554  DecisionTreeClassifier(std::vector<std::string> classes,
555  ClassesIdx_t n_classes,
556  std::vector<std::string> features,
557  FeaturesIdx_t n_features,
558  std::string const& class_balance = "balanced",
559  TreeDepthIdx_t max_depth = 0,
560  FeaturesIdx_t max_features = 0,
561  unsigned long max_thresholds = 0,
562  std::string const& missing_values = "None",
563  long random_state_seed = 0);
564 
566 
571  void fit(Features_t* X,
572  Classes_t* y,
573  SamplesIdx_t n_samples);
574 
576 
581  void predict_proba(Features_t* X,
582  SamplesIdx_t n_samples,
583  double* y_prob);
584 
586 
591  void predict(Features_t* X,
592  SamplesIdx_t n_samples,
593  Classes_t* y);
594 
596 
602  double score(Features_t* X,
603  Classes_t* y,
604  SamplesIdx_t n_samples);
605 
607 
613  void calculate_feature_importances(double* importances);
614 
616 
638  void export_graphviz(std::string const& filename, bool rotate=false);
639 
641  std::string export_text();
642 
644 
647  void export_serialize(std::string const& filename);
648 
650 
653  static DecisionTreeClassifier import_deserialize(std::string const& filename);
654 
656  void serialize(std::ofstream& fout);
658  static DecisionTreeClassifier deserialize(std::ifstream& fin);
659  };
660 
661 } // namespace koho
662 
663 #endif
std::string missing_values
Definition: decision_tree.h:491
unsigned long NodesIdx_t
Definition: decision_tree.h:47
FeaturesIdx_t max_features
Definition: decision_tree.h:327
void serialize(std::ofstream &fout)
Serialize.
Definition: decision_tree.cpp:49
double ClassWeights_t
Definition: decision_tree.h:40
std::vector< Histogram_t > node_weighted_histogram_NA
Definition: decision_tree.h:160
std::vector< Node > nodes
Definition: decision_tree.h:94
Definition: decision_forest.cpp:20
ClassesIdx_t n_classes
Definition: decision_tree.h:89
unsigned long ClassesIdx_t
Definition: decision_tree.h:46
std::string class_balance
Definition: decision_tree.h:487
Histogram_t node_weighted_n_samples_threshold_left_NA
Definition: decision_tree.h:173
static Node deserialize(std::ifstream &fin)
Deserialize.
Definition: decision_tree.cpp:70
Histogram_t node_weighted_n_samples_NA
Definition: decision_tree.h:161
unsigned long TreeDepthIdx_t
Definition: decision_tree.h:48
FeaturesIdx_t n_features
Definition: decision_tree.h:90
double get_node_impurity_NA()
Definition: decision_tree.h:307
SamplesIdx_t node_pos_threshold
Definition: decision_tree.h:182
double node_impurity_threshold_right
Definition: decision_tree.h:178
Splitter to find the best split for a node.
Definition: decision_tree.h:322
unsigned long FeaturesIdx_t
Definition: decision_tree.h:45
NodesIdx_t node_count
Definition: decision_tree.h:93
Tree tree_
Definition: decision_tree.h:497
int NA
Definition: decision_tree.h:63
SamplesIdx_t n_samples
Definition: decision_tree.h:326
Random Number Generator module.
A random number generator.
Definition: random_number_generator.h:20
Node of a binary tree.
Definition: decision_tree.h:57
Histogram_t node_weighted_n_samples
Definition: decision_tree.h:157
std::vector< Histogram_t > node_weighted_histogram_threshold_right
Definition: decision_tree.h:176
Histogram_t node_weighted_n_samples_values
Definition: decision_tree.h:165
std::vector< Histogram_t > node_weighted_histogram_threshold_left
Definition: decision_tree.h:169
ClassWeights_t * class_weight
Definition: decision_tree.h:152
RandomState random_state
Definition: decision_tree.h:494
double node_impurity
Definition: decision_tree.h:158
SamplesIdx_t end
Definition: decision_tree.h:336
double impurity
Definition: decision_tree.h:66
double node_impurity_NA
Definition: decision_tree.h:162
double get_node_impurity_threshold_right()
Definition: decision_tree.h:313
double improvement
Definition: decision_tree.h:67
std::vector< SamplesIdx_t > samples
Definition: decision_tree.h:334
ClassesIdx_t n_classes
Definition: decision_tree.h:150
unsigned long max_thresholds
Definition: decision_tree.h:490
std::vector< Histogram_t > node_weighted_histogram_values
Definition: decision_tree.h:164
Node(NodesIdx_t left_child, NodesIdx_t right_child, FeaturesIdx_t feature, int NA, Features_t threshold, const std::vector< Histogram_t > &histogram, double impurity, double improvement)
Create a new node.
Definition: decision_tree.cpp:30
BestSplitter splitter
Definition: decision_tree.h:405
FeaturesIdx_t n_features
Definition: decision_tree.h:325
double node_impurity_values
Definition: decision_tree.h:166
std::vector< std::string > features
Definition: decision_tree.h:483
std::vector< std::string > classes
Definition: decision_tree.h:481
double Histogram_t
Definition: decision_tree.h:42
double get_node_impurity_values()
Definition: decision_tree.h:309
TreeDepthIdx_t max_depth
Definition: decision_tree.h:402
double get_node_impurity_threshold_left()
Definition: decision_tree.h:311
FeaturesIdx_t n_features
Definition: decision_tree.h:484
long Classes_t
Definition: decision_tree.h:39
Build a binary decision tree in depth-first order.
Definition: decision_tree.h:399
SamplesIdx_t n_samples
Definition: decision_tree.h:151
NodesIdx_t right_child
Definition: decision_tree.h:61
double get_node_impurity()
Definition: decision_tree.h:305
SamplesIdx_t node_pos_NA
Definition: decision_tree.h:167
std::vector< Histogram_t > histogram
Definition: decision_tree.h:65
RandomState random_state
Definition: decision_tree.h:329
ClassesIdx_t n_classes
Definition: decision_tree.h:482
FeaturesIdx_t feature
Definition: decision_tree.h:62
SamplesIdx_t start
Definition: decision_tree.h:335
TreeDepthIdx_t max_depth
Definition: decision_tree.h:91
Gini Index impurity criterion.
Definition: decision_tree.h:147
GiniCriterion criterion
Definition: decision_tree.h:339
NodesIdx_t left_child
Definition: decision_tree.h:60
const double PRECISION_EQUAL
Definition: decision_tree.h:50
Histogram_t node_weighted_n_samples_threshold_right
Definition: decision_tree.h:177
double node_impurity_threshold_right_NA
Definition: decision_tree.h:181
Features_t threshold
Definition: decision_tree.h:64
A decision tree classifier.
Definition: decision_tree.h:478
std::vector< Histogram_t > node_weighted_histogram
Definition: decision_tree.h:156
FeaturesIdx_t max_features
Definition: decision_tree.h:489
unsigned long max_thresholds
Definition: decision_tree.h:328
TreeDepthIdx_t max_depth
Definition: decision_tree.h:488
double node_impurity_threshold_left
Definition: decision_tree.h:171
double node_impurity_threshold_left_NA
Definition: decision_tree.h:174
double Features_t
Definition: decision_tree.h:38
Binary tree structure build up of nodes.
Definition: decision_tree.h:86
std::string missing_values
Definition: decision_tree.h:403
Histogram_t node_weighted_n_samples_threshold_left
Definition: decision_tree.h:170
Histogram_t node_weighted_n_samples_threshold_right_NA
Definition: decision_tree.h:180
std::vector< Histogram_t > get_node_weighted_histogram()
Definition: decision_tree.h:303
unsigned long SamplesIdx_t
Definition: decision_tree.h:44