Newer
Older
Master_thesis / data / CLs / finished / f1d1 / 2257334 / raremodel-nb.py
#!/usr/bin/env python
# coding: utf-8

# # Import

# In[1]:


import os

# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import numpy as np
from pdg_const import pdg
import matplotlib
import matplotlib.pyplot as plt
import pickle as pkl
import sys
import time
from helperfunctions import display_time, prepare_plot
import cmath as c
import scipy.integrate as integrate
from scipy.optimize import fminbound
from array import array as arr
import collections
from itertools import compress
import tensorflow as tf
import zfit
from zfit import ztf
# from IPython.display import clear_output
import os
import tensorflow_probability as tfp
tfd = tfp.distributions


# In[2]:


# chunksize = 10000
# zfit.run.chunking.active = True
# zfit.run.chunking.max_n_points = chunksize


# # Build model and graphs
# ## Create graphs

# In[ ]:





# In[3]:


def formfactor(q2, subscript, b0_0, b0_1, b0_2, bplus_0, bplus_1, bplus_2, bT_0, bT_1, bT_2): #returns real value
    #check if subscript is viable

    if subscript != "0" and subscript != "+" and subscript != "T":
        raise ValueError('Wrong subscript entered, choose either 0, + or T')

    #get constants

    mK = ztf.constant(pdg['Ks_M'])
    mbstar0 = ztf.constant(pdg["mbstar0"])
    mbstar = ztf.constant(pdg["mbstar"])


    mmu = ztf.constant(pdg['muon_M'])
    mb = ztf.constant(pdg['bquark_M'])
    ms = ztf.constant(pdg['squark_M'])
    mB = ztf.constant(pdg['Bplus_M'])

    #N comes from derivation in paper

    N = 3

    #some helperfunctions

    tpos = (mB - mK)**2
    tzero = (mB + mK)*(ztf.sqrt(mB)-ztf.sqrt(mK))**2

    z_oben = ztf.sqrt(tpos - q2) - ztf.sqrt(tpos - tzero)
    z_unten = ztf.sqrt(tpos - q2) + ztf.sqrt(tpos - tzero)
    z = tf.divide(z_oben, z_unten)

    #calculate f0

    if subscript == "0":
        prefactor = 1/(1 - q2/(mbstar0**2))
        _sum = 0
        b0 = [b0_0, b0_1, b0_2]

        for i in range(N):
            _sum += b0[i]*(tf.pow(z,i))

        return ztf.to_complex(prefactor * _sum)

    #calculate f+ or fT

    else:
        prefactor = 1/(1 - q2/(mbstar**2))
        _sum = 0

        if subscript == "T":
            bT = [bT_0, bT_1, bT_2]
            for i in range(N):
                _sum += bT[i] * (tf.pow(z, i) - ((-1)**(i-N)) * (i/N) * tf.pow(z, N))
        else:
            bplus = [bplus_0, bplus_1, bplus_2]
            for i in range(N):
                _sum += bplus[i] * (tf.pow(z, i) - ((-1)**(i-N)) * (i/N) * tf.pow(z, N))

        return ztf.to_complex(prefactor * _sum)

def resonance(q, _mass, width, phase, scale):

    q2 = tf.pow(q, 2)

    mmu = ztf.constant(pdg['muon_M'])

    p = 0.5 * ztf.sqrt(q2 - 4*(mmu**2))

    p0 =  0.5 * ztf.sqrt(_mass**2 - 4*mmu**2)

    gamma_j = tf.divide(p, q) * _mass * width / p0

    #Calculate the resonance

    _top = tf.complex(_mass * width, ztf.constant(0.0))

    _bottom = tf.complex(_mass**2 - q2, -_mass*gamma_j)

    com = _top/_bottom

    #Rotate by the phase

    r = ztf.to_complex(scale*tf.abs(com))

    _phase = tf.angle(com)

    _phase += phase

    com = r * tf.exp(tf.complex(ztf.constant(0.0), _phase))

    return com


def axiv_nonres(q, b0_0, b0_1, b0_2, bplus_0, bplus_1, bplus_2, bT_0, bT_1, bT_2):

    GF = ztf.constant(pdg['GF'])
    alpha_ew = ztf.constant(pdg['alpha_ew'])
    Vtb = ztf.constant(pdg['Vtb'])
    Vts = ztf.constant(pdg['Vts'])
    C10eff = ztf.constant(pdg['C10eff'])

    mmu = ztf.constant(pdg['muon_M'])
    mb = ztf.constant(pdg['bquark_M'])
    ms = ztf.constant(pdg['squark_M'])
    mK = ztf.constant(pdg['Ks_M'])
    mB = ztf.constant(pdg['Bplus_M'])

    q2 = tf.pow(q, 2)

    #Some helperfunctions

    beta = 1. - 4. * mmu**2. / q2

    kabs = ztf.sqrt(mB**2. + tf.pow(q2, 2)/mB**2. + mK**4./mB**2. - 2. * (mB**2. * mK**2. + mK**2. * q2 + mB**2. * q2) / mB**2.)

    #prefactor in front of whole bracket

    prefactor1 = GF**2. *alpha_ew**2. * (tf.abs(Vtb*Vts))**2. * kabs * beta / (128. * np.pi**5.)

    #left term in bracket

    bracket_left = 2./3. * tf.pow(kabs,2) * tf.pow(beta,2) * tf.pow(tf.abs(ztf.to_complex(C10eff)*formfactor(q2, "+", b0_0, b0_1, b0_2, bplus_0, bplus_1, bplus_2, bT_0, bT_1, bT_2)),2)

    #middle term in bracket

    _top = 4. * mmu**2. * (mB**2. - mK**2.) * (mB**2. - mK**2.)

    _under = q2 * mB**2.

    bracket_middle = _top/_under *tf.pow(tf.abs(ztf.to_complex(C10eff) * formfactor(q2, "0", b0_0, b0_1, b0_2, bplus_0, bplus_1, bplus_2, bT_0, bT_1, bT_2)), 2)
    
    #Note sqrt(q2) comes from derivation as we use q2 and plot q

    return prefactor1 * (bracket_left + bracket_middle) * 2 * q

def vec(q, funcs, b0_0, b0_1, b0_2, bplus_0, bplus_1, bplus_2, bT_0, bT_1, bT_2):
    
    q2 = tf.pow(q, 2)

    GF = ztf.constant(pdg['GF'])
    alpha_ew = ztf.constant(pdg['alpha_ew'])
    Vtb = ztf.constant(pdg['Vtb'])
    Vts = ztf.constant(pdg['Vts'])
    C7eff = ztf.constant(pdg['C7eff'])

    mmu = ztf.constant(pdg['muon_M'])
    mb = ztf.constant(pdg['bquark_M'])
    ms = ztf.constant(pdg['squark_M'])
    mK = ztf.constant(pdg['Ks_M'])
    mB = ztf.constant(pdg['Bplus_M'])

    #Some helperfunctions

    beta = 1. - 4. * mmu**2. / q2

    kabs = ztf.sqrt(mB**2. + tf.pow(q2, 2)/mB**2. + mK**4./mB**2. - 2 * (mB**2 * mK**2 + mK**2 * q2 + mB**2 * q2) / mB**2)
    
    #prefactor in front of whole bracket

    prefactor1 = GF**2. *alpha_ew**2. * (tf.abs(Vtb*Vts))**2 * kabs * beta / (128. * np.pi**5.)

    #right term in bracket

    prefactor2 = tf.pow(kabs,2) * (1. - 1./3. * beta)

    abs_bracket = tf.pow(tf.abs(c9eff(q, funcs) * formfactor(q2, "+", b0_0, b0_1, b0_2, bplus_0, bplus_1, bplus_2, bT_0, bT_1, bT_2) + ztf.to_complex(2.0 * C7eff * (mb + ms)/(mB + mK)) * formfactor(q2, "T", b0_0, b0_1, b0_2, bplus_0, bplus_1, bplus_2, bT_0, bT_1, bT_2)),2)

    bracket_right = prefactor2 * abs_bracket

    #Note sqrt(q2) comes from derivation as we use q2 and plot q

    return prefactor1 * bracket_right * 2 * q

def c9eff(q, funcs):

    C9eff_nr = ztf.to_complex(ztf.constant(pdg['C9eff']))

    c9 = C9eff_nr + funcs

    return c9


# In[4]:


def G(y):
    
    def inner_rect_bracket(q):
        return tf.log(ztf.to_complex((1+tf.sqrt(q))/(1-tf.sqrt(q)))-tf.complex(ztf.constant(0), -1*ztf.constant(np.pi)))    
    
    def inner_right(q):
        return ztf.to_complex(2 * tf.atan(1/tf.sqrt(tf.math.real(-q))))
    
    big_bracket = tf.where(tf.math.real(y) > ztf.constant(0.0), inner_rect_bracket(y), inner_right(y))
    
    return ztf.to_complex(tf.sqrt(tf.abs(y))) * big_bracket

def h_S(m, q):
    
    return ztf.to_complex(2) - G(ztf.to_complex(1) - ztf.to_complex(4*tf.pow(m, 2)) / ztf.to_complex(tf.pow(q, 2)))

def h_P(m, q):
    
    return ztf.to_complex(2/3) + (ztf.to_complex(1) - ztf.to_complex(4*tf.pow(m, 2)) / ztf.to_complex(tf.pow(q, 2))) * h_S(m,q)

def two_p_ccbar(mD, m_D_bar, m_D_star, q):
    
    
    #Load constants
    nu_D_bar = ztf.to_complex(pdg["nu_D_bar"])
    nu_D = ztf.to_complex(pdg["nu_D"])
    nu_D_star = ztf.to_complex(pdg["nu_D_star"])
    
    phase_D_bar = ztf.to_complex(pdg["phase_D_bar"])
    phase_D = ztf.to_complex(pdg["phase_D"])
    phase_D_star = ztf.to_complex(pdg["phase_D_star"])
    
    #Calculation
    left_part =  nu_D_bar * tf.exp(tf.complex(ztf.constant(0.0), phase_D_bar)) * h_S(m_D_bar, q) 
    
    right_part_D = nu_D * tf.exp(tf.complex(ztf.constant(0.0), phase_D)) * h_P(m_D, q) 
    
    right_part_D_star = nu_D_star * tf.exp(tf.complex(ztf.constant(0.0), phase_D_star)) * h_P(m_D_star, q) 

    return left_part + right_part_D + right_part_D_star


# ## Build pdf

# In[5]:


class total_pdf_cut(zfit.pdf.ZPDF):
    _N_OBS = 1  # dimension, can be omitted
    _PARAMS = ['b0_0', 'b0_1', 'b0_2', 
               'bplus_0', 'bplus_1', 'bplus_2', 
               'bT_0', 'bT_1', 'bT_2', 
               'rho_mass', 'rho_scale', 'rho_phase', 'rho_width',
               'jpsi_mass', 'jpsi_scale', 'jpsi_phase', 'jpsi_width',
               'psi2s_mass', 'psi2s_scale', 'psi2s_phase', 'psi2s_width',
               'p3770_mass', 'p3770_scale', 'p3770_phase', 'p3770_width',
               'p4040_mass', 'p4040_scale', 'p4040_phase', 'p4040_width',
               'p4160_mass', 'p4160_scale', 'p4160_phase', 'p4160_width',
               'p4415_mass', 'p4415_scale', 'p4415_phase', 'p4415_width',
               'omega_mass', 'omega_scale', 'omega_phase', 'omega_width',
               'phi_mass', 'phi_scale', 'phi_phase', 'phi_width',
               'Dbar_mass', 'Dbar_scale', 'Dbar_phase',
               'Dstar_mass', 'DDstar_scale', 'DDstar_phase', 'D_mass',
               'tau_mass', 'C_tt']
# the name of the parameters

    def _unnormalized_pdf(self, x):
        
        x = x.unstack_x()
        
        b0 = [self.params['b0_0'], self.params['b0_1'], self.params['b0_2']]
        bplus = [self.params['bplus_0'], self.params['bplus_1'], self.params['bplus_2']]
        bT = [self.params['bT_0'], self.params['bT_1'], self.params['bT_2']]
        
        def rho_res(q):
            return resonance(q, _mass = self.params['rho_mass'], scale = self.params['rho_scale'],
                             phase = self.params['rho_phase'], width = self.params['rho_width'])
    
        def omega_res(q):
            return resonance(q, _mass = self.params['omega_mass'], scale = self.params['omega_scale'],
                             phase = self.params['omega_phase'], width = self.params['omega_width'])
        
        def phi_res(q):
            return resonance(q, _mass = self.params['phi_mass'], scale = self.params['phi_scale'],
                             phase = self.params['phi_phase'], width = self.params['phi_width'])

        def jpsi_res(q):
            return  ztf.to_complex(tf.pow(q, 2) / tf.pow(self.params['jpsi_mass'], 2)) * resonance(q, _mass = self.params['jpsi_mass'], 
                                                                                  scale = self.params['jpsi_scale'],
                                                                                  phase = self.params['jpsi_phase'], 
                                                                                  width = self.params['jpsi_width'])
        def psi2s_res(q):
            return ztf.to_complex(tf.pow(q, 2) / tf.pow(self.params['psi2s_mass'], 2)) * resonance(q, _mass = self.params['psi2s_mass'], 
                                                                                   scale = self.params['psi2s_scale'],
                                                                                   phase = self.params['psi2s_phase'], 
                                                                                   width = self.params['psi2s_width'])
        def p3770_res(q):
            return ztf.to_complex(tf.pow(q, 2) / tf.pow(self.params['p3770_mass'], 2)) * resonance(q, _mass = self.params['p3770_mass'], 
                                                                                   scale = self.params['p3770_scale'],
                                                                                   phase = self.params['p3770_phase'], 
                                                                                   width = self.params['p3770_width'])
        
        def p4040_res(q):
            return ztf.to_complex(tf.pow(q, 2) / tf.pow(self.params['p4040_mass'], 2)) * resonance(q, _mass = self.params['p4040_mass'], 
                                                                                   scale = self.params['p4040_scale'],
                                                                                   phase = self.params['p4040_phase'], 
                                                                                   width = self.params['p4040_width'])
        
        def p4160_res(q):
            return ztf.to_complex(tf.pow(q, 2) / tf.pow(self.params['p4160_mass'], 2)) * resonance(q, _mass = self.params['p4160_mass'], 
                                                                                   scale = self.params['p4160_scale'],
                                                                                   phase = self.params['p4160_phase'], 
                                                                                   width = self.params['p4160_width'])
        
        def p4415_res(q):
            return ztf.to_complex(tf.pow(q, 2) / tf.pow(self.params['p4415_mass'], 2)) * resonance(q, _mass = self.params['p4415_mass'], 
                                                                                   scale = self.params['p4415_scale'],
                                                                                   phase = self.params['p4415_phase'], 
                                                                                   width = self.params['p4415_width'])
        
        def P2_D(q):
            Dbar_contrib = ztf.to_complex(self.params['Dbar_scale'])*tf.exp(tf.complex(ztf.constant(0.0), self.params['Dbar_phase']))*ztf.to_complex(h_S(self.params['Dbar_mass'], q))
            DDstar_contrib = ztf.to_complex(self.params['DDstar_scale'])*tf.exp(tf.complex(ztf.constant(0.0), self.params['DDstar_phase']))*(ztf.to_complex(h_P(self.params['Dstar_mass'], q)) + ztf.to_complex(h_P(self.params['D_mass'], q)))
            return Dbar_contrib + DDstar_contrib
        
        def ttau_cusp(q):
            return ztf.to_complex(self.params['C_tt'])*(ztf.to_complex((h_S(self.params['tau_mass'], q))) - ztf.to_complex(h_P(self.params['tau_mass'], q)))
        

        funcs = rho_res(x) + omega_res(x) + phi_res(x) + jpsi_res(x) + psi2s_res(x) + p3770_res(x) + p4040_res(x)+ p4160_res(x) + p4415_res(x) + P2_D(x) + ttau_cusp(x)

        vec_f = vec(x, funcs, b0_0, b0_1, b0_2, bplus_0, bplus_1, bplus_2, bT_0, bT_1, bT_2)

        axiv_nr = axiv_nonres(x, b0_0, b0_1, b0_2, bplus_0, bplus_1, bplus_2, bT_0, bT_1, bT_2)

        tot = vec_f + axiv_nr
        
        #Cut out jpsi and psi2s
        
        tot = tf.where(tf.math.logical_or(x < ztf.constant(jpsi_mass-60.), x > ztf.constant(jpsi_mass+70.)), tot, 0.0*tot)
        
        tot = tf.where(tf.math.logical_or(x < ztf.constant(psi2s_mass-50.), x > ztf.constant(psi2s_mass+50.)), tot, 0.0*tot)
        
        return tot
    
class total_pdf_full(zfit.pdf.ZPDF):
    _N_OBS = 1  # dimension, can be omitted
    _PARAMS = ['b0_0', 'b0_1', 'b0_2', 
               'bplus_0', 'bplus_1', 'bplus_2', 
               'bT_0', 'bT_1', 'bT_2', 
               'rho_mass', 'rho_scale', 'rho_phase', 'rho_width',
               'jpsi_mass', 'jpsi_scale', 'jpsi_phase', 'jpsi_width',
               'psi2s_mass', 'psi2s_scale', 'psi2s_phase', 'psi2s_width',
               'p3770_mass', 'p3770_scale', 'p3770_phase', 'p3770_width',
               'p4040_mass', 'p4040_scale', 'p4040_phase', 'p4040_width',
               'p4160_mass', 'p4160_scale', 'p4160_phase', 'p4160_width',
               'p4415_mass', 'p4415_scale', 'p4415_phase', 'p4415_width',
               'omega_mass', 'omega_scale', 'omega_phase', 'omega_width',
               'phi_mass', 'phi_scale', 'phi_phase', 'phi_width',
               'Dbar_mass', 'Dbar_scale', 'Dbar_phase',
               'Dstar_mass', 'DDstar_scale', 'DDstar_phase', 'D_mass',
               'tau_mass', 'C_tt']
# the name of the parameters

    def _unnormalized_pdf(self, x):
        
        x = x.unstack_x()
        
        b0 = [self.params['b0_0'], self.params['b0_1'], self.params['b0_2']]
        bplus = [self.params['bplus_0'], self.params['bplus_1'], self.params['bplus_2']]
        bT = [self.params['bT_0'], self.params['bT_1'], self.params['bT_2']]
        
        def rho_res(q):
            return resonance(q, _mass = self.params['rho_mass'], scale = self.params['rho_scale'],
                             phase = self.params['rho_phase'], width = self.params['rho_width'])
    
        def omega_res(q):
            return resonance(q, _mass = self.params['omega_mass'], scale = self.params['omega_scale'],
                             phase = self.params['omega_phase'], width = self.params['omega_width'])
        
        def phi_res(q):
            return resonance(q, _mass = self.params['phi_mass'], scale = self.params['phi_scale'],
                             phase = self.params['phi_phase'], width = self.params['phi_width'])

        def jpsi_res(q):
            return  ztf.to_complex(tf.pow(q, 2) / tf.pow(self.params['jpsi_mass'], 2)) * resonance(q, _mass = self.params['jpsi_mass'], 
                                                                                  scale = self.params['jpsi_scale'],
                                                                                  phase = self.params['jpsi_phase'], 
                                                                                  width = self.params['jpsi_width'])
        def psi2s_res(q):
            return ztf.to_complex(tf.pow(q, 2) / tf.pow(self.params['psi2s_mass'], 2)) * resonance(q, _mass = self.params['psi2s_mass'], 
                                                                                   scale = self.params['psi2s_scale'],
                                                                                   phase = self.params['psi2s_phase'], 
                                                                                   width = self.params['psi2s_width'])
        def p3770_res(q):
            return ztf.to_complex(tf.pow(q, 2) / tf.pow(self.params['p3770_mass'], 2)) * resonance(q, _mass = self.params['p3770_mass'], 
                                                                                   scale = self.params['p3770_scale'],
                                                                                   phase = self.params['p3770_phase'], 
                                                                                   width = self.params['p3770_width'])
        
        def p4040_res(q):
            return ztf.to_complex(tf.pow(q, 2) / tf.pow(self.params['p4040_mass'], 2)) * resonance(q, _mass = self.params['p4040_mass'], 
                                                                                   scale = self.params['p4040_scale'],
                                                                                   phase = self.params['p4040_phase'], 
                                                                                   width = self.params['p4040_width'])
        
        def p4160_res(q):
            return ztf.to_complex(tf.pow(q, 2) / tf.pow(self.params['p4160_mass'], 2)) * resonance(q, _mass = self.params['p4160_mass'], 
                                                                                   scale = self.params['p4160_scale'],
                                                                                   phase = self.params['p4160_phase'], 
                                                                                   width = self.params['p4160_width'])
        
        def p4415_res(q):
            return ztf.to_complex(tf.pow(q, 2) / tf.pow(self.params['p4415_mass'], 2)) * resonance(q, _mass = self.params['p4415_mass'], 
                                                                                   scale = self.params['p4415_scale'],
                                                                                   phase = self.params['p4415_phase'], 
                                                                                   width = self.params['p4415_width'])
        
        def P2_D(q):
            Dbar_contrib = ztf.to_complex(self.params['Dbar_scale'])*tf.exp(tf.complex(ztf.constant(0.0), self.params['Dbar_phase']))*ztf.to_complex(h_S(self.params['Dbar_mass'], q))
            DDstar_contrib = ztf.to_complex(self.params['DDstar_scale'])*tf.exp(tf.complex(ztf.constant(0.0), self.params['DDstar_phase']))*(ztf.to_complex(h_P(self.params['Dstar_mass'], q)) + ztf.to_complex(h_P(self.params['D_mass'], q)))
            return Dbar_contrib + DDstar_contrib
        
        def ttau_cusp(q):
            return ztf.to_complex(self.params['C_tt'])*(ztf.to_complex((h_S(self.params['tau_mass'], q))) - ztf.to_complex(h_P(self.params['tau_mass'], q)))
        

        funcs = rho_res(x) + omega_res(x) + phi_res(x) + jpsi_res(x) + psi2s_res(x) + p3770_res(x) + p4040_res(x)+ p4160_res(x) + p4415_res(x) + P2_D(x) + ttau_cusp(x)

        vec_f = vec(x, funcs, b0_0, b0_1, b0_2, bplus_0, bplus_1, bplus_2, bT_0, bT_1, bT_2)

        axiv_nr = axiv_nonres(x, b0_0, b0_1, b0_2, bplus_0, bplus_1, bplus_2, bT_0, bT_1, bT_2)

        tot = vec_f + axiv_nr
        
        #Cut out jpsi and psi2s
        
#         tot = tf.where(tf.math.logical_or(x < ztf.constant(jpsi_mass-60.), x > ztf.constant(jpsi_mass+70.)), tot, 0.0*tot)
        
#         tot = tf.where(tf.math.logical_or(x < ztf.constant(psi2s_mass-50.), x > ztf.constant(psi2s_mass+50.)), tot, 0.0*tot)
        
        return tot


# ## Setup parameters

# In[6]:


# formfactors

b0_0 = zfit.Parameter("b0_0", ztf.constant(0.292), floating = False) #, lower_limit = -2.0, upper_limit= 2.0)
b0_1 = zfit.Parameter("b0_1", ztf.constant(0.281), floating = False) #, lower_limit = -2.0, upper_limit= 2.0)
b0_2 = zfit.Parameter("b0_2", ztf.constant(0.150), floating = False) #, lower_limit = -2.0, upper_limit= 2.0)

bplus_0 = zfit.Parameter("bplus_0", ztf.constant(0.466), lower_limit = -2.0, upper_limit= 2.0)
bplus_1 = zfit.Parameter("bplus_1", ztf.constant(-0.885), lower_limit = -2.0, upper_limit= 2.0)
bplus_2 = zfit.Parameter("bplus_2", ztf.constant(-0.213), lower_limit = -2.0, upper_limit= 2.0)

bT_0 = zfit.Parameter("bT_0", ztf.constant(0.460), floating = False) #, lower_limit = -2.0, upper_limit= 2.0)
bT_1 = zfit.Parameter("bT_1", ztf.constant(-1.089), floating = False) #, lower_limit = -2.0, upper_limit= 2.0)
bT_2 = zfit.Parameter("bT_2", ztf.constant(-1.114), floating = False) #, lower_limit = -2.0, upper_limit= 2.0)


#rho

rho_mass, rho_width, rho_phase, rho_scale = pdg["rho"]

rho_m = zfit.Parameter("rho_m", ztf.constant(rho_mass), floating = False) #lower_limit = rho_mass - rho_width, upper_limit = rho_mass + rho_width)
rho_w = zfit.Parameter("rho_w", ztf.constant(rho_width), floating = False)
rho_p = zfit.Parameter("rho_p", ztf.constant(rho_phase), lower_limit=-2*np.pi, upper_limit=2*np.pi)
rho_s = zfit.Parameter("rho_s", ztf.constant(rho_scale), lower_limit=rho_scale-np.sqrt(rho_scale), upper_limit=rho_scale+np.sqrt(rho_scale))

#omega

omega_mass, omega_width, omega_phase, omega_scale = pdg["omega"]

omega_m = zfit.Parameter("omega_m", ztf.constant(omega_mass), floating = False)
omega_w = zfit.Parameter("omega_w", ztf.constant(omega_width), floating = False)
omega_p = zfit.Parameter("omega_p", ztf.constant(omega_phase), lower_limit=-2*np.pi, upper_limit=2*np.pi)
omega_s = zfit.Parameter("omega_s", ztf.constant(omega_scale), lower_limit=omega_scale-np.sqrt(omega_scale), upper_limit=omega_scale+np.sqrt(omega_scale))


#phi

phi_mass, phi_width, phi_phase, phi_scale = pdg["phi"]

phi_m = zfit.Parameter("phi_m", ztf.constant(phi_mass), floating = False)
phi_w = zfit.Parameter("phi_w", ztf.constant(phi_width), floating = False)
phi_p = zfit.Parameter("phi_p", ztf.constant(phi_phase), lower_limit=-2*np.pi, upper_limit=2*np.pi)
phi_s = zfit.Parameter("phi_s", ztf.constant(phi_scale), lower_limit=phi_scale-np.sqrt(phi_scale), upper_limit=phi_scale+np.sqrt(phi_scale))

#jpsi

jpsi_mass, jpsi_width, jpsi_phase, jpsi_scale = pdg["jpsi"]

jpsi_m = zfit.Parameter("jpsi_m", ztf.constant(jpsi_mass), floating = False)
jpsi_w = zfit.Parameter("jpsi_w", ztf.constant(jpsi_width), floating = False)
jpsi_p = zfit.Parameter("jpsi_p", ztf.constant(jpsi_phase), lower_limit=-2*np.pi, upper_limit=2*np.pi)
jpsi_s = zfit.Parameter("jpsi_s", ztf.constant(jpsi_scale), floating = False) #, lower_limit=jpsi_scale-np.sqrt(jpsi_scale), upper_limit=jpsi_scale+np.sqrt(jpsi_scale))

#psi2s

psi2s_mass, psi2s_width, psi2s_phase, psi2s_scale = pdg["psi2s"]

psi2s_m = zfit.Parameter("psi2s_m", ztf.constant(psi2s_mass), floating = False)
psi2s_w = zfit.Parameter("psi2s_w", ztf.constant(psi2s_width), floating = False)
psi2s_p = zfit.Parameter("psi2s_p", ztf.constant(psi2s_phase), lower_limit=-2*np.pi, upper_limit=2*np.pi)
psi2s_s = zfit.Parameter("psi2s_s", ztf.constant(psi2s_scale), floating = False) #, lower_limit=psi2s_scale-np.sqrt(psi2s_scale), upper_limit=psi2s_scale+np.sqrt(psi2s_scale))

#psi(3770)

p3770_mass, p3770_width, p3770_phase, p3770_scale = pdg["p3770"]

p3770_m = zfit.Parameter("p3770_m", ztf.constant(p3770_mass), floating = False)
p3770_w = zfit.Parameter("p3770_w", ztf.constant(p3770_width), floating = False)
p3770_p = zfit.Parameter("p3770_p", ztf.constant(p3770_phase), lower_limit=-2*np.pi, upper_limit=2*np.pi)
p3770_s = zfit.Parameter("p3770_s", ztf.constant(p3770_scale), lower_limit=p3770_scale-np.sqrt(p3770_scale), upper_limit=p3770_scale+np.sqrt(p3770_scale))

#psi(4040)

p4040_mass, p4040_width, p4040_phase, p4040_scale = pdg["p4040"]

p4040_m = zfit.Parameter("p4040_m", ztf.constant(p4040_mass), floating = False)
p4040_w = zfit.Parameter("p4040_w", ztf.constant(p4040_width), floating = False)
p4040_p = zfit.Parameter("p4040_p", ztf.constant(p4040_phase), lower_limit=-2*np.pi, upper_limit=2*np.pi)
p4040_s = zfit.Parameter("p4040_s", ztf.constant(p4040_scale), lower_limit=p4040_scale-np.sqrt(p4040_scale), upper_limit=p4040_scale+np.sqrt(p4040_scale))

#psi(4160)

p4160_mass, p4160_width, p4160_phase, p4160_scale = pdg["p4160"]

p4160_m = zfit.Parameter("p4160_m", ztf.constant(p4160_mass), floating = False)
p4160_w = zfit.Parameter("p4160_w", ztf.constant(p4160_width), floating = False)
p4160_p = zfit.Parameter("p4160_p", ztf.constant(p4160_phase), lower_limit=-2*np.pi, upper_limit=2*np.pi)
p4160_s = zfit.Parameter("p4160_s", ztf.constant(p4160_scale), lower_limit=p4160_scale-np.sqrt(p4160_scale), upper_limit=p4160_scale+np.sqrt(p4160_scale))

#psi(4415)

p4415_mass, p4415_width, p4415_phase, p4415_scale = pdg["p4415"]

p4415_m = zfit.Parameter("p4415_m", ztf.constant(p4415_mass), floating = False)
p4415_w = zfit.Parameter("p4415_w", ztf.constant(p4415_width), floating = False)
p4415_p = zfit.Parameter("p4415_p", ztf.constant(p4415_phase), lower_limit=-2*np.pi, upper_limit=2*np.pi)
p4415_s = zfit.Parameter("p4415_s", ztf.constant(p4415_scale), lower_limit=p4415_scale-np.sqrt(p4415_scale), upper_limit=p4415_scale+np.sqrt(p4415_scale))


# ## Dynamic generation of 2 particle contribution

# In[7]:


m_c = 1300

Dbar_phase = 0.0
DDstar_phase = 0.0
Dstar_mass = pdg['Dst_M']
Dbar_mass = pdg['D0_M']
D_mass = pdg['D0_M']

Dbar_s = zfit.Parameter("Dbar_s", ztf.constant(0.0), lower_limit=-0.3, upper_limit=0.3)
Dbar_m = zfit.Parameter("Dbar_m", ztf.constant(Dbar_mass), floating = False)
Dbar_p = zfit.Parameter("Dbar_p", ztf.constant(Dbar_phase), lower_limit=-2*np.pi, upper_limit=2*np.pi)#, floating = False)
DDstar_s = zfit.Parameter("DDstar_s", ztf.constant(0.0), lower_limit=-0.3, upper_limit=0.3)#, floating = False)
Dstar_m = zfit.Parameter("Dstar_m", ztf.constant(Dstar_mass), floating = False)
D_m = zfit.Parameter("D_m", ztf.constant(D_mass), floating = False)
DDstar_p = zfit.Parameter("DDstar_p", ztf.constant(DDstar_phase), lower_limit=-2*np.pi, upper_limit=2*np.pi)#, floating = False)


# ## Tau parameters

# In[8]:


tau_m = zfit.Parameter("tau_m", ztf.constant(pdg['tau_M']), floating = False)
Ctt = zfit.Parameter("Ctt", ztf.constant(0.0), lower_limit=-0.5, upper_limit=0.5)


# ## Load data

# In[9]:


x_min = 2*pdg['muon_M']
x_max = (pdg["Bplus_M"]-pdg["Ks_M"]-0.1)

# # Full spectrum

obs_toy = zfit.Space('q', limits = (x_min, x_max))

# Jpsi and Psi2s cut out

obs1 = zfit.Space('q', limits = (x_min, jpsi_mass - 60.))
obs2 = zfit.Space('q', limits = (jpsi_mass + 70., psi2s_mass - 50.))
obs3 = zfit.Space('q', limits = (psi2s_mass + 50., x_max))

obs_fit = obs1 + obs2 + obs3

# with open(r"./data/slim_points/slim_points_toy_0_range({0}-{1}).pkl".format(int(x_min), int(x_max)), "rb") as input_file:
#     part_set = pkl.load(input_file)

# x_part = part_set['x_part']

# x_part = x_part.astype('float64')

# data = zfit.data.Data.from_numpy(array=x_part, obs=obs)


# ## Setup pdf

# In[10]:


total_f = total_pdf_cut(obs=obs_toy, jpsi_mass = jpsi_m, jpsi_scale = jpsi_s, jpsi_phase = jpsi_p, jpsi_width = jpsi_w,
                    psi2s_mass = psi2s_m, psi2s_scale = psi2s_s, psi2s_phase = psi2s_p, psi2s_width = psi2s_w,
                    p3770_mass = p3770_m, p3770_scale = p3770_s, p3770_phase = p3770_p, p3770_width = p3770_w,
                    p4040_mass = p4040_m, p4040_scale = p4040_s, p4040_phase = p4040_p, p4040_width = p4040_w,
                    p4160_mass = p4160_m, p4160_scale = p4160_s, p4160_phase = p4160_p, p4160_width = p4160_w,
                    p4415_mass = p4415_m, p4415_scale = p4415_s, p4415_phase = p4415_p, p4415_width = p4415_w,
                    rho_mass = rho_m, rho_scale = rho_s, rho_phase = rho_p, rho_width = rho_w,
                    omega_mass = omega_m, omega_scale = omega_s, omega_phase = omega_p, omega_width = omega_w,
                    phi_mass = phi_m, phi_scale = phi_s, phi_phase = phi_p, phi_width = phi_w,
                    Dstar_mass = Dstar_m, DDstar_scale = DDstar_s, DDstar_phase = DDstar_p, D_mass = D_m,
                    Dbar_mass = Dbar_m, Dbar_scale = Dbar_s, Dbar_phase = Dbar_p,
                    tau_mass = tau_m, C_tt = Ctt, b0_0 = b0_0, b0_1 = b0_1, b0_2 = b0_2,
                    bplus_0 = bplus_0, bplus_1 = bplus_1, bplus_2 = bplus_2,
                    bT_0 = bT_0, bT_1 = bT_1, bT_2 = bT_2)

total_f_fit = total_pdf_full(obs=obs_fit, jpsi_mass = jpsi_m, jpsi_scale = jpsi_s, jpsi_phase = jpsi_p, jpsi_width = jpsi_w,
                    psi2s_mass = psi2s_m, psi2s_scale = psi2s_s, psi2s_phase = psi2s_p, psi2s_width = psi2s_w,
                    p3770_mass = p3770_m, p3770_scale = p3770_s, p3770_phase = p3770_p, p3770_width = p3770_w,
                    p4040_mass = p4040_m, p4040_scale = p4040_s, p4040_phase = p4040_p, p4040_width = p4040_w,
                    p4160_mass = p4160_m, p4160_scale = p4160_s, p4160_phase = p4160_p, p4160_width = p4160_w,
                    p4415_mass = p4415_m, p4415_scale = p4415_s, p4415_phase = p4415_p, p4415_width = p4415_w,
                    rho_mass = rho_m, rho_scale = rho_s, rho_phase = rho_p, rho_width = rho_w,
                    omega_mass = omega_m, omega_scale = omega_s, omega_phase = omega_p, omega_width = omega_w,
                    phi_mass = phi_m, phi_scale = phi_s, phi_phase = phi_p, phi_width = phi_w,
                    Dstar_mass = Dstar_m, DDstar_scale = DDstar_s, DDstar_phase = DDstar_p, D_mass = D_m,
                    Dbar_mass = Dbar_m, Dbar_scale = Dbar_s, Dbar_phase = Dbar_p,
                    tau_mass = tau_m, C_tt = Ctt, b0_0 = b0_0, b0_1 = b0_1, b0_2 = b0_2,
                    bplus_0 = bplus_0, bplus_1 = bplus_1, bplus_2 = bplus_2,
                    bT_0 = bT_0, bT_1 = bT_1, bT_2 = bT_2)
                   
# print(total_pdf.obs)

# print(calcs_test)

# for param in total_f.get_dependents():
#     print(zfit.run(param))


# In[11]:


total_f_fit.normalization(obs_toy)


# ## Test if graphs actually work and compute values

# In[12]:


# def total_test_tf(xq):

#     def jpsi_res(q):
#         return resonance(q, jpsi_m, jpsi_s, jpsi_p, jpsi_w)

#     def psi2s_res(q):
#         return resonance(q, psi2s_m, psi2s_s, psi2s_p, psi2s_w)

#     def cusp(q):
#         return bifur_gauss(q, cusp_m, sig_L, sig_R, cusp_s)

#     funcs = jpsi_res(xq) + psi2s_res(xq) + cusp(xq)

#     vec_f = vec(xq, funcs)

#     axiv_nr = axiv_nonres(xq)

#     tot = vec_f + axiv_nr
    
#     return tot

# def jpsi_res(q):
#     return resonance(q, jpsi_m, jpsi_s, jpsi_p, jpsi_w)

# calcs = zfit.run(total_test_tf(x_part))

test_q = np.linspace(x_min, x_max, int(2e6))

probs = total_f_fit.pdf(test_q, norm_range=False)

calcs_test = zfit.run(probs)
# res_y = zfit.run(jpsi_res(test_q))
# b0 = [b0_0, b0_1, b0_2]
# bplus = [bplus_0, bplus_1, bplus_2]
# bT = [bT_0, bT_1, bT_2]
# f0_y = zfit.run(tf.math.real(formfactor(test_q,"0", b0, bplus, bT)))
# fplus_y = zfit.run(tf.math.real(formfactor(test_q,"+", b0, bplus, bT)))
# fT_y = zfit.run(tf.math.real(formfactor(test_q,"T", b0, bplus, bT)))


# In[13]:


plt.clf()
# plt.plot(x_part, calcs, '.')
plt.plot(test_q, calcs_test, label = 'pdf')
# plt.plot(test_q, f0_y, label = '0')
# plt.plot(test_q, fT_y, label = 'T')
# plt.plot(test_q, fplus_y, label = '+')
# plt.plot(test_q, res_y, label = 'res')
plt.legend()
plt.ylim(0.0, 1.5e-6)
# plt.yscale('log')
# plt.xlim(770, 785)
plt.savefig('test.png')
# print(jpsi_width)


# In[14]:




# probs = mixture.prob(test_q)
# probs_np = zfit.run(probs)
# probs_np *= np.max(calcs_test) / np.max(probs_np)
# plt.figure()
# plt.semilogy(test_q, probs_np,label="importance sampling")
# plt.semilogy(test_q, calcs_test, label = 'pdf')


# In[15]:


# 0.213/(0.00133+0.213+0.015)


# ## Adjust scaling of different parts

# In[16]:


total_f.update_integration_options(draws_per_dim=2000000, mc_sampler=None)
# inte = total_f.integrate(limits = (950., 1050.), norm_range=False)
# inte_fl = zfit.run(inte)
# print(inte_fl/4500)
# print(pdg["jpsi_BR"]/pdg["NR_BR"], inte_fl*pdg["psi2s_auc"]/pdg["NR_auc"])


# In[17]:


# # print("jpsi:", inte_fl)
# # print("Increase am by factor:", np.sqrt(pdg["jpsi_BR"]/pdg["NR_BR"]*pdg["NR_auc"]/inte_fl))
# # print("New amp:", pdg["jpsi"][3]*np.sqrt(pdg["jpsi_BR"]/pdg["NR_BR"]*pdg["NR_auc"]/inte_fl))

# # print("psi2s:", inte_fl)
# # print("Increase am by factor:", np.sqrt(pdg["psi2s_BR"]/pdg["NR_BR"]*pdg["NR_auc"]/inte_fl))
# # print("New amp:", pdg["psi2s"][3]*np.sqrt(pdg["psi2s_BR"]/pdg["NR_BR"]*pdg["NR_auc"]/inte_fl))

# name = "phi"

# print(name+":", inte_fl)
# print("Increase am by factor:", np.sqrt(pdg[name+"_BR"]/pdg["NR_BR"]*pdg["NR_auc"]/inte_fl))
# print("New amp:", pdg[name][0]*np.sqrt(pdg[name+"_BR"]/pdg["NR_BR"]*pdg["NR_auc"]/inte_fl))


# print(x_min)
# print(x_max)
# # total_f.update_integration_options(draws_per_dim=2000000, mc_sampler=None)
# total_f.update_integration_options(mc_sampler=lambda dim, num_results,
#                                     dtype: tf.random_uniform(maxval=1., shape=(num_results, dim), dtype=dtype),
#                                    draws_per_dim=1000000)
# # _ = []

