import numpy,math

#1 ###########################################################################################################################
class histo:                                                                         #########################################
#1 ###########################################################################################################################

    #2 ####################################################################################        
    def __init__(self):
        """ initialisation of variable """
        self.nb_point=0
        self.value={}
        self.bin=[]
        self.len=0
        
    #2 ####################################################################################        
    def init_dict(self,data):
        """ define 0 entry for each bin  """

        for bin,value in data.items():
            self.bin.append('%f' % float(bin))
            self.value['%f' % float(bin)]=0
        self.len=len(data)

    #2 ####################################################################################        			
    def add(self,value):
    
    	if self.nb_point==0:
    		self.init_dict(value)
    	self.nb_point+=1    		
    	for bin in self.bin:
            self.value[bin]+=float(value[bin])

    #2 ####################################################################################        			
    def write(self,filename):
    
	    ff=open(filename,'w')
	    for i in range(0,self.len):
		    ff.writelines(str(i)+'\t'+str(self.value[i])+'\n')
		    
    #2 ####################################################################################        			
    def out(self):
    	return self.value
    	
    #2 ####################################################################################        
    def charge_file(self,name,auto=1):
        """ read file and return a dict of type {bin:value} """

        out={}
        for line in file(name):
            data=line.split()
            out['%f' % float(data[0])] = float(data[1])
        
        if auto==1:
            self.add(out)
            
        return out


#1 ###########################################################################################################################
class Error_in_histo(histo):                                                         #########################################
#1 ###########################################################################################################################

    #2 ####################################################################################
    def __init__(self):
        """ initialisation of variable """
        self.nb_point=0
        self.sum={}
        self.sum_square={}
        self.quad_term={}
        
    #2 ####################################################################################        
    def init_dict(self,data):
        """ define 0 entry for each bin  """

        if isinstance(data,dict):
            data=data.keys()

        for bin in data:
            self.sum['%f' % float(bin)]=0
            self.sum_square['%f' % float(bin)]=0
            self.quad_term['%f' % float(bin)]={}
            list_bin2=[ bin2 for bin2 in data if float(bin2)<float(bin)]
            for bin2 in list_bin2:
                self.quad_term['%f'% float(bin)]['%f' % float(bin2)]=0
        
    #2 ####################################################################################        
    def charge_file(self,name,auto=1):
        """ read file and return a dict of type {bin:value} """
        
        data=histo.charge_file(self,name,auto=0)
        
        if auto==1:
            self.update_value(data)
		
        return data
        
    #2 ####################################################################################        
    def update_value(self,data):
        """ update self.sum,self.sum_square,self.quad_term,...
            data is a dict of type {bin:value}
            doesn't return anything
        """
        if self.nb_point==0 and self.sum=={}:
            self.init_dict(data)

        self.nb_point+=1
        prov= data.keys()
        prov.sort()

        for bin,value in data.items():
            self.sum['%f' % float(bin)] += float(value)
            self.sum_square['%f' % float(bin)] += float(value)**2
            list_bin2=[ bin2 for bin2 in data.keys() if float(bin2)<float(bin)]
            for bin2 in list_bin2:
                self.quad_term['%f' % float(bin)]['%f' % float(bin2)]+=float(value)*float(data[bin2])

    #2 ####################################################################################        
    def compute_result(self):
        """ compute the actual status of the average,variance and correlation
            return self.average(dict),self.variance(dict),self.correlation(dict^2)
        """
        average={}
        variance={}
        covariance={}
        correlation={}

        key_list=[float(value) for value in self.sum.keys()]
        key_list.sort()
        key_list=['%f' % value for value in key_list]
        for bin in key_list:
            #average
            average[bin]= self.sum[bin]/self.nb_point
            #variance
            try:
                variance[bin]=math.sqrt((self.sum_square[bin]-self.nb_point*average[bin]**2)/(self.nb_point))#add -1 to have the other estimator
            except ValueError:
                print (self.sum_square[bin]-self.nb_point*average[bin]),'/',(self.nb_point-1)
                variance[bin]=0.0
            #covariance
            covariance[bin]={}
            correlation[bin]={}
            list_bin2=[bin2 for bin2 in self.sum.keys() if float(bin2)<float(bin)]
            for bin2 in list_bin2:
                covariance[bin][bin2]=self.quad_term[bin][bin2]/self.nb_point-average[bin]*average[bin2]
                
                if variance[bin] and variance[bin2]: correlation[bin][bin2]=covariance[bin][bin2]/(variance[bin]*variance[bin2])
                else:                                correlation[bin][bin2]=float('nan')

        self.average=average
        self.variance=variance
        #self.covariance=covariance
        self.correlation=correlation
        return average,variance,correlation
        
    #2 ####################################################################################        
    def write_result(self,filename1,filename2,filename3,n=1):
        """ write result in a file """

        key_list=[ float(value) for value in self.sum.keys()]
        key_list.sort()
        key_list=['%f' % value for value in key_list]

        ff=open(filename1,'w')
        for bin in key_list:
            ff.writelines(str(bin)+'\t'+str(n*self.average[bin])+'\n')
        ff.close()

        ff=open(filename2,'w')
        for bin in key_list:
            ff.writelines(str(bin)+'\t'+str(math.sqrt(n)*self.variance[bin])+'\n')
        ff.close()


        ff=open(filename3,'w')
        #ff.writelines('\t'.join([' ']+[str(name) for name in key_list])+'\n')
        for i in range(0,len(key_list)):
        	line=''
        	for j in range(0,len(key_list)): 
        		if j<i:    line+=str(self.correlation[key_list[i]][key_list[j]])+'\t'
        		elif j==i: line+='1.'+'0'*12+'\t'
        		else:      line+=str(self.correlation[key_list[j]][key_list[i]])+'\t'
        	ff.writelines(line+'\n')
        ff.close()


#1 ###########################################################################################################################
def MW_correlation(nb_event):                                                        #########################################
#1 ###########################################################################################################################
    """ compute sigma and correlation from a sample of event """

    stat_tool=Error_in_histo()
    j=0
    while j<nb_event:
        if j%250==0: print 'status',j
        step=numpy.random.poisson(1)
        hist=histo()
        for k in range(j,j+step):
            hist.charge_file('card_1event_'+str(j)+'.txt')
        stat_tool.update_value(hist.out())
        j+=step
    stat_tool.compute_result()
    stat_tool.write_result('../SM_PGS_mean_5','../SM_PGS_error_5','../SM_PGS_correlation_5')


#1 ###########################################################################################################################
def MW_correlation_mult(nb_event_by_dir, cur_var, bin):          #########################################
#1 ###########################################################################################################################
    """ compute sigma and correlation from a sample of event """

    #initialize data
    stat_tool = Error_in_histo()
    #data = stat_tool.charge_file(dir_pos(0)+'/0%d_card_1event_0.txt' % (cur_var), auto=0)
    stat_tool.init_dict(bin)

    cur_event = 0
    while 1:
        if cur_event%100 == 0: print 'status', cur_var, cur_event
        step=numpy.random.poisson(1)
        if cur_event + step > nb_event_by_dir:
            break
        hist=histo()
        for k in range(cur_event, cur_event+step):
            hist.charge_file(init_pos+'/0%d_card_1event_%s.txt' % (cur_var, k))
        stat_tool.update_value(hist.out())
        cur_event += step
                             
    stat_tool.compute_result()
    stat_tool.write_result('%s/mean_var_%s' % (output_dir, cur_var),'%s/error_var_%s' % (output_dir, cur_var),'%s/correlation_var_%s' % (output_dir, cur_var))

    

#1 ###########################################################################################################################  
class MW_chi_square:                                                                 #########################################  
#1 ###########################################################################################################################   

    def __init__(self,total_event=380000):
        self.read_event=0
        self.total_event=total_event

    def next_sample(self,N):
        step=numpy.random.poisson(N)
        if not self.read_event+step<self.total_event:
            print 'end of event'
            return {}
        hist=histo()
        for nb_data in range(self.read_event,self.read_event+step):
            print 'charge all/event_'+str(nb_data)
            hist.charge_file('all/event_'+str(nb_data))

        return hist.out()



bin = lambda min,gap: [ '%f' % float(min+i*gap) for i in range(0,50)]

if '__main__'==__name__:

    init_dir = 'where find the data ' #./Events/MYNAME/Graph/all_event
    output_dir 'where to write the results'



    MW_correlation_mult(5000, 1, bin(112, 24))
    MW_correlation_mult(5000, 2, bin(5, 10))
    MW_correlation_mult(5000, 3, bin(-0.98,0.04))

