Package parsimony :: Package algorithms :: Module multiblock
[hide private]
[frames] | no frames]

Source Code for Module parsimony.algorithms.multiblock

  1  # -*- coding: utf-8 -*- 
  2  """ 
  3  The :mod:`parsimony.algorithms.multiblock` module includes several multiblock 
  4  algorithms. 
  5   
  6  Algorithms may not store states. I.e., if they are classes, do not keep 
  7  references to objects with state in the algorithm objects. It should be 
  8  possible to copy and share algorithms between e.g. estimators, and thus they 
  9  should not depend on any state. 
 10   
 11  Created on Thu Feb 20 22:12:00 2014 
 12   
 13  Copyright (c) 2013-2014, CEA/DSV/I2BM/Neurospin. All rights reserved. 
 14   
 15  @author:  Tommy Löfstedt 
 16  @email:   lofstedt.tommy@gmail.com 
 17  @license: BSD 3-clause. 
 18  """ 
 19  import numpy as np 
 20   
 21  try: 
 22      from . import bases  # Only works when imported as a package. 
 23  except ValueError: 
 24      import parsimony.algorithms.bases as bases  # When run as a program.import parsimony.utils.maths as maths 
 25  import parsimony.utils.consts as consts 
 26  from parsimony.algorithms.utils import Info 
 27  import parsimony.utils as utils 
 28  import parsimony.utils.maths as maths 
 29  import parsimony.functions.properties as properties 
 30  import parsimony.functions.multiblock.properties as multiblock_properties 
 31  import parsimony.functions.multiblock.losses as mb_losses 
 32  #from parsimony.algorithms.proximal import FISTA, ISTA 
 33   
 34  __all__ = ["MultiblockFISTA"] 