# # for i in range(10):

# #     inte = total_f.integrate(limits = (x_min, x_max))
# #     inte_fl = zfit.run(inte)
# #     print(inte_fl)
# #     _.append(inte_fl)

# # print("mean:", np.mean(_))

# _ = time.time()

# inte = total_f.integrate(limits = (x_min, x_max))
# inte_fl = zfit.run(inte)
# print(inte_fl)
# print("Time taken: {}".format(display_time(int(time.time() - _))))

# print(pdg['NR_BR']/pdg['NR_auc']*inte_fl)
# print(0.25**2*4.2/1000)


# # Sampling
# ## Mixture distribution for sampling

# In[18]:



    
# print(list_of_borders[:9])
# print(list_of_borders[-9:])


class UniformSampleAndWeights(zfit.util.execution.SessionHolderMixin):
    def __call__(self, limits, dtype, n_to_produce):
        # n_to_produce = tf.cast(n_to_produce, dtype=tf.int32)
        low, high = limits.limit1d
        low = tf.cast(low, dtype=dtype)
        high = tf.cast(high, dtype=dtype)
#         uniform = tfd.Uniform(low=low, high=high)
#         uniformjpsi = tfd.Uniform(low=tf.constant(3080, dtype=dtype), high=tf.constant(3112, dtype=dtype))
#         uniformpsi2s = tfd.Uniform(low=tf.constant(3670, dtype=dtype), high=tf.constant(3702, dtype=dtype))

#         list_of_borders = []
#         _p = []
#         splits = 10

#         _ = np.linspace(x_min, x_max, splits)

#         for i in range(splits):
#             list_of_borders.append(tf.constant(_[i], dtype=dtype))
#             _p.append(tf.constant(1/splits, dtype=dtype))
    
#         mixture = tfd.MixtureSameFamily(mixture_distribution=tfd.Categorical(probs=_p[:(splits-1)]),
#                                         components_distribution=tfd.Uniform(low=list_of_borders[:(splits-1)], 
#                                                                             high=list_of_borders[-(splits-1):]))
        mixture = tfd.MixtureSameFamily(mixture_distribution=tfd.Categorical(probs=[tf.constant(0.05, dtype=dtype),
                                                                                    tf.constant(0.93, dtype=dtype),
                                                                                    tf.constant(0.05, dtype=dtype),
                                                                                    tf.constant(0.065, dtype=dtype),
                                                                                    tf.constant(0.04, dtype=dtype),
                                                                                    tf.constant(0.05, dtype=dtype)]),
                                        components_distribution=tfd.Uniform(low=[tf.constant(x_min, dtype=dtype), 
                                                                                 tf.constant(3090, dtype=dtype),
                                                                                 tf.constant(3681, dtype=dtype), 
                                                                                 tf.constant(3070, dtype=dtype),
                                                                                 tf.constant(1000, dtype=dtype),
                                                                                 tf.constant(3660, dtype=dtype)], 
                                                                            high=[tf.constant(x_max, dtype=dtype),
                                                                                  tf.constant(3102, dtype=dtype), 
                                                                                  tf.constant(3691, dtype=dtype),
                                                                                  tf.constant(3110, dtype=dtype),
                                                                                  tf.constant(1040, dtype=dtype),
                                                                                  tf.constant(3710, dtype=dtype)]))
#         dtype = tf.float64
#         mixture = tfd.MixtureSameFamily(mixture_distribution=tfd.Categorical(probs=[tf.constant(0.04, dtype=dtype),
#                                                                                     tf.constant(0.90, dtype=dtype),
#                                                                                     tf.constant(0.02, dtype=dtype),
#                                                                                     tf.constant(0.07, dtype=dtype),
#                                                                                     tf.constant(0.02, dtype=dtype)]),
#                                         components_distribution=tfd.Uniform(low=[tf.constant(x_min, dtype=dtype), 
#                                                                                  tf.constant(3089, dtype=dtype),
#                                                                                  tf.constant(3103, dtype=dtype), 
#                                                                                  tf.constant(3681, dtype=dtype),
#                                                                                  tf.constant(3691, dtype=dtype)], 
#                                                                             high=[tf.constant(3089, dtype=dtype),
#                                                                                   tf.constant(3103, dtype=dtype), 
#                                                                                   tf.constant(3681, dtype=dtype),
#                                                                                   tf.constant(3691, dtype=dtype), 
#                                                                                   tf.constant(x_max, dtype=dtype)]))
#         mixture = tfd.Uniform(tf.constant(x_min, dtype=dtype), tf.constant(x_max, dtype=dtype))
#         sample = tf.random.uniform((n_to_produce, 1), dtype=dtype)
        sample = mixture.sample((n_to_produce, 1))
#         sample = tf.random.uniform((n_to_produce, 1), dtype=dtype)
        weights = mixture.prob(sample)[:,0]
#         weights = tf.broadcast_to(tf.constant(1., dtype=dtype), shape=(n_to_produce,))
        # sample = tf.expand_dims(sample, axis=-1)
#         print(sample, weights)
        
#         weights = tf.ones(shape=(n_to_produce,), dtype=dtype)
        weights_max = None
        thresholds = tf.random_uniform(shape=(n_to_produce,), dtype=dtype)
        return sample, thresholds, weights, weights_max, n_to_produce


# In[19]:


# total_f._sample_and_weights = UniformSampleAndWeights


# In[20]:


# 0.00133/(0.00133+0.213+0.015)*(x_max-3750)/(x_max-x_min)


# In[21]:


# zfit.settings.set_verbosity(10)


# In[22]:


# # zfit.run.numeric_checks = False   

# nr_of_toys = 1
# nevents = int(pdg["number_of_decays"])
# nevents = pdg["number_of_decays"]
# event_stack = 1000000
# # zfit.settings.set_verbosity(10)
# calls = int(nevents/event_stack + 1)

# total_samp = []

# start = time.time()

# sampler = total_f.create_sampler(n=event_stack)

# for toy in range(nr_of_toys):
    
#     dirName = 'data/zfit_toys/toy_{0}'.format(toy)
    
#     if not os.path.exists(dirName):
#         os.mkdir(dirName)
#         print("Directory " , dirName ,  " Created ")

#     for call in range(calls):

#         sampler.resample(n=event_stack)
#         s = sampler.unstack_x()
#         sam = zfit.run(s)
# #         clear_output(wait=True)

#         c = call + 1
        
#         print("{0}/{1} of Toy {2}/{3}".format(c, calls, toy+1, nr_of_toys))
#         print("Time taken: {}".format(display_time(int(time.time() - start))))
#         print("Projected time left: {}".format(display_time(int((time.time() - start)/(c+calls*(toy))*((nr_of_toys-toy)*calls-c)))))

#         with open("data/zfit_toys/toy_{0}/{1}.pkl".format(toy, call), "wb") as f:
#             pkl.dump(sam, f, pkl.HIGHEST_PROTOCOL)


# In[23]:


# with open(r"data/zfit_toys/toy_0/0.pkl", "rb") as input_file:
#     sam = pkl.load(input_file)
# print(sam[:10])

# with open(r"data/zfit_toys/toy_0/1.pkl", "rb") as input_file:
#     sam2 = pkl.load(input_file)
# print(sam2[:10])

# print(np.sum(sam-sam2))


# In[24]:


# print("Time to generate full toy: {} s".format(int(time.time()-start)))

# total_samp = []

# for call in range(calls):
#     with open(r"data/zfit_toys/toy_0/{}.pkl".format(call), "rb") as input_file:
#         sam = pkl.load(input_file)
#         total_samp = np.append(total_samp, sam)

# total_samp = total_samp.astype('float64')

# data2 = zfit.data.Data.from_numpy(array=total_samp[:int(nevents)], obs=obs)

# data3 = zfit.data.Data.from_numpy(array=total_samp, obs=obs)

# print(total_samp[:nevents].shape)


# In[25]:


# plt.clf()

# bins = int((x_max-x_min)/7)

# # calcs = zfit.run(total_test_tf(samp))
# print(total_samp[:nevents].shape)

# plt.hist(total_samp[:nevents], bins = bins, range = (x_min,x_max), label = 'data')
# # plt.plot(test_q, calcs_test*nevents , label = 'pdf')

# # plt.plot(sam, calcs, '.')
# # plt.plot(test_q, calcs_test)
# # plt.yscale('log')
# plt.ylim(0, 200)
# # plt.xlim(3080, 3110)

# plt.legend()

# plt.savefig('test2.png')


# In[26]:


# sampler = total_f.create_sampler(n=nevents)
# nll = zfit.loss.UnbinnedNLL(model=total_f, data=sampler, fit_range = (x_min, x_max))

# # for param in pdf.get_dependents():
# #     param.set_value(initial_value)

# sampler.resample(n=nevents)

# # Randomise initial values
# # for param in pdf.get_dependents():
# #     param.set_value(random value here)

# # Minimise the NLL
# minimizer = zfit.minimize.MinuitMinimizer(verbosity = 10)
# minimum = minimizer.minimize(nll)


# In[27]:


# jpsi_width


# In[28]:


# plt.hist(sample, weights=1 / prob(sample))


# # Fitting

# In[29]:


# start = time.time()

# for param in total_f.get_dependents():
#     param.randomize()
    
# # for param in total_f.get_dependents():
# #     print(zfit.run(param))
    
# nll = zfit.loss.UnbinnedNLL(model=total_f, data=data2, fit_range = (x_min, x_max))

# minimizer = zfit.minimize.MinuitMinimizer(verbosity = 5)
# # minimizer._use_tfgrad = False
# result = minimizer.minimize(nll)

# # param_errors = result.error()

# # for var, errors in param_errors.items():
# #     print('{}: ^{{+{}}}_{{{}}}'.format(var.name, errors['upper'], errors['lower']))

# print("Function minimum:", result.fmin)
# # print("Results:", result.params)
# print("Hesse errors:", result.hesse())


# In[30]:


# print("Time taken for fitting: {}".format(display_time(int(time.time()-start))))

# # probs = total_f.pdf(test_q)

# calcs_test = zfit.run(probs)
# res_y = zfit.run(jpsi_res(test_q))


# In[31]:


