Package parsimony :: Package utils :: Module check_arrays'
[hide private]
[frames] | no frames]

Source Code for Module parsimony.utils.check_arrays'

 1  # -*- coding: utf-8 -*- 
 2  """ 
 3  Created on Mon Jul 29 17:37:17 2013 
 4   
 5  Copyright (c) 2013-2014, CEA/DSV/I2BM/Neurospin. All rights reserved. 
 6   
 7  @author:  Edouard Duchesnay 
 8  @email:   edouard.duchesnay@cea.fr 
 9  @license: BSD 3-clause. 
10  """ 
11  import numpy as np 
12   
13   
14 -def check_arrays(*arrays):
15 """Checks that: 16 - Lists are converted to numpy arrays. 17 - All arrays are cast to float. 18 - All arrays have consistent first dimensions. 19 - Arrays are at least 2D arrays, if not they are reshaped. 20 21 Parameters 22 ---------- 23 *arrays: Sequence of arrays or scipy.sparse matrices with same shape[0] 24 Python lists or tuples occurring in arrays are converted to 2D 25 numpy arrays. 26 27 Examples 28 -------- 29 >>> import numpy as np 30 >>> check_arrays([1, 2], np.array([3, 4]), np.array([[1., 2.], [3., 4.]])) 31 [array([[ 1.], 32 [ 2.]]), array([[ 3.], 33 [ 4.]]), array([[ 1., 2.], 34 [ 3., 4.]])] 35 """ 36 if len(arrays) == 0: 37 return None 38 39 n_samples = None 40 checked_arrays = [] 41 for array in arrays: 42 # Recast input as float array 43 array = np.asarray(array, dtype=np.float) 44 45 if n_samples is None: 46 n_samples = array.shape[0] 47 if array.shape[0] != n_samples: 48 raise ValueError("Found array with dim %d. Expected %d" 49 % (array.shape[0], n_samples)) 50 if len(array.shape) == 1: 51 array = array[:, np.newaxis] 52 53 checked_arrays.append(array) 54 55 return checked_arrays[0] if len(checked_arrays) == 1 else checked_arrays
56 57 58 if __name__ == "__main__": 59 import doctest 60 doctest.testmod() 61