25 #ifndef KOHO_DECISION_TREE_H 26 #define KOHO_DECISION_TREE_H 71 Node(NodesIdx_t left_child,
72 NodesIdx_t right_child,
73 FeaturesIdx_t feature,
76 const std::vector<std::vector<Histogram_t>>& histogram,
105 Tree(OutputsIdx_t n_outputs,
106 std::vector<ClassesIdx_t> n_classes,
107 FeaturesIdx_t n_features);
115 Tree(OutputsIdx_t n_outputs,
116 ClassesIdx_t* n_classes_ptr,
117 FeaturesIdx_t n_features);
126 NodesIdx_t add_node(TreeDepthIdx_t depth,
127 NodesIdx_t parent_id,
132 const std::vector<std::vector<Histogram_t>>&
histogram,
148 void predict(Features_t* X,
149 SamplesIdx_t n_samples,
156 void calculate_feature_importances(
double* importances);
214 ClassesIdx_t* n_classes,
215 ClassesIdx_t n_classes_max,
216 SamplesIdx_t n_samples,
217 ClassWeights_t* class_weight);
220 void calculate_node_histogram(Classes_t* y,
221 std::vector<SamplesIdx_t>& samples,
227 double calculate_impurity(std::vector<Histogram_t>&
histogram);
233 void calculate_node_impurity();
239 void calculate_NA_histogram(Classes_t* y,
240 std::vector<SamplesIdx_t>& samples,
248 void calculate_NA_impurity();
256 double calculate_NA_impurity_improvement();
263 void init_threshold_histograms();
271 void init_threshold_values_histograms();
279 void update_threshold_histograms(Classes_t* y,
280 std::vector<SamplesIdx_t>& samples,
281 SamplesIdx_t new_pos);
287 void calculate_threshold_impurity();
295 void calculate_threshold_NA_impurity();
303 double calculate_threshold_impurity_improvement();
311 double calculate_threshold_values_impurity_improvement();
319 double calculate_threshold_NA_left_impurity_improvement();
327 double calculate_threshold_NA_right_impurity_improvement();
385 ClassesIdx_t* n_classes,
386 ClassesIdx_t n_classes_max,
387 FeaturesIdx_t n_features,
388 SamplesIdx_t n_samples,
389 ClassWeights_t* class_weight,
390 FeaturesIdx_t max_features,
391 unsigned long max_thresholds,
395 void init_node(Classes_t* y,
400 void split_feature(Features_t* X,
402 std::vector<SamplesIdx_t>& s,
411 void split_feature_extreme_random(Features_t* X,
413 std::vector<SamplesIdx_t>& s,
428 void split_node(Features_t* X,
495 ClassesIdx_t* n_classes,
496 ClassesIdx_t n_classes_max,
497 FeaturesIdx_t n_features,
498 SamplesIdx_t n_samples,
499 ClassWeights_t* class_weight,
500 TreeDepthIdx_t max_depth,
501 FeaturesIdx_t max_features,
502 unsigned long max_thresholds,
503 std::string missing_values,
514 void build(
Tree& tree,
517 SamplesIdx_t n_samples);
529 std::vector<std::vector<std::string>>
classes;
602 std::vector<std::string>
const& features,
603 std::string
const& class_balance =
"balanced",
604 TreeDepthIdx_t max_depth = 0,
605 FeaturesIdx_t max_features = 0,
606 unsigned long max_thresholds = 0,
607 std::string
const& missing_values =
"None",
608 long random_state_seed = 0);
615 void fit(std::vector<Features_t> & X,
616 std::vector<Classes_t> & y);
627 void predict_proba(Features_t* X,
628 SamplesIdx_t n_samples,
638 void predict(Features_t* X,
639 SamplesIdx_t n_samples,
650 double score(Features_t* X,
652 SamplesIdx_t n_samples);
661 void calculate_feature_importances(
double* importances);
686 void export_graphviz(std::string
const& filename,
bool rotate=
false);
689 std::string export_text();
695 void export_serialize(std::string
const& filename);
std::string missing_values
Definition: decision_tree.h:540
unsigned long NodesIdx_t
Definition: decision_tree.h:48
std::vector< double > node_impurity_threshold_right_NA
Definition: decision_tree.h:205
FeaturesIdx_t max_features
Definition: decision_tree.h:368
void serialize(std::ofstream &fout)
Serialize.
Definition: decision_tree.cpp:50
double ClassWeights_t
Definition: decision_tree.h:40
std::vector< Node > nodes
Definition: decision_tree.h:97
std::vector< std::vector< Histogram_t > > node_weighted_histogram_values
Definition: decision_tree.h:188
Definition: decision_forest.cpp:20
std::vector< Histogram_t > node_weighted_n_samples_threshold_right
Definition: decision_tree.h:201
unsigned long ClassesIdx_t
Definition: decision_tree.h:46
std::string class_balance
Definition: decision_tree.h:536
static Node deserialize(std::ifstream &fin)
Deserialize.
Definition: decision_tree.cpp:74
unsigned long TreeDepthIdx_t
Definition: decision_tree.h:49
std::vector< double > node_impurity_threshold_left_NA
Definition: decision_tree.h:198
FeaturesIdx_t n_features
Definition: decision_tree.h:93
std::vector< Histogram_t > node_weighted_n_samples_threshold_right_NA
Definition: decision_tree.h:204
double get_node_impurity_NA()
Definition: decision_tree.h:336
Tree()
Create a new tree without nodes for Python binding using pickle.
Definition: decision_tree.h:120
SamplesIdx_t node_pos_threshold
Definition: decision_tree.h:206
Splitter to find the best split for a node.
Definition: decision_tree.h:363
unsigned long FeaturesIdx_t
Definition: decision_tree.h:45
NodesIdx_t node_count
Definition: decision_tree.h:96
OutputsIdx_t n_outputs
Definition: decision_tree.h:90
std::vector< double > node_impurity_NA
Definition: decision_tree.h:186
Tree tree_
Definition: decision_tree.h:546
OutputsIdx_t n_outputs
Definition: decision_tree.h:172
std::vector< double > node_impurity_threshold_right
Definition: decision_tree.h:202
int NA
Definition: decision_tree.h:64
SamplesIdx_t n_samples
Definition: decision_tree.h:367
Random Number Generator module.
A random number generator.
Definition: random_number_generator.h:20
Node of a binary tree.
Definition: decision_tree.h:58
std::vector< std::vector< Histogram_t > > node_weighted_histogram_threshold_right
Definition: decision_tree.h:200
unsigned long OutputsIdx_t
Definition: decision_tree.h:47
ClassWeights_t * class_weight
Definition: decision_tree.h:176
RandomState random_state
Definition: decision_tree.h:543
SamplesIdx_t end
Definition: decision_tree.h:377
double impurity
Definition: decision_tree.h:67
std::vector< std::vector< Histogram_t > > node_weighted_histogram
Definition: decision_tree.h:180
double get_node_impurity_threshold_right()
Definition: decision_tree.h:351
double improvement
Definition: decision_tree.h:68
std::vector< std::vector< Histogram_t > > histogram
Definition: decision_tree.h:66
std::vector< SamplesIdx_t > samples
Definition: decision_tree.h:375
unsigned long max_thresholds
Definition: decision_tree.h:539
std::vector< std::vector< std::string > > classes
Definition: decision_tree.h:529
BestSplitter splitter
Definition: decision_tree.h:448
std::vector< double > node_impurity
Definition: decision_tree.h:182
OutputsIdx_t n_outputs
Definition: decision_tree.h:528
FeaturesIdx_t n_features
Definition: decision_tree.h:366
std::vector< std::string > features
Definition: decision_tree.h:532
double Histogram_t
Definition: decision_tree.h:42
double get_node_impurity_values()
Definition: decision_tree.h:341
TreeDepthIdx_t max_depth
Definition: decision_tree.h:445
double get_node_impurity_threshold_left()
Definition: decision_tree.h:346
FeaturesIdx_t n_features
Definition: decision_tree.h:533
long Classes_t
Definition: decision_tree.h:39
std::vector< Histogram_t > node_weighted_n_samples_threshold_left
Definition: decision_tree.h:194
std::vector< ClassesIdx_t > n_classes
Definition: decision_tree.h:530
ClassesIdx_t n_classes_max
Definition: decision_tree.h:174
Node(NodesIdx_t left_child, NodesIdx_t right_child, FeaturesIdx_t feature, int NA, Features_t threshold, const std::vector< std::vector< Histogram_t >> &histogram, double impurity, double improvement)
Create a new node.
Definition: decision_tree.cpp:31
Build a binary decision tree in depth-first order.
Definition: decision_tree.h:442
std::vector< Histogram_t > node_weighted_n_samples
Definition: decision_tree.h:181
SamplesIdx_t n_samples
Definition: decision_tree.h:175
std::vector< std::vector< Histogram_t > > node_weighted_histogram_threshold_left
Definition: decision_tree.h:193
NodesIdx_t right_child
Definition: decision_tree.h:62
double get_node_impurity()
Definition: decision_tree.h:331
std::vector< Histogram_t > node_weighted_n_samples_threshold_left_NA
Definition: decision_tree.h:197
SamplesIdx_t node_pos_NA
Definition: decision_tree.h:191
std::vector< double > node_impurity_threshold_left
Definition: decision_tree.h:195
std::vector< std::vector< Histogram_t > > get_node_weighted_histogram()
Definition: decision_tree.h:329
ClassesIdx_t * n_classes
Definition: decision_tree.h:173
ClassesIdx_t n_classes_max
Definition: decision_tree.h:531
RandomState random_state
Definition: decision_tree.h:370
FeaturesIdx_t feature
Definition: decision_tree.h:63
std::vector< Histogram_t > node_weighted_n_samples_values
Definition: decision_tree.h:189
SamplesIdx_t start
Definition: decision_tree.h:376
TreeDepthIdx_t max_depth
Definition: decision_tree.h:94
Gini Index impurity criterion.
Definition: decision_tree.h:169
GiniCriterion criterion
Definition: decision_tree.h:380
NodesIdx_t left_child
Definition: decision_tree.h:61
const double PRECISION_EQUAL
Definition: decision_tree.h:51
std::vector< Histogram_t > node_weighted_n_samples_NA
Definition: decision_tree.h:185
Features_t threshold
Definition: decision_tree.h:65
std::vector< ClassesIdx_t > n_classes
Definition: decision_tree.h:91
A decision tree classifier.
Definition: decision_tree.h:525
std::vector< std::vector< Histogram_t > > node_weighted_histogram_NA
Definition: decision_tree.h:184
ClassesIdx_t n_classes_max
Definition: decision_tree.h:92
FeaturesIdx_t max_features
Definition: decision_tree.h:538
unsigned long max_thresholds
Definition: decision_tree.h:369
TreeDepthIdx_t max_depth
Definition: decision_tree.h:537
double Features_t
Definition: decision_tree.h:38
std::vector< double > node_impurity_values
Definition: decision_tree.h:190
Binary tree structure build up of nodes.
Definition: decision_tree.h:87
std::string missing_values
Definition: decision_tree.h:446
unsigned long SamplesIdx_t
Definition: decision_tree.h:44