# plt.clf()
# # plt.plot(x_part, calcs, '.')
# plt.plot(test_q, calcs_test, label = 'pdf')
# # plt.plot(test_q, res_y, label = 'res')
# plt.legend()
# plt.ylim(0.0, 10e-6)
# # plt.yscale('log')
# # plt.xlim(3080, 3110)
# plt.savefig('test3.png')
# # print(jpsi_width)


# In[32]:


# _tot = 4.37e-7+6.02e-5+4.97e-6
# _probs = []
# _probs.append(6.02e-5/_tot)
# _probs.append(4.97e-6/_tot)
# _probs.append(4.37e-7/_tot)
# print(_probs)


# In[33]:


# dtype = 'float64'
# # mixture = tfd.Uniform(tf.constant(x_min, dtype=dtype), tf.constant(x_max, dtype=dtype))
# mixture = tfd.MixtureSameFamily(mixture_distribution=tfd.Categorical(probs=[tf.constant(0.007, dtype=dtype),
#                                                                             tf.constant(0.917, dtype=dtype),
#                                                                             tf.constant(0.076, dtype=dtype)]),
#                                 components_distribution=tfd.Uniform(low=[tf.constant(x_min, dtype=dtype), 
#                                                                          tf.constant(3080, dtype=dtype),
#                                                                          tf.constant(3670, dtype=dtype)], 
#                                                                     high=[tf.constant(x_max, dtype=dtype),
#                                                                           tf.constant(3112, dtype=dtype), 
#                                                                           tf.constant(3702, dtype=dtype)]))
# # for i in range(10):
# #     print(zfit.run(mixture.prob(mixture.sample((10, 1)))))


# In[34]:


# print((zfit.run(jpsi_p)%(2*np.pi))/np.pi)
# print((zfit.run(psi2s_p)%(2*np.pi))/np.pi)


# In[35]:


#         def jpsi_res(q):
#             return resonance(q, _mass = jpsi_mass, scale = jpsi_scale,
#                              phase = jpsi_phase, width = jpsi_width)

#         def psi2s_res(q):
#             return resonance(q, _mass = psi2s_mass, scale = psi2s_scale,
#                              phase = psi2s_phase, width = psi2s_width)
        
#         def p3770_res(q):
#             return resonance(q, _mass = p3770_mass, scale = p3770_scale,
#                              phase = p3770_phase, width = p3770_width)
        
#         def p4040_res(q):
#             return resonance(q, _mass = p4040_mass, scale = p4040_scale,
#                              phase = p4040_phase, width = p4040_width)
        
#         def p4160_res(q):
#             return resonance(q, _mass = p4160_mass, scale = p4160_scale,
#                              phase = p4160_phase, width = p4160_width)
        
#         def p4415_res(q):
#             return resonance(q, _mass = p4415_mass, scale = p4415_scale,
#                              phase = p4415_phase, width = p4415_width)


# In[36]:


# 0.15**2*4.2/1000
# result.hesse()


# ## Constraints

# In[37]:


# 1. Constraint - Real part of sum of Psi contrib and D contribs

sum_list = []

sum_list.append(ztf.to_complex(jpsi_s) * tf.exp(tf.complex(ztf.constant(0.0), jpsi_p)) * ztf.to_complex(jpsi_w / (tf.pow(jpsi_m,3))))
sum_list.append(ztf.to_complex(psi2s_s) * tf.exp(tf.complex(ztf.constant(0.0), psi2s_p)) * ztf.to_complex(psi2s_w / (tf.pow(psi2s_m,3))))
sum_list.append(ztf.to_complex(p3770_s) * tf.exp(tf.complex(ztf.constant(0.0), p3770_p)) * ztf.to_complex(p3770_w / (tf.pow(p3770_m,3))))
sum_list.append(ztf.to_complex(p4040_s) * tf.exp(tf.complex(ztf.constant(0.0), p4040_p)) * ztf.to_complex(p4040_w / (tf.pow(p4040_m,3))))
sum_list.append(ztf.to_complex(p4160_s) * tf.exp(tf.complex(ztf.constant(0.0), p4160_p)) * ztf.to_complex(p4160_w / (tf.pow(p4160_m,3))))
sum_list.append(ztf.to_complex(p4415_s) * tf.exp(tf.complex(ztf.constant(0.0), p4415_p)) * ztf.to_complex(p4415_w / (tf.pow(p4415_m,3))))
sum_list.append(ztf.to_complex(DDstar_s) * tf.exp(tf.complex(ztf.constant(0.0), DDstar_p)) * (ztf.to_complex(1.0 / (10.0*tf.pow(Dstar_m,2)) + 1.0 / (10.0*tf.pow(D_m,2)))))
sum_list.append(ztf.to_complex(Dbar_s) * tf.exp(tf.complex(ztf.constant(0.0), Dbar_p)) * ztf.to_complex(1.0 / (6.0*tf.pow(Dbar_m,2))))

sum_ru_1 = ztf.to_complex(ztf.constant(0.0))

for part in sum_list:
    sum_ru_1 += part

sum_1 = tf.math.real(sum_ru_1)
# constraint1 = zfit.constraint.GaussianConstraint(params = sum_1, mu = ztf.constant(1.7*10**-8), 
#                                                  sigma = ztf.constant(2.2*10**-8))

constraint1 = tf.pow((sum_1-ztf.constant(1.7*10**-8))/ztf.constant(2.2*10**-8),2)/ztf.constant(2.)

# 2. Constraint - Abs. of sum of Psi contribs and D contribs

sum_2 = tf.abs(sum_ru_1)
constraint2 = tf.cond(tf.greater_equal(sum_2, 5.0e-8), lambda: 100000., lambda: 0.)

# 3. Constraint - Maximum eta of D contribs

constraint3_0 = tf.cond(tf.greater_equal(tf.abs(Dbar_s), 0.2), lambda: 100000., lambda: 0.)

constraint3_1 = tf.cond(tf.greater_equal(tf.abs(DDstar_s), 0.2), lambda: 100000., lambda: 0.)

# 4. Constraint - Formfactor multivariant gaussian covariance fplus

Cov_matrix = [[ztf.constant(   1.), ztf.constant( 0.45), ztf.constant( 0.19), ztf.constant(0.857), ztf.constant(0.598), ztf.constant(0.531), ztf.constant(0.752), ztf.constant(0.229), ztf.constant(0,117)],
              [ztf.constant( 0.45), ztf.constant(   1.), ztf.constant(0.677), ztf.constant(0.708), ztf.constant(0.958), ztf.constant(0.927), ztf.constant(0.227), ztf.constant(0.443), ztf.constant(0.287)],
              [ztf.constant( 0.19), ztf.constant(0.677), ztf.constant(   1.), ztf.constant(0.595), ztf.constant(0.770), ztf.constant(0.819),ztf.constant(-0.023), ztf.constant( 0.07), ztf.constant(0.196)],
              [ztf.constant(0.857), ztf.constant(0.708), ztf.constant(0.595), ztf.constant(   1.), ztf.constant( 0.83), ztf.constant(0.766), ztf.constant(0.582), ztf.constant(0.237), ztf.constant(0.192)],
              [ztf.constant(0.598), ztf.constant(0.958), ztf.constant(0.770), ztf.constant( 0.83), ztf.constant(   1.), ztf.constant(0.973), ztf.constant(0.324), ztf.constant(0.372), ztf.constant(0.272)],
              [ztf.constant(0.531), ztf.constant(0.927), ztf.constant(0.819), ztf.constant(0.766), ztf.constant(0.973), ztf.constant(   1.), ztf.constant(0.268), ztf.constant(0.332), ztf.constant(0.269)],
              [ztf.constant(0.752), ztf.constant(0.227),ztf.constant(-0.023), ztf.constant(0.582), ztf.constant(0.324), ztf.constant(0.268), ztf.constant(   1.), ztf.constant( 0.59), ztf.constant(0.515)],
              [ztf.constant(0.229), ztf.constant(0.443), ztf.constant( 0.07), ztf.constant(0.237), ztf.constant(0.372), ztf.constant(0.332), ztf.constant( 0.59), ztf.constant(   1.), ztf.constant(0.897)],
              [ztf.constant(0.117), ztf.constant(0.287), ztf.constant(0.196), ztf.constant(0.192), ztf.constant(0.272), ztf.constant(0.269), ztf.constant(0.515), ztf.constant(0.897), ztf.constant(   1.)]]

def triGauss(val1,val2,val3,m = Cov_matrix):

    mean1 = ztf.constant(0.466)
    mean2 = ztf.constant(-0.885)
    mean3 = ztf.constant(-0.213)
    sigma1 = ztf.constant(0.014)
    sigma2 = ztf.constant(0.128)
    sigma3 = ztf.constant(0.548)
    x1 = (val1-mean1)/sigma1
    x2 = (val2-mean2)/sigma2
    x3 = (val3-mean3)/sigma3
    rho12 = m[0][1]
    rho13 = m[0][2]
    rho23 = m[1][2]
    w = x1*x1*(rho23*rho23-1) + x2*x2*(rho13*rho13-1)+x3*x3*(rho12*rho12-1)+2*(x1*x2*(rho12-rho13*rho23)+x1*x3*(rho13-rho12*rho23)+x2*x3*(rho23-rho12*rho13))
    d = 2*(rho12*rho12+rho13*rho13+rho23*rho23-2*rho12*rho13*rho23-1)
    
    fcn = -w/d
    chisq = -2*fcn
    return chisq

constraint4 = triGauss(bplus_0, bplus_1, bplus_2)

# mean1 = ztf.constant(0.466)
# mean2 = ztf.constant(-0.885)
# mean3 = ztf.constant(-0.213)
# sigma1 = ztf.constant(0.014)
# sigma2 = ztf.constant(0.128)
# sigma3 = ztf.constant(0.548)
# constraint4_0 = tf.pow((bplus_0-mean1)/sigma1,2)/ztf.constant(2.)
# constraint4_1 = tf.pow((bplus_1-mean2)/sigma2,2)/ztf.constant(2.)
# constraint4_2 = tf.pow((bplus_2-mean3)/sigma3,2)/ztf.constant(2.)

# 5. Constraint - Abs. of sum of light contribs

sum_list_5 = []

sum_list_5.append(rho_s*rho_w/rho_m)
sum_list_5.append(omega_s*omega_w/omega_m)
sum_list_5.append(phi_s*phi_w/phi_m)


sum_ru_5 = ztf.constant(0.0)

for part in sum_list_5:
    sum_ru_5 += part

