ModErn Text Analysis
META Enumerates Textual Applications
crf.h
Go to the documentation of this file.
1 
10 #ifndef META_SEQUENCE_CRF_H_
11 #define META_SEQUENCE_CRF_H_
12 
13 #include "sequence/observation.h"
14 #include "sequence/sequence.h"
16 #include "sequence/trellis.h"
17 #include "util/dense_matrix.h"
18 #include "util/disk_vector.h"
19 #include "util/optional.h"
20 #include "util/range.h"
21 
22 namespace meta
23 {
24 namespace sequence
25 {
26 
27 MAKE_NUMERIC_IDENTIFIER(crf_feature_id, uint64_t)
28 
29 
40 class crf
41 {
42  public:
48  struct parameters
49  {
53  double c2 = 1;
54 
60  double delta = 1e-5;
61 
65  uint64_t period = 10;
66 
71  double lambda = 0;
72 
81  double t0 = 0;
82 
87  uint64_t max_iters = 1000;
88 
93  double calibration_eta = 0.1;
94 
98  double calibration_rate = 2.0;
99 
103  uint64_t calibration_samples = 1000;
104 
109  uint64_t calibration_trials = 10;
110  };
111 
117  class tagger;
118 
126  crf(const std::string& prefix);
127 
137  double train(parameters params, const std::vector<sequence>& examples);
138 
145  tagger make_tagger() const;
146 
150  uint64_t num_labels() const;
151 
152  private:
153 
159 
164 
165  class scorer;
166  // grant the scorer access to the model weights
167  friend scorer;
168  class viterbi_scorer;
169 
179  void initialize(const std::vector<sequence>& examples);
180 
184  void load_model();
185 
189  void reset();
190 
202  double calibrate(parameters params, const std::vector<uint64_t>& indices,
203  const std::vector<sequence>& examples);
204 
209  const double& obs_weight(crf_feature_id idx) const;
210 
215  double& obs_weight(crf_feature_id idx);
216 
221  const double& trans_weight(crf_feature_id idx) const;
222 
227  double& trans_weight(crf_feature_id idx);
228 
234  feature_range obs_range(feature_id fid) const;
235 
241  feature_range trans_range(label_id lbl) const;
242 
247  label_id observation(crf_feature_id idx) const;
248 
254  label_id transition(crf_feature_id idx) const;
255 
267  double epoch(parameters params, printing::progress& progress,
268  uint64_t iter, const std::vector<uint64_t>& indices,
269  const std::vector<sequence>& examples, scorer& scorer);
270 
281  double iteration(parameters params, uint64_t iter, const sequence& seq,
282  scorer& scorer);
283 
291  void gradient_observation_expectation(const sequence& seq, double gain);
292 
302  void gradient_model_expectation(const sequence& seq, double gain,
303  const scorer& scr);
304 
308  double l2norm() const;
309 
314  void rescale();
315 
325 
335 
343 
350 
356 
362 
364  double scale_;
366  uint64_t num_labels_;
368  const std::string& prefix_;
369 };
370 }
371 }
372 #endif
A class for representing optional values.
Definition: vocabulary_map.h:21
util::optional< util::disk_vector< crf_feature_id > > transition_ranges_
Analogous to the observation range, but for transitions.
Definition: crf.h:334
util::optional< util::disk_vector< crf_feature_id > > observation_ranges_
Represents the feature id range for a given observation: observation_ranges_[i] gives the start of a ...
Definition: crf.h:324
util::optional< util::disk_vector< label_id > > transitions_
Represents the destination label for a given transition feature.
Definition: crf.h:349
Internal class that holds scoring information for sequences under the current model.
Definition: scorer.h:24
Represents a tagged sequence of observations.
Definition: sequence.h:24
Linear-chain conditional random field for POS tagging and chunking applications.
Definition: crf.h:40
Wrapper to represent the parameters used during learning.
Definition: crf.h:48
uint64_t num_labels_
the number of allowed labels
Definition: crf.h:366
double scale_
the current decay factor applied to all of the weights
Definition: crf.h:364
The ModErn Text Analysis toolkit is a suite of natural language processing, classification, information retreival, data mining, and other applications of text processing.
Definition: analyzer.h:24
util::optional< util::disk_vector< double > > transition_weights_
Weights for all of the transition features.
Definition: crf.h:361
Scorer for performing viterbi-based tagging.
Definition: viterbi_scorer.h:23
util::optional< util::disk_vector< double > > observation_weights_
The weights for all of the node-observation features.
Definition: crf.h:355
Represents an observation in a tagged sequence.
Definition: observation.h:32
util::optional< util::disk_vector< label_id > > observations_
Represents the state that fired for a given observation feature.
Definition: crf.h:342
Definition: tagger.h:20
Implements a range that spans a loop's extension and termination conditions, most useful for iteratin...
Definition: range.h:27
Simple class for reporting progress of lengthy operations.
Definition: progress.h:27
const std::string & prefix_
the prefix (folder) where model files are to be stored
Definition: crf.h:368