Newer
Older
R_phipi / tools / select_and_fit_funct.py
@Davide Lancierini Davide Lancierini on 15 Nov 2018 28 KB big changes
import ROOT as r
import root_numpy as rn
import pickle
import numpy as np
import matplotlib.pyplot as plt

from xgboost import XGBClassifier
from tools.data_processing import *
from tools.mc_fitter import MC_fit
import os
import argparse

l_flv=['e','mu']
mother_ID=["Ds","Dplus","both"]

def plot_MC_vs_data(data, weight_dict, MC, variable, 
                    sw_idx=None, l_index=None, mother_index_fit=None):
    
    for key in MC:
        if 'Ds' in key:
            mother='Ds'
            weight_idx = 1
            variable_MC=variable
        if 'Dplus' in key:
            mother='Dplus'
            weight_idx = 0
            variable_MC=variable.replace("Ds","Dplus")
    
    inf=data[variable].min()
    sup=data[variable].max()
    
    data_entries=data[variable].shape[0]
    mc_entries=MC[variable_MC].shape[0]
    
    
    
    data_hist = r.TH1F(mother+" sWeighted data "+ variable,
                       mother+" sWeighted data "+ variable,
                       70,inf,sup)
    
    mc_hist = r.TH1F(mother+" MC "+ variable,
                     mother+" MC "+ variable,
                     70,inf,sup)
    
    mc_hist.Sumw2()
    
    for i in range(data_entries):
        data_hist.Fill(data[variable][i],weight_dict[weight_idx][i])
        
    for i in range(len(MC[variable_MC])):    
        mc_hist.Fill(MC[variable_MC][i])
        
    n1 = data_hist.Integral("width") 
    data_hist.Scale(1/n1)
    n2 = mc_hist.Integral("width")
    mc_hist.Scale(1/n2)
    
    c1 = r.TCanvas(variable, variable, 900, 600)
    
    
    mc_hist.Draw("E")
    data_hist.Draw("E Same")
    data_hist.GetXaxis().SetTitle(variable)
    data_hist.GetYaxis().SetTitleOffset(1.5)
    data_hist.GetYaxis().SetTitle("distr dN/d"+variable)
    data_hist.SetLineColor(r.kRed)
    r.gStyle.SetOptStat("11")
    legend = r.TLegend(0.89, 0.89, 0.75, 0.8)
    legend.AddEntry(data_hist,"data","f")
    legend.AddEntry(mc_hist,"mc","f")
    legend.Draw()
    save_path=l_flv[l_index]+"/fits/"+mother_ID[mother_index_fit]+"/{0}".format(sw_idx)+"/"+mother+"_sweighted_"+variable+"_{0}.png".format(sw_idx)
    c1.SaveAs(save_path)
    print("Saved: "+save_path)

