9 #ifndef META_CLASSIFY_SGD_H_
10 #define META_CLASSIFY_SGD_H_
54 sgd(
const std::string& prefix, std::shared_ptr<index::forward_index> idx,
55 class_label positive, class_label negative,
56 std::unique_ptr<loss::loss_function> loss,
double alpha = default_alpha,
57 double gamma = default_gamma,
double bias = default_bias,
58 double lambda = default_lambda,
size_t max_iter = default_max_iter);
68 double predict(doc_id d_id)
const;
70 void train(
const std::vector<doc_id>& docs)
override;
72 void reset()
override;
77 const static std::string
id;
105 std::unique_ptr<loss::loss_function>
loss_;
110 using counts_t = std::vector<std::pair<term_id, double>>;
126 std::unique_ptr<binary_classifier>
128 std::shared_ptr<index::forward_index> idx,
129 class_label positive, class_label negative);