constraint5 = tf.cond(tf.greater_equal(tf.abs(sum_ru_5), ztf.constant(0.02)), lambda: 100000., lambda: 0.)

# 6. Constraint on phases of Jpsi and Psi2s for cut out fit


# constraint6_0 = zfit.constraint.GaussianConstraint(params = jpsi_p, mu = ztf.constant(pdg["jpsi_phase_unc"]),
#                                                    sigma = ztf.constant(jpsi_phase))
# constraint6_1 = zfit.constraint.GaussianConstraint(params = psi2s_p, mu = ztf.constant(pdg["psi2s_phase_unc"]),
#                                                    sigma = ztf.constant(psi2s_phase))

constraint6_0  =  tf.pow((jpsi_p-ztf.constant(jpsi_phase))/ztf.constant(pdg["jpsi_phase_unc"]),2)/ztf.constant(2.)
constraint6_1  =  tf.pow((psi2s_p-ztf.constant(psi2s_phase))/ztf.constant(pdg["psi2s_phase_unc"]),2)/ztf.constant(2.)

# 7. Constraint on Ctt with higher limits

constraint7 = tf.cond(tf.greater_equal(Ctt*Ctt, 0.25), lambda: 100000., lambda: 0.)

constraint7dtype = tf.float64

# zfit.run(constraint6_0)

# ztf.convert_to_tensor(constraint6_0)

#List of all constraints

constraints = [constraint1, constraint2, constraint3_0, constraint3_1,# constraint4, #constraint4_0, constraint4_1, constraint4_2,
               constraint6_0, constraint6_1]#, constraint7]


# ## Reset params

# In[38]:


def reset_param_values():   
    jpsi_m.set_value(jpsi_mass)
    jpsi_s.set_value(jpsi_scale)
    jpsi_p.set_value(jpsi_phase)
    jpsi_w.set_value(jpsi_width)
    psi2s_m.set_value(psi2s_mass)
    psi2s_s.set_value(psi2s_scale)
    psi2s_p.set_value(psi2s_phase)
    psi2s_w.set_value(psi2s_width)
    p3770_m.set_value(p3770_mass)
    p3770_s.set_value(p3770_scale)
    p3770_p.set_value(p3770_phase)
    p3770_w.set_value(p3770_width)
    p4040_m.set_value(p4040_mass)
    p4040_s.set_value(p4040_scale)
    p4040_p.set_value(p4040_phase)
    p4040_w.set_value(p4040_width)
    p4160_m.set_value(p4160_mass)
    p4160_s.set_value(p4160_scale)
    p4160_p.set_value(p4160_phase)
    p4160_w.set_value(p4160_width)
    p4415_m.set_value(p4415_mass)
    p4415_s.set_value(p4415_scale)
    p4415_p.set_value(p4415_phase)
    p4415_w.set_value(p4415_width)
    rho_m.set_value(rho_mass)
    rho_s.set_value(rho_scale)
    rho_p.set_value(rho_phase)
    rho_w.set_value(rho_width)
    omega_m.set_value(omega_mass)
    omega_s.set_value(omega_scale)
    omega_p.set_value(omega_phase)
    omega_w.set_value(omega_width)
    phi_m.set_value(phi_mass)
    phi_s.set_value(phi_scale)
    phi_p.set_value(phi_phase)
    phi_w.set_value(phi_width)
    Dstar_m.set_value(Dstar_mass)
    DDstar_s.set_value(0.0)
    DDstar_p.set_value(0.0)
    D_m.set_value(D_mass)
    Dbar_m.set_value(Dbar_mass)
    Dbar_s.set_value(0.0)
    Dbar_p.set_value(0.0)
    tau_m.set_value(pdg['tau_M'])
    Ctt.set_value(0.0)
    b0_0.set_value(0.292)
    b0_1.set_value(0.281)
    b0_2.set_value(0.150)
    bplus_0.set_value(0.466)
    bplus_1.set_value(-0.885)
    bplus_2.set_value(-0.213)
    bT_0.set_value(0.460)
    bT_1.set_value(-1.089)
    bT_2.set_value(-1.114)


# # Analysis

# In[39]:


# # zfit.run.numeric_checks = False   

# fitting_range = 'cut'
# total_BR = 1.7e-10 + 4.9e-10 + 2.5e-9 + 6.02e-5 + 4.97e-6 + 1.38e-9 + 4.2e-10 + 2.6e-9 + 6.1e-10 + 4.37e-7
# cut_BR = 1.0 - (6.02e-5 + 4.97e-6)/total_BR

# Ctt_list = []
# Ctt_error_list = []

# nr_of_toys = 1
# if fitting_range == 'cut':
#     nevents = int(pdg["number_of_decays"]*cut_BR)
# else:
#     nevents = int(pdg["number_of_decays"])
# # nevents = pdg["number_of_decays"]
# event_stack = 1000000
# # nevents *= 41
# # zfit.settings.set_verbosity(10)
# calls = int(nevents/event_stack + 1)

# total_samp = []

# start = time.time()

# sampler = total_f.create_sampler(n=event_stack)

# for toy in range(nr_of_toys):
    
#     ### Generate data
    
# #     clear_output(wait=True)
    
#     print("Toy {}: Generating data...".format(toy))
    
#     dirName = 'data/zfit_toys/toy_{0}'.format(toy)
    
#     if not os.path.exists(dirName):
#         os.mkdir(dirName)
#         print("Directory " , dirName ,  " Created ")
    
#     reset_param_values()
    
#     if fitting_range == 'cut':
        
#         sampler.resample(n=nevents)
#         s = sampler.unstack_x()
#         sam = zfit.run(s)
#         calls = 0
#         c = 1
        
#     else:    
#         for call in range(calls):

#             sampler.resample(n=event_stack)
#             s = sampler.unstack_x()
#             sam = zfit.run(s)

#             c = call + 1

#             with open("data/zfit_toys/toy_{0}/{1}.pkl".format(toy, call), "wb") as f:
#                 pkl.dump(sam, f, pkl.HIGHEST_PROTOCOL)
            
#     print("Toy {}: Data generation finished".format(toy))
        
#     ### Load data
    
#     print("Toy {}: Loading data...".format(toy))
    
#     if fitting_range == 'cut':
        
#         total_samp = sam
    
#     else:
                
#         for call in range(calls):
#             with open(r"data/zfit_toys/toy_0/{}.pkl".format(call), "rb") as input_file:
#                 sam = pkl.load(input_file)
#             total_samp = np.append(total_samp, sam)

#         total_samp = total_samp.astype('float64')
    
#     if fitting_range == 'full':

#         data = zfit.data.Data.from_numpy(array=total_samp[:int(nevents)], obs=obs)
    
#         print("Toy {}: Loading data finished".format(toy))

#         ### Fit data

#         print("Toy {}: Fitting pdf...".format(toy))

#         for param in total_f.get_dependents():
#             param.randomize()

#         nll = zfit.loss.UnbinnedNLL(model=total_f, data=data, fit_range = (x_min, x_max), constraints = constraints)

#         minimizer = zfit.minimize.MinuitMinimizer(verbosity = 5)
#         # minimizer._use_tfgrad = False
#         result = minimizer.minimize(nll)

#         print("Toy {}: Fitting finished".format(toy))

#         print("Function minimum:", result.fmin)
#         print("Hesse errors:", result.hesse())

#         params = result.params
#         Ctt_list.append(params[Ctt]['value'])
#         Ctt_error_list.append(params[Ctt]['minuit_hesse']['error'])

#         #plotting the result

#         plotdirName = 'data/plots'.format(toy)

#         if not os.path.exists(plotdirName):
#             os.mkdir(plotdirName)
# #             print("Directory " , dirName ,  " Created ")
        
#         probs = total_f.pdf(test_q, norm_range=False)
#         calcs_test = zfit.run(probs)
#         plt.clf()
#         plt.plot(test_q, calcs_test, label = 'pdf')
#         plt.legend()
#         plt.ylim(0.0, 6e-6)
#         plt.savefig(plotdirName + '/toy_fit_full_range{}.png'.format(toy))

#         print("Toy {0}/{1}".format(toy+1, nr_of_toys))
#         print("Time taken: {}".format(display_time(int(time.time() - start))))
#         print("Projected time left: {}".format(display_time(int((time.time() - start)/(c+calls*(toy))*((nr_of_toys-toy)*calls-c)))))
    
#     if fitting_range == 'cut':
        
#         _1 = np.where((total_samp >= x_min) & (total_samp <= (jpsi_mass - 60.)))
        
#         tot_sam_1 = total_samp[_1]
    
#         _2 = np.where((total_samp >= (jpsi_mass + 70.)) & (total_samp <= (psi2s_mass - 50.)))
        
#         tot_sam_2 = total_samp[_2]

#         _3 = np.where((total_samp >= (psi2s_mass + 50.)) & (total_samp <= x_max))
        
#         tot_sam_3 = total_samp[_3]

#         tot_sam = np.append(tot_sam_1, tot_sam_2)
#         tot_sam = np.append(tot_sam, tot_sam_3)
    
#         data = zfit.data.Data.from_numpy(array=tot_sam[:int(nevents)], obs=obs_fit)
        
#         print("Toy {}: Loading data finished".format(toy))
        
#         ### Fit data

#         print("Toy {}: Fitting pdf...".format(toy))

#         for param in total_f_fit.get_dependents():
#             param.randomize()

#         nll = zfit.loss.UnbinnedNLL(model=total_f_fit, data=data, constraints = constraints)

#         minimizer = zfit.minimize.MinuitMinimizer(verbosity = 5)
#         # minimizer._use_tfgrad = False
#         result = minimizer.minimize(nll)

#         print("Function minimum:", result.fmin)
#         print("Hesse errors:", result.hesse())

#         params = result.params
        
#         if result.converged:
#             Ctt_list.append(params[Ctt]['value'])
#             Ctt_error_list.append(params[Ctt]['minuit_hesse']['error'])

#         #plotting the result

#         plotdirName = 'data/plots'.format(toy)

#         if not os.path.exists(plotdirName):
#             os.mkdir(plotdirName)
#         #         print("Directory " , dirName ,  " Created ")
        
