Advertisement
supermario

PCA

Oct 7th, 2012
276
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.25 KB | None | 0 0
  1. import numpy as np
  2. import matplotlib.pylab as plt
  3. import h5py
  4. from scipy.ndimage.filters import median_filter
  5. from scipy.ndimage import rotate
  6.  
  7. from sdp.correction.function import subtract_dark_frame, subtract_flat_field, copy_metadata, get_dark_frame, get_flat_field, make_scan
  8. from sdp.correction.corrector import Corrector
  9.  
  10. from sdp.mathlib.spectrum  import power_spectrum_1d
  11. from scipy.linalg import svd
  12. from numpy.fft import rfft2, irfft2,fftfreq
  13.  
  14. #/////////////////////////////////////////////////////////////
  15.  
  16. class pca2d:
  17.     def __init__(self, data_d):
  18.         self.cov_matrix = None
  19.         self.data = data_d
  20.         self.num_frames = self.data.shape[0]
  21.         self.y_size = self.data.shape[1]
  22.         self.x_size = self.data.shape[2]
  23.         self.u_matrix = None
  24.         self.s_matrix = None
  25.         self.vh_matrix = None
  26.  
  27.     def calc_cov_matrix(self, init=0, fin=-1, step=1, excl_list=[], summing='row', filename=None):
  28.         sum = np.zeros((self.y_size, self.x_size), dtype='float64')
  29.         if fin is -1:
  30.             fin = self.num_frames
  31.         overall = (fin-init)
  32.         for j in range(init,fin,step):
  33.             if j in excl_list:
  34.                 overall = overall - 1
  35.                 continue
  36.             sum =sum + self.data[j]
  37.         if overall < 1:
  38.             print('number of frame summed < 1')
  39.         aver_frame = sum / overall
  40.         if summing is 'column':
  41.             self.cov_matrix = np.zeros((self.y_size, self.y_size), dtype='float64')
  42.             for j in range(init, fin, step):
  43.                 if j in excl_list:
  44.                     continue
  45.                 self.cov_matrix = self.cov_matrix + np.dot((self.data[j]-aver_frame), np.transpose(self.data[j]-aver_frame))
  46.         elif summing is 'row':
  47.             self.cov_matrix = np.zeros((self.x_size, self.x_size), dtype='float64')
  48.             for j in range(init,fin, step):
  49.                 if j in excl_list:
  50.                     continue
  51.                 self.cov_matrix = self.cov_matrix + np.dot(np.transpose(self.data[j]-aver_frame),(self.data[j]-aver_frame))
  52.         else:
  53.             pass#TODO
  54.         self.cov_matrix = self.cov_matrix / overall
  55.         if filename is not None:
  56.             dsc = h5py.File(filename, 'w')
  57.             ds = dsc.create_dataset('cov_matrix', self.cov_matrix.shape,
  58.                                     dtype=self.cov_matrix.dtype)
  59.             ds[:,:] = self.cov_matrix[:,:]
  60.             dsc.close()
  61.         self.svd()
  62.  
  63.     def load_cov_matrix(self, filename,dpath):
  64.         dsc = h5py.File(filename, 'r')
  65.         ds = dsc[dpath]
  66.         self.cov_matrix = ds[:,:]
  67.         self.svd()
  68.  
  69.     def set_cov_matrix(self, matrix):
  70.         self.cov_matrix = matrix
  71.         self.svd()
  72.  
  73.     def get_cov_matrix(self):
  74.         if cov_matrix is not None:
  75.             return self.cov_matrix
  76.  
  77.     def svd(self):
  78.         if self.cov_matrix is not None:
  79.             u,s,vh = svd(self.cov_matrix)
  80.             self.u_matrix = u
  81.             self.s_matrix = s
  82.             self.vh_matrix = vh
  83.         else:
  84.             pass#TODO
  85.  
  86.     def get_pca_vector(self, frame_number):
  87.         if self.u_matrix is not None:
  88.             v = np.dot(self.data[frame_number], self.u_matrix)
  89.         else:
  90.             pass#TODO
  91.         return v
  92.  
  93.     def get_original_recon(self, frame_number):
  94.         return np.dot(self.get_pca_vector(frame_number), self.u_matrix.T)
  95.  
  96.     def get_frame_recon(self, frame_number,max_mode_number, excl_list = []):
  97.         v = self.get_pca_vector(frame_number)
  98.         frame = np.zeros((v.shape[0], self.u_matrix.shape[1]), dtype='float32')
  99.         ut=self.u_matrix.T
  100.         for j in range(max_mode_number):
  101.             if  j in excl_list:
  102.                 continue
  103.             vt = v[:,j].reshape(v.shape[0],1)
  104.             utt = ut[j].reshape(1,ut[j].shape[0])
  105.             frame = frame + np.dot(vt,utt)
  106.         return frame
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement