Package parsimony :: Package utils :: Module plot
[hide private]
[frames] | no frames]

Source Code for Module parsimony.utils.plot

  1  # -*- coding: utf-8 -*- 
  2  """ 
  3  Created on Mon Jan  6 14:13:42 2014 
  4   
  5  Copyright (c) 2013-2014, CEA/DSV/I2BM/Neurospin. All rights reserved. 
  6   
  7  @author:  Edouard Duchesnay, Tommy Löfstedt 
  8  @email:   edouard.duchesnay@cea.fr, lofstedt.tommy@gmail.com 
  9  @license: BSD 3-clause. 
 10  """ 
 11  import numpy as np 
 12  import scipy.stats as ss 
 13   
 14  __all__ = ["plot_map2d", "plot_classes"] 
 15   
 16  COLORS = ["b", "g", "r", "c", "m", "y", "k", "w"] 
 17  COLORS_FULL = ["blue", "green", "red", "cyan", "magenta", "yellow", "black", 
 18                 "white"] 
 19  MARKERS = ["+", ".", "o", "*", "p", "s", "x", "D", "h", "^"] 
 20  LINE_STYLE = ["-", "--", "--.", ":"] 
 21   
 22   
23 -def plot_map2d(map2d, plot=None, title=None, limits=None, 24 center_cmap=True):
25 import matplotlib.pyplot as plt 26 27 if plot is None: 28 plot = plt 29 30 map2d = map2d.squeeze() 31 32 if len(map2d.shape) != 2: 33 raise ValueError("input map is not 2D") 34 35 if np.asarray(limits).size is 2: 36 mx = limits[0] 37 mi = limits[1] 38 else: 39 mx = map2d.max() 40 mi = map2d.min() 41 42 if center_cmap: 43 mx = np.abs([mi, mx]).max() 44 mi = -mx 45 46 cax = plot.matshow(map2d, cmap=plt.cm.coolwarm) 47 frame = plt.gca() 48 frame.get_xaxis().set_visible(False) 49 frame.get_yaxis().set_visible(False) 50 #k = 1 51 #while (10 ** k * mx) < 1 and k < 10: 52 # k += 1 53 #ticks = np.array([-mi, -mi / 4 - mi / 2, 0, mx / 2, mx / 2, 54 # mx]).round(k + 2) 55 cbar = plt.colorbar(cax) # , ticks=ticks) 56 cbar.set_clim(vmin=mi, vmax=mx) 57 58 if title is not None: 59 plt.title(title)
60 61
62 -def plot_classes(X, classes, title=None, xlabel=None, ylabel=None, show=True):
63 64 import matplotlib.pyplot as plot 65 66 if isinstance(classes, np.ndarray): 67 classes = classes.ravel().tolist() 68 69 cls = list(set(classes)) 70 71 # TODO: Add the other cases. 72 if X.shape[1] == 2: 73 74 for i in xrange(len(cls)): 75 c = cls[i] 76 cl = np.array(classes) == c 77 # print cl.shape 78 # print X[cl, 0].shape 79 # print X[cl, 1].shape 80 plot.plot(X[cl, 0], X[cl, 1], 81 color=COLORS[i % len(COLORS)], 82 marker='.', 83 markersize=15, 84 linestyle="None") 85 86 if title is not None: 87 plot.title(title, fontsize=22) 88 89 if xlabel is not None: 90 plot.xlabel(xlabel, fontsize=16) 91 if ylabel is not None: 92 plot.ylabel(ylabel, fontsize=16) 93 94 if show: 95 plot.show()
96 97
98 -def plot_errorbars(X, classes=None, means=None, alpha=0.05, 99 title=None, xlabel=None, ylabel=None, 100 colors=None, 101 show=True, latex=True):
102 103 import matplotlib.pyplot as plot 104 105 B, n = X.shape 106 if classes is None: 107 classes = np.array([1] * n) 108 classes = np.array(classes).reshape((n, 1)) 109 110 if colors is None: 111 colors = COLORS 112 113 data_mu = np.mean(X, axis=0) 114 data_df = np.array([B - 1] * n) 115 data_sd = np.std(X, axis=0) 116 117 x = np.arange(1, n + 1) 118 119 labels, cls_inverse = np.unique(classes, return_inverse=True) 120 labels = labels.ravel().tolist() 121 122 # plot.figure() 123 if latex: 124 plot.rc('text', usetex=True) 125 plot.rc('font', family='serif') 126 if means is not None: 127 plot.plot(x, means, '*', 128 markerfacecolor="black", markeredgecolor="black", 129 markersize=10) 130 131 ci = ss.t.ppf(1.0 - alpha / 2.0, data_df) * data_sd / np.sqrt(B) 132 133 for i in xrange(len(labels)): 134 ind = np.where(classes == labels[i])[0] 135 136 plot.errorbar(x[ind], 137 data_mu[ind], 138 yerr=ci[ind], 139 fmt='o' + colors[i % len(colors)], 140 color=colors[i % len(colors)], 141 ecolor=colors[i % len(colors)], 142 elinewidth=2, 143 markeredgewidth=2, 144 markeredgecolor=colors[i % len(colors)], 145 capsize=5) 146 147 plot.xlim((0, n + 1)) 148 mn = np.min(data_mu - ci) 149 mx = np.max(data_mu + ci) 150 d = mx - mn 151 plot.ylim((mn - d * 0.05, mx + d * 0.05)) 152 153 if xlabel is not None: 154 plot.xlabel(xlabel) 155 if ylabel is not None: 156 plot.ylabel(ylabel) 157 if title is not None: 158 plot.title(title) 159 if show: 160 plot.show()
161