35 36 37 #class GeneralisedMultiblockISTA(ExplicitAlgorithm): 38 # """ The iterative shrinkage threshold algorithm in a multiblock setting. 39 # """ 40 # INTERFACES = [functions.MultiblockFunction, 41 # functions.MultiblockGradient, 42 # functions.MultiblockProximalOperator, 43 # functions.StepSize, 44 # ] 45 # 46 # def __init__(self, step=None, output=False, 47 # eps=consts.TOLERANCE, 48 # max_iter=consts.MAX_ITER, min_iter=1): 49 # 50 # self.step = step 51 # self.output = output 52 # self.eps = eps 53 # self.max_iter = max_iter 54 # self.min_iter = min_iter 55 # 56 # def __call__(self, function, w): 57 # 58 # self.check_compatability(function, self.INTERFACES) 59 # 60 # for it in xrange(10): # TODO: Get number of iterations! 61 # print "it:", it 62 # 63 # for i in xrange(len(w)): 64 # print " i:", i 65 # 66 # for k in xrange(10000): 67 # print " k:", k 68 # 69 # t = function.step(w, i) 70 # w[i] = w[i] - t * function.grad(w, i) 71 # w = function.prox(w, i, t) 72 ## = w[:i] + [wi] + w[i+1:] 73 # 74 # print " f:", function.f(w) 75 # 76 ## w[i] = wi 77 # 78 # return w 79 80 81 -class MultiblockFISTA(bases.ExplicitAlgorithm, 82 bases.IterativeAlgorithm, 83 bases.InformationAlgorithm):
84 """The projected or proximal gradient algorithm with alternating 85 minimisations in a multiblock setting. 86 87 Parameters 88 ---------- 89 info : List or tuple of utils.Info. What, if any, extra run information 90 should be stored. Default is an empty list, which means that no 91 run information is computed nor returned. 92 93 eps : Positive float. Tolerance for the stopping criterion. 94 95 max_iter : Non-negative integer. Maximum total allowed number of 96 iterations. 97 98 min_iter : Non-negative integer less than or equal to max_iter. Minimum 99 number of iterations that must be performed. Default is 1. 100 """ 101 INTERFACES = [multiblock_properties.MultiblockFunction, 102 multiblock_properties.MultiblockGradient, 103 multiblock_properties.MultiblockStepSize, 104 properties.OR( 105 multiblock_properties.MultiblockProjectionOperator, 106 multiblock_properties.MultiblockProximalOperator)] 107 108 INFO_PROVIDED = [Info.ok, 109 Info.num_iter, 110 Info.time, 111 Info.func_val, 112 Info.smooth_func_val, 113 Info.converged] 114
115 - def __init__(self, info=[], 116 eps=consts.TOLERANCE, 117 max_iter=consts.MAX_ITER, min_iter=1):
118 119 super(MultiblockFISTA, self).__init__(info=info, 120 max_iter=max_iter, 121 min_iter=min_iter) 122 123 self.eps = max(consts.FLOAT_EPSILON, float(eps))
124 125 @bases.force_reset 126 @bases.check_compatibility
127 - def run(self, function, w):
128 129 # Not ok until the end. 130 if self.info_requested(Info.ok): 131 self.info_set(Info.ok, False) 132 133 # Initialise info variables. Info variables have the prefix "_". 134 if self.info_requested(Info.time): 135 _t = [] 136 if self.info_requested(Info.func_val): 137 _f = [] 138 if self.info_requested(Info.smooth_func_val): 139 _fmu = [] 140 if self.info_requested(Info.converged): 141 self.info_set(Info.converged, False) 142 143 FISTA = True 144 if FISTA: 145 exp = 4.0 + consts.FLOAT_EPSILON 146 else: 147 exp = 2.0 + consts.FLOAT_EPSILON 148 block_iter = [1] * len(w) 149 150 it = 0 151 while True: 152 153 for i in xrange(len(w)): 154 print "it: %d, i: %d" % (it, i) 155 156 # if True: 157 # pass 158 159 # Wrap a function around the ith block. 160 func = mb_losses.MultiblockFunctionWrapper(function, w, i) 161 162 # Run FISTA. 163 w_old = w[i] 164 for k in xrange(1, max(self.min_iter + 1, 165 self.max_iter - self.num_iter + 1)): 166 167 if self.info_requested(Info.time): 168 time = utils.time_wall() 169 170 if FISTA: 171 # Take an interpolated step. 172 z = w[i] + ((k - 2.0) / (k + 1.0)) * (w[i] - w_old) 173 else: 174 z = w[i] 175 176 # Compute the step. 177 step = func.step(z) 178 # Compute inexact precision. 179 eps = max(consts.FLOAT_EPSILON, 180 1.0 / (block_iter[i] ** exp)) 181 # eps = consts.TOLERANCE 182 183 w_old = w[i] 184 # Take a FISTA step. 185 w[i] = func.prox(z - step * func.grad(z), 186 factor=step, eps=eps) 187 188 # Store info variables. 189 if self.info_requested(Info.time): 190 _t.append(utils.time_wall() - time) 191 if self.info_requested(Info.func_val): 192 _f.append(function.f(w)) 193 if self.info_requested(Info.smooth_func_val): 194 _fmu.append(function.fmu(w)) 195 196 # Update iteration counts. 197 self.num_iter += 1 198 block_iter[i] += 1 199 200 # print i, function.fmu(w), step, \ 201 # (1.0 / step) * maths.norm(w[i] - z), self.eps, \ 202 # k, self.num_iter, self.max_iter 203 # Test stopping criterion. 204 if maths.norm(w[i] - z) < step * self.eps \ 205 and k >= self.min_iter: 206 break 207 208 # Test global stopping criterion. 209 all_converged = True 210 for i in xrange(len(w)): 211 212 # Wrap a function around the ith block. 213 func = mb_losses.MultiblockFunctionWrapper(function, w, i) 214 215 # Compute the step. 216 step = func.step(w[i]) 217 # Compute inexact precision. 218 eps = max(consts.FLOAT_EPSILON, 219 1.0 / (block_iter[i] ** exp)) 220 # eps = consts.TOLERANCE 221 # Take one ISTA step for use in the stopping criterion. 222 w_tilde = func.prox(w[i] - step * func.grad(w[i]), 223 factor=step, eps=eps) 224 225 # Test if converged for block i. 226 if maths.norm(w[i] - w_tilde) > step * self.eps: 227 all_converged = False 228 break 229 230 # Converged in all blocks! 231 if all_converged: 232 if self.info_requested(Info.converged): 233 self.info_set(Info.converged, True) 234 235 break 236 237 # Stop after maximum number of iterations. 238 if self.num_iter >= self.max_iter: 239 break 240 241 it += 1 242 243 # Store information. 244 if self.info_requested(Info.num_iter): 245 self.info_set(Info.num_iter, self.num_iter) 246 if self.info_requested(Info.time): 247 self.info_set(Info.time, _t) 248 if self.info_requested(Info.func_val): 249 self.info_set(Info.func_val, _f) 250 if self.info_requested(Info.smooth_func_val): 251 self.info_set(Info.smooth_func_val, _fmu) 252 if self.info_requested(Info.ok): 253 self.info_set(Info.ok, True) 254 255 return w
256
257 258 -class MultiblockCONESTA(bases.ExplicitAlgorithm, 259 bases.IterativeAlgorithm, 260 bases.InformationAlgorithm):
261 """An alternating minimising multiblock algorithm that utilises CONESTA in 262 the inner minimisation. 263 264 Parameters 265 ---------- 266 info : List or tuple of utils.consts.Info. What, if any, extra run 267 information should be stored. Default is an empty list, which means 268 that no run information is computed nor returned. 269 270 eps : Positive float. Tolerance for the stopping criterion. 271 272 outer_iter : Non-negative integer. Maximum allowed number of outer loop 273 iterations. 274 275 max_iter : Non-negative integer. Maximum allowed number of iterations. 276 277 min_iter : Non-negative integer. Number of required iterations. Default 278 is 1. 279 """ 280 INTERFACES = [multiblock_properties.MultiblockFunction, 281 multiblock_properties.MultiblockGradient, 282 multiblock_properties.MultiblockProjectionOperator, 283 multiblock_properties.MultiblockStepSize] 284 285 INFO_PROVIDED = [Info.ok, 286 Info.num_iter, 287 Info.time, 288 Info.fvalue, 289 Info.converged] 290
291 - def __init__(self, mu_start=None, mu_min=consts.TOLERANCE, 292 tau=0.5, outer_iter=20, 293 info=[], eps=consts.TOLERANCE, 294 max_iter=consts.MAX_ITER, min_iter=1):
295 296 super(MultiblockCONESTA, self).__init__(info=info, 297 max_iter=max_iter, 298 min_iter=min_iter) 299 300 self.outer_iter = outer_iter 301 self.eps = eps 302 303 # Copy the allowed info keys for FISTA. 304 from parsimony.algorithms.proximal import FISTA 305 from parsimony.algorithms.primaldual import NaiveCONESTA 306 alg_info = [] 307 for nfo in self.info_copy(): 308 if nfo in FISTA.INFO_PROVIDED: 309 alg_info.append(nfo) 310 # if not self.alg_info.allows(consts.Info.num_iter): 311 # self.alg_info.add_key(consts.Info.num_iter) 312 if Info.converged not in alg_info: 313 alg_info.append(Info.converged) 314 315 self.fista = FISTA(info=alg_info, 316 eps=self.eps, 317 max_iter=self.max_iter, 318 min_iter=self.min_iter) 319 self.conesta = NaiveCONESTA(mu_start=mu_start, 320 mu_min=mu_min, 321 tau=tau, 322 323 eps=self.eps, 324 info=alg_info, 325 max_iter=self.max_iter, 326 min_iter=self.min_iter)
327 328 @bases.force_reset 329 @bases.check_compatibility
330 - def run(self, function, w):
331 332 # self.info.clear() 333 334 if self.info_requested(Info.ok): 335 self.info_set(Info.ok, False) 336 if self.info_requested(Info.time): 337 t = [] 338 if self.info_requested(Info.fvalue): 339 f = [] 340 if self.info_requested(Info.converged): 341 self.info_set(Info.converged, False) 342 343 print "len(w):", len(w) 344 print "max_iter:", self.max_iter 345 346 num_iter = [0] * len(w) 347 348 for it in xrange(1, self.outer_iter + 1): 349 350 all_converged = True 351 352 for i in xrange(len(w)): 353 print "it: %d, i: %d" % (it, i) 354 355 if function.has_nesterov_function(i): 356 print "Block %d has a Nesterov function!" % (i,) 357 func = mb_losses.MultiblockNesterovFunctionWrapper( 358 function, w, i) 359 algorithm = self.conesta 360 else: 361 func = mb_losses.MultiblockFunctionWrapper(function, w, i) 362 algorithm = self.fista 363 364 # self.alg_info.clear() 365 # self.algorithm.set_params(max_iter=self.max_iter - num_iter[i]) 366 # w[i] = self.algorithm.run(func, w_old[i]) 367 if i == 1: 368 pass 369 w[i] = algorithm.run(func, w[i]) 370 371 if algorithm.info_requested(Info.num_iter): 372 num_iter[i] += algorithm.info_get(Info.num_iter) 373 if algorithm.info_requested(Info.time): 374 tval = algorithm.info_get(Info.time) 375 if algorithm.info_requested(Info.fvalue): 376 fval = algorithm.info_get(Info.fvalue) 377 378 if self.info_requested(Info.time): 379 t = t + tval 380 if self.info_requested(Info.fvalue): 381 f = f + fval 382 383 print "l0 :", maths.norm0(w[i]), \ 384 ", l1 :", maths.norm1(w[i]), \ 385 ", l2²:", maths.norm(w[i]) ** 2.0 386 387 print "f:", fval[-1] 388 389 for i in xrange(len(w)): 390 391 # Take one ISTA step for use in the stopping criterion. 392 step = function.step(w, i) 393 w_tilde = function.prox(w[:i] + 394 [w[i] - step * function.grad(w, i)] + 395 w[i + 1:], i, step) 396 397 # func = mb_losses.MultiblockFunctionWrapper(function, w, i) 398 # step2 = func.step(w[i]) 399 # w_tilde2 = func.prox(w[i] - step2 * func.grad(w[i]), step2) 400 # 401 # print "diff:", maths.norm(w_tilde - w_tilde2) 402 403 print "err:", maths.norm(w[i] - w_tilde) * (1.0 / step) 404 if (1.0 / step) * maths.norm(w[i] - w_tilde) > self.eps: 405 all_converged = False 406 break 407 408 if all_converged: 409 print "All converged!" 410 411 if self.info_requested(Info.converged): 412 self.info_set(Info.converged, True) 413 414 break 415 416 # # If all blocks have used max_iter iterations, stop. 417 # if np.all(np.asarray(num_iter) >= self.max_iter): 418 # break 419 420 # it += 1 421 422 if self.info_requested(Info.num_iter): 423 self.info_set(Info.num_iter, num_iter) 424 if self.info_requested(Info.time): 425 self.info_set(Info.time, t) 426 if self.info_requested(Info.fvalue): 427 self.info_set(Info.fvalue, f) 428 if self.info_requested(Info.ok): 429 self.info_set(Info.ok, True) 430 431 return w
432 433 if __name__ == "__main__": 434 import doctest 435 doctest.testmod() 436