def select_and_fit(BDT_PATH, mother_index_fit, l_index, lower_mass_cut, upper_mass_cut,
                   data_mass, mc_Dplus_mass, mc_Ds_mass, data_to_select, mc_Dplus_to_select=None, mc_Ds_to_select=None,
                   data_dict=None, MC_Dplus_dict=None, MC_Ds_dict=None,  branches_needed=None,
                   x_cut=None, iteration=None, test=None):

    """
    Get DCB Shapes

    """

    mc_Ds_mass_for_CB = np.array([MC_Ds_dict["Ds_ConsD_M"][i][0] for i in range(len(MC_Ds_dict["Ds_ConsD_M"]))])
    mc_Dplus_mass_for_CB = np.array([MC_Dplus_dict["Dplus_ConsD_M"][i][0] for i in range(len(MC_Dplus_dict["Dplus_ConsD_M"]))])


    BDT_PATH=l_flv[l_index]+'/BDTs/test_'+str(test)
    FIT_PATH=l_flv[l_index]+'/fits'

    if os.path.exists(FIT_PATH+'/MC_'+mother_ID[0]+'_CB_params.pickle'):
        with open(FIT_PATH+'/MC_'+mother_ID[0]+'_CB_params.pickle','rb') as f:
            Ds_CB_params = pickle.load(f)
    
    
            alpha_Ds_l=Ds_CB_params['alpha_Ds_l']
            power_Ds_l=Ds_CB_params['n_Ds_l']
            alpha_Ds_r=Ds_CB_params['alpha_Ds_r']
            power_Ds_r=Ds_CB_params['n_Ds_r']
            fraction_Ds=Ds_CB_params['CB fraction']
    
    
    else:
    
        MC_Ds_fit_results, alpha_Ds_l, power_Ds_l, alpha_Ds_r, power_Ds_r, fraction_Ds = MC_fit( mc_Ds_mass_for_CB, mother_index=0, l_index=l_index)
    
    if os.path.exists(FIT_PATH+'/MC_'+mother_ID[1]+'_CB_params.pickle'):
        with open(FIT_PATH+'/MC_'+mother_ID[1]+'_CB_params.pickle','rb') as f:
            Dplus_CB_params = pickle.load(f)
            alpha_Dplus_l=Dplus_CB_params['alpha_Dplus_l']
            power_Dplus_l=Dplus_CB_params['n_Dplus_l']
            alpha_Dplus_r=Dplus_CB_params['alpha_Dplus_r']
            power_Dplus_r=Dplus_CB_params['n_Dplus_r']
            fraction_Dplus=Dplus_CB_params['CB fraction']
    
    else:
    
        MC_Dplus_fit_results, alpha_Dplus_l, power_Dplus_l, alpha_Dplus_r, power_Dplus_r, fraction_Dplus = MC_fit(mc_Dplus_mass_for_CB, mother_index=1, l_index=l_index)


    branches_needed_Dplus=[]
    for label in branches_needed:
        if 'Ds' in label:
            branches_needed_Dplus.append(label.replace("Ds", "Dplus"))
        else:
            branches_needed_Dplus.append(label)

    """
    BDT Selection step

    """
    
    print("BDT cut at {0} fit".format(x_cut))
    
    print("BDT selection..")
    
    k = np.random.randint(10)
    MODEL_PATH=BDT_PATH+'/XG_'+str(k)
    loaded_model = pickle.load(open(MODEL_PATH+"/XG_"+str(k)+"_.pickle.dat", "rb"))
    

    output_XG_data=loaded_model.predict_proba(data_to_select)
    bdt_selection_data= np.where(output_XG_data[:,1]>x_cut)

    output_XG_Dplus_MC=loaded_model.predict_proba(mc_Dplus_to_select)
    bdt_selection_mc_Dplus= np.where(output_XG_Dplus_MC[:,1]>x_cut)

    output_XG_Ds_MC=loaded_model.predict_proba(mc_Ds_to_select)
    bdt_selection_mc_Ds= np.where(output_XG_Ds_MC[:,1]>x_cut)
    

    """
    Save data/MC tuples with geom and kinetic variables after BDT cut for later comparison

    """

    data_bdt_sel={}
    
    for label in branches_needed: 
        data_bdt_sel[label] = data_dict[label][bdt_selection_data]

    MC_Dplus_dict_sel={}

    for label in branches_needed_Dplus: 
        MC_Dplus_dict_sel[label] = MC_Dplus_dict[label][bdt_selection_mc_Dplus]

    MC_Ds_dict_sel={}

    for label in branches_needed: 
        MC_Ds_dict_sel[label] = MC_Ds_dict[label][bdt_selection_mc_Ds]

    print(len(bdt_selection_mc_Ds[0]),len(bdt_selection_mc_Dplus[0]))
    # MC_Dplus_array= np.array( 
    #                           [
    #                           tuple( MC_Dplus_dict_sel[label][k] for label in branches_needed_Dplus )
    #                           for k in range(len(MC_Dplus_dict_sel[label]))
    #                           ],

    #                           dtype=[(label, np.float32) for label in branches_needed_Dplus]
    #                         )


    # MC_Ds_array= np.array( 
    #                           [
    #                           tuple( MC_Ds_dict_sel[label][k] for label in branches_needed )
    #                           for k in range(len(MC_Ds_dict_sel[label]))
    #                           ],

    #                           dtype=[(label, np.float32) for label in branches_needed]
    #                         )

    
    """

    Apply BDT cuts on data/MC mass variables


    """

    data_mass_selected=np.array(data_mass[bdt_selection_data])
    mc_Dplus_mass_selected=np.array(mc_Dplus_mass[bdt_selection_mc_Dplus])
    mc_Ds_mass_selected=np.array(mc_Ds_mass[bdt_selection_mc_Ds])

    """

    Signal selection efficiency of BDT


    """

    
    if mother_ID[mother_index_fit]=='both':

        nsig_from_MC=mc_Dplus_to_select.shape[0]+mc_Ds_to_select.shape[0]
        nsig_bdt_sel_MC=mc_Dplus_mass_selected.shape[0]+mc_Ds_mass_selected.shape[0]
        sig_sel_eff=np.float(nsig_bdt_sel_MC)/np.float(nsig_from_MC)

    if mother_ID[mother_index_fit]=='Dplus':

        nsig_from_MC=mc_Dplus_to_select.shape[0]
        nsig_bdt_sel_MC=mc_Dplus_mass_selected.shape[0]
        sig_sel_eff=np.float(nsig_bdt_sel_MC)/np.float(nsig_from_MC)

    if mother_ID[mother_index_fit]=='Ds':

        nsig_from_MC=mc_Ds_to_select.shape[0]
        nsig_bdt_sel_MC=mc_Ds_mass_selected.shape[0]
        sig_sel_eff=np.float(nsig_bdt_sel_MC)/np.float(nsig_from_MC)


    print("BDT Signal selection efficiency: {0}".format(sig_sel_eff))

    """
    Cut for mass fit

    """

    data_cut_indices=[]
    for i in range(len(data_mass_selected)):
        if lower_mass_cut<data_mass_selected[i]<upper_mass_cut:

            data_cut_indices.append(i)

    mc_Dplus_indices=[]
    for i in range(len(mc_Dplus_mass_selected)):
        if lower_mass_cut<mc_Dplus_mass_selected[i]<upper_mass_cut:
            
            mc_Dplus_indices.append(i)

    mc_Ds_indices=[]
    for i in range(len(mc_Ds_mass_selected)):
        if lower_mass_cut<mc_Ds_mass_selected[i]<upper_mass_cut:

            mc_Ds_indices.append(i)


    data_mass_cut=data_mass_selected[data_cut_indices]/1000.
    
    mc_Dplus_mass_selected/=1000.   
    mc_Ds_mass_selected/=1000.

    lower_mass_cut/=1000.
    upper_mass_cut/=1000.

    data_cut={}
    
    for label in branches_needed: 
        data_cut[label] = data_bdt_sel[label][data_cut_indices]

    MC_Dplus_cut={}

    for label in branches_needed_Dplus: 
        MC_Dplus_cut[label] = MC_Dplus_dict_sel[label][mc_Dplus_indices]

    MC_Ds_cut={}

    for label in branches_needed: 
        MC_Ds_cut[label] = MC_Ds_dict_sel[label][mc_Ds_indices]

    if mother_ID[mother_index_fit]=='both':

        nsig_mass_cut_MC=len(mc_Dplus_indices)+len(mc_Ds_indices)
        mass_cut_eff=np.float(nsig_mass_cut_MC)/np.float(nsig_bdt_sel_MC)

    if mother_ID[mother_index_fit]=='Dplus':

        nsig_mass_cut_MC=len(mc_Dplus_indices)
        mass_cut_eff=np.float(nsig_mass_cut_MC)/np.float(nsig_bdt_sel_MC)

    if mother_ID[mother_index_fit]=='Ds':
        
        nsig_mass_cut_MC=len(mc_Ds_indices)
        mass_cut_eff=np.float(nsig_mass_cut_MC)/np.float(nsig_bdt_sel_MC)

    sig_eff = sig_sel_eff*mass_cut_eff
    

    # sel_data_array= np.array( 
    #                           [
    #                           tuple( data_bdt_sel[label][k] for label in branches_needed )
    #                           for k in range(len(data_bdt_sel[label]))
    #                           ],
    #                           dtype=[(label, np.float32) for label in branches_needed]
    #                         )



    #
    data_mass_array=np.array(

            [data_mass_cut[i] for i in range(len(data_mass_cut))], dtype=[('D_reco_M', np.float32)]

            )
    
    rn.array2root(data_mass_array,
                  filename='/disk/lhcb_data/davide/Rphipi/BDT_selected_data/'+l_flv[l_index]+l_flv[l_index]+'/BDT_sel_data_'+l_flv[l_index]+l_flv[l_index]+'_mass.root',
                  treename='decay_tree',
                  mode='recreate',
                 )

    f=r.TFile('/disk/lhcb_data/davide/Rphipi/BDT_selected_data/'+l_flv[l_index]+l_flv[l_index]+'/BDT_sel_data_'+l_flv[l_index]+l_flv[l_index]+'_mass.root')
    tree=f.Get("decay_tree") 
    

    """
    Create RooDataSet for mass

    """

    mass = np.array([0],dtype=np.float32)
    branch = tree.GetBranch("D_reco_M")
    branch.SetAddress(mass)
    
    num_entries=tree.GetEntries()
    m = r.RooRealVar("D mass reco (GeV/c^2)","D mass reco (GeV/c^2)",lower_mass_cut,upper_mass_cut)
    l = r.RooArgSet(m)
    data_set = r.RooDataSet("data set", "data set", l)
    
    for i in range(num_entries):
        tree.GetEvent(i)
        r.RooAbsRealLValue.__assign__(m, mass[0])
        data_set.add(l, 1.0)

    """
    Create total PDF

    """

    ##Creating D+ signal PDF left tail
    mean_Dplus= r.RooRealVar("mean_Dplus","mean_Dplus",1.87,1.83,1.91)
    sigma_Dplus = r.RooRealVar("sigma_Dplus","sigma_Dplus",0.020,0.,0.2)
    
    sig_frc_Dplus =r.RooRealVar("Dplus Double crystalball fraction","Dplus Double crystalball fraction",fraction_Dplus)
    
    al_Dplus_left = r.RooRealVar("alpha_Dplus_left","alpha_Dplus_left",alpha_Dplus_l)
    pwr_Dplus_left = r.RooRealVar("n_Dplus_left","n_Dplus_left",power_Dplus_l)
    sig_Dplus_left = r.RooCBShape("signal Dplus left","signal Dplus left", m, mean_Dplus, sigma_Dplus, al_Dplus_left, pwr_Dplus_left)
    #Creating D+ signal PDF right tail
  
    al_Dplus_right = r.RooRealVar("alpha_Dplus_right","alpha_Dplus_right",alpha_Dplus_r)
    pwr_Dplus_right = r.RooRealVar("n_Dplus_right","n_Dplus_right",power_Dplus_r)
    sig_Dplus_right = r.RooCBShape("signal Dplus right","signal Dplus right", m, mean_Dplus, sigma_Dplus, al_Dplus_right, pwr_Dplus_right)
   
    #Adding the two CBs with their relative fraction
    sig_Dplus = r.RooAddPdf("signal Dplus","Dplus mass peak",r.RooArgList(sig_Dplus_left, sig_Dplus_right),r.RooArgList(sig_frc_Dplus))
    
    #Creating Ds+ signal PDF
    mean_Ds= r.RooRealVar("mean_Ds","mean_Ds",1.97,1.93,2.01)
    sigma_Ds = r.RooRealVar("sigma_Ds","sigma_Ds",0.020,0.,0.2)
    
    sig_frc_Ds =r.RooRealVar("Ds Double crystalball fraction","Ds Double crystalball fraction",fraction_Ds)
    
    al_Ds_left = r.RooRealVar("alpha_Ds_left","alpha_Ds_left",alpha_Ds_l)
    pwr_Ds_left = r.RooRealVar("n_Ds_left","n_Ds_left",power_Ds_l)
    sig_Ds_left = r.RooCBShape("signal Ds left","signal Ds left", m, mean_Ds, sigma_Ds, al_Ds_left, pwr_Ds_left)
    #Creating D+ signal PDF right tail
  
    al_Ds_right = r.RooRealVar("alpha_Ds_right","alpha_Ds_right",alpha_Ds_r)
    pwr_Ds_right = r.RooRealVar("n_Ds_right","n_Ds_right",power_Ds_r)
    sig_Ds_right = r.RooCBShape("signal Ds right","signal Ds right", m, mean_Ds, sigma_Ds, al_Ds_right, pwr_Ds_right)
   
    #Adding the two CBs with their relative fraction
    sig_Ds = r.RooAddPdf("sig_Ds","Ds mass peak",r.RooArgList(sig_Ds_left, sig_Ds_right),r.RooArgList(sig_frc_Ds))
    
    #Model the background
    
    coef0 = r.RooRealVar("c0","coefficient #0",1.0,-10.,10)
    coef1 = r.RooRealVar("c1","coefficient #1",0.1,-10.,10)
    #coef2 = r.RooRealVar("c2","coefficient #2",-0.1,-10.,10)
    bkg = r.RooChebychev("bkg","background p.d.f.",m,r.RooArgList(coef0,coef1))#coeff2))
    
    #add it altogether
    
    #Add 2 sources of signal
    if mother_ID[mother_index_fit]=='both':

        NDs = r.RooRealVar("nDs","nDs",100.,0.,40000.)
        NDplus = r.RooRealVar("nDplus","nDplus",100.,0.,40000.)
        nbkg = r.RooRealVar("nbkg","nbkg",100.,0.,40000.)
        model = r.RooAddPdf("model","Dplus and Ds mass peaks",r.RooArgList(sig_Ds, sig_Dplus, bkg),r.RooArgList(NDs, NDplus,nbkg))

    if mother_ID[mother_index_fit]=='Ds':
        NDs = r.RooRealVar("nDs","nDs",100.,0.,40000.)
        nbkg = r.RooRealVar("nbkg","nbkg",100.,0.,40000.)
        model = r.RooAddPdf("model","Ds mass peak",r.RooArgList(sig_Ds, bkg),r.RooArgList(NDs,nbkg))

    if mother_ID[mother_index_fit]=='Dplus':

        NDplus = r.RooRealVar("nDplus","nDplus",100.,0.,40000.)
        nbkg = r.RooRealVar("nbkg","nbkg",100.,0.,40000.)
        model = r.RooAddPdf("model","Dplus mass peak",r.RooArgList(sig_Dplus, bkg),r.RooArgList(NDplus,nbkg))
    
    """
    Fit step


    """

    fitr = model.fitTo(data_set,r.RooFit.Extended(),r.RooFit.Save())
    xframe = m.frame(r.RooFit.Title("Fit to "+l_flv[l_index]+" data"))

    bkg_component = r.RooArgSet(bkg)
    sig_Dplus_component = r.RooArgSet(sig_Dplus)
    sig_Ds_component = r.RooArgSet(sig_Ds)
    #sig_D_component = r.RooArgSet(sig_Dplus_component,sig_Ds_component)

    data_set.plotOn(xframe)
    model.plotOn(xframe)
    hpull=xframe.pullHist()

    n_param = fitr.floatParsFinal().getSize()
    reduced_chi_square = xframe.chiSquare(n_param)

    """
    Plot the fit results

    """

    model.plotOn(xframe,r.RooFit.Components(bkg_component),r.RooFit.LineStyle(2))

    if mother_ID[mother_index_fit]=='Dplus':

        model.plotOn(xframe, r.RooFit.Components(sig_Dplus_component), r.RooFit.LineColor(2),r.RooFit.LineStyle(2) )

    if mother_ID[mother_index_fit]=='Ds':

        model.plotOn(xframe, r.RooFit.Components(sig_Ds_component), r.RooFit.LineColor(2),r.RooFit.LineStyle(2) )

    if mother_ID[mother_index_fit]=='both':

        model.plotOn(xframe, r.RooFit.Components(sig_Ds_component), r.RooFit.LineColor(2),r.RooFit.LineStyle(2) )
        model.plotOn(xframe, r.RooFit.Components(sig_Dplus_component), r.RooFit.LineColor(2),r.RooFit.LineStyle(2) )


    model.paramOn(xframe, r.RooFit.Layout(0.69,0.99,0.92), r.RooFit.Format("NEU", r.RooFit.AutoPrecision(1)))
    xframe.getAttText().SetTextSize(0.04)

    xframe2 = m.frame(r.RooFit.Title("Pulls"))
    xframe2.addPlotable(hpull,"P")

    c = r.TCanvas("Fit {0}".format(iteration),"Fit {0}".format(iteration),900,600)
    pad1 = r.TPad("pad1", "pad1", 0, 0.35, 1, 1.0)
    pad1.SetBottomMargin(0)
    pad1.Draw()
    c.cd()
    pad2 = r.TPad("pad2", "pad2", 0, 0.05, 1, 0.35)
    pad2.SetTopMargin(0)
    pad2.SetBottomMargin(0.2)
    pad2.SetGridx()
    pad2.SetGridy()
    pad2.Draw()

    pad1.cd()
    xframe.Draw()
    #xframe.GetYaxis().SetLabelOffset(-0.01)
    pad2.cd()
    xframe2.Draw()
    xframe2.GetXaxis().SetTitleSize(0.09)
    xframe2.GetXaxis().SetLabelSize(0.1)
    xframe2.GetYaxis().SetLabelSize(0.05)


    print("chi2 {0}".format(reduced_chi_square))
    c.SaveAs(l_flv[l_index]+"/fits/"+mother_ID[mother_index_fit]+"/{0}/fit_{1}.png".format(iteration,iteration))
    
    """
    Create .pickle files with discriminating variables of selected data

    """
    # rn.array2root(sel_data_array,
    #               filename='/disk/lhcb_data/davide/Rphipi/BDT_selected_data/'+l_flv[l_index]+l_flv[l_index]+'/BDT_sel_data_'+l_flv[l_index]+l_flv[l_index]+'_geom_var.root',
    #               treename='decay_tree',
    #               mode='recreate',
    #              )
    # rn.array2root(MC_Ds_array,
    #               filename='/disk/lhcb_data/davide/Rphipi/MC/for_fit/'+l_flv[l_index]+l_flv[l_index]+'/MC_'+mother_ID[0]+'_'+l_flv[l_index]+l_flv[l_index]+'_geom_var.root',
    #               treename='decay_tree',
    #               mode='recreate',
    #              )
    # rn.array2root(MC_Dplus_array,
    #               filename='/disk/lhcb_data/davide/Rphipi/MC/for_fit/'+l_flv[l_index]+l_flv[l_index]+'/MC_'+mother_ID[1]+'_'+l_flv[l_index]+l_flv[l_index]+'_geom_var.root',
    #               treename='decay_tree',
    #               mode='recreate',
    #              )
    
    with open('/disk/lhcb_data/davide/Rphipi/BDT_selected_data/'+l_flv[l_index]+l_flv[l_index]+'/BDT_sel_data_'+l_flv[l_index]+l_flv[l_index]+'_geom_var.pickle', 'wb') as handle:
        pickle.dump(data_cut, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open('/disk/lhcb_data/davide/Rphipi/MC/for_fit/'+l_flv[l_index]+l_flv[l_index]+'/MC_'+mother_ID[0]+'_'+l_flv[l_index]+l_flv[l_index]+'_geom_var.pickle', 'wb') as handle:
        pickle.dump(MC_Ds_cut, handle, protocol=pickle.HIGHEST_PROTOCOL)
    with open('/disk/lhcb_data/davide/Rphipi/MC/for_fit/'+l_flv[l_index]+l_flv[l_index]+'/MC_'+mother_ID[1]+'_'+l_flv[l_index]+l_flv[l_index]+'_geom_var.pickle', 'wb') as handle:
        pickle.dump(MC_Dplus_cut, handle, protocol=pickle.HIGHEST_PROTOCOL)


    xframe_sw = m.frame(r.RooFit.Title("Weighted "+l_flv[l_index]+" data"))

    """
    sPlotting

    """

    if mother_ID[mother_index_fit]=='both':

        splot = r.RooStats.SPlot(
        'sData','sData', data_set, model,
        r.RooArgList(NDs, NDplus, nbkg)

        )

        Ds_sw=splot.GetSWeightVars()[0]
        Dplus_sw=splot.GetSWeightVars()[1]
        bkg_sw=splot.GetSWeightVars()[2]

        argset = r.RooArgSet(m, Ds_sw)
        dataset_Ds_w = r.RooDataSet("Ds weighted dataset","Ds weighted dataset",data_set, argset,"","nDs_sw")

        argset = r.RooArgSet(m, Dplus_sw)
        dataset_Dplus_w = r.RooDataSet("Dplus weighted dataset","Dplus weighted dataset",data_set, argset,"","nDplus_sw")

        
        dataset_Ds_w.plotOn(xframe_sw, r.RooFit.MarkerSize(0.5), r.RooFit.MarkerColor(r.kBlue))
        dataset_Dplus_w.plotOn(xframe_sw, r.RooFit.MarkerSize(0.5), r.RooFit.MarkerColor(r.kRed))


        Ds_sw.setRange(-1.2,1.5)
        frame_Ds_sw = Ds_sw.frame(r.RooFit.Title("signal Ds sWeights"))
        data_set.plotOn(frame_Ds_sw, r.RooFit.MarkerSize(0.05))

        Dplus_sw.setRange(-1.2,1.5)
        frame_Dplus_sw = Dplus_sw.frame(r.RooFit.Title("signal Dplus sWeights"))
        data_set.plotOn(frame_Dplus_sw, r.RooFit.MarkerSize(0.05))
           
        bkg_sw.setRange(-.6,1.2)
        frame_bkg_sw = bkg_sw.frame(r.RooFit.Title("background sWeights"))
        data_set.plotOn(frame_bkg_sw, r.RooFit.MarkerSize(0.05))
           

        sframe = m.frame(r.RooFit.Title("m vs sWeights"))
        data_set.plotOnXY(sframe, r.RooFit.YVar(Ds_sw),
                        r.RooFit.MarkerColor(r.kGreen), r.RooFit.Name("Ds_sig"))
        data_set.plotOnXY(sframe, r.RooFit.YVar(Dplus_sw),
                        r.RooFit.MarkerColor(r.kBlue), r.RooFit.Name("Dplus_sig"))
        data_set.plotOnXY(sframe, r.RooFit.YVar(bkg_sw),
                        r.RooFit.MarkerColor(r.kRed), r.RooFit.Name("bkg"))

        legend = r.TLegend(0.89, 0.89, 0.5, 0.7)
        legend.AddEntry(sframe.findObject('Ds_sig'), 'Signal Ds sWeights', 'p')
        legend.AddEntry(sframe.findObject('Dplus_sig'), 'Signal Dplus sWeights', 'p')
        legend.AddEntry(sframe.findObject('bkg'), 'Background sWeights', 'p')
        #
        c = r.TCanvas("sWeighted mass dist {0}".format(iteration),"sWeighted mass dist {0}".format(iteration),900,600)
        xframe_sw.Draw()
        c.SaveAs(l_flv[l_index]+"/fits/"+mother_ID[mother_index_fit]+"/{0}".format(iteration)+"/sweighted_mass_plot_fit{0}.png".format(iteration))


        c = r.TCanvas("sWeights {0}".format(iteration),"sWeights {0}".format(iteration),900,900)
        c.Divide(2,2)
        c.cd(1)
        frame_Ds_sw.Draw()
        c.cd(2)
        frame_Dplus_sw.Draw()
        c.cd(3)
        frame_bkg_sw.Draw()
        c.cd(4)
        sframe.Draw()
        legend.Draw()
        c.SaveAs(l_flv[l_index]+"/fits/"+mother_ID[mother_index_fit]+"/{0}".format(iteration)+"/sweights_{0}.png".format(iteration))


        nbkg_w = splot.GetYieldFromSWeight("nbkg_sw")
        nDs_w = splot.GetYieldFromSWeight("nDs_sw")
        nDplus_w = splot.GetYieldFromSWeight("nDplus_sw")
        raw_sig=nDs_w+nDplus_w

    if mother_ID[mother_index_fit]=='Ds':

        splot = r.RooStats.SPlot(
        'sData','sData', data_set, model,
        r.RooArgList(NDs, nbkg)

        )
        
        Ds_sw=splot.GetSWeightVars()[0]
        bkg_sw=splot.GetSWeightVars()[1]

        argset = r.RooArgSet(m, Ds_sw)
        dataset_Ds_w = r.RooDataSet("Ds weighted dataset","Ds weighted dataset",data_set, argset,"","nDs_sw")

        dataset_Ds_w.plotOn(xframe_sw, r.RooFit.MarkerSize(0.5), r.RooFit.MarkerColor(r.kBlue))

        Ds_sw.setRange(-3.,1.4)
        frame_Ds_sw = Ds_sw.frame(r.RooFit.Title("signal Ds sWeights"))
        data_set.plotOn(frame_Ds_sw, r.RooFit.MarkerSize(0.05))
           
        bkg_sw.setRange(-.5,5.)
        frame_bkg_sw = bkg_sw.frame(r.RooFit.Title("background sWeights"))
        data_set.plotOn(frame_bkg_sw, r.RooFit.MarkerSize(0.05))
           
        sframe = m.frame(r.RooFit.Title("m vs sWeights"))
        data_set.plotOnXY(sframe, r.RooFit.YVar(Ds_sw),
                        r.RooFit.MarkerColor(r.kGreen), r.RooFit.Name("Ds_sig"))
        data_set.plotOnXY(sframe, r.RooFit.YVar(bkg_sw),
                        r.RooFit.MarkerColor(r.kRed), r.RooFit.Name("bkg"))

        legend = r.TLegend(0.89, 0.89, 0.75, 0.8)
        legend.AddEntry(sframe.findObject('Ds_sig'), 'Signal Ds sWeights', 'p')
        legend.AddEntry(sframe.findObject('bkg'), 'Background sWeights', 'p')
        #
        c = r.TCanvas("sWeighted mass dist {0}".format(iteration),"sWeighted mass dist {0}".format(iteration),900,600)
        xframe_sw.Draw()
        c.SaveAs(l_flv[l_index]+"/fits/"+mother_ID[mother_index_fit]+"/{0}".format(iteration)+"/sweighted_mass_plot_fit{0}.png".format(iteration))


        c = r.TCanvas("sWeights {0}".format(iteration),"sWeights  {0}".format(iteration),900,600)
        c.cd()
        pad1 = r.TPad("pad1", "pad1", 0, 0.5, 0.5, 1.0)
        pad1.Draw()
        c.cd()
        pad2 = r.TPad("pad2", "pad2", 0.5, 0.5, 1.0, 1.0)
        pad2.Draw()
        c.cd()
        pad3 = r.TPad("pad3", "pad3", 0, 0.05, 1, 0.50)
        pad3.SetGridx()
        pad3.Draw()
        

        pad1.cd()
        frame_Ds_sw.Draw()
        pad2.cd()
        frame_bkg_sw.Draw()
        pad3.cd()
        sframe.Draw()
        legend.Draw()

        c.SaveAs(l_flv[l_index]+"/fits/"+mother_ID[mother_index_fit]+"/{0}".format(iteration)+"/sweights_{0}.png".format(iteration))


        nbkg_w = splot.GetYieldFromSWeight("nbkg_sw")
        nDs_w = splot.GetYieldFromSWeight("nDs_sw")
        raw_sig=nDs_w

    if mother_ID[mother_index_fit]=='Dplus':

        splot = r.RooStats.SPlot(
        'sData','sData', data_set, model,
        r.RooArgList(NDplus, nbkg)

        )

        Dplus_sw=splot.GetSWeightVars()[0]
        bkg_sw=splot.GetSWeightVars()[1]

        argset = r.RooArgSet(m, Dplus_sw)
        dataset_Dplus_w = r.RooDataSet("Dplus weighted dataset","Dplus weighted dataset",data_set, argset,"","nDplus_sw")

        #model.plotOn(xframe_sw, r.RooFit.Components(sig_Dplus_component),r.RooFit.LineColor(2),r.RooFit.LineStyle(2) )
        dataset_Dplus_w.plotOn(xframe_sw, r.RooFit.MarkerSize(0.5), r.RooFit.MarkerColor(r.kRed))


        Dplus_sw.setRange(-1.2,1.8)
        frame_Dplus_sw = Dplus_sw.frame(r.RooFit.Title("signal Dplus sWeights"))
        data_set.plotOn(frame_Dplus_sw, r.RooFit.MarkerSize(0.05))
           
        bkg_sw.setRange(-1.2,2.1)
        frame_bkg_sw = bkg_sw.frame(r.RooFit.Title("background sWeights"))
        data_set.plotOn(frame_bkg_sw, r.RooFit.MarkerSize(0.05))
           

        sframe = m.frame(r.RooFit.Title("m vs sWeights"))
        data_set.plotOnXY(sframe, r.RooFit.YVar(Dplus_sw),
                        r.RooFit.MarkerColor(r.kBlue), r.RooFit.Name("Dplus_sig"))
        data_set.plotOnXY(sframe, r.RooFit.YVar(bkg_sw),
                        r.RooFit.MarkerColor(r.kRed), r.RooFit.Name("bkg"))

        legend = r.TLegend(0.89, 0.89, 0.75, 0.8)
        legend.AddEntry(sframe.findObject('Dplus_sig'), 'Signal Dplus sWeights', 'p')
        legend.AddEntry(sframe.findObject('bkg'), 'Background sWeights', 'p')
        #
        c = r.TCanvas("sWeighted mass dist {0}".format(iteration),"sWeighted mass dist {0}".format(iteration),900,600)
        xframe_sw.Draw()
        c.SaveAs(l_flv[l_index]+"/fits/"+mother_ID[mother_index_fit]+"/{0}".format(iteration)+"/sweighted_mass_plot_fit{0}.png".format(iteration))


        c = r.TCanvas("sWeights {0}".format(iteration),"sWeights  {0}".format(iteration),900,600)
        c.cd()
        pad1 = r.TPad("pad1", "pad1", 0, 0.5, 0.5, 1.0)
        pad1.Draw()
        c.cd()
        pad2 = r.TPad("pad2", "pad2", 0.5, 0.5, 1.0, 1.0)
        pad2.Draw()
        c.cd()
        pad3 = r.TPad("pad3", "pad3", 0, 0.05, 1, 0.50)
        pad3.SetGridx()
        pad3.Draw()
        

        pad1.cd()
        frame_Dplus_sw.Draw()
        pad2.cd()
        frame_bkg_sw.Draw()
        pad3.cd()
        sframe.Draw()
        legend.Draw()

        c.SaveAs(l_flv[l_index]+"/fits/"+mother_ID[mother_index_fit]+"/{0}".format(iteration)+"/sweights_{0}.png".format(iteration))


        nbkg_w = splot.GetYieldFromSWeight("nbkg_sw")
        nDplus_w = splot.GetYieldFromSWeight("nDplus_sw")
        raw_sig=nDplus_w

    
    """
    Save fit results


    """

    fom=(nsig_from_MC)/np.sqrt(nbkg_w+nsig_from_MC)

    eff_corrected_sig = (raw_sig)/(sig_sel_eff)

    fit_results={
                 '# background': nbkg_w,
                 #'err bkg': err_bkg,
                 '# fitted signal': raw_sig,
                 #'error sig': raw_sig_err,
                 'signal selection efficiency from BDT cut': sig_sel_eff,
                 'sig selection efficiency from mass cut': mass_cut_eff,
                 'overall signal efficiency': sig_eff,
                 'S/sqrt(S+B)': fom,
                 'efficiency corrected yeld': eff_corrected_sig,
                 'chi 2': reduced_chi_square
                 }



    return fit_results, splot

def plot_weighted(f, var_name, splot, i):
    
    
    tree=f.Get("decay_tree")
    proxy_array = np.array([0],dtype=np.float32)
    branch = tree.GetBranch(var_name)
    branch.SetAddress(proxy_array)
    
    num_entries=tree.GetEntries()
    real_var = r.RooRealVar(var_name, var_name, tree.GetMinimum(var_name), tree.GetMaximum(var_name))
    l = r.RooArgSet(real_var)
    
    var_data_set = r.RooDataSet(var_name+"_data_set", var_name+"_data_set", l)
    
    for j in range(num_entries):
        tree.GetEvent(j)
        r.RooAbsRealLValue.__assign__(real_var, proxy_array[0])
        var_data_set.add(l, 1.0)
    
    weight_set = splot.GetSWeightVars()[i]
    weight_name = weight_set.GetName()
    
    argset = r.RooArgSet(real_var, weight_set)
    var_data_set_sw = r.RooDataSet(var_name+"_weighted_"+weight_name,
                                   var_name+"_weighted_"+weight_name,
                                   var_data_set, argset, "", weight_name)
    
    xframe_sw = real_var.frame(r.RooFit.Title(l_flv[l_index]+" "+weight_name+" weighted "+var_name))
    
    var_data_set_sw.plotOn(xframe_sw,r.RooFit.MarkerSize(0.5), r.RooFit.MarkerColor(r.kBlue))
    
    xframe_sw.Draw()
    
    return xframe_sw