-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgirf_estimator_explicit.jl
131 lines (80 loc) · 2.81 KB
/
girf_estimator_explicit.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
using MRIReco, DSP, NIfTI, FiniteDifferences, PyPlot, Waveforms, Distributions, ImageFiltering, Flux, CUDA, NFFT
pygui(true)
# figure()
# plot(testK[1,:],testK[2,:])
function filterGradientWaveForms(G,theta)
imfilter(G,reflect(centered(theta)),Fill(0))
end
function simulatePerfectRecon(EH,b₀)
EH * b₀
end
ker = rand(30)
ker = ker ./ sum(ker)
#ker = ImageFiltering.Kernel.gaussian(30)
## Test Setting Up Simulation (forward sim)
N = 256
I = shepp_logan(N)
I = circularShutterFreq!(I,1)
# simulation parameters
parameters = Dict{Symbol, Any}()
parameters[:simulation] = "explicit"
parameters[:trajName] = "Spiral"
parameters[:numProfiles] = 1
parameters[:numSamplingPerProfile] = N*N
parameters[:windings] = 128
parameters[:AQ] = 3.0e-2
# do simulation
acqData = simulation(I, parameters)
# Define Perfect Reconstruction Operation
testOp = NFFTOp((256,256), acqData.traj[1])
testOp2 = adjoint(testOp)
recon1 = simulatePerfectRecon(testOp2, vec(acqData.kdata[1]))
figure()
scatter(acqData.traj[1].nodes[1,:], acqData.traj[1].nodes[2,:])
oldNodesX = acqData.traj[1].nodes[1,:]
oldNodesY = acqData.traj[1].nodes[2,:]
filteredWaveformX = filterGradientWaveForms(diff(acqData.traj[1].nodes[1,:]),ker)
filteredWaveformY = filterGradientWaveForms(diff(acqData.traj[1].nodes[2,:]),ker)
newNodesY = prepend!(vec(cumsum(filteredWaveformY)),[0.0])
newNodesX = prepend!(vec(cumsum(filteredWaveformX)),[0.0])
acqData.traj[1].nodes[1,:] = newNodesX
acqData.traj[1].nodes[2,:] = newNodesY
scatter(acqData.traj[1].nodes[1,:], acqData.traj[1].nodes[2,:])
testOp3 = NFFTOp((256,256), acqData.traj[1])
testOp4 = adjoint(testOp3)
recon2 = testOp4*vec(acqData.kdata[1])
error = mse(recon1,recon2)
figure()
plot(error)
figure()
plot(sqrt.(abs2.(oldNodesX) + abs2.(oldNodesY)) - sqrt.(abs2.(newNodesX) .+ abs2.(newNodesY)))
layer = Conv((1,200),1=>1,pad=SamePad())
model = Chain(layer)
# # Test The layer Idea
# layer = Conv((1,30),1=>30,identity; bias=true, pad=SamePad())
# testDat = reshape(oldNodesX,1,256*256,1,1)
trajRef = deepcopy(acqData.traj[1])
dataRef = deepcopy(vec(acqData.kdata[1]))
reconRef = testOp2 * vec(acqData.kdata[1])
function getPrediction(x)
nodesX = reshapeNodes(x.nodes[1,:])
x.nodes[1,:] = vec(model(nodesX))
op = NFFTOp((256,256),x)
op \ dataRef
end
function reshapeNodes(x)
reshape(x, 1,length(x),1,1)
end
function loss(x,y)
Flux.Losses.mae(getPrediction(x), y)
end
parameters = Flux.params(model)
opt = Descent()
Flux.train!(loss, parameters, [(trajRef, reconRef)], opt)
M, N = 1024, 16
x = rand(2, M) .- 0.5
fHat = randn(ComplexF64,M)
p = plan_nfft(x, (N,N))
f = nfft_adjoint(p, fHat)
g = nfft(p, f)
# NOTE: Current implementation of the NFFTOp relies on FFTW calls and so is inherently not autodifferentiable without adding custom Adjoint. Need to ask Jon about this tomorrow.