koho.cpp  1.1.0
decision_forest.h
Go to the documentation of this file.
1 
10 // Author: AI Werkstatt (TM)
11 // (C) Copyright 2019, AI Werkstatt (TM) www.aiwerkstatt.com. All rights reserved.
12 
13 #ifndef KOHO_DECISION_FOREST_H
14 #define KOHO_DECISION_FOREST_H
15 
16 #include "decision_tree.h"
17 
18 namespace koho {
19 
20 // =============================================================================
21 // Decision Forest Classifier
22 // =============================================================================
23 
26 
27  protected:
29  std::vector<std::vector<std::string>> classes;
30  std::vector<ClassesIdx_t> n_classes;
31  ClassesIdx_t n_classes_max; // just for convenience
32  std::vector<std::string> features;
34 
35  // Hyperparameters
36  unsigned long n_estimators;
37  bool bootstrap;
38  bool oob_score;
39  std::string class_balance;
42  unsigned long max_thresholds;
43  std::string missing_values;
44 
45  // Random Number Generator
47 
48  // Model
49  std::vector<DecisionTreeClassifier> dtc_; // underlying sub-estimators
50 
51  // Performance Characteristics
52  double oob_score_; // Out_Of-Bag estimate
53 
54  public:
56 
115  DecisionForestClassifier(std::vector<std::vector<std::string>> const& classes,
116  std::vector<std::string> const& features,
117  unsigned long n_estimators = 100,
118  bool bootstrap = false,
119  bool oob_score = false,
120  std::string const& class_balance = "balanced",
121  TreeDepthIdx_t max_depth = 3,
122  FeaturesIdx_t max_features = 0,
123  unsigned long max_thresholds = 0,
124  std::string const& missing_values = "None",
125  long random_state_seed = 0);
126 
128 
132  void fit(std::vector<Features_t> & X,
133  std::vector<Classes_t> & y);
134 
136 
144  void predict_proba(Features_t* X,
145  SamplesIdx_t n_samples,
146  double* y_prob);
147 
149 
155  void predict(Features_t* X,
156  SamplesIdx_t n_samples,
157  Classes_t* y);
158 
160 
167  double score(Features_t* X,
168  Classes_t* y,
169  SamplesIdx_t n_samples);
170 
172 
177  void calculate_feature_importances(double* importances);
178 
180 
203  void export_graphviz(std::string const& filename, bool rotate=false);
204 
206 
215  void export_graphviz(std::string const& filename, unsigned long e, bool rotate);
216 
218 
221  std::string export_text(unsigned long e);
222 
225 
229  void export_serialize(std::string const& filename);
230 
233 
237  static DecisionForestClassifier import_deserialize(std::string const& filename);
238 
240  void serialize(std::ofstream& fout);
242  static DecisionForestClassifier deserialize(std::ifstream& fin);
243  };
244 
245 } // namespace koho
246 
247 #endif
void serialize(std::ofstream &fout)
Serialize.
Definition: decision_forest.cpp:473
Definition: decision_forest.cpp:20
void fit(std::vector< Features_t > &X, std::vector< Classes_t > &y)
Build a decision forest classifier from the training data.
Definition: decision_forest.cpp:118
unsigned long ClassesIdx_t
Definition: decision_tree.h:46
unsigned long TreeDepthIdx_t
Definition: decision_tree.h:49
unsigned long max_thresholds
Definition: decision_forest.h:42
RandomState random_state
Definition: decision_forest.h:46
unsigned long FeaturesIdx_t
Definition: decision_tree.h:45
std::string export_text(unsigned long e)
Export of a decision tree from a decision forest in a simple text format.
Definition: decision_forest.cpp:465
A decision forest classifier.
Definition: decision_forest.h:25
DecisionForestClassifier(std::vector< std::vector< std::string >> const &classes, std::vector< std::string > const &features, unsigned long n_estimators=100, bool bootstrap=false, bool oob_score=false, std::string const &class_balance="balanced", TreeDepthIdx_t max_depth=3, FeaturesIdx_t max_features=0, unsigned long max_thresholds=0, std::string const &missing_values="None", long random_state_seed=0)
Create and initialize a new decision forest classifier.
Definition: decision_forest.cpp:36
ClassesIdx_t n_classes_max
Definition: decision_forest.h:31
Decision Tree module.
A random number generator.
Definition: random_number_generator.h:20
std::vector< std::string > features
Definition: decision_forest.h:32
double oob_score_
Definition: decision_forest.h:52
std::vector< std::vector< std::string > > classes
Definition: decision_forest.h:29
unsigned long OutputsIdx_t
Definition: decision_tree.h:47
FeaturesIdx_t n_features
Definition: decision_forest.h:33
void calculate_feature_importances(double *importances)
Calculate feature importances from the decision forest.
Definition: decision_forest.cpp:423
static DecisionForestClassifier import_deserialize(std::string const &filename)
Definition: decision_forest.cpp:638
void predict_proba(Features_t *X, SamplesIdx_t n_samples, double *y_prob)
Predict classes probabilities for the test data.
Definition: decision_forest.cpp:340
void export_graphviz(std::string const &filename, bool rotate=false)
Export of a decision forest as individual decision trees in GraphViz dot format.
Definition: decision_forest.cpp:447
std::string class_balance
Definition: decision_forest.h:39
double score(Features_t *X, Classes_t *y, SamplesIdx_t n_samples)
Calculate score for the test data.
Definition: decision_forest.cpp:402
TreeDepthIdx_t max_depth
Definition: decision_forest.h:40
long Classes_t
Definition: decision_tree.h:39
bool bootstrap
Definition: decision_forest.h:37
bool oob_score
Definition: decision_forest.h:38
void export_serialize(std::string const &filename)
Definition: decision_forest.cpp:523
OutputsIdx_t n_outputs
Definition: decision_forest.h:28
std::vector< ClassesIdx_t > n_classes
Definition: decision_forest.h:30
void predict(Features_t *X, SamplesIdx_t n_samples, Classes_t *y)
Predict classes for the test data.
Definition: decision_forest.cpp:380
static DecisionForestClassifier deserialize(std::ifstream &fin)
Deserialize.
Definition: decision_forest.cpp:552
unsigned long n_estimators
Definition: decision_forest.h:36
double Features_t
Definition: decision_tree.h:38
std::string missing_values
Definition: decision_forest.h:43
std::vector< DecisionTreeClassifier > dtc_
Definition: decision_forest.h:49
unsigned long SamplesIdx_t
Definition: decision_tree.h:44
FeaturesIdx_t max_features
Definition: decision_forest.h:41