-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathinference.h
166 lines (139 loc) · 4.11 KB
/
inference.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
/* inference.h - General inference technique base class
Adrian Groves & Michael Chappell, FMRIB Image Analysis Group & IBME QuBIc Group
Copyright (C) 2007-2015 University of Oxford */
/* CCOPYRIGHT */
#pragma once
#include "dist_mvn.h"
#include "easylog.h"
#include "factories.h"
#include "fwdmodel.h"
#include "noisemodel.h"
#include "rundata.h"
#include <iomanip>
#include <string>
#include <vector>
class InferenceTechnique : public Loggable
{
public:
/**
* Static member function to return the names of all known
* inference techniques
*/
static std::vector<std::string> GetKnown();
/**
* Static member function, to pick an inference technique from a name
*/
static InferenceTechnique *NewFromName(const std::string &name);
/**
* Get usage information for a named method
*/
static void UsageFromName(const std::string &name, std::ostream &stream);
/**
* Create a new instance of this class.
* @return pointer to new instance.
*/
static InferenceTechnique *NewInstance();
/**
* Default constructor.
*/
InferenceTechnique();
/**
* Get option descriptions for this inference method.
*/
virtual void GetOptions(std::vector<OptionSpec> &opts) const {};
/**
* @return human-readable description of the inference method.
*/
virtual std::string GetDescription() const = 0;
/**
* Get the code version. There is no fixed format for this,
* and it has no meaning other than by comparison with different
* versions of the same inference method code.
*
* See fwdmodel.cc for an example of how to implement this to
* return a CVS file version.
*
* @return a string indicating the inference code version.
*/
virtual std::string GetVersion() const = 0;
/**
* Initialize a new instance to use the given forward model
* and extract additional configuration from the given
* arguments.
* @param fwd_model Forward model to be used.
* @param args Additional configuration parameters.
*/
virtual void Initialize(FwdModel *fwd_model, FabberRunData &rundata);
/**
* Perform inference using the given model upon the given data.
*
* This method should only be called after Initialize()
* Subclasses of InferenceTechnique must implement this method
* to carry out their given inference calculations
*
* @param data
*/
virtual void DoCalculations(FabberRunData &rundata) = 0;
/**
* Save the results
*/
virtual void SaveResults(FabberRunData &rundata) const;
/**
* Destructor.
*/
virtual ~InferenceTechnique();
protected:
void InitMVNFromFile(FabberRunData &rundata, std::string paramFilename);
/**
* Pointer to forward model, passed in to initialize.
*
* Will not be deleted, that is the responsibility of
* the caller
*/
FwdModel *m_model;
/**
* Number of model parameters.
*
* This is used regularly so it's sensible to keep a
* copy around
*/
int m_num_params;
/**
* If true, stop if we get a numerical execption in any voxel. If false,
* simply print a warning and continue
*/
bool m_halt_bad_voxel;
/**
* Results of the inference method
*
* Vector of MVNDist, one for each voxel
* Each MVNDist contains the means and covariance/precisions for
* the parameters in the model
*/
std::vector<MVNDist *> resultMVNs;
/**
* List of masked timepoints
*
* Masked timepoints are indexed starting at 1 and are ignored
* in the analysis and parameter updates.
*/
std::vector<int> m_masked_tpoints;
/**
* Include very verbose debugging output
*/
bool m_debug;
private:
/**
* Private to prevent assignment
*/
const InferenceTechnique &operator=(const InferenceTechnique &from)
{
assert(false);
return from;
}
};
/**
* \ref SingletonFactory that returns pointers to
* \ref InferenceTechnique.
*/
typedef SingletonFactory<InferenceTechnique> InferenceTechniqueFactory;