#         plt.clf()
#         plt.hist(tot_sam, bins = int((x_max-x_min)/7.), label = 'toy data')
#         plt.savefig(plotdirName + '/toy_histo_cut_region{}.png'.format(toy))

        
#         probs = total_f_fit.pdf(test_q, norm_range=False)
#         calcs_test = zfit.run(probs)
#         plt.clf()
#         plt.plot(test_q, calcs_test, label = 'pdf')
#         plt.axvline(x=jpsi_mass-60.,color='red', linewidth=0.7, linestyle = 'dotted')
#         plt.axvline(x=jpsi_mass+70.,color='red', linewidth=0.7, linestyle = 'dotted')
#         plt.axvline(x=psi2s_mass-50.,color='red', linewidth=0.7, linestyle = 'dotted')
#         plt.axvline(x=psi2s_mass+50.,color='red', linewidth=0.7, linestyle = 'dotted')
#         plt.legend()
#         plt.ylim(0.0, 1.5e-6)
#         plt.savefig(plotdirName + '/toy_fit_cut_region{}.png'.format(toy))
        
#         print("Toy {0}/{1}".format(toy+1, nr_of_toys))
#         print("Time taken: {}".format(display_time(int(time.time() - start))))
#         print("Projected time left: {}".format(display_time(int((time.time() - start)/(toy+1))*((nr_of_toys-toy-1)))))
        


# In[40]:


# with open("data/results/Ctt_list.pkl", "wb") as f:
#     pkl.dump(Ctt_list, f, pkl.HIGHEST_PROTOCOL)
# with open("data/results/Ctt_error_list.pkl", "wb") as f:
#     pkl.dump(Ctt_error_list, f, pkl.HIGHEST_PROTOCOL)


# In[41]:


# print('{0}/{1} fits converged'.format(len(Ctt_list), nr_of_toys))
# print('Mean Ctt value = {}'.format(np.mean(Ctt_list)))
# print('Mean Ctt error = {}'.format(np.mean(Ctt_error_list)))
# print('95 Sensitivy = {}'.format(((2*np.mean(Ctt_error_list))**2)*4.2/1000))


# In[42]:


# plt.hist(tot_sam, bins = int((x_max-x_min)/7.))

# plt.show()

# # _ = np.where((total_samp >= x_min) & (total_samp <= (jpsi_mass - 50.)))

# tot_sam.shape


# In[43]:


# Ctt.floating = False


# In[44]:


# zfit.run(nll.value())


# In[45]:


# result.fmin


# In[46]:


# BR_steps = np.linspace(0.0, 1e-3, 11)


# # CLS Code

# In[48]:


# zfit.run.numeric_checks = False   

load = False

bo5 = True

bo5_set = 5

fitting_range = 'cut'
total_BR = 1.7e-10 + 4.9e-10 + 2.5e-9 + 6.02e-5 + 4.97e-6 + 1.38e-9 + 4.2e-10 + 2.6e-9 + 6.1e-10 + 4.37e-7
cut_BR = 1.0 - (6.02e-5 + 4.97e-6)/total_BR

Ctt_list = []
Ctt_error_list = []

nr_of_toys = 1
nevents = int(pdg["number_of_decays"]*cut_BR)
# nevents = pdg["number_of_decays"]
event_stack = 1000000
# nevents *= 41
# zfit.settings.set_verbosity(10)

mi = 0.0
ma = 1e-3
ste = 11

BR_steps = np.linspace(mi, ma, ste)

Ctt_steps = np.sqrt(BR_steps/4.2*1000)

total_samp = []

start = time.time()

Nll_list = []

sampler = total_f.create_sampler(n=nevents)

__ = 0

#-----------------------------------------------------

if not load:

    for Ctt_step in Ctt_steps:
        
        __ += 1
        
        newset = True
        
        for floaty in [True, False]:

            Ctt.floating = floaty

            Nll_list.append([])
            
            if bo5:
                
                if __ < 6:
                
                    while len(Nll_list[-1])/bo5_set < nr_of_toys:

                        print('Step: {0}/{1}'.format(__, ste))

                        print('Current Ctt: {0}'.format(Ctt_step))
                        print('Ctt floating: {0}'.format(floaty))

                        print('Toy {0}/{1} - Fit {2}/{3}'.format(int(len(Nll_list[-1])/bo5_set), nr_of_toys, len(Nll_list[-1]), bo5_set))

                        reset_param_values()

                        if floaty:
                            Ctt.set_value(Ctt_step)
                        else:
                            Ctt.set_value(0.0)

                        if newset:
                            sampler.resample(n=nevents)
                            s = sampler.unstack_x()
                            total_samp = zfit.run(s)
                            calls = 0
                            c = 1
                            newset = False


                            data = zfit.data.Data.from_numpy(array=total_samp[:int(nevents)], obs=obs_fit)

                        ### Fit data

                        for param in total_f_fit.get_dependents():
                            param.randomize()

                        nll = zfit.loss.UnbinnedNLL(model=total_f_fit, data=data, constraints = constraints)

                        minimizer = zfit.minimize.MinuitMinimizer(verbosity = 5)
                        # minimizer._use_tfgrad = False
                        result = minimizer.minimize(nll)

                #         print("Function minimum:", result.fmin)
                #         print("Hesse errors:", result.hesse())

                        params = result.params

                        if result.converged:
                            Nll_list[-1].append(result.fmin)

            else:

                while len(Nll_list[-1]) < nr_of_toys:

                    print('Step: {0}/{1}'.format(__, ste))

                    print('Current Ctt: {0}'.format(Ctt_step))
                    print('Ctt floating: {0}'.format(floaty))

                    print('Toy {0}/{1}'.format(len(Nll_list[-1]), nr_of_toys))

                    reset_param_values()

                    if floaty:
                        Ctt.set_value(Ctt_step)
                    else:
                        Ctt.set_value(0.0)

                    if floaty:
                        sampler.resample(n=nevents)
                        s = sampler.unstack_x()
                        total_samp = zfit.run(s)
                        calls = 0
                        c = 1


                        data = zfit.data.Data.from_numpy(array=total_samp[:int(nevents)], obs=obs_fit)

                    ### Fit data

                    for param in total_f_fit.get_dependents():
                        param.randomize()

                    nll = zfit.loss.UnbinnedNLL(model=total_f_fit, data=data, constraints = constraints)

                    minimizer = zfit.minimize.MinuitMinimizer(verbosity = 5)
                    # minimizer._use_tfgrad = False
                    result = minimizer.minimize(nll)

            #         print("Function minimum:", result.fmin)
            #         print("Hesse errors:", result.hesse())

                    params = result.params

                    if result.converged:
                        Nll_list[-1].append(result.fmin)

            _t = int(time.time()-start)

            print('Time Taken: {}'.format(display_time(int(_t))))

            print('Predicted time left: {}'.format(display_time(int((_t/(__+1)*(ste-__-1))))))


# In[49]:


if load:
    Nll_list = []
    CLs_values = []

    _dir = 'data/CLs/finished/f1d1'
    
    jobs = os.listdir(_dir)
    
    for s in range(ste):
        CLs_values.append([])
        
    for s in range(2*ste):
        Nll_list.append([])
    
    for job in jobs:
        if not os.path.exists("{}/{}/data/CLs/{}-{}_{}s--CLs_Nll_list.pkl".format(_dir, job, mi,ma,ste)):
            print(job)
            continue
        
        with open(r"{}/{}/data/CLs/{}-{}_{}s--CLs_Nll_list.pkl".format(_dir, job, mi,ma,ste), "rb") as input_file:
            _Nll_list = pkl.load(input_file)
        
        if bo5:     
            for s in range(2*ste):
                Nll_list[s].append(np.min(_Nll_list[s]))
        else:
            for s in range(2*ste):
                Nll_list[s].extend(_Nll_list[s])
        
        with open(r"{}/{}/data/CLs/{}-{}_{}s--CLs_list.pkl".format(_dir, job, mi,ma,ste), "rb") as input_file:
            _CLs_values = pkl.load(input_file)
        
        for s in range(ste):
            CLs_values[s].extend(_CLs_values[s])
            
        print(np.shape(Nll_list))


# In[50]:


dirName = 'data/CLs'

# if bo5 and not load:
#     for s in range(2*ste):
#         Nll_list[s] = [np.min(Nll_list[s])]

# if bo5: 
#     CLs_values= []
#     for i in range(int(len(Nll_list)/2)):
#         CLs_values.append([])
#         for j in range(len(Nll_list[0])):
#             CLs_values[i].append(Nll_list[2*i][j]-Nll_list[2*i+1][j])


if not load:
        
    if not os.path.exists(dirName):
        os.mkdir(dirName)
        print("Directory " , dirName ,  " Created ")

    with open("{}/{}-{}_{}s--CLs_Nll_list.pkl".format(dirName, mi,ma,ste), "wb") as f:
        pkl.dump(Nll_list, f, pkl.HIGHEST_PROTOCOL)
        
#     CLs_values = []
    
#     for i in range(int(len(Nll_list)/2)):
#         CLs_values.append([])
#         for j in range(nr_of_toys):
#             CLs_values[i].append(Nll_list[2*i][j]-Nll_list[2*i+1][j])

#     with open("{}/{}-{}_{}s--CLs_list.pkl".format(dirName, mi,ma,ste), "wb") as f:
#         pkl.dump(CLs_values, f, pkl.HIGHEST_PROTOCOL)


# In[51]:


# print(CLs_values)
# print(Nll_list)


# ## Plot

# In[56]:


# l = []

# if not os.path.exists('data/CLs/plots'):
#     os.mkdir('data/CLs/plots')
#     print("Directory " , 'data/CLs/plots' ,  " Created ")

# for i in range(len(CLs_values)):
#     plt.clf()
#     plt.title('Ctt value: {:.2f}'.format(Ctt_steps[i]))
#     plt.hist(CLs_values[0], bins = 100, range = (-25, 25), label = 'Ctt fixed to 0')
#     plt.hist(CLs_values[i], bins = 100, range = (-25, 25), label = 'Ctt floating')
#     plt.axvline(x=np.mean(CLs_values[0]),color='red', linewidth=1.0, linestyle = 'dotted')
#     plt.legend()
#     plt.savefig('data/CLs/plots/CLs-BR({:.1E}).png'.format(BR_steps[i]))
    
#     l.append(len(np.where(np.array(CLs_values[i]) < np.mean(CLs_values[0]))[0]))


# In[57]:


# for s in range(len(l)):
#     print('BR: {:.4f}'.format(BR_steps[s]))
#     print(2*l[s]/len(CLs_values[0]))
#     print()


# In[ ]:


# print(np.array(Nll_list[0][:10])-np.array(Nll_list[1][:10]))


# In[ ]: