Newer
Older
Master_thesis / raremodel-nb.py
@saslie saslie on 8 Jul 2019 38 KB ...
#!/usr/bin/env python
# coding: utf-8

# # Import

# In[1]:


import os

# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import numpy as np
from pdg_const import pdg
import matplotlib
import matplotlib.pyplot as plt
import pickle as pkl
import sys
import time
from helperfunctions import display_time, prepare_plot
import cmath as c
import scipy.integrate as integrate
from scipy.optimize import fminbound
from array import array as arr
import collections
from itertools import compress
import tensorflow as tf
import zfit
from zfit import ztf
# from IPython.display import clear_output
import os
import tensorflow_probability as tfp
tfd = tfp.distributions


# In[2]:


# chunksize = 10000
# zfit.run.chunking.active = True
# zfit.run.chunking.max_n_points = chunksize


# # Build model and graphs
# ## Create graphs

# In[3]:


def formfactor( q2, subscript): #returns real value
    #check if subscript is viable

    if subscript != "0" and subscript != "+" and subscript != "T":
        raise ValueError('Wrong subscript entered, choose either 0, + or T')

    #get constants

    mK = ztf.constant(pdg['Ks_M'])
    mbstar0 = ztf.constant(pdg["mbstar0"])
    mbstar = ztf.constant(pdg["mbstar"])
    b0 = ztf.constant(pdg["b0"])
    bplus = ztf.constant(pdg["bplus"])
    bT = ztf.constant(pdg["bT"])

    mmu = ztf.constant(pdg['muon_M'])
    mb = ztf.constant(pdg['bquark_M'])
    ms = ztf.constant(pdg['squark_M'])
    mB = ztf.constant(pdg['Bplus_M'])

    #N comes from derivation in paper

    N = 3

    #some helperfunctions

    tpos = (mB - mK)**2
    tzero = (mB + mK)*(ztf.sqrt(mB)-ztf.sqrt(mK))**2

    z_oben = ztf.sqrt(tpos - q2) - ztf.sqrt(tpos - tzero)
    z_unten = ztf.sqrt(tpos - q2) + ztf.sqrt(tpos - tzero)
    z = tf.divide(z_oben, z_unten)

    #calculate f0

    if subscript == "0":
        prefactor = 1/(1 - q2/(mbstar0**2))
        _sum = 0

        for i in range(N):
            _sum += b0[i]*(tf.pow(z,i))

        return tf.complex(prefactor * _sum, ztf.constant(0.0))

    #calculate f+ or fT

    else:
        prefactor = 1/(1 - q2/(mbstar**2))
        _sum = 0

        if subscript == "T":
            b = bT
        else:
            b = bplus

        for i in range(N):
            _sum += b[i] * (tf.pow(z, i) - ((-1)**(i-N)) * (i/N) * tf.pow(z, N))

        return tf.complex(prefactor * _sum, ztf.constant(0.0))

def resonance(q, _mass, width, phase, scale):

    q2 = tf.pow(q, 2)

    mmu = ztf.constant(pdg['muon_M'])

    p = 0.5 * ztf.sqrt(q2 - 4*(mmu**2))

    p0 =  0.5 * ztf.sqrt(_mass**2 - 4*mmu**2)

    gamma_j = tf.divide(p, q) * _mass * width / p0

    #Calculate the resonance

    _top = tf.complex(_mass * width, ztf.constant(0.0))

    _bottom = tf.complex(_mass**2 - q2, -_mass*gamma_j)

    com = _top/_bottom

    #Rotate by the phase

    r = ztf.to_complex(scale*tf.abs(com))

    _phase = tf.angle(com)

    _phase += phase

    com = r * tf.exp(tf.complex(ztf.constant(0.0), _phase))

    return com

def bifur_gauss(q, mean, sigma_L, sigma_R, scale):

    _exp = tf.where(q < mean, ztf.exp(- tf.pow((q-mean),2) / (2 * sigma_L**2)), ztf.exp(- tf.pow((q-mean),2) / (2 * sigma_R**2)))

    #Scale so the total area under curve is 1 and the top of the cusp is continuous

    dgamma = scale*_exp/(ztf.sqrt(2*np.pi))*2*(sigma_L*sigma_R)/(sigma_L+sigma_R)

    com = ztf.complex(dgamma, ztf.constant(0.0))

    return com

def axiv_nonres(q):

    GF = ztf.constant(pdg['GF'])
    alpha_ew = ztf.constant(pdg['alpha_ew'])
    Vtb = ztf.constant(pdg['Vtb'])
    Vts = ztf.constant(pdg['Vts'])
    C10eff = ztf.constant(pdg['C10eff'])

    mmu = ztf.constant(pdg['muon_M'])
    mb = ztf.constant(pdg['bquark_M'])
    ms = ztf.constant(pdg['squark_M'])
    mK = ztf.constant(pdg['Ks_M'])
    mB = ztf.constant(pdg['Bplus_M'])

    q2 = tf.pow(q, 2)

    #Some helperfunctions

    beta = ztf.sqrt(tf.abs(1. - 4. * mmu**2. / q2))

    kabs = ztf.sqrt(mB**2. +tf.pow(q2, 2)/mB**2. + mK**4./mB**2. - 2. * (mB**2. * mK**2. + mK**2. * q2 + mB**2. * q2) / mB**2.)

    #prefactor in front of whole bracket

    prefactor1 = GF**2. *alpha_ew**2. * (tf.abs(Vtb*Vts))**2. * kabs * beta / (128. * np.pi**5.)

    #left term in bracket

    bracket_left = 2./3. * kabs**2. * beta**2. *tf.abs(tf.complex(C10eff, ztf.constant(0.0))*formfactor(q2, "+"))**2.

    #middle term in bracket

    _top = 4. * mmu**2. * (mB**2. - mK**2.) * (mB**2. - mK**2.)

    _under = q2 * mB**2.

    bracket_middle = _top/_under *tf.pow(tf.abs(tf.complex(C10eff, ztf.constant(0.0)) * formfactor(q2, "0")), 2)

    #Note sqrt(q2) comes from derivation as we use q2 and plot q

    return prefactor1 * (bracket_left + bracket_middle) * 2 *ztf.sqrt(q2)

