ModErn Text Analysis
META Enumerates Textual Applications
sgd.h
Go to the documentation of this file.
1 
9 #ifndef META_CLASSIFY_SGD_H_
10 #define META_CLASSIFY_SGD_H_
11 
15 #include "util/disk_vector.h"
16 #include "meta.h"
17 
18 namespace meta
19 {
20 namespace classify
21 {
22 
28 class sgd : public binary_classifier
29 {
30  public:
32  const static constexpr double default_alpha = 0.001;
34  const static constexpr double default_gamma = 1e-6;
36  const static constexpr double default_bias = 1;
38  const static constexpr double default_lambda = 0.0001;
40  const static constexpr size_t default_max_iter = 50;
41 
54  sgd(const std::string& prefix, std::shared_ptr<index::forward_index> idx,
55  class_label positive, class_label negative,
56  std::unique_ptr<loss::loss_function> loss, double alpha = default_alpha,
57  double gamma = default_gamma, double bias = default_bias,
58  double lambda = default_lambda, size_t max_iter = default_max_iter);
59 
68  double predict(doc_id d_id) const;
69 
70  void train(const std::vector<doc_id>& docs) override;
71 
72  void reset() override;
73 
77  const static std::string id;
78 
79  private:
82 
84  double coeff_{1.0};
85 
87  const double alpha_;
88 
90  const double gamma_;
91 
93  double bias_;
94 
96  const double bias_weight_;
97 
99  const double lambda_;
100 
102  const size_t max_iter_;
103 
105  std::unique_ptr<loss::loss_function> loss_;
106 
110  using counts_t = std::vector<std::pair<term_id, double>>;
111 
119  double predict(const counts_t& doc) const;
120 };
121 
125 template <>
126 std::unique_ptr<binary_classifier>
127  make_binary_classifier<sgd>(const cpptoml::table& config,
128  std::shared_ptr<index::forward_index> idx,
129  class_label positive, class_label negative);
130 }
131 }
132 #endif
static const constexpr double default_gamma
The default parameter.
Definition: sgd.h:34
double bias_
, the bias.
Definition: sgd.h:93
const double alpha_
, the learning rate.
Definition: sgd.h:87
Contains top-level namespace documentation for the META toolkit.
const double gamma_
, the error threshold.
Definition: sgd.h:90
void train(const std::vector< doc_id > &docs) override
Creates a classification model based on training documents.
Definition: sgd.cpp:52
std::vector< std::pair< term_id, double >> counts_t
Typedef for the sparse vector training/test instances.
Definition: sgd.h:110
static const constexpr size_t default_max_iter
The default number of allowed iterations.
Definition: sgd.h:40
void reset() override
Clears any learning data associated with this classifier.
Definition: sgd.cpp:118
const size_t max_iter_
The maximum number of iterations for training.
Definition: sgd.h:102
static const constexpr double default_alpha
The default parameter.
Definition: sgd.h:32
The ModErn Text Analysis toolkit is a suite of natural language processing, classification, information retreival, data mining, and other applications of text processing.
Definition: analyzer.h:24
const double lambda_
, the regularization constant
Definition: sgd.h:99
const double bias_weight_
The weight of the bias term for each document (defaults to 1)
Definition: sgd.h:96
util::disk_vector< double > weights_
The weights vector.
Definition: sgd.h:81
sgd(const std::string &prefix, std::shared_ptr< index::forward_index > idx, class_label positive, class_label negative, std::unique_ptr< loss::loss_function > loss, double alpha=default_alpha, double gamma=default_gamma, double bias=default_bias, double lambda=default_lambda, size_t max_iter=default_max_iter)
Definition: sgd.cpp:20
static const constexpr double default_bias
The default parameter.
Definition: sgd.h:36
Implements stochastic gradient descent for learning binary linear classifiers.
Definition: sgd.h:28
double coeff_
The scalar coefficient for the weights vector.
Definition: sgd.h:84
static const std::string id
The identifier for this classifier.
Definition: sgd.h:77
double predict(doc_id d_id) const
Returns the dot product with the current weight vector.
Definition: sgd.cpp:38
std::unique_ptr< binary_classifier > make_binary_classifier< sgd >(const cpptoml::table &config, std::shared_ptr< index::forward_index > idx, class_label positive, class_label negative)
Specialization of the factory method used to create sgd classifiers.
Definition: sgd.cpp:128
std::unique_ptr< loss::loss_function > loss_
The loss function to be used for the update.
Definition: sgd.h:105
A classifier which classifies documents as "positive" or "negative".
Definition: binary_classifier.h:24
static const constexpr double default_lambda
The default parameter.
Definition: sgd.h:38