1
2 """
3 The :mod:`parsimony.algorithms.multiblock` module includes several multiblock
4 algorithms.
5
6 Algorithms may not store states. I.e., if they are classes, do not keep
7 references to objects with state in the algorithm objects. It should be
8 possible to copy and share algorithms between e.g. estimators, and thus they
9 should not depend on any state.
10
11 Created on Thu Feb 20 22:12:00 2014
12
13 Copyright (c) 2013-2014, CEA/DSV/I2BM/Neurospin. All rights reserved.
14
15 @author: Tommy Löfstedt
16 @email: lofstedt.tommy@gmail.com
17 @license: BSD 3-clause.
18 """
19 import numpy as np
20
21 try:
22 from . import bases
23 except ValueError:
24 import parsimony.algorithms.bases as bases
25 import parsimony.utils.consts as consts
26 from parsimony.algorithms.utils import Info
27 import parsimony.utils as utils
28 import parsimony.utils.maths as maths
29 import parsimony.functions.properties as properties
30 import parsimony.functions.multiblock.properties as multiblock_properties
31 import parsimony.functions.multiblock.losses as mb_losses
32
33
34 __all__ = ["MultiblockFISTA"]
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81 -class MultiblockFISTA(bases.ExplicitAlgorithm,
82 bases.IterativeAlgorithm,
83 bases.InformationAlgorithm):
84 """The projected or proximal gradient algorithm with alternating
85 minimisations in a multiblock setting.
86
87 Parameters
88 ----------
89 info : List or tuple of utils.Info. What, if any, extra run information
90 should be stored. Default is an empty list, which means that no
91 run information is computed nor returned.
92
93 eps : Positive float. Tolerance for the stopping criterion.
94
95 max_iter : Non-negative integer. Maximum total allowed number of
96 iterations.
97
98 min_iter : Non-negative integer less than or equal to max_iter. Minimum
99 number of iterations that must be performed. Default is 1.
100 """
101 INTERFACES = [multiblock_properties.MultiblockFunction,
102 multiblock_properties.MultiblockGradient,
103 multiblock_properties.MultiblockStepSize,
104 properties.OR(
105 multiblock_properties.MultiblockProjectionOperator,
106 multiblock_properties.MultiblockProximalOperator)]
107
108 INFO_PROVIDED = [Info.ok,
109 Info.num_iter,
110 Info.time,
111 Info.func_val,
112 Info.smooth_func_val,
113 Info.converged]
114
124
125 @bases.force_reset
126 @bases.check_compatibility
127 - def run(self, function, w):
128
129
130 if self.info_requested(Info.ok):
131 self.info_set(Info.ok, False)
132
133
134 if self.info_requested(Info.time):
135 _t = []
136 if self.info_requested(Info.func_val):
137 _f = []
138 if self.info_requested(Info.smooth_func_val):
139 _fmu = []
140 if self.info_requested(Info.converged):
141 self.info_set(Info.converged, False)
142
143 FISTA = True
144 if FISTA:
145 exp = 4.0 + consts.FLOAT_EPSILON
146 else:
147 exp = 2.0 + consts.FLOAT_EPSILON
148 block_iter = [1] * len(w)
149
150 it = 0
151 while True:
152
153 for i in xrange(len(w)):
154 print "it: %d, i: %d" % (it, i)
155
156
157
158
159
160 func = mb_losses.MultiblockFunctionWrapper(function, w, i)
161
162
163 w_old = w[i]
164 for k in xrange(1, max(self.min_iter + 1,
165 self.max_iter - self.num_iter + 1)):
166
167 if self.info_requested(Info.time):
168 time = utils.time_wall()
169
170 if FISTA:
171
172 z = w[i] + ((k - 2.0) / (k + 1.0)) * (w[i] - w_old)
173 else:
174 z = w[i]
175
176
177 step = func.step(z)
178
179 eps = max(consts.FLOAT_EPSILON,
180 1.0 / (block_iter[i] ** exp))
181
182
183 w_old = w[i]
184
185 w[i] = func.prox(z - step * func.grad(z),
186 factor=step, eps=eps)
187
188
189 if self.info_requested(Info.time):
190 _t.append(utils.time_wall() - time)
191 if self.info_requested(Info.func_val):
192 _f.append(function.f(w))
193 if self.info_requested(Info.smooth_func_val):
194 _fmu.append(function.fmu(w))
195
196
197 self.num_iter += 1
198 block_iter[i] += 1
199
200
201
202
203
204 if maths.norm(w[i] - z) < step * self.eps \
205 and k >= self.min_iter:
206 break
207
208
209 all_converged = True
210 for i in xrange(len(w)):
211
212
213 func = mb_losses.MultiblockFunctionWrapper(function, w, i)
214
215
216 step = func.step(w[i])
217
218 eps = max(consts.FLOAT_EPSILON,
219 1.0 / (block_iter[i] ** exp))
220
221
222 w_tilde = func.prox(w[i] - step * func.grad(w[i]),
223 factor=step, eps=eps)
224
225
226 if maths.norm(w[i] - w_tilde) > step * self.eps:
227 all_converged = False
228 break
229
230
231 if all_converged:
232 if self.info_requested(Info.converged):
233 self.info_set(Info.converged, True)
234
235 break
236
237
238 if self.num_iter >= self.max_iter:
239 break
240
241 it += 1
242
243
244 if self.info_requested(Info.num_iter):
245 self.info_set(Info.num_iter, self.num_iter)
246 if self.info_requested(Info.time):
247 self.info_set(Info.time, _t)
248 if self.info_requested(Info.func_val):
249 self.info_set(Info.func_val, _f)
250 if self.info_requested(Info.smooth_func_val):
251 self.info_set(Info.smooth_func_val, _fmu)
252 if self.info_requested(Info.ok):
253 self.info_set(Info.ok, True)
254
255 return w
256
257
258 -class MultiblockCONESTA(bases.ExplicitAlgorithm,
259 bases.IterativeAlgorithm,
260 bases.InformationAlgorithm):
261 """An alternating minimising multiblock algorithm that utilises CONESTA in
262 the inner minimisation.
263
264 Parameters
265 ----------
266 info : List or tuple of utils.consts.Info. What, if any, extra run
267 information should be stored. Default is an empty list, which means
268 that no run information is computed nor returned.
269
270 eps : Positive float. Tolerance for the stopping criterion.
271
272 outer_iter : Non-negative integer. Maximum allowed number of outer loop
273 iterations.
274
275 max_iter : Non-negative integer. Maximum allowed number of iterations.
276
277 min_iter : Non-negative integer. Number of required iterations. Default
278 is 1.
279 """
280 INTERFACES = [multiblock_properties.MultiblockFunction,
281 multiblock_properties.MultiblockGradient,
282 multiblock_properties.MultiblockProjectionOperator,
283 multiblock_properties.MultiblockStepSize]
284
285 INFO_PROVIDED = [Info.ok,
286 Info.num_iter,
287 Info.time,
288 Info.fvalue,
289 Info.converged]
290
295
296 super(MultiblockCONESTA, self).__init__(info=info,
297 max_iter=max_iter,
298 min_iter=min_iter)
299
300 self.outer_iter = outer_iter
301 self.eps = eps
302
303
304 from parsimony.algorithms.proximal import FISTA
305 from parsimony.algorithms.primaldual import NaiveCONESTA
306 alg_info = []
307 for nfo in self.info_copy():
308 if nfo in FISTA.INFO_PROVIDED:
309 alg_info.append(nfo)
310
311
312 if Info.converged not in alg_info:
313 alg_info.append(Info.converged)
314
315 self.fista = FISTA(info=alg_info,
316 eps=self.eps,
317 max_iter=self.max_iter,
318 min_iter=self.min_iter)
319 self.conesta = NaiveCONESTA(mu_start=mu_start,
320 mu_min=mu_min,
321 tau=tau,
322
323 eps=self.eps,
324 info=alg_info,
325 max_iter=self.max_iter,
326 min_iter=self.min_iter)
327
328 @bases.force_reset
329 @bases.check_compatibility
330 - def run(self, function, w):
331
332
333
334 if self.info_requested(Info.ok):
335 self.info_set(Info.ok, False)
336 if self.info_requested(Info.time):
337 t = []
338 if self.info_requested(Info.fvalue):
339 f = []
340 if self.info_requested(Info.converged):
341 self.info_set(Info.converged, False)
342
343 print "len(w):", len(w)
344 print "max_iter:", self.max_iter
345
346 num_iter = [0] * len(w)
347
348 for it in xrange(1, self.outer_iter + 1):
349
350 all_converged = True
351
352 for i in xrange(len(w)):
353 print "it: %d, i: %d" % (it, i)
354
355 if function.has_nesterov_function(i):
356 print "Block %d has a Nesterov function!" % (i,)
357 func = mb_losses.MultiblockNesterovFunctionWrapper(
358 function, w, i)
359 algorithm = self.conesta
360 else:
361 func = mb_losses.MultiblockFunctionWrapper(function, w, i)
362 algorithm = self.fista
363
364
365
366
367 if i == 1:
368 pass
369 w[i] = algorithm.run(func, w[i])
370
371 if algorithm.info_requested(Info.num_iter):
372 num_iter[i] += algorithm.info_get(Info.num_iter)
373 if algorithm.info_requested(Info.time):
374 tval = algorithm.info_get(Info.time)
375 if algorithm.info_requested(Info.fvalue):
376 fval = algorithm.info_get(Info.fvalue)
377
378 if self.info_requested(Info.time):
379 t = t + tval
380 if self.info_requested(Info.fvalue):
381 f = f + fval
382
383 print "l0 :", maths.norm0(w[i]), \
384 ", l1 :", maths.norm1(w[i]), \
385 ", l2²:", maths.norm(w[i]) ** 2.0
386
387 print "f:", fval[-1]
388
389 for i in xrange(len(w)):
390
391
392 step = function.step(w, i)
393 w_tilde = function.prox(w[:i] +
394 [w[i] - step * function.grad(w, i)] +
395 w[i + 1:], i, step)
396
397
398
399
400
401
402
403 print "err:", maths.norm(w[i] - w_tilde) * (1.0 / step)
404 if (1.0 / step) * maths.norm(w[i] - w_tilde) > self.eps:
405 all_converged = False
406 break
407
408 if all_converged:
409 print "All converged!"
410
411 if self.info_requested(Info.converged):
412 self.info_set(Info.converged, True)
413
414 break
415
416
417
418
419
420
421
422 if self.info_requested(Info.num_iter):
423 self.info_set(Info.num_iter, num_iter)
424 if self.info_requested(Info.time):
425 self.info_set(Info.time, t)
426 if self.info_requested(Info.fvalue):
427 self.info_set(Info.fvalue, f)
428 if self.info_requested(Info.ok):
429 self.info_set(Info.ok, True)
430
431 return w
432
433 if __name__ == "__main__":
434 import doctest
435 doctest.testmod()
436