def vec(q, funcs):
    
    q2 = tf.pow(q, 2)

    GF = ztf.constant(pdg['GF'])
    alpha_ew = ztf.constant(pdg['alpha_ew'])
    Vtb = ztf.constant(pdg['Vtb'])
    Vts = ztf.constant(pdg['Vts'])
    C7eff = ztf.constant(pdg['C7eff'])

    mmu = ztf.constant(pdg['muon_M'])
    mb = ztf.constant(pdg['bquark_M'])
    ms = ztf.constant(pdg['squark_M'])
    mK = ztf.constant(pdg['Ks_M'])
    mB = ztf.constant(pdg['Bplus_M'])

    #Some helperfunctions

    beta = ztf.sqrt(tf.abs(1. - 4. * mmu**2. / q2))

    kabs = ztf.sqrt(mB**2. + tf.pow(q2, 2)/mB**2. + mK**4./mB**2. - 2 * (mB**2 * mK**2 + mK**2 * q2 + mB**2 * q2) / mB**2)

    #prefactor in front of whole bracket

    prefactor1 = GF**2. *alpha_ew**2. * (tf.abs(Vtb*Vts))**2 * kabs * beta / (128. * np.pi**5.)

    #right term in bracket

    prefactor2 = kabs**2 * (1. - 1./3. * beta**2)

    abs_bracket = tf.abs(c9eff(q, funcs) * formfactor(q2, "+") + tf.complex(2.0 * C7eff * (mb + ms)/(mB + mK), ztf.constant(0.0)) * formfactor(q2, "T"))**2

    bracket_right = prefactor2 * abs_bracket

    #Note sqrt(q2) comes from derivation as we use q2 and plot q

    return prefactor1 * bracket_right * 2 * ztf.sqrt(q2)

def c9eff(q, funcs):

    C9eff_nr = tf.complex(ztf.constant(pdg['C9eff']), ztf.constant(0.0))

    c9 = C9eff_nr

    c9 = c9 + funcs

    return c9


# In[4]:


def G(y):
    
    def inner_rect_bracket(q):
        return tf.log(ztf.to_complex((1+tf.sqrt(q))/(1-tf.sqrt(q)))-tf.complex(ztf.constant(0), -1*ztf.constant(np.pi)))    
    
    def inner_right(q):
        return ztf.to_complex(2 * tf.atan(1/tf.sqrt(tf.math.real(-q))))
    
    big_bracket = tf.where(tf.math.real(y) > ztf.constant(0.0), inner_rect_bracket(y), inner_right(y))
    
    return ztf.to_complex(tf.sqrt(tf.abs(y))) * big_bracket

def h_S(m, q):
    
    return ztf.to_complex(2) - G(ztf.to_complex(1) - ztf.to_complex(4*tf.pow(m, 2)) / ztf.to_complex(tf.pow(q, 2)))

def h_P(m, q):
    
    return ztf.to_complex(2/3) + (ztf.to_complex(1) - ztf.to_complex(4*tf.pow(m, 2)) / ztf.to_complex(tf.pow(q, 2))) * h_S(m,q)

def two_p_ccbar(mD, m_D_bar, m_D_star, q):
    
    
    #Load constants
    nu_D_bar = ztf.to_complex(pdg["nu_D_bar"])
    nu_D = ztf.to_complex(pdg["nu_D"])
    nu_D_star = ztf.to_complex(pdg["nu_D_star"])
    
    phase_D_bar = ztf.to_complex(pdg["phase_D_bar"])
    phase_D = ztf.to_complex(pdg["phase_D"])
    phase_D_star = ztf.to_complex(pdg["phase_D_star"])
    
    #Calculation
    left_part =  nu_D_bar * tf.exp(tf.complex(ztf.constant(0.0), phase_D_bar)) * h_S(m_D_bar, q) 
    
    right_part_D = nu_D * tf.exp(tf.complex(ztf.constant(0.0), phase_D)) * h_P(m_D, q) 
    
    right_part_D_star = nu_D_star * tf.exp(tf.complex(ztf.constant(0.0), phase_D_star)) * h_P(m_D_star, q) 

    return left_part + right_part_D + right_part_D_star


# ## C_q,qbar constraint

# In[5]:


# r = rho_scale * rho_width/rho_mass * np.cos(rho_phase)*(1-np.tan(rho_phase)*rho_width/rho_mass)
# o = omega_scale*np.cos(omega_phase)*omega_width/omega_mass
# p = phi_scale*np.cos(phi_phase)*phi_width/phi_mass

# # phi_s = np.linspace(-500, 5000, 100000)

# # p_ = phi_s*np.cos(phi_phase)*phi_width/phi_mass

# # p_y = r+o+p_

# # plt.plot(phi_s, p_y)

# print(r + o + p)


# ## Build pdf

# In[6]:


class total_pdf(zfit.pdf.ZPDF):
    _N_OBS = 1  # dimension, can be omitted
    _PARAMS = ['rho_mass', 'rho_scale', 'rho_phase', 'rho_width',
               'jpsi_mass', 'jpsi_scale', 'jpsi_phase', 'jpsi_width',
               'psi2s_mass', 'psi2s_scale', 'psi2s_phase', 'psi2s_width',
               'p3770_mass', 'p3770_scale', 'p3770_phase', 'p3770_width',
               'p4040_mass', 'p4040_scale', 'p4040_phase', 'p4040_width',
               'p4160_mass', 'p4160_scale', 'p4160_phase', 'p4160_width',
               'p4415_mass', 'p4415_scale', 'p4415_phase', 'p4415_width',
               'omega_mass', 'omega_scale', 'omega_phase', 'omega_width',
               'phi_mass', 'phi_scale', 'phi_phase', 'phi_width',
               'Dbar_mass', 'Dbar_scale', 'Dbar_phase',
               'DDstar_mass', 'DDstar_scale', 'DDstar_phase',
               'tau_mass', 'C_tt']
