Newer
Older
HCAL_project / architectures / utils / toolbox.py
@Davide Lancierini Davide Lancierini on 2 Dec 2018 24 KB First commit
import numpy as np
import os 
import math
import tensorflow as tf
import pickle
import matplotlib.pyplot as plt

rnd_seed=1

def conv_concat(X, y, y_dim):

    yb = tf.reshape(y, [tf.shape(X)[0], 1, 1, y_dim])
    yb = tf.tile(yb, [1, tf.shape(X)[1], tf.shape(X)[2] ,1])
    output = tf.concat([X, yb], 3)
    return output

def lin_concat(X, y, y_dim):

    yb = tf.reshape(y, [tf.shape(X)[0], y_dim])
    output = tf.concat([X, yb], 1)
    
    return output

def lrelu(x, alpha=0.2):

    """
    Implements the leakyRELU function:

    inputs X, returns X if X>0, returns alpha*X if X<0
    """


    return tf.maximum(alpha*x,x)

def evaluation(Y_pred, Y):

    """
    Returns the accuracy by comparing the convoluted output Y_hat
    with the labels of the samples Y

    """
    
    correct = tf.equal(tf.argmax(Y_pred, 1), tf.argmax(Y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
    
    return accuracy

def supervised_random_mini_batches(X, Y, mini_batch_size, seed):

    """
    Creates a list of random mini_batches from (X, Y)
    
    Arguments:
    X -- input data, of shape (number of examples, input size)
    Y -- true "label" one hot matrix of shape (number of examples, n_classes)
    mini_batch_size -- size of the mini-batches, integer
    
    Returns:
    mini_batches -- list of synchronous (mini_batch_X, mini_batch_Y)
    """

    
    m = X.shape[0]        #number of examples in set
    n_classes = Y.shape[1]
    mini_batches=[]
    
    np.random.seed(seed)
    permutation = list(np.random.permutation(m))
    #print('Zeroth element of batch permutation:', permutation[0])
    shuffled_X = X[permutation,:]
    shuffled_Y = Y[permutation,:]
    #partition of (shuffled_X, shuffled_Y) except the last mini_batch

    num_complete_mini_batches = math.floor(m/mini_batch_size)
    for k in range(num_complete_mini_batches):
        mini_batch_X = shuffled_X[k*mini_batch_size:(k+1)*mini_batch_size,:]
        mini_batch_Y = shuffled_Y[k*mini_batch_size:(k+1)*mini_batch_size,:]
        
        mini_batch = (mini_batch_X, mini_batch_Y)
        mini_batches.append(mini_batch)
        
    # handling the case of last mini_batch < mini_batch_size    
    if m % mini_batch_size !=0:
        
        mini_batch_X = shuffled_X[mini_batch_size*num_complete_mini_batches:m,:]
        mini_batch_Y = shuffled_Y[mini_batch_size*num_complete_mini_batches:m,:]
        
        mini_batch = (mini_batch_X, mini_batch_Y)
        mini_batches.append(mini_batch)
    
    return mini_batches

def unsupervised_random_mini_batches(X, mini_batch_size, seed):

    """
    Creates a list of random mini_batches from (X)
    
    Arguments:
    X -- input data, of shape (number of examples, input size)
    mini_batch_size -- size of the mini-batches, integer
    
    Returns:
    mini_batches -- list of mini_batch_X
    """
    
    m = X.shape[0]        #number of examples in set
    mini_batches=[]
    
    np.random.seed(seed)
    permutation = list(np.random.permutation(m))
    #print('Zeroth element of batch permutation:', permutation[0])
    shuffled_X = X[permutation,:]
    
    #partition of shuffled_X except the last mini_batch
    
    num_complete_mini_batches = math.floor(m/mini_batch_size)
    for k in range(num_complete_mini_batches):
        mini_batch_X = shuffled_X[k*mini_batch_size:(k+1)*mini_batch_size,:]
        mini_batches.append(mini_batch_X)
        
    # handling the case of last mini_batch < mini_batch_size    
    if m % mini_batch_size !=0:
        
        mini_batch_X = shuffled_X[mini_batch_size*num_complete_mini_batches:m,:]
        mini_batches.append(mini_batch_X)
    
    return mini_batches

def unsupervised_random_mini_batches_labels(X, mini_batch_size, seed):

    """
    Creates a list of random mini_batches from (Y)
    
    Arguments:
    X -- input data, of shape (number of examples, input size)
    mini_batch_size -- size of the mini-batches, integer
    
    Returns:
    mini_batches -- list of mini_batch_X
    """
    
    m = X.shape[0]        #number of examples in set
    mini_batches=[]
    
    np.random.seed(seed)
    permutation = list(np.random.permutation(m))
    #print('Zeroth element of batch permutation:', permutation[0])
    shuffled_X = X[permutation]
    
    #partition of shuffled_X except the last mini_batch
    
    num_complete_mini_batches = math.floor(m/mini_batch_size)
    for k in range(num_complete_mini_batches):
        mini_batch_X = shuffled_X[k*mini_batch_size:(k+1)*mini_batch_size]
        mini_batches.append(mini_batch_X)
        
    # handling the case of last mini_batch < mini_batch_size    
    if m % mini_batch_size !=0:
        
        mini_batch_X = shuffled_X[mini_batch_size*num_complete_mini_batches:m]
        mini_batches.append(mini_batch_X)
    
    return mini_batches

# def preprocess_true(true):

# 	mean_true=true[true!=0].mean()
# 	std_true=np.std(true[np.where(true!=0)],axis=0)

# 	true[true!=0]-=mean_true
# 	true=np.where(true==0,0,true/std_true)

# 	return true, mean_true, std_true
# def preprocess_reco(reco):

#   mean_reco=np.mean(reco,axis=0)
#   std_reco=np.std(reco,axis=0)

#   reco-=mean_reco
#   reco=np.where(reco==0,0,reco/std_reco)

#   return reco, mean_reco, std_reco

# def reconstruct(sample, mean, std):
#   return np.where(sample!=0,sample*std+mean,0)

def four_cells(img):
    img = img.flatten()
    return img[img.argsort()[-4:][::-1]]

def normalise(X, norm_space=False):

    if norm_space:
        X[:,0]=X[:,0]/X[:,0].max()
        X[:,1]=X[:,1]/X[:,1].max()
        max_X = X[:,2].max()
        X[:,2]=X[:,2]/max_X
        min_X = 0

    else:
        X=np.where(X>0,X,0)
        #temp = X.reshape(X.shape[0],X.shape[1]*X.shape[2]*X.shape[3])
        #temp = temp.sum(axis=1)
        max_X = np.max(X)
        #max_X = np.max(temp.sum(axis=1)) 
        
        X=X/max_X
        min_X=0

    return X, min_X, max_X

def denormalise(X, min_X, max_X, norm_space=False):
    #mask = X!=0
    #return np.where(X!=0, np.exp(X*max_X), 0)
    denormalised = np.zeros_like(X)
    if norm_space:
        denormalised[:,0]=(X[:,0]*52).astype(int)
        denormalised[:,1]=(X[:,1]*64).astype(int)
        denormalised[:,2]=X[:,2]*max_X
        return denormalised

    else:
        return np.where(X!=0, X*max_X, 0)

# def normalise(X):

#     X=np.where(X>12 ,X,0)
#     #X=np.where(X>12,np.log(X),0)

#     # E_max = X.max()
#     # E_min = np.min(X[X>0])
#     # X = np.where(X>0, X-(E_max+E_min)/2,0)
#     # X/=X.max()

#     E_min=np.min(X[X>0])
#     X=np.where(X>0,X-E_min,0)
#     E_max=np.max(X)
#     X=np.where(X!=0,X/E_max,0)
    
#     return X, E_max, E_min

# def denormalise(X, E_max, E_min):
#     # X=X*E_max-E_min)/2
#     # X=np.where(X!=0, X+(E_max+E_min)/2, 0)

#     X=np.where(X!=0,X*E_max,0)
#     X=np.where(X!=0, X+E_min, 0)
#     #X=np.where(X!=0, np.exp(X), 0)

#     return X

def delete_undetected_events_single(X):

    pos_rejected=[]
    
    for i in range(len(X)):
        if np.array_equal(X[i],np.zeros_like(X[i])):
            pos_rejected.append(i)

    X_filtered=np.delete(X,pos_rejected,axis=0)
    
    return X_filtered

def delete_undetected_events_double(true, reco):

    pos_rejected=[]
    
    for i in range(len(true)):
        if np.array_equal(reco[i],np.zeros_like(reco[i])) or np.array_equal(true[i],np.zeros_like(true[i])) :
            pos_rejected.append(i)

    reco_filtered=np.delete(reco,pos_rejected,axis=0)
    true_filtered=np.delete(true, pos_rejected, axis=0)

    assert len(true_filtered)==len(reco_filtered)

    return true_filtered, reco_filtered

def delete_undetected_events_triple(true_p, true_K, reco):

    pos_rejected=[]
    
    for i in range(len(true_p)):
        if np.array_equal(reco[i],np.zeros_like(reco[i])) or np.array_equal(true_p[i],np.zeros_like(true_p[i])) or np.array_equal(true_K[i],np.zeros_like(true_K[i])) :
            pos_rejected.append(i)

    reco_filtered=np.delete(reco,pos_rejected,axis=0)
    true_p_filtered=np.delete(true_p, pos_rejected, axis=0)
    true_K_filtered=np.delete(true_K, pos_rejected, axis=0)

    assert len(true_p_filtered)==len(reco_filtered)==len(true_K_filtered)


    return true_p_filtered, true_K_filtered, reco_filtered

def selection(true, reco, n_cells, energy_fraction):

    pos_selected=[]
    pos_rejected=[]
    
    for i in range(len(reco)):
        tot_E=reco[i].sum()
        reshaped=reco[i].flatten()
        if (reshaped[reshaped.argsort()[-n_cells:][::-1]].sum())/tot_E<energy_fraction:
            pos_rejected.append(i)
        else:
            pos_selected.append(i)

    reco_filtered=np.delete(reco,pos_rejected,axis=0)
    true_filtered=np.delete(true, pos_rejected, axis=0)

    reco_rejected=np.delete(reco, pos_selected, axis=0)
    true_rejected=np.delete(true, pos_selected, axis=0)

    assert len(true_filtered)==len(reco_filtered)

    return true_filtered, reco_filtered, true_rejected, reco_rejected


def load_batch(true_path, reco_path, i):
    
    with open(reco_path+'sample{0}.pickle'.format(i), 'rb') as f:
        reco=pickle.load(f, encoding='latin1')

    with open(true_path+'sample{0}.pickle'.format(i), 'rb') as f:
        true=pickle.load(f, encoding='latin1')
    
    #cut that extra produced pixel
    #true=true[:,1:true.shape[1]-1,1:true.shape[2]-1,:]
    
    return true, reco

def load_conditional(true, reco):

    ETs=np.zeros(shape=(len(true),3,1))

    for i in range(len(true)):

        x, y, _ = np.where(true[i]!=0)
        ETs[i][0]=x[0]
        ETs[i][1]=y[0]
        ETs[i][2]=true[i][x, y][0][0]
    
    return ETs, reco

def load_data(true_path, reco_path, n_batches, 
    select=False, n_cells=None, energy_fraction=0.0, 
    preprocess=None, test_size=None):

    if n_batches == 1:

        #delete undetected particles
        true, reco = load_batch(true_path, reco_path, 0)
        #true, reco = delete_undetected_events_double(true1, reco1)
        
        #delete too noisy events
        if select:
            true_output, reco_output, _, _,  = selection(true, reco, n_cells, energy_fraction)
            true=true_output
            reco=reco_output

    elif n_batches > 1:

        true, reco = load_batch(true_path, reco_path, 0)
        #true, reco = delete_undetected_events_double(true, reco)
        
        #delete too noisy events
        if select:
            true_output, reco_output, _, _,  = selection(true, reco, n_cells, energy_fraction)

        for i in range(1, n_batches):

            true_temp, reco_temp = load_batch(true_path, reco_path, i)
            #true_temp, reco_temp = delete_undetected_events_double(true1, reco1)
            
            #delete too noisy events
            if select:
                true_temp, reco_temp, _, _,  = selection(true_temp, reco_temp, n_cells, energy_fraction)
            
            true = np.concatenate((true, true_temp), axis=0)
            reco = np.concatenate((reco, reco_temp), axis=0)

    if preprocess =='normalise':

        reco, min_reco, max_reco = normalise(reco)
        true, min_true, max_true = normalise(true)
        m = reco.shape[0]
        train_size = m - test_size

        train_true = true[0:train_size]
        test_true = true[train_size:m]

        train_reco = reco[0:train_size]
        test_reco = reco[train_size:m]

        return train_true, test_true, min_true, max_true, train_reco, test_reco, min_reco, max_reco

    else:
        m = reco.shape[0]
        train_size = m - test_size

        train_true = true[0:train_size]
        test_true = true[train_size:m]

        train_reco = reco[0:train_size]
        test_reco = reco[train_size:m]

        return train_true, test_true, train_reco, test_reco

def load_data_conditional(true_path, reco_path, n_batches, dim=None, preprocess=None, test_size=None):

    if n_batches == 1:

        #delete undetected particles
        true1, reco1 = load_batch(true_path, reco_path, 0)
        true2, reco2 = delete_undetected_events_double(true1, reco1)
        ETs, reco_output = load_conditional(true2, reco2)

    elif n_batches > 1:

        true1, reco1 = load_batch(true_path, reco_path, 0)
        true2, reco2 = delete_undetected_events_double(true1, reco1)
        ETs, reco_output = load_conditional(true2, reco2)

        for i in range(1, n_batches):

            true1, reco1 = load_batch(true_path, reco_path, i)
            true_temp, reco_temp = delete_undetected_events_double(true1, reco1)
            ETs_temp, reco_output_temp = load_conditional(true_temp, reco_temp)
            
            #delete too noisy events
            ETs = np.concatenate((ETs, ETs_temp), axis=0)
            reco_output = np.concatenate((reco_output, reco_output_temp), axis=0)

    true = ETs
    reco = reco_output

    if preprocess =='normalise':

        reco, min_reco, max_reco = normalise(reco, norm_space=False)
        true, min_true, max_true = normalise(true, norm_space=True)
        m = reco.shape[0]
        train_size = m - test_size

        train_true = true[0:train_size]
        test_true = true[train_size:m]

        train_reco = reco[0:train_size]
        test_reco = reco[train_size:m]

        return train_true, test_true, min_true, max_true, train_reco, test_reco, min_reco, max_reco

    else:
        m = reco.shape[0]
        train_size = m - test_size

        train_true = true[0:train_size]
        test_true = true[train_size:m]

        train_reco = reco[0:train_size]
        test_reco = reco[train_size:m]

        return train_true, test_true, train_reco, test_reco

def draw_one_sample(train_true, train_reco, preprocess=None,
    min_true=None, max_true=None, min_reco=None, max_reco=None, 
    save=False, PATH=None):

    j = np.random.randint(len(train_true))

    X_batch_A = train_true[j]
    X_batch_B = train_reco[j]

    if preprocess=='normalise':
        X_batch_A=denormalise(X_batch_A, min_true, max_true)
        X_batch_B=denormalise(X_batch_B, min_reco, max_reco)

    n_H_A, n_W_A ,n_C = X_batch_A.shape
    n_H_B, n_W_B ,n_C = X_batch_B.shape

    plt.subplot(2,2,1)
    plt.imshow(X_batch_A.reshape(n_H_A,n_W_A))
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.title('True E_T: {:.6g} MeV'.format(X_batch_A.sum()))
    plt.subplots_adjust(wspace=0.2,hspace=0.2)

    plt.subplot(2,2,2)
    plt.imshow(X_batch_B.reshape(n_H_B,n_W_B))
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.title('Reco E_T: {:.6g} MeV'.format(X_batch_B.sum()))
    plt.subplots_adjust(wspace=0.2,hspace=0.2)

    plt.suptitle('HCAL MC simulation\n ')
    fig = plt.gcf()
    fig.set_size_inches(11,4)
    if not save:
        plt.show()
    else:
        plt.savefig(PATH+'/HCAL_reconstruction_example_{0}.png'.format(j),dpi=80)

def draw_one_sample_conditional(train_true, train_reco, preprocess=None,
    min_true=None, max_true=None, min_reco=None, max_reco=None, 
    save=False, PATH=None):

    j = np.random.randint(len(train_true))


    if preprocess=='normalise':
        X_batch_A=denormalise(train_true, min_true, max_true, norm_space=True)
        X_batch_B=denormalise(train_reco, min_reco, max_reco)

    X_batch_A = X_batch_A[j]
    X_batch_B = X_batch_B[j]

    n_H_B, n_W_B, n_C = X_batch_B.shape

    plt.imshow(X_batch_B.reshape(n_H_B,n_W_B))
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.title('HCAL MC simulation \n X: {1} Y: {0} \n True E_T: {2:.6g} MeV, Reco MC E_T: {3:.6g}'.format(X_batch_A[0].sum(), X_batch_A[1].sum(), X_batch_A[2].sum(), X_batch_B.sum()))
    fig = plt.gcf()
    fig.set_size_inches(11,4)
    if not save:
        plt.show()
    else:
        plt.savefig(PATH+'/HCAL_reconstruction_example_{0}.png'.format(j),dpi=80)

def draw_nn_sample(X_A, X_B, i, preprocess=False,
    min_true=None, max_true=None, min_reco=None, max_reco=None, f=None, 
    save=True, is_training=False, total_iters=None, PATH=None):

    j = np.random.randint(len(X_A))

    _, n_H_A, n_W_A, n_C = X_A.shape 
    _, n_H_B, n_W_B, _ = X_B.shape 

    #draw the response for one particle
    if i ==1 :
        X_A = X_A[j]
        X_B = X_B[j]
        sample_nn = f(X_A.reshape(1, n_H_A, n_W_A, n_C))
    #draw the response for i particles
    if i>1:
        X_A = X_A[j:j+i]
        X_B = X_B[j:j+i]
        X_A = X_A.sum(axis=0)
        X_B = X_B.sum(axis=0)

        sample_nn = f(X_A.reshape(1, n_H_A, n_W_A, n_C))

    if preprocess=='normalise':

        X_A=denormalise(X_A, min_true, max_true)
        X_B=denormalise(X_B, min_reco, max_reco)
        sample_nn=denormalise(sample_nn, min_reco, max_reco)

    plt.subplot(1,3,1)
    plt.gca().set_title('True ET {0:.6g}'.format(X_A.sum()))
    plt.imshow(X_A.reshape(n_H_A,n_W_A))
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.subplots_adjust(wspace=0.2,hspace=0.2)

    plt.subplot(1,3,2)
    plt.gca().set_title('MC Reco ET {0:.6g}'.format(X_B.sum()))
    plt.imshow(X_B.reshape(n_H_B,n_W_B))
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.subplots_adjust(wspace=0.2,hspace=0.2)

    plt.subplot(1,3,3)
    plt.gca().set_title('NN Reco ET {0:.6g}'.format(sample_nn.sum()))
    plt.imshow(sample_nn.reshape(n_H_B,n_W_B))
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.subplots_adjust(wspace=0.2,hspace=0.2)
    if is_training:
        plt.suptitle('At iter {0}'.format(total_iters))
    fig = plt.gcf()
    fig.set_size_inches(10,8)

    if save:
        if is_training:
            plt.savefig(PATH+'/sample_at_iter_{0}.png'.format(total_iters),dpi=80)
        else:
            plt.savefig(PATH+'/nn_reco_sample_{0}.png'.format(j),dpi=80)
    else:
        plt.show()

def draw_nn_sample_conditional(y, reco_MC, i, preprocess=False,
    min_true=None, max_true=None, min_reco=None, max_reco=None, f=None, 
    save=True, is_training=False, total_iters=None, PATH=None):

    j = np.random.randint(len(reco_MC))

    #_, n_H_A, n_W_A, n_C = X_A.shape 
    _, n_H_B, n_W_B, _ = reco_MC.shape 

    #draw the response for one particle

    y =y[j:j+4]
    reco_MC = reco_MC[j:j+4]
    sample_nn = f(y.reshape(4, y.shape[1])).reshape(4, n_H_B, n_W_B)
    
    
    if preprocess=='normalise':

        Y=denormalise(y, min_true, max_true, norm_space=True)
        X_B_mc=denormalise(reco_MC, min_reco, max_reco)
        X_B_nn=denormalise(sample_nn, min_reco, max_reco)

    
    for i in range(4):

        plt.subplot(2,4,i+1)
        plt.gca().set_title('X: {1}, Y: {0} \n True ET {2:.6g}, \n MC ET {3:.6g}\n'.format(Y[i,0].sum(), Y[i,1].sum(), Y[i,2].sum(), X_B_mc[i].sum()))
        plt.imshow(X_B_mc[i].reshape(n_H_B,n_W_B))
        plt.subplots_adjust(wspace=0.25,hspace=0.25)
        plt.xlabel('X')
        plt.ylabel('Y')
        plt.subplot(2,4,i+5)
        plt.gca().set_title('NN ET {0:.6g}'.format(X_B_nn[i].sum()))
        plt.imshow(X_B_nn[i].reshape(n_H_B,n_W_B))
        plt.xlabel('X')
        plt.ylabel('Y')
        plt.subplots_adjust(wspace=0.25,hspace=0.25)

    fig = plt.gcf()
    fig.set_size_inches(20,10)

    if save:
        if is_training:
            plt.savefig(PATH+'/sample_at_iter_{0}.png'.format(total_iters),dpi=80)
        else:
            plt.savefig(PATH+'/nn_reco_sample_{0}.png'.format(j),dpi=80)
    else:
        plt.show()



def get_inner_HCAL(reco):
    inner_HCAL = reco[:,12:40,16:48,:]
    return inner_HCAL

def get_outer_HCAL(reco):
    m_tot, h, w, c = reco.shape
    outer_HCAL = np.zeros(shape=(m_tot,h//2,w//2,c))
    
    for m in range(0, m_tot):
        img=reco[m]
        for j in range(0, w, 2):
            for i in range(0, h, 2):
                outer_HCAL[m,i//2,j//2,0]=img[i:i+2,j:j+2].sum()
            
    outer_HCAL[:,6:20,8:24,:]=0
    return outer_HCAL

def get_4_max_cells(img):
    
    c =0
    value = np.zeros(shape=(2,2))
    pos =np.zeros(shape=(2,1))
    
    for i in range(img.shape[0]-1):
        for j in range(img.shape[1]-1):
            c_prime = img[i:i+2,j:j+2].sum()
            if c_prime > c:
                c = c_prime
                value[0,0]=img[i,j]
                value[0,1]=img[i,j+1]
                value[1,0]=img[i+1,j]
                value[1,1]=img[i+1,j+1]
                
                pos[0]=i
                pos[1]=j
    return value, pos

def get_triggered_events(true, reco_inner, reco_outer):
    l = []
    for m in range(len(reco_inner)):
        value_inner, pos_inner = get_4_max_cells(reco_inner[m])
        value_outer, pos_outer = get_4_max_cells(reco_outer[m])
        
        if value_inner.sum()>3680 or value_outer.sum()>3680:
                l.append(m)
    triggered_true = np.array([true[l[i]].sum() for i in range(len(l))])
    triggered_reco_inner = np.array([reco_inner[l[i]].sum() for i in range(len(l))])
    triggered_reco_outer = np.array([reco_outer[l[i]].sum() for i in range(len(l))])
    
    return l, triggered_true, triggered_reco_inner, triggered_reco_outer

# def crop_conditional(true, reco, dim):

#     ETs=[]
#     assert len(reco)==len(true)

#     cropped_reco=np.zeros(shape=(reco.shape[0],2*dim+1,2*dim+1,1))
#     max_x = reco.shape[2]
#     max_y = reco.shape[1]
#     pos_rejected=[]
#     for i in range(len(reco)):
        
#         reco_y, reco_x, _ = np.where(reco[i]==reco[i].max())
        
#         #CENTER OF IMAGE
#         if 2*dim<reco_y[0]<=max_y-2*dim and 2*dim<reco_x[0]<=max_x-2*dim:
#             cropped_reco[i]=reco[i, reco_y[0]-dim:reco_y[0]+dim+1, reco_x[0]-dim:reco_x[0]+dim+1, :]
#             ETs.append(true[i][np.where(true[i]>0)][0])

#         else:
#             pos_rejected.append(i)
#     # 

#     # print(len(pos_rejected))
#     ETs=np.array(ETs)   
#     reco_rejected=np.delete(cropped_reco,pos_rejected,axis=0)
    
#     assert len(reco_rejected)==len(ETs)
     
#     return ETs, reco_rejected

# def crop_function(true, reco, dim):
    
#     assert len(reco)==len(true)
#     j=0
#     cropped_reco=np.zeros(shape=(reco.shape[0],2*dim,2*dim,1))
#     cropped_true=np.zeros(shape=(true.shape[0],2*dim,2*dim,1))
#     max_x = reco.shape[2]
#     max_y = reco.shape[1]
    
#     for i in range(len(reco)):
#         y , x , _= np.where(true[i]>0)

#         #CORNERS
        
#         #top left corner
#         # if y[0]<=2*dim  and x[0]<=2*dim:
#         #     #print(i, x, y)
#         #     cropped_reco[i]=reco[i,0:2*dim, 0:2*dim, :]
#         #     cropped_true[i]=true[i,0:2*dim, 0:2*dim, :]
#         #     j+=1
            
#         # #top right corner
#         # elif y[0]<=2*dim  and max_x-2*dim<x[0]<=max_x:
#         #     #print(i, x, y)
#         #     cropped_reco[i]=reco[i,0:2*dim,  max_x-2*dim:max_x, :]
#         #     cropped_true[i]=true[i,0:2*dim,  max_x-2*dim:max_x, :]
#         #     j+=1
            
#         # #bottom right corner
#         # elif max_y-2*dim<y[0]<=max_y and max_x-2*dim<x[0]<=max_x:
#         #     #print(i, x, y)
#         #     cropped_reco[i]=reco[i,max_y-2*dim:max_y, max_x-2*dim:max_x, :]
#         #     cropped_true[i]=true[i,max_y-2*dim:max_y, max_x-2*dim:max_x, :]
#         #     j+=1
            
#         # #bottom left corner
#         # elif max_y-2*dim<y[0]<=max_y and x[0] <=2*dim:
#         #     #print(i, x, y)
#         #     cropped_reco[i]=reco[i,max_y-2*dim:max_y, 0:2*dim, :]
#         #     cropped_true[i]=true[i,max_y-2*dim:max_y, 0:2*dim, :]
#         #     j+=1
            
#         # #BORDERS
        
#         # #bottom border without corners
#         # elif max_y-2*dim<=y[0]<max_y and 2*dim<x[0]<=max_x-2*dim:
#         #     #print(i, x, y)
#         #     cropped_reco[i]=reco[i, max_y-2*dim:max_y, x[0]-dim:x[0]+dim,  :]
#         #     cropped_true[i]=true[i, max_y-2*dim:max_y, x[0]-dim:x[0]+dim,  :]
#         #     j+=1
            
#         # #top border without corners
#         # elif y[0]-2*dim<=0 and 2*dim<x[0]<=max_x-2*dim:
#         #     #print(i, x, y)
#         #     cropped_reco[i]=reco[i, 0:2*dim, x[0]-dim:x[0]+dim,  :]
#         #     cropped_true[i]=true[i, 0:2*dim, x[0]-dim:x[0]+dim,  :]
#         #     j+=1
            
#         # #left border without corners
#         # elif 2*dim<y[0]<=max_y-2*dim and x[0]<=2*dim:
#         #     #print(i, x, y)
#         #     cropped_reco[i]=reco[i,y[0]-dim:y[0]+dim,0:2*dim,  :]
#         #     cropped_true[i]=true[i,y[0]-dim:y[0]+dim,0:2*dim,  :]
#         #     j+=1
            
#         # #right border without corners
#         # elif 2*dim<y[0]<=max_y-2*dim and max_x-2*dim<x[0]<=max_x:
#         #     #print(i, x, y)
#         #     cropped_reco[i]=reco[i, y[0]-dim:y[0]+dim, max_x-2*dim:max_x,  :]
#         #     cropped_true[i]=true[i, y[0]-dim:y[0]+dim, max_x-2*dim:max_x,  :]
#         #     j+=1
 
        
#         #CENTER OF IMAGE
#         if 2*dim<y[0]<=max_y-2*dim and 2*dim<x[0]<=max_x-2*dim:
#             #print(i, x, y)
#             cropped_reco[i]=reco[i, y[0]-dim:y[0]+dim, x[0]-dim:x[0]+dim, :]
#             cropped_true[i]=true[i, y[0]-dim:y[0]+dim, x[0]-dim:x[0]+dim, :]
#             j+=1
#         #assert i==j-1
#     return cropped_true, cropped_reco