25 #ifndef KOHO_DECISION_TREE_H 26 #define KOHO_DECISION_TREE_H 70 Node(NodesIdx_t left_child,
71 NodesIdx_t right_child,
72 FeaturesIdx_t feature,
75 const std::vector<Histogram_t>& histogram,
101 Tree(ClassesIdx_t n_classes,
102 FeaturesIdx_t n_features);
108 NodesIdx_t add_node(TreeDepthIdx_t depth,
109 NodesIdx_t parent_id,
114 const std::vector<Histogram_t>&
histogram,
126 void predict(Features_t* X,
127 SamplesIdx_t n_samples,
134 void calculate_feature_importances(
double* importances);
190 SamplesIdx_t n_samples,
191 ClassWeights_t* class_weight);
194 void calculate_node_histogram(Classes_t* y,
195 std::vector<SamplesIdx_t>& samples,
201 double calculate_impurity(std::vector<Histogram_t>&
histogram);
207 void calculate_node_impurity();
213 void calculate_NA_histogram(Classes_t* y,
214 std::vector<SamplesIdx_t>& samples,
222 void calculate_NA_impurity();
230 double calculate_NA_impurity_improvement();
237 void init_threshold_histograms();
245 void init_threshold_values_histograms();
253 void update_threshold_histograms(Classes_t* y,
254 std::vector<SamplesIdx_t>& samples,
255 SamplesIdx_t new_pos);
261 void calculate_threshold_impurity();
269 void calculate_threshold_NA_impurity();
277 double calculate_threshold_impurity_improvement();
285 double calculate_threshold_values_impurity_improvement();
293 double calculate_threshold_NA_left_impurity_improvement();
301 double calculate_threshold_NA_right_impurity_improvement();
344 FeaturesIdx_t n_features,
345 SamplesIdx_t n_samples,
346 ClassWeights_t* class_weight,
347 FeaturesIdx_t max_features,
348 unsigned long max_thresholds,
352 void init_node(Classes_t* y,
357 void split_feature(Features_t* X,
359 std::vector<SamplesIdx_t>& s,
368 void split_feature_extreme_random(Features_t* X,
370 std::vector<SamplesIdx_t>& s,
385 void split_node(Features_t* X,
450 FeaturesIdx_t n_features,
451 SamplesIdx_t n_samples,
452 ClassWeights_t* class_weight,
453 TreeDepthIdx_t max_depth,
454 FeaturesIdx_t max_features,
455 unsigned long max_thresholds,
456 std::string missing_values,
467 void build(
Tree& tree,
470 SamplesIdx_t n_samples);
555 ClassesIdx_t n_classes,
556 std::vector<std::string> features,
557 FeaturesIdx_t n_features,
558 std::string
const& class_balance =
"balanced",
559 TreeDepthIdx_t max_depth = 0,
560 FeaturesIdx_t max_features = 0,
561 unsigned long max_thresholds = 0,
562 std::string
const& missing_values =
"None",
563 long random_state_seed = 0);
571 void fit(Features_t* X,
573 SamplesIdx_t n_samples);
581 void predict_proba(Features_t* X,
582 SamplesIdx_t n_samples,
591 void predict(Features_t* X,
592 SamplesIdx_t n_samples,
602 double score(Features_t* X,
604 SamplesIdx_t n_samples);
613 void calculate_feature_importances(
double* importances);
638 void export_graphviz(std::string
const& filename,
bool rotate=
false);
641 std::string export_text();
647 void export_serialize(std::string
const& filename);
std::string missing_values
Definition: decision_tree.h:491
unsigned long NodesIdx_t
Definition: decision_tree.h:47
FeaturesIdx_t max_features
Definition: decision_tree.h:327
void serialize(std::ofstream &fout)
Serialize.
Definition: decision_tree.cpp:49
double ClassWeights_t
Definition: decision_tree.h:40
std::vector< Histogram_t > node_weighted_histogram_NA
Definition: decision_tree.h:160
std::vector< Node > nodes
Definition: decision_tree.h:94
Definition: decision_forest.cpp:20
ClassesIdx_t n_classes
Definition: decision_tree.h:89
unsigned long ClassesIdx_t
Definition: decision_tree.h:46
std::string class_balance
Definition: decision_tree.h:487
Histogram_t node_weighted_n_samples_threshold_left_NA
Definition: decision_tree.h:173
static Node deserialize(std::ifstream &fin)
Deserialize.
Definition: decision_tree.cpp:70
Histogram_t node_weighted_n_samples_NA
Definition: decision_tree.h:161
unsigned long TreeDepthIdx_t
Definition: decision_tree.h:48
FeaturesIdx_t n_features
Definition: decision_tree.h:90
double get_node_impurity_NA()
Definition: decision_tree.h:307
SamplesIdx_t node_pos_threshold
Definition: decision_tree.h:182
double node_impurity_threshold_right
Definition: decision_tree.h:178
Splitter to find the best split for a node.
Definition: decision_tree.h:322
unsigned long FeaturesIdx_t
Definition: decision_tree.h:45
NodesIdx_t node_count
Definition: decision_tree.h:93
Tree tree_
Definition: decision_tree.h:497
int NA
Definition: decision_tree.h:63
SamplesIdx_t n_samples
Definition: decision_tree.h:326
Random Number Generator module.
A random number generator.
Definition: random_number_generator.h:20
Node of a binary tree.
Definition: decision_tree.h:57
Histogram_t node_weighted_n_samples
Definition: decision_tree.h:157
std::vector< Histogram_t > node_weighted_histogram_threshold_right
Definition: decision_tree.h:176
Histogram_t node_weighted_n_samples_values
Definition: decision_tree.h:165
std::vector< Histogram_t > node_weighted_histogram_threshold_left
Definition: decision_tree.h:169
ClassWeights_t * class_weight
Definition: decision_tree.h:152
RandomState random_state
Definition: decision_tree.h:494
double node_impurity
Definition: decision_tree.h:158
SamplesIdx_t end
Definition: decision_tree.h:336
double impurity
Definition: decision_tree.h:66
double node_impurity_NA
Definition: decision_tree.h:162
double get_node_impurity_threshold_right()
Definition: decision_tree.h:313
double improvement
Definition: decision_tree.h:67
std::vector< SamplesIdx_t > samples
Definition: decision_tree.h:334
ClassesIdx_t n_classes
Definition: decision_tree.h:150
unsigned long max_thresholds
Definition: decision_tree.h:490
std::vector< Histogram_t > node_weighted_histogram_values
Definition: decision_tree.h:164
Node(NodesIdx_t left_child, NodesIdx_t right_child, FeaturesIdx_t feature, int NA, Features_t threshold, const std::vector< Histogram_t > &histogram, double impurity, double improvement)
Create a new node.
Definition: decision_tree.cpp:30
BestSplitter splitter
Definition: decision_tree.h:405
FeaturesIdx_t n_features
Definition: decision_tree.h:325
double node_impurity_values
Definition: decision_tree.h:166
std::vector< std::string > features
Definition: decision_tree.h:483
std::vector< std::string > classes
Definition: decision_tree.h:481
double Histogram_t
Definition: decision_tree.h:42
double get_node_impurity_values()
Definition: decision_tree.h:309
TreeDepthIdx_t max_depth
Definition: decision_tree.h:402
double get_node_impurity_threshold_left()
Definition: decision_tree.h:311
FeaturesIdx_t n_features
Definition: decision_tree.h:484
long Classes_t
Definition: decision_tree.h:39
Build a binary decision tree in depth-first order.
Definition: decision_tree.h:399
SamplesIdx_t n_samples
Definition: decision_tree.h:151
NodesIdx_t right_child
Definition: decision_tree.h:61
double get_node_impurity()
Definition: decision_tree.h:305
SamplesIdx_t node_pos_NA
Definition: decision_tree.h:167
std::vector< Histogram_t > histogram
Definition: decision_tree.h:65
RandomState random_state
Definition: decision_tree.h:329
ClassesIdx_t n_classes
Definition: decision_tree.h:482
FeaturesIdx_t feature
Definition: decision_tree.h:62
SamplesIdx_t start
Definition: decision_tree.h:335
TreeDepthIdx_t max_depth
Definition: decision_tree.h:91
Gini Index impurity criterion.
Definition: decision_tree.h:147
GiniCriterion criterion
Definition: decision_tree.h:339
NodesIdx_t left_child
Definition: decision_tree.h:60
const double PRECISION_EQUAL
Definition: decision_tree.h:50
Histogram_t node_weighted_n_samples_threshold_right
Definition: decision_tree.h:177
double node_impurity_threshold_right_NA
Definition: decision_tree.h:181
Features_t threshold
Definition: decision_tree.h:64
A decision tree classifier.
Definition: decision_tree.h:478
std::vector< Histogram_t > node_weighted_histogram
Definition: decision_tree.h:156
FeaturesIdx_t max_features
Definition: decision_tree.h:489
unsigned long max_thresholds
Definition: decision_tree.h:328
TreeDepthIdx_t max_depth
Definition: decision_tree.h:488
double node_impurity_threshold_left
Definition: decision_tree.h:171
double node_impurity_threshold_left_NA
Definition: decision_tree.h:174
double Features_t
Definition: decision_tree.h:38
Binary tree structure build up of nodes.
Definition: decision_tree.h:86
std::string missing_values
Definition: decision_tree.h:403
Histogram_t node_weighted_n_samples_threshold_left
Definition: decision_tree.h:170
Histogram_t node_weighted_n_samples_threshold_right_NA
Definition: decision_tree.h:180
std::vector< Histogram_t > get_node_weighted_histogram()
Definition: decision_tree.h:303
unsigned long SamplesIdx_t
Definition: decision_tree.h:44