# the name of the parameters

    def _unnormalized_pdf(self, x):
        
        x = x.unstack_x()
        
        def rho_res(q):
            return resonance(q, _mass = self.params['rho_mass'], scale = self.params['rho_scale'],
                             phase = self.params['rho_phase'], width = self.params['rho_width'])
    
        def omega_res(q):
            return resonance(q, _mass = self.params['omega_mass'], scale = self.params['omega_scale'],
                             phase = self.params['omega_phase'], width = self.params['omega_width'])
        
        def phi_res(q):
            return resonance(q, _mass = self.params['phi_mass'], scale = self.params['phi_scale'],
                             phase = self.params['phi_phase'], width = self.params['phi_width'])

        def jpsi_res(q):
            return resonance(q, _mass = self.params['jpsi_mass'], scale = self.params['jpsi_scale'],
                             phase = self.params['jpsi_phase'], width = self.params['jpsi_width'])

        def psi2s_res(q):
            return resonance(q, _mass = self.params['psi2s_mass'], scale = self.params['psi2s_scale'],
                             phase = self.params['psi2s_phase'], width = self.params['psi2s_width'])
        
        def p3770_res(q):
            return resonance(q, _mass = self.params['p3770_mass'], scale = self.params['p3770_scale'],
                             phase = self.params['p3770_phase'], width = self.params['p3770_width'])
        
        def p4040_res(q):
            return resonance(q, _mass = self.params['p4040_mass'], scale = self.params['p4040_scale'],
                             phase = self.params['p4040_phase'], width = self.params['p4040_width'])
        
        def p4160_res(q):
            return resonance(q, _mass = self.params['p4160_mass'], scale = self.params['p4160_scale'],
                             phase = self.params['p4160_phase'], width = self.params['p4160_width'])
        
        def p4415_res(q):
            return resonance(q, _mass = self.params['p4415_mass'], scale = self.params['p4415_scale'],
                             phase = self.params['p4415_phase'], width = self.params['p4415_width'])
        
        def P2_D(q):
            Dbar_contrib = ztf.to_complex(self.params['Dbar_scale'])*tf.exp(tf.complex(ztf.constant(0.0), self.params['Dbar_phase']))*ztf.to_complex(h_S(self.params['Dbar_mass'], q))
            DDstar_contrib = ztf.to_complex(self.params['DDstar_scale'])*tf.exp(tf.complex(ztf.constant(0.0), self.params['DDstar_phase']))*ztf.to_complex(h_P(self.params['DDstar_mass'], q))
            return Dbar_contrib + DDstar_contrib
        
        def ttau_cusp(q):
            return ztf.to_complex(self.params['C_tt'])*(ztf.to_complex((h_S(self.params['tau_mass'], q))) - ztf.to_complex(h_P(self.params['tau_mass'], q)))
        

        funcs = rho_res(x) + omega_res(x) + phi_res(x) + jpsi_res(x) + psi2s_res(x) + p3770_res(x) + p4040_res(x)+ p4160_res(x) + p4415_res(x) + P2_D(x) + ttau_cusp(x)

        vec_f = vec(x, funcs)

        axiv_nr = axiv_nonres(x)

        tot = vec_f + axiv_nr

        return tot


# ## Load data

# In[7]:


x_min = 2*pdg['muon_M']
x_max = (pdg["Bplus_M"]-pdg["Ks_M"]-0.1)

obs = zfit.Space('q', limits = (x_min, x_max))

# with open(r"./data/slim_points/slim_points_toy_0_range({0}-{1}).pkl".format(int(x_min), int(x_max)), "rb") as input_file:
#     part_set = pkl.load(input_file)

# x_part = part_set['x_part']

# x_part = x_part.astype('float64')

# data = zfit.data.Data.from_numpy(array=x_part, obs=obs)


# ## Setup parameters

# In[8]:


#rho

rho_mass, rho_width, rho_phase, rho_scale = pdg["rho"]

rho_m = zfit.Parameter("rho_m", ztf.constant(rho_mass), floating = False) #lower_limit = rho_mass - rho_width,
#                        upper_limit = rho_mass + rho_width)
rho_w = zfit.Parameter("rho_w", ztf.constant(rho_width), floating = False)
rho_p = zfit.Parameter("rho_p", ztf.constant(rho_phase), floating = False) #, lower_limit=-2*np.pi, upper_limit=2*np.pi)
rho_s = zfit.Parameter("rho_s", ztf.constant(rho_scale), floating = False) #, lower_limit=rho_scale-np.sqrt(rho_scale), upper_limit=rho_scale+np.sqrt(rho_scale))

#omega

omega_mass, omega_width, omega_phase, omega_scale = pdg["omega"]

omega_m = zfit.Parameter("omega_m", ztf.constant(omega_mass), floating = False)
omega_w = zfit.Parameter("omega_w", ztf.constant(omega_width), floating = False)
omega_p = zfit.Parameter("omega_p", ztf.constant(omega_phase), floating = False) #, lower_limit=-2*np.pi, upper_limit=2*np.pi)
omega_s = zfit.Parameter("omega_s", ztf.constant(omega_scale), floating = False) #, lower_limit=omega_scale-np.sqrt(omega_scale), upper_limit=omega_scale+np.sqrt(omega_scale))


#phi

phi_mass, phi_width, phi_phase, phi_scale = pdg["phi"]

phi_m = zfit.Parameter("phi_m", ztf.constant(phi_mass), floating = False)
phi_w = zfit.Parameter("phi_w", ztf.constant(phi_width), floating = False)
phi_p = zfit.Parameter("phi_p", ztf.constant(phi_phase), floating = False) #, lower_limit=-2*np.pi, upper_limit=2*np.pi)
phi_s = zfit.Parameter("phi_s", ztf.constant(phi_scale), floating = False) #, lower_limit=phi_scale-np.sqrt(phi_scale), upper_limit=phi_scale+np.sqrt(phi_scale))

#jpsi

jpsi_mass, jpsi_width, jpsi_phase, jpsi_scale = pdg["jpsi"]
# jpsi_scale *= pdg["factor_jpsi"]

jpsi_m = zfit.Parameter("jpsi_m", ztf.constant(jpsi_mass), floating = False)
jpsi_w = zfit.Parameter("jpsi_w", ztf.constant(jpsi_width), floating = False)
jpsi_p = zfit.Parameter("jpsi_p", ztf.constant(jpsi_phase), lower_limit=-2*np.pi, upper_limit=2*np.pi)
jpsi_s = zfit.Parameter("jpsi_s", ztf.constant(jpsi_scale), lower_limit=jpsi_scale-np.sqrt(jpsi_scale), upper_limit=jpsi_scale+np.sqrt(jpsi_scale))

#psi2s

psi2s_mass, psi2s_width, psi2s_phase, psi2s_scale = pdg["psi2s"]

psi2s_m = zfit.Parameter("psi2s_m", ztf.constant(psi2s_mass), floating = False)
psi2s_w = zfit.Parameter("psi2s_w", ztf.constant(psi2s_width), floating = False)
psi2s_p = zfit.Parameter("psi2s_p", ztf.constant(psi2s_phase), lower_limit=-2*np.pi, upper_limit=2*np.pi)
psi2s_s = zfit.Parameter("psi2s_s", ztf.constant(psi2s_scale), lower_limit=psi2s_scale-np.sqrt(psi2s_scale), upper_limit=psi2s_scale+np.sqrt(psi2s_scale))

#psi(3770)

p3770_mass, p3770_width, p3770_phase, p3770_scale = pdg["p3770"]

p3770_m = zfit.Parameter("p3770_m", ztf.constant(p3770_mass), floating = False)
p3770_w = zfit.Parameter("p3770_w", ztf.constant(p3770_width), floating = False)
p3770_p = zfit.Parameter("p3770_p", ztf.constant(p3770_phase), floating = False) #, lower_limit=-2*np.pi, upper_limit=2*np.pi)
p3770_s = zfit.Parameter("p3770_s", ztf.constant(p3770_scale), floating = False) #, lower_limit=p3770_scale-np.sqrt(p3770_scale), upper_limit=p3770_scale+np.sqrt(p3770_scale))

#psi(4040)

p4040_mass, p4040_width, p4040_phase, p4040_scale = pdg["p4040"]

p4040_m = zfit.Parameter("p4040_m", ztf.constant(p4040_mass), floating = False)
p4040_w = zfit.Parameter("p4040_w", ztf.constant(p4040_width), floating = False)
p4040_p = zfit.Parameter("p4040_p", ztf.constant(p4040_phase), floating = False) #, lower_limit=-2*np.pi, upper_limit=2*np.pi)
p4040_s = zfit.Parameter("p4040_s", ztf.constant(p4040_scale), floating = False) #, lower_limit=p4040_scale-np.sqrt(p4040_scale), upper_limit=p4040_scale+np.sqrt(p4040_scale))

#psi(4160)

p4160_mass, p4160_width, p4160_phase, p4160_scale = pdg["p4160"]

p4160_m = zfit.Parameter("p4160_m", ztf.constant(p4160_mass), floating = False)
p4160_w = zfit.Parameter("p4160_w", ztf.constant(p4160_width), floating = False)
p4160_p = zfit.Parameter("p4160_p", ztf.constant(p4160_phase), floating = False) #, lower_limit=-2*np.pi, upper_limit=2*np.pi)
p4160_s = zfit.Parameter("p4160_s", ztf.constant(p4160_scale), floating = False) #, lower_limit=p4160_scale-np.sqrt(p4160_scale), upper_limit=p4160_scale+np.sqrt(p4160_scale))

#psi(4415)

p4415_mass, p4415_width, p4415_phase, p4415_scale = pdg["p4415"]

p4415_m = zfit.Parameter("p4415_m", ztf.constant(p4415_mass), floating = False)
p4415_w = zfit.Parameter("p4415_w", ztf.constant(p4415_width), floating = False)
p4415_p = zfit.Parameter("p4415_p", ztf.constant(p4415_phase), floating = False) #, lower_limit=-2*np.pi, upper_limit=2*np.pi)
p4415_s = zfit.Parameter("p4415_s", ztf.constant(p4415_scale), floating = False) #, lower_limit=p4415_scale-np.sqrt(p4415_scale), upper_limit=p4415_scale+np.sqrt(p4415_scale))


# ## Dynamic generation of 2 particle contribution

# In[9]:


_0 = jpsi_scale*np.cos(jpsi_phase)*jpsi_width/jpsi_mass**3 + psi2s_scale*np.cos(psi2s_phase)*psi2s_width/psi2s_mass**3
_1 = p3770_scale*np.cos(p3770_phase)*p3770_width/p3770_mass**3 + p4040_scale*np.cos(p4040_phase)*p4040_width/p4040_mass**3
_2 = p4160_scale*np.cos(p4160_phase)*p4160_width/p4160_mass**3 + p4415_scale*np.cos(p4415_phase)*p4415_width/p4415_mass**3

C_pert = np.random.uniform(0.03, 0.1)
# c_pert = 0.1
m_c = 1300

cDDstar_phase = 10


DDstar_eta = 0
Dbar_phase = np.random.uniform(0.0, 2*np.pi)
DDstar_phase = np.random.uniform(0.0, 2*np.pi)
DDstar_mass = pdg['D0_M']

if Dbar_phase < np.pi:
    Dbar_phase = 0.0
else:
    Dbar_phase = np.pi

R = (C_pert/(m_c**2) - ((_0 + _1 + _2)))

Dbar_mass = (pdg['D0_M']+pdg['Dst_M'])/2

Dbar_eta = R/np.cos(Dbar_phase)*(6*Dbar_mass**2)

# print(np.cos(Dbar_phase))

# cDDstar_phase = R_*10*DDstar_mass**2/DDstar_eta


# print(Dbar_eta)


Dbar_s = zfit.Parameter("Dbar_s", ztf.constant(Dbar_eta), lower_limit=-1.464, upper_limit=1.464)
Dbar_m = zfit.Parameter("Dbar_m", ztf.constant(Dbar_mass), floating = False)
Dbar_p = zfit.Parameter("Dbar_p", ztf.constant(Dbar_phase), floating = False)
DDstar_s = zfit.Parameter("DDstar_s", ztf.constant(DDstar_eta), floating = False)
DDstar_m = zfit.Parameter("DDstar_m", ztf.constant(DDstar_mass), floating = False)
DDstar_p = zfit.Parameter("DDstar_p", ztf.constant(DDstar_phase), floating = False)

Dbar_s.set_value(0.0)
DDstar_s.set_value(0.0)


# ## Tau parameters

# In[10]:


tau_m = zfit.Parameter("tau_m", ztf.constant(pdg['tau_M']), floating = False)
Ctt = zfit.Parameter("Ctt", ztf.constant(0.0), lower_limit=-0.5, upper_limit=0.5)


# ## Setup pdf

# In[11]:


total_f = total_pdf(obs=obs, jpsi_mass = jpsi_m, jpsi_scale = jpsi_s, jpsi_phase = jpsi_p, jpsi_width = jpsi_w,
                    psi2s_mass = psi2s_m, psi2s_scale = psi2s_s, psi2s_phase = psi2s_p, psi2s_width = psi2s_w,
                    p3770_mass = p3770_m, p3770_scale = p3770_s, p3770_phase = p3770_p, p3770_width = p3770_w,
                    p4040_mass = p4040_m, p4040_scale = p4040_s, p4040_phase = p4040_p, p4040_width = p4040_w,
                    p4160_mass = p4160_m, p4160_scale = p4160_s, p4160_phase = p4160_p, p4160_width = p4160_w,
                    p4415_mass = p4415_m, p4415_scale = p4415_s, p4415_phase = p4415_p, p4415_width = p4415_w,
                    rho_mass = rho_m, rho_scale = rho_s, rho_phase = rho_p, rho_width = rho_w,
                    omega_mass = omega_m, omega_scale = omega_s, omega_phase = omega_p, omega_width = omega_w,
                    phi_mass = phi_m, phi_scale = phi_s, phi_phase = phi_p, phi_width = phi_w,
                    DDstar_mass = DDstar_m, DDstar_scale = DDstar_s, DDstar_phase = DDstar_p,
                    Dbar_mass = Dbar_m, Dbar_scale = Dbar_s, Dbar_phase = Dbar_p,
                    tau_mass = tau_m, C_tt = Ctt)
                    
                   
# print(total_pdf.obs)

# print(calcs_test)

# for param in total_f.get_dependents():
#     print(zfit.run(param))


# ## Test if graphs actually work and compute values

# In[12]:


# def total_test_tf(xq):

#     def jpsi_res(q):
#         return resonance(q, jpsi_m, jpsi_s, jpsi_p, jpsi_w)

#     def psi2s_res(q):
#         return resonance(q, psi2s_m, psi2s_s, psi2s_p, psi2s_w)

#     def cusp(q):
#         return bifur_gauss(q, cusp_m, sig_L, sig_R, cusp_s)

#     funcs = jpsi_res(xq) + psi2s_res(xq) + cusp(xq)

#     vec_f = vec(xq, funcs)

#     axiv_nr = axiv_nonres(xq)

#     tot = vec_f + axiv_nr
    
#     return tot

def jpsi_res(q):
    return resonance(q, jpsi_m, jpsi_s, jpsi_p, jpsi_w)

# calcs = zfit.run(total_test_tf(x_part))

test_q = np.linspace(x_min, x_max, 200000)

probs = total_f.pdf(test_q)

calcs_test = zfit.run(probs)
res_y = zfit.run(jpsi_res(test_q))
f0_y = zfit.run(formfactor(test_q,"0"))
fplus_y = zfit.run(formfactor(test_q,"+"))
fT_y = zfit.run(formfactor(test_q,"T"))


# In[13]:


plt.clf()
# plt.plot(x_part, calcs, '.')
plt.plot(test_q, calcs_test, label = 'pdf')
# plt.plot(test_q, f0_y, label = '0')
# plt.plot(test_q, fT_y, label = 'T')
# plt.plot(test_q, fplus_y, label = '+')
# plt.plot(test_q, res_y, label = 'res')
plt.legend()
plt.ylim(0.0, 6e-6)
# plt.yscale('log')
# plt.xlim(770, 785)
plt.savefig('test.png')
# print(jpsi_width)


# In[14]:




# probs = mixture.prob(test_q)
# probs_np = zfit.run(probs)
# probs_np *= np.max(calcs_test) / np.max(probs_np)
# plt.figure()
# plt.semilogy(test_q, probs_np,label="importance sampling")
# plt.semilogy(test_q, calcs_test, label = 'pdf')


# In[15]:


# 0.213/(0.00133+0.213+0.015)


# ## Adjust scaling of different parts

# In[16]:


total_f.update_integration_options(draws_per_dim=200000, mc_sampler=None)
# inte = total_f.integrate(limits = (2000, x_max), norm_range=False)
# inte_fl = zfit.run(inte)
# print(inte_fl)
# print(pdg["jpsi_BR"]/pdg["NR_BR"], inte_fl*pdg["psi2s_auc"]/pdg["NR_auc"])


# In[17]:


# # print("jpsi:", inte_fl)
# # print("Increase am by factor:", np.sqrt(pdg["jpsi_BR"]/pdg["NR_BR"]*pdg["NR_auc"]/inte_fl))
# # print("New amp:", pdg["jpsi"][3]*np.sqrt(pdg["jpsi_BR"]/pdg["NR_BR"]*pdg["NR_auc"]/inte_fl))

# # print("psi2s:", inte_fl)
# # print("Increase am by factor:", np.sqrt(pdg["psi2s_BR"]/pdg["NR_BR"]*pdg["NR_auc"]/inte_fl))
# # print("New amp:", pdg["psi2s"][3]*np.sqrt(pdg["psi2s_BR"]/pdg["NR_BR"]*pdg["NR_auc"]/inte_fl))

# name = "phi"

# print(name+":", inte_fl)
# print("Increase am by factor:", np.sqrt(pdg[name+"_BR"]/pdg["NR_BR"]*pdg["NR_auc"]/inte_fl))
# print("New amp:", pdg[name][3]*np.sqrt(pdg[name+"_BR"]/pdg["NR_BR"]*pdg["NR_auc"]/inte_fl))


# # print(x_min)
# # print(x_max)
# # # total_f.update_integration_options(draws_per_dim=2000000, mc_sampler=None)
# # total_f.update_integration_options(mc_sampler=lambda dim, num_results,
# #                                     dtype: tf.random_uniform(maxval=1., shape=(num_results, dim), dtype=dtype),
# #                                    draws_per_dim=1000000)
# # # _ = []

# # # for i in range(10):

# # #     inte = total_f.integrate(limits = (x_min, x_max))
# # #     inte_fl = zfit.run(inte)
# # #     print(inte_fl)
# # #     _.append(inte_fl)

# # # print("mean:", np.mean(_))

# # _ = time.time()

# # inte = total_f.integrate(limits = (x_min, x_max))
# # inte_fl = zfit.run(inte)
# # print(inte_fl)
# # print("Time taken: {}".format(display_time(int(time.time() - _))))

# print(pdg['NR_BR']/pdg['NR_auc']*inte_fl)
# print(0.25**2*4.2/1000)


# # Sampling
# ## Toys

# In[18]:



    
# print(list_of_borders[:9])
# print(list_of_borders[-9:])


class UniformSampleAndWeights(zfit.util.execution.SessionHolderMixin):
    def __call__(self, limits, dtype, n_to_produce):
        # n_to_produce = tf.cast(n_to_produce, dtype=tf.int32)
        low, high = limits.limit1d
        low = tf.cast(low, dtype=dtype)
        high = tf.cast(high, dtype=dtype)
#         uniform = tfd.Uniform(low=low, high=high)
#         uniformjpsi = tfd.Uniform(low=tf.constant(3080, dtype=dtype), high=tf.constant(3112, dtype=dtype))
#         uniformpsi2s = tfd.Uniform(low=tf.constant(3670, dtype=dtype), high=tf.constant(3702, dtype=dtype))

#         list_of_borders = []
#         _p = []
#         splits = 10

#         _ = np.linspace(x_min, x_max, splits)

#         for i in range(splits):
#             list_of_borders.append(tf.constant(_[i], dtype=dtype))
#             _p.append(tf.constant(1/splits, dtype=dtype))
    
#         mixture = tfd.MixtureSameFamily(mixture_distribution=tfd.Categorical(probs=_p[:(splits-1)]),
#                                         components_distribution=tfd.Uniform(low=list_of_borders[:(splits-1)], 
#                                                                             high=list_of_borders[-(splits-1):]))
        mixture = tfd.MixtureSameFamily(mixture_distribution=tfd.Categorical(probs=[tf.constant(0.05, dtype=dtype),
                                                                                    tf.constant(0.93, dtype=dtype),
                                                                                    tf.constant(0.05, dtype=dtype),
                                                                                    tf.constant(0.065, dtype=dtype),
                                                                                    tf.constant(0.04, dtype=dtype),
                                                                                    tf.constant(0.05, dtype=dtype)]),
                                        components_distribution=tfd.Uniform(low=[tf.constant(x_min, dtype=dtype), 
                                                                                 tf.constant(3090, dtype=dtype),
                                                                                 tf.constant(3681, dtype=dtype), 
                                                                                 tf.constant(3070, dtype=dtype),
                                                                                 tf.constant(1000, dtype=dtype),
                                                                                 tf.constant(3660, dtype=dtype)], 
                                                                            high=[tf.constant(x_max, dtype=dtype),
                                                                                  tf.constant(3102, dtype=dtype), 
                                                                                  tf.constant(3691, dtype=dtype),
                                                                                  tf.constant(3110, dtype=dtype),
                                                                                  tf.constant(1040, dtype=dtype),
                                                                                  tf.constant(3710, dtype=dtype)]))
#         dtype = tf.float64
#         mixture = tfd.MixtureSameFamily(mixture_distribution=tfd.Categorical(probs=[tf.constant(0.04, dtype=dtype),
#                                                                                     tf.constant(0.90, dtype=dtype),
#                                                                                     tf.constant(0.02, dtype=dtype),
#                                                                                     tf.constant(0.07, dtype=dtype),
#                                                                                     tf.constant(0.02, dtype=dtype)]),
#                                         components_distribution=tfd.Uniform(low=[tf.constant(x_min, dtype=dtype), 
#                                                                                  tf.constant(3089, dtype=dtype),
#                                                                                  tf.constant(3103, dtype=dtype), 
#                                                                                  tf.constant(3681, dtype=dtype),
#                                                                                  tf.constant(3691, dtype=dtype)], 
#                                                                             high=[tf.constant(3089, dtype=dtype),
#                                                                                   tf.constant(3103, dtype=dtype), 
#                                                                                   tf.constant(3681, dtype=dtype),
#                                                                                   tf.constant(3691, dtype=dtype), 
#                                                                                   tf.constant(x_max, dtype=dtype)]))
#         mixture = tfd.Uniform(tf.constant(x_min, dtype=dtype), tf.constant(x_max, dtype=dtype))
#         sample = tf.random.uniform((n_to_produce, 1), dtype=dtype)
        sample = mixture.sample((n_to_produce, 1))
#         sample = tf.random.uniform((n_to_produce, 1), dtype=dtype)
        weights = mixture.prob(sample)[:,0]
#         weights = tf.broadcast_to(tf.constant(1., dtype=dtype), shape=(n_to_produce,))
        # sample = tf.expand_dims(sample, axis=-1)
#         print(sample, weights)
        
#         weights = tf.ones(shape=(n_to_produce,), dtype=dtype)
        weights_max = None
        thresholds = tf.random_uniform(shape=(n_to_produce,), dtype=dtype)
        return sample, thresholds, weights, weights_max, n_to_produce


# In[19]:


total_f._sample_and_weights = UniformSampleAndWeights


# In[20]:


# 0.00133/(0.00133+0.213+0.015)*(x_max-3750)/(x_max-x_min)


# In[21]:


# zfit.settings.set_verbosity(10)


# In[22]:


# # zfit.run.numeric_checks = False   

# nr_of_toys = 1
# nevents = int(pdg["number_of_decays"])
# nevents = pdg["number_of_decays"]
# event_stack = 1000000
# # zfit.settings.set_verbosity(10)
# calls = int(nevents/event_stack + 1)

# total_samp = []

# start = time.time()

# sampler = total_f.create_sampler(n=event_stack)

# for toy in range(nr_of_toys):
    
#     dirName = 'data/zfit_toys/toy_{0}'.format(toy)
    
#     if not os.path.exists(dirName):
#         os.mkdir(dirName)
#         print("Directory " , dirName ,  " Created ")

#     for call in range(calls):

#         sampler.resample(n=event_stack)
#         s = sampler.unstack_x()
#         sam = zfit.run(s)
# #         clear_output(wait=True)

#         c = call + 1
        
#         print("{0}/{1} of Toy {2}/{3}".format(c, calls, toy+1, nr_of_toys))
#         print("Time taken: {}".format(display_time(int(time.time() - start))))
#         print("Projected time left: {}".format(display_time(int((time.time() - start)/(c+calls*(toy))*((nr_of_toys-toy)*calls-c)))))

#         with open("data/zfit_toys/toy_{0}/{1}.pkl".format(toy, call), "wb") as f:
#             pkl.dump(sam, f, pkl.HIGHEST_PROTOCOL)


# In[23]:


# with open(r"data/zfit_toys/toy_0/0.pkl", "rb") as input_file:
#     sam = pkl.load(input_file)
# print(sam[:10])

# with open(r"data/zfit_toys/toy_0/1.pkl", "rb") as input_file:
#     sam2 = pkl.load(input_file)
# print(sam2[:10])

# print(np.sum(sam-sam2))


# In[24]:


# print("Time to generate full toy: {} s".format(int(time.time()-start)))

# total_samp = []

# for call in range(calls):
#     with open(r"data/zfit_toys/toy_0/{}.pkl".format(call), "rb") as input_file:
#         sam = pkl.load(input_file)
#         total_samp = np.append(total_samp, sam)

# total_samp = total_samp.astype('float64')

# data2 = zfit.data.Data.from_numpy(array=total_samp[:int(nevents)], obs=obs)

# data3 = zfit.data.Data.from_numpy(array=total_samp, obs=obs)

# print(total_samp[:nevents].shape)


# In[25]:


# plt.clf()

# bins = int((x_max-x_min)/7)

# # calcs = zfit.run(total_test_tf(samp))
# print(total_samp[:nevents].shape)

# plt.hist(total_samp[:nevents], bins = bins, range = (x_min,x_max), label = 'data')
# # plt.plot(test_q, calcs_test*nevents , label = 'pdf')

# # plt.plot(sam, calcs, '.')
# # plt.plot(test_q, calcs_test)
# # plt.yscale('log')
# plt.ylim(0, 200)
# # plt.xlim(3080, 3110)

# plt.legend()

# plt.savefig('test2.png')


# In[26]:


# sampler = total_f.create_sampler(n=nevents)
# nll = zfit.loss.UnbinnedNLL(model=total_f, data=sampler, fit_range = (x_min, x_max))

# # for param in pdf.get_dependents():
# #     param.set_value(initial_value)

# sampler.resample(n=nevents)

# # Randomise initial values
# # for param in pdf.get_dependents():
# #     param.set_value(random value here)

# # Minimise the NLL
# minimizer = zfit.minimize.MinuitMinimizer(verbosity = 10)
# minimum = minimizer.minimize(nll)


# In[27]:


# jpsi_width


# In[28]:


# plt.hist(sample, weights=1 / prob(sample))


# # Fitting

# In[29]:


# start = time.time()

# for param in total_f.get_dependents():
#     param.randomize()
    
# # for param in total_f.get_dependents():
# #     print(zfit.run(param))
    
# nll = zfit.loss.UnbinnedNLL(model=total_f, data=data2, fit_range = (x_min, x_max))

# minimizer = zfit.minimize.MinuitMinimizer(verbosity = 5)
# # minimizer._use_tfgrad = False
# result = minimizer.minimize(nll)

# # param_errors = result.error()

# # for var, errors in param_errors.items():
# #     print('{}: ^{{+{}}}_{{{}}}'.format(var.name, errors['upper'], errors['lower']))

# print("Function minimum:", result.fmin)
# # print("Results:", result.params)
# print("Hesse errors:", result.hesse())


# In[30]:


# print("Time taken for fitting: {}".format(display_time(int(time.time()-start))))

# # probs = total_f.pdf(test_q)

# calcs_test = zfit.run(probs)
# res_y = zfit.run(jpsi_res(test_q))


# In[31]:


# plt.clf()
# # plt.plot(x_part, calcs, '.')
# plt.plot(test_q, calcs_test, label = 'pdf')
# # plt.plot(test_q, res_y, label = 'res')
# plt.legend()
# plt.ylim(0.0, 10e-6)
# # plt.yscale('log')
# # plt.xlim(3080, 3110)
# plt.savefig('test3.png')
# # print(jpsi_width)


# In[32]:


# _tot = 4.37e-7+6.02e-5+4.97e-6
# _probs = []
# _probs.append(6.02e-5/_tot)
# _probs.append(4.97e-6/_tot)
# _probs.append(4.37e-7/_tot)
# print(_probs)


# In[33]:


# dtype = 'float64'
# # mixture = tfd.Uniform(tf.constant(x_min, dtype=dtype), tf.constant(x_max, dtype=dtype))
# mixture = tfd.MixtureSameFamily(mixture_distribution=tfd.Categorical(probs=[tf.constant(0.007, dtype=dtype),
#                                                                             tf.constant(0.917, dtype=dtype),
#                                                                             tf.constant(0.076, dtype=dtype)]),
#                                 components_distribution=tfd.Uniform(low=[tf.constant(x_min, dtype=dtype), 
#                                                                          tf.constant(3080, dtype=dtype),
#                                                                          tf.constant(3670, dtype=dtype)], 
#                                                                     high=[tf.constant(x_max, dtype=dtype),
#                                                                           tf.constant(3112, dtype=dtype), 
#                                                                           tf.constant(3702, dtype=dtype)]))
# # for i in range(10):
# #     print(zfit.run(mixture.prob(mixture.sample((10, 1)))))


# In[34]:


# print((zfit.run(jpsi_p)%(2*np.pi))/np.pi)
# print((zfit.run(psi2s_p)%(2*np.pi))/np.pi)


# In[35]:


#         def jpsi_res(q):
#             return resonance(q, _mass = jpsi_mass, scale = jpsi_scale,
#                              phase = jpsi_phase, width = jpsi_width)

#         def psi2s_res(q):
#             return resonance(q, _mass = psi2s_mass, scale = psi2s_scale,
#                              phase = psi2s_phase, width = psi2s_width)
        
#         def p3770_res(q):
#             return resonance(q, _mass = p3770_mass, scale = p3770_scale,
#                              phase = p3770_phase, width = p3770_width)
        
#         def p4040_res(q):
#             return resonance(q, _mass = p4040_mass, scale = p4040_scale,
#                              phase = p4040_phase, width = p4040_width)
        
#         def p4160_res(q):
#             return resonance(q, _mass = p4160_mass, scale = p4160_scale,
#                              phase = p4160_phase, width = p4160_width)
        
#         def p4415_res(q):
#             return resonance(q, _mass = p4415_mass, scale = p4415_scale,
#                              phase = p4415_phase, width = p4415_width)


# In[36]:


# 0.15**2*4.2/1000
# result.hesse()


# # Analysis

# In[ ]:


# zfit.run.numeric_checks = False   

Ctt_list = []
Ctt_error_list = []

nr_of_toys = 1
nevents = int(pdg["number_of_decays"])
nevents = pdg["number_of_decays"]
event_stack = 1000000
# zfit.settings.set_verbosity(10)
calls = int(nevents/event_stack + 1)

total_samp = []

start = time.time()

sampler = total_f.create_sampler(n=event_stack)

for toy in range(nr_of_toys):
    
    ### Generate data
    
    print("Toy {}: Generating data...".format(toy))
    
    dirName = 'data/zfit_toys/toy_{0}'.format(toy)
    
    if not os.path.exists(dirName):
        os.mkdir(dirName)
        print("Directory " , dirName ,  " Created ")

    for call in range(calls):

        sampler.resample(n=event_stack)
        s = sampler.unstack_x()
        sam = zfit.run(s)
#         clear_output(wait=True)

        c = call + 1
        
        with open("data/zfit_toys/toy_{0}/{1}.pkl".format(toy, call), "wb") as f:
            pkl.dump(sam, f, pkl.HIGHEST_PROTOCOL)
            
    print("Toy {}: Data generation finished".format(toy))
        
    ### Load data
    
    print("Toy {}: Loading data...".format(toy))

    for call in range(calls):
        with open(r"data/zfit_toys/toy_0/{}.pkl".format(call), "rb") as input_file:
            sam = pkl.load(input_file)
        total_samp = np.append(total_samp, sam)

    total_samp = total_samp.astype('float64')

    data = zfit.data.Data.from_numpy(array=total_samp[:int(nevents)], obs=obs)
    
    print("Toy {}: Loading data finished".format(toy))

    ### Fit data
    
    print("Toy {}: Fitting pdf...".format(toy))

    for param in total_f.get_dependents():
        param.randomize()

    nll = zfit.loss.UnbinnedNLL(model=total_f, data=data, fit_range = (jpsi_mass+50.0, psi2s_mass-50.0))

    minimizer = zfit.minimize.MinuitMinimizer(verbosity = 5)
    # minimizer._use_tfgrad = False
    result = minimizer.minimize(nll)
    
    print("Toy {}: Fitting finished".format(toy))

    print("Function minimum:", result.fmin)
    print("Hesse errors:", result.hesse())
    
    params = result.params
    Ctt_list.append(params[Ctt]['value'])
    Ctt_error_list.append(params[Ctt]['minuit_hesse']['error'])

    print("Toy {0}/{1}".format(toy+1, nr_of_toys))
    print("Time taken: {}".format(display_time(int(time.time() - start))))
    print("Projected time left: {}".format(display_time(int((time.time() - start)/(c+calls*(toy))*((nr_of_toys-toy)*calls-c)))))


# In[ ]:





# In[ ]:


print('Mean Ctt value = {}'.format(np.mean(Ctt_list)))
print('Mean Ctt error = {}'.format(np.mean(Ctt_error_list)))


# In[ ]: