Newer
Older
Rphipi_new / tools / fitAndSplotPhipill.py
@Davide Lancierini Davide Lancierini on 28 May 2019 31 KB first commit
import ROOT as r
import root_numpy as rn
import pickle
import numpy as np
import matplotlib.pyplot as plt
import array as array
from xgboost import XGBClassifier
from tools.data_processing import *

l_flv=['e','mu']
mother_ID=["Dplus","Ds","both"]
PATH='/disk/lhcb_data/davide/Rphipi_new/'
data_type=['MC','data']

class fitAndSplotphipill(object):

	"""docstring for fitAndSplotphipiee"""

	def __init__(self,l_index=None, trigCat=None):

		self.trigCat=trigCat
		self.l_index=l_index

		# self.perform_Full_Fit()

	def set_paths(self, year='2016'):



		self.FIT_PATH=l_flv[self.l_index]+'_fits/'
		self.workspace_filename='fitAndsPlot_D_phipi_'+l_flv[self.l_index]+l_flv[self.l_index]+'_'+year+'_trigCat'+str(self.trigCat)+'.root'
		self.workspace_path='/home/hep/davide/Rphipi_new/workspaces/'

		self.SAVE_BDT_SELECTED_PATH=PATH+data_type[1]+'/'+l_flv[self.l_index]+'_tuples/BDT_selected/trigCats/'+str(self.trigCat)+'/'
		self.SAVE_BDT_SELECTED_MC_Dplus_PATH=PATH+data_type[0]+'/'+l_flv[self.l_index]+'_tuples/BDT_selected/'+mother_ID[0]+'_phipi_'+l_flv[self.l_index]+l_flv[self.l_index]+'/trigCats/'+str(self.trigCat)+'/'
		self.SAVE_BDT_SELECTED_MC_Ds_PATH=PATH+data_type[0]+'/'+l_flv[self.l_index]+'_tuples/BDT_selected/'+mother_ID[1]+'_phipi_'+l_flv[self.l_index]+l_flv[self.l_index]+'/trigCats/'+str(self.trigCat)+'/'
		self.PLOT_DIR=l_flv[self.l_index]+'_fits/trigCat{0}/'.format(self.trigCat)
		self.year=year
	def perform_Full_Fit(self, BDT_CUT):

		self.set_paths()

		self.prepare_MC_workspaces(mother_index=0, mis_id_bkg=False)
		self.prepare_MC_workspaces(mother_index=1, mis_id_bkg=False)

		self.prepare_MC_workspaces(mother_index=0, mis_id_bkg=True)
		self.prepare_MC_workspaces(mother_index=1, mis_id_bkg=True)

		self.fit_MC(mother_index=0, mis_id_bkg=False)
		self.fit_MC(mother_index=1, mis_id_bkg=False)

		self.fit_MC(mother_index=0, mis_id_bkg=True)
		self.fit_MC(mother_index=1, mis_id_bkg=True)


		self.plot_MC(mother_index=0, mis_id_bkg=False)
		self.plot_MC(mother_index=1, mis_id_bkg=False)

		self.plot_MC(mother_index=0, mis_id_bkg=True)
		self.plot_MC(mother_index=1, mis_id_bkg=True)

		self.prepare_data_workspace(BDT_CUT=None)
		self.fit_data()
		self.plot_data(wantThesisStyle=False)
		self.plot_data(wantThesisStyle=True)
		self.splot(plot=True)

	def prepare_MC_workspaces(self, mother_index, mis_id_bkg):

		print("Preparing MC "+mother_ID[mother_index]+"_phipi_"+l_flv[self.l_index]+" workspace" )

		data_index=0

		MC_PATH = PATH+'MC/'+l_flv[self.l_index]+'_tuples/trigged/'+mother_ID[mother_index]+'_phipi_'+l_flv[self.l_index]+l_flv[self.l_index]+'/trigCats/'+str(self.trigCat)+'/'

		workspace_filename='fitAndsPlot_D_phipi_'+l_flv[self.l_index]+l_flv[self.l_index]+'_2016_trigCat'+str(self.trigCat)+'.root'
		workspace_path='/home/hep/davide/Rphipi_new/workspaces/'

		mpi = 140
		ml=0.5
		if mis_id_bkg==False:
			if mother_ID[mother_index]=='Dplus':

				mean_d_down=1780
				mean_d_startval=1850
				mean_d_up=1970

			if mother_ID[mother_index]=='Ds':

				mean_d_down=1850
				mean_d_startval=1950
				mean_d_up=2090

		if mis_id_bkg==True:

			if mother_ID[mother_index]=='Dplus':

				mean_d_down=1650
				mean_d_startval=1840
				mean_d_up=1950

			if mother_ID[mother_index]=='Ds':
				mean_d_down=1750
				mean_d_startval=1900
				mean_d_up=2080

		names = return_names(data_index, mother_index, mis_id_bkg)

		if mis_id_bkg==True:
			with open(MC_PATH+mother_ID[mother_index]+'_phipi_'+l_flv[self.l_index]+l_flv[self.l_index]+'_trigCat'+str(self.trigCat)+'.pickle', 'rb') as f:
				MC_sig_dict=pickle.load(f)

			data = MC_sig_dict[mother_ID[mother_index]+'_ConsD_M']

			deltam2 = (mpi*mpi)-(ml*ml)

			Epi=MC_sig_dict['pi_PE']
			El1=MC_sig_dict[l_flv[self.l_index]+'_minus_PE']
			El2=MC_sig_dict[l_flv[self.l_index]+'_plus_PE']


			data_shift = np.sqrt(data**2-deltam2*(2+Epi*((El1+El2)/((El1*El2)))+(((El1**2+El2**2))/((El1*El2)))))
			data_array=np.array([data_shift[i] for i in range(len(data))], dtype=[(mother_ID[mother_index]+'_ConsD_M', np.float32)])


			rn.array2root(data_array,
							filename=MC_PATH+mother_ID[mother_index]+'_phipi_'+l_flv[self.l_index]+l_flv[self.l_index]+'_trigCat'+str(self.trigCat)+'_for_misID_FIT.root',
							treename='DecayTree',
							mode='recreate',
							)
			f=r.TFile(MC_PATH+mother_ID[mother_index]+'_phipi_'+l_flv[self.l_index]+l_flv[self.l_index]+'_trigCat'+str(self.trigCat)+'_for_misID_FIT.root')
			tree=f.Get("DecayTree")

		else:

			f=r.TFile(MC_PATH+mother_ID[mother_index]+'_phipi_'+l_flv[self.l_index]+l_flv[self.l_index]+'_trigCat'+str(self.trigCat)+'.root')
			tree=f.Get("DecayTree")


		mass = np.array([0],dtype=np.float32)
		branch = tree.GetBranch(mother_ID[mother_index]+"_ConsD_M")
		branch.SetAddress(mass)

		num_entries=tree.GetEntries()

		m = r.RooRealVar(names['m_name'],names['m_name'],mean_d_down,mean_d_up)
		l = r.RooArgSet(m)


		data_set = r.RooDataSet(names['data_set_name'], names['data_set_name'], l)

		for i in range(num_entries):
			tree.GetEvent(i)
			r.RooAbsRealLValue.__assign__(m, mass[0])
			data_set.add(l, 1.0)

		mean_D= r.RooRealVar(names['mean_D_name'],names['mean_D_name'],mean_d_startval,mean_d_down,mean_d_up)
		sigma_D = r.RooRealVar(names['sigma_D_name'],names['sigma_D_name'],20,0.,200)

		alpha_D_l = r.RooRealVar(names['alpha_left_name'],names['alpha_left_name'],-0.020,-10.,10.)
		n_D_l = r.RooRealVar(names['n_left_name'],names['n_left_name'],0.020,-50.,50.)
		sig_D_l = r.RooCBShape(names['sig_pdf_left_name'], names['sig_pdf_left_name'], m, mean_D, sigma_D, alpha_D_l, n_D_l)

		alpha_D_r = r.RooRealVar(names['alpha_right_name'],names['alpha_right_name'],+0.020,-10.,10.)
		n_D_r = r.RooRealVar(names['n_right_name'],names['n_right_name'],0.020,-50.,50.)
		sig_D_r = r.RooCBShape(names['sig_pdf_right_name'],names['sig_pdf_right_name'], m, mean_D, sigma_D, alpha_D_r, n_D_r)

		frac_D = r.RooRealVar(names['frac_lr_name'],names['frac_lr_name'],0.5,0.,1.)

		model = r.RooAddPdf(names['model_name'],names['model_name'],r.RooArgList(sig_D_l,sig_D_r),r.RooArgList(frac_D))			
		

		if not os.path.exists(workspace_path+workspace_filename):    

			w = r.RooWorkspace(workspace_filename)

		elif os.path.exists(workspace_path+workspace_filename):

			f = r.TFile(workspace_path+workspace_filename)
			w = (f.Get(workspace_filename))

		getattr(w,'import')(data_set)
		getattr(w,'import')(mean_D)
		getattr(w,'import')(sigma_D)

		getattr(w,'import')(alpha_D_r)
		getattr(w,'import')(n_D_r)
		#getattr(w,'import')(sig_D_r)

		getattr(w,'import')(alpha_D_l)
		getattr(w,'import')(n_D_l)
		#getattr(w,'import')(sig_D_l)


		getattr(w,'import')(frac_D)
		getattr(w,'import')(model)

		w.Print()

		w.writeToFile(workspace_path+workspace_filename, r.kTRUE)	

		print('Done!')		

	def fit_MC(self, mother_index, mis_id_bkg):

		data_index=0

		print("Fitting MC "+mother_ID[mother_index]+"_phipi_"+l_flv[self.l_index] )

		names = return_names(data_index, mother_index, mis_id_bkg)

		"""
		Get workspace from file

		"""

		f = r.TFile(self.workspace_path+self.workspace_filename)
		w = f.Get(self.workspace_filename)

		w.Print()

		model = w.pdf(names['model_name'])
		data_set = w.data(names['data_set_name'])
		m = w.var(names['m_name'])

		"""
		Fit step
		
		"""

		model.fitTo(data_set,r.RooFit.Save())

		model_vars = w.allVars()
		getattr(w,'import')(model_vars)

		"""
		Update workspace rootfile
		
		"""

		w.writeToFile(self.workspace_path+self.workspace_filename, r.kTRUE)

		print('Done!')
		return

	def plot_MC(self, mother_index, mis_id_bkg):

		data_index=0

		print("Plotting MC "+mother_ID[mother_index]+"_phipi_"+l_flv[self.l_index] )

		names = return_names(data_index, mother_index, mis_id_bkg)

		"""
		Get workspace from file

		"""

		f = r.TFile(self.workspace_path+self.workspace_filename)
		w = f.Get(self.workspace_filename)

		w.Print()

		model = w.pdf(names['model_name'])
		data_set = w.data(names['data_set_name'])
		m = w.var(names['m_name'])


		frame_title="Fit to "+mother_ID[mother_index]+"_phipi_"+l_flv[self.l_index]+l_flv[self.l_index]+" MC trigCat"+str(self.trigCat)

		if mis_id_bkg: frame_title=frame_title+" misID"
		xframe = m.frame(r.RooFit.Title(frame_title))

		data_set.plotOn(xframe)
		model.plotOn(xframe)

		

		alpha_D_l=w.var(names['alpha_left_name'])
		alpha_D_r=w.var(names['alpha_right_name'])
		n_D_l=w.var(names['n_left_name'])
		n_D_r=w.var(names['n_right_name'])
	
		argset=r.RooArgSet(alpha_D_l, alpha_D_r, n_D_l, n_D_r)
		
		model.paramOn(xframe, r.RooFit.Layout(0.6,0.99,0.92), r.RooFit.Parameters(argset),r.RooFit.Format("NEU", r.RooFit.AutoPrecision(1)))
		
		xframe.getAttText().SetTextSize(0.03)
		c = r.TCanvas("Fit {0}".format(mother_index, mis_id_bkg), "Fit {0}".format(mother_index, mis_id_bkg), 900, 600)
		xframe.Draw()

		plot_file_name="MC"+mother_ID[mother_index]+l_flv[self.l_index]+l_flv[self.l_index]+"_"+str(self.trigCat)
		if mis_id_bkg: plot_file_name=plot_file_name+"misID"
		c.SaveAs(self.PLOT_DIR+plot_file_name+".png")

		print('Done!')

		return

	def prepare_data_workspace(self, BDT_CUT):

		print('Preparing data workspace')
		data_index=1
		mother_index=1

		if l_flv[self.l_index]=='mu':

			mean_d_down=1810
			mean_d_startval=1950
			mean_d_up=2040

		if l_flv[self.l_index]=='e':

			mean_d_down=1780
			mean_d_startval=1950
			mean_d_up=2080  

		names=return_names(data_index, mother_index)

		""""

		Retrieve BDT selected data

		"""

		f=r.TFile(self.SAVE_BDT_SELECTED_PATH+mother_ID[1]+'_phipi_'+l_flv[self.l_index]+l_flv[self.l_index]+'_BDT_selection.root')
		tree=f.Get("DecayTree")

		# if sys.version_info[0]>2:
		# 	with open(self.SAVE_BDT_SELECTED_PATH+mother_ID[1]+'_phipi_'+l_flv[self.l_index]+l_flv[self.l_index]+'_BDT_selection.pickle', 'rb') as f:
		# 		data_dict=pickle.load(f, encoding='latin1')

		# else:
		# 	with open(self.SAVE_BDT_SELECTED_PATH+mother_ID[1]+'_phipi_'+l_flv[self.l_index]+l_flv[self.l_index]+'_BDT_selection.pickle', 'rb') as f:
		# 		data_dict=pickle.load(f)

		""""

		Create RooDataSet

		"""

		##VERSION 1 VANILLA DATASET
		#mass = np.array([0],dtype=np.float32)
		#branch = tree.GetBranch(mother_ID[mother_index]+"_ConsD_M")
		#branch.SetAddress(mass)
		#
		#num_entries=tree.GetEntries()
		#
		#m = r.RooRealVar(names['m_name'],names['m_name'],mean_d_down,mean_d_up)
		#l = r.RooArgSet(m)
		#
		#data_set = r.RooDataSet(names['data_set_name'], names['data_set_name'], l)
		#
		#for i in range(num_entries):
		#	tree.GetEvent(i)
		#	if mean_d_down < mass[0] < mean_d_up:
		#		r.RooAbsRealLValue.__assign__(m, mass[0])
		#		data_set.add(l, 1.0)

		#VERSION 2 HARD CUT ON BDT_VALUE

		mass = np.array([0],dtype=np.float32)
		branch = tree.GetBranch(mother_ID[mother_index]+"_ConsD_M")
		branch.SetAddress(mass)
		
		bdt_weight = np.array([0],dtype=np.float32)
		weight_branch = tree.GetBranch("BDT_selection")
		weight_branch.SetAddress(bdt_weight)
		
		num_entries=tree.GetEntries()
		
		m = r.RooRealVar(names['m_name'],names['m_name'],mean_d_down,mean_d_up)
		bdt_weights = r.RooRealVar('bdt_weights', 'bdt_weights', 0., 1.)
		l = r.RooArgSet(m)
		
		data_set = r.RooDataSet(names['data_set_name'], names['data_set_name'], l)
		selected = []
		for i in range(num_entries):
			tree.GetEvent(i)
			if bdt_weight[0]>BDT_CUT:
				if mean_d_down < mass[0] < mean_d_up:
					selected.append(i)
					r.RooAbsRealLValue.__assign__(m, mass[0])
					data_set.add(l, 1.0)

		selected_np = np.array(selected)

		with open(self.SAVE_BDT_SELECTED_PATH+mother_ID[1]+'_phipi_'+l_flv[self.l_index]+l_flv[self.l_index]+'_BDT_selection.pickle') as f:
			data_dict=pickle.load(f)

		branches=data_dict.keys()
		data_BDT_selected={}

		for label in branches:
			data_BDT_selected[label]=data_dict[label][selected_np]

		FILE_PATH = self.SAVE_BDT_SELECTED_PATH+mother_ID[1]+'_phipi_'+l_flv[self.l_index]+l_flv[self.l_index]+'_BDT_selection_temp.pickle'
		if os.path.exists(FILE_PATH):

			print('Overwriting pickle file at'+FILE_PATH)

		else:
			print('Saving pickle file at'+FILE_PATH)


		with open(FILE_PATH, 'wb') as handle:
			pickle.dump(data_BDT_selected, handle, protocol=pickle.HIGHEST_PROTOCOL)
		
		#VERSION 3 WEIGHTED DATASET

		# mass = np.array([0],dtype=np.float32)
		# branch = tree.GetBranch(mother_ID[mother_index]+"_ConsD_M")
		# branch.SetAddress(mass)

		# bdt_weight = np.array([0],dtype=np.float32)
		# weight_branch = tree.GetBranch("BDT_selection")
		# weight_branch.SetAddress(bdt_weight)

		# num_entries=tree.GetEntries()

		# m = r.RooRealVar(names['m_name'],names['m_name'],mean_d_down,mean_d_up)
		# bdt_weights = r.RooRealVar('bdt_weights', 'bdt_weights', 0., 1.)
		# obs = r.RooArgSet(m, bdt_weights)

		# data_set = r.RooDataSet(names['data_set_name'], names['data_set_name'], obs, r.RooFit.WeightVar(bdt_weights))

		# for i in range(num_entries):
		# 	tree.GetEvent(i)
		# 	if mean_d_down < mass[0] < mean_d_up:
		# 			r.RooAbsRealLValue.__assign__(m, mass[0])
		# 			data_set.add(obs, bdt_weight[0])

		""""

		Load and update dataset

		"""				

		f = r.TFile(self.workspace_path+self.workspace_filename)
		w = f.Get(self.workspace_filename)


		getattr(w, 'import')(m)
		#getattr(w, 'import')(bdt_weights)
		getattr(w, 'import')(data_set)

		""""

		Prepare fit model

		"""
		##################
		#
		# Dplus SIGNAL PDF
		#
		##################

		mean_Dplus_MC=w.var("MC_mean_Dplus")
		sigma_Dplus_MC=w.var("MC_sigma_Dplus")

		mean_Dplus_MC.setConstant(True)
		sigma_Dplus_MC.setConstant(True)

		meanShift_Dplus=r.RooRealVar("mean_shift_Dplus","mean_shift_Dplus",1.5,-5,5)
		scaleSigma_Dplus=r.RooRealVar("scale_factor_Dplus","scale_factor_Dplus",1.12,1,5)

		mean_Dplus_shifted=r.RooFormulaVar("shifted_mean_Dplus","shifted_mean_Dplus","@0+@1",r.RooArgList(mean_Dplus_MC,meanShift_Dplus))
		sigma_Dplus_scaled=r.RooFormulaVar("scaled_sigma_Dplus","scaled_sigma_Dplus","@0*@1",r.RooArgList(sigma_Dplus_MC,scaleSigma_Dplus))

		frc_lr_Dplus_MC=w.var("MC_Dplus_lr_fraction")

		al_Dplus_left = w.var("MC_alpha_Dplus_left")
		n_Dplus_left = w.var("MC_n_Dplus_left")
		pdf_Dplus_left = r.RooCBShape("data_pdf_Dplus_left","data_pdf_Dplus_left", m, mean_Dplus_shifted, sigma_Dplus_scaled, al_Dplus_left, n_Dplus_left)

		al_Dplus_right = w.var("MC_alpha_Dplus_right")
		n_Dplus_right = w.var("MC_n_Dplus_right")
		pdf_Dplus_right = r.RooCBShape("data_pdf_Dplus_right","data_pdf_Dplus_right", m, mean_Dplus_shifted, sigma_Dplus_scaled, al_Dplus_right, n_Dplus_right)

		sig_Dplus = r.RooAddPdf("data_pdf_Dplus","data_pdf_Dplus",r.RooArgList(pdf_Dplus_left, pdf_Dplus_right),r.RooArgList(frc_lr_Dplus_MC))

		##################
		#
		# Dplus ->PiPiPi MisID PDF
		#
		##################

		# mean_Dplus_misID=w.var("MC_mean_Dplus_misID")
		# sigma_Dplus_misID=w.var("MC_sigma_Dplus_misID")

		# frc_lr_Dplus_misID=w.var("MC_Dplus_lr_fraction_misID")

		# al_Dplus_left_misID = w.var("MC_alpha_Dplus_left_misID")
		# n_Dplus_left_misID = w.var("MC_n_Dplus_left_misID")
		# pdf_Dplus_left_misID = r.RooCBShape("data_pdf_Dplus_left_misID","data_pdf_Dplus_left_misID", m, mean_Dplus_misID, sigma_Dplus_misID, al_Dplus_left_misID, n_Dplus_left_misID)

		# al_Dplus_right_misID = w.var("MC_alpha_Dplus_right_misID")
		# n_Dplus_right_misID = w.var("MC_n_Dplus_right_misID")
		# pdf_Dplus_right_misID = r.RooCBShape("data_pdf_Dplus_right_misID","data_pdf_Dplus_right_misID", m, mean_Dplus_misID, sigma_Dplus_misID, al_Dplus_right_misID, n_Dplus_right_misID)

		# bkg_Dplus_misID = r.RooAddPdf("data_pdf_Dplus_misID","data_pdf_Dplus_misID",r.RooArgList(pdf_Dplus_left_misID, pdf_Dplus_right_misID),r.RooArgList(frc_lr_Dplus_misID))

		##################
		#
		# Ds SIGNAL PDF
		#
		##################

		mean_Ds_MC=w.var("MC_mean_Ds")
		sigma_Ds_MC=w.var("MC_sigma_Ds")

		mean_Ds_MC.setConstant(True)
		sigma_Ds_MC.setConstant(True)

		meanShift_Ds=r.RooRealVar("mean_shift_Ds","mean_shift_Ds",1.5,-5,5)
		scaleSigma_Ds=r.RooRealVar("scale_factor_Ds","scale_factor_Ds",1.12,1,1.5)

		mean_Ds_shifted=r.RooFormulaVar("shifted_mean_Ds","shifted_mean_Ds","@0+@1",r.RooArgList(mean_Ds_MC,meanShift_Ds))
		sigma_Ds_scaled=r.RooFormulaVar("scaled_sigma_Ds","scaled_sigma_Ds","@0*@1",r.RooArgList(sigma_Ds_MC,scaleSigma_Ds))


		frc_lr_Ds_MC=w.var("MC_Ds_lr_fraction")

		al_Ds_left = w.var("MC_alpha_Ds_left")
		n_Ds_left = w.var("MC_n_Ds_left")
		pdf_Ds_left = r.RooCBShape("data_pdf_Ds_left","data_pdf_Ds_left", m, mean_Ds_shifted, sigma_Ds_scaled, al_Ds_left, n_Ds_left)

		al_Ds_right = w.var("MC_alpha_Ds_right")
		n_Ds_right = w.var("MC_n_Ds_right")
		pdf_Ds_right = r.RooCBShape("data_pdf_Ds_right","data_pdf_Ds_right", m, mean_Ds_shifted, sigma_Ds_scaled, al_Ds_right, n_Ds_right)

		sig_Ds = r.RooAddPdf("data_pdf_Ds","data_pdf_Ds",r.RooArgList(pdf_Ds_left, pdf_Ds_right),r.RooArgList(frc_lr_Ds_MC))

		##################
		#
		# Ds ->PiPiPi MisID PD
		#
		##################

		# mean_Ds_misID=w.var("MC_mean_Ds_misID")
		# sigma_Ds_misID=w.var("MC_sigma_Ds_misID")

		# frc_lr_Ds_misID=w.var("MC_Ds_lr_fraction_misID")

		# al_Ds_left_misID = w.var("MC_alpha_Ds_left_misID")
		# n_Ds_left_misID = w.var("MC_n_Ds_left_misID")
		# pdf_Ds_left_misID = r.RooCBShape("data_pdf_Ds_left_misID","data_pdf_Ds_left_misID", m, mean_Ds_misID, sigma_Ds_misID, al_Ds_left_misID, n_Ds_left_misID)

		# al_Ds_right_misID = w.var("MC_alpha_Ds_right_misID")
		# n_Ds_right_misID = w.var("MC_n_Ds_right_misID")
		# pdf_Ds_right_misID = r.RooCBShape("data_pdf_Ds_right_misID","data_pdf_Ds_right_misID", m, mean_Ds_misID, sigma_Ds_misID, al_Ds_right_misID, n_Ds_right_misID)

		# bkg_Ds_misID = r.RooAddPdf("data_pdf_Ds_misID","data_pdf_Ds_misID",r.RooArgList(pdf_Ds_left_misID, pdf_Ds_right_misID),r.RooArgList(frc_lr_Ds_misID))

		#################
		#
		#
		# Bkg PDF
		#
		#############

		lambd = r.RooRealVar("lambda","lambda",-5e-4,-0.1,0.1)
		bkg = r.RooExponential("pdf_comb_bkg","pdf_comb_bkg",m,lambd)

		#################
		#
		#
		# Yields and total model
		#
		#############

		yield_Ds=r.RooRealVar("yield_Ds","yield_Ds",100.,0.,40000.)
		yield_Dplus=r.RooRealVar("yield_Dplus","yield_Dplus",100.,0.,40000.)

		yield_Ds_misID=r.RooRealVar("yield_Ds_misID","yield_Ds_misID",100.,0.,200.)
		yield_Dplus_misID=r.RooRealVar("yield_Dplus_misID","yield_Dplus_misID",100.,0.,200.)

		yield_comb_bkg=r.RooRealVar("yeld_comb_bkg","yeld_comb_bkg",100.,0.,1000.)

		#model = r.RooAddPdf("fit_model","fit_model",r.RooArgList(sig_Dplus, sig_Ds, bkg_Dplus_misID, bkg_Ds_misID, bkg),r.RooArgList(yield_Dplus, yield_Ds, yield_Dplus_misID, yield_Ds_misID, yield_comb_bkg))	
		# model = r.RooAddPdf("fit_model","fit_model",r.RooArgList(sig_Dplus, sig_Ds, bkg_Dplus_misID, bkg_Ds_misID, bkg),r.RooArgList(yield_Dplus, yield_Ds, yield_Dplus_misID, yield_Ds_misID, yield_comb_bkg))	
		# model = r.RooAddPdf("fit_model","fit_model",r.RooArgList(sig_Dplus, sig_Ds, bkg_Dplus_misID, bkg_Ds_misID, bkg),r.RooArgList(yield_Dplus, yield_Ds, yield_Dplus_misID, yield_Ds_misID, yield_comb_bkg))	
		model = r.RooAddPdf("fit_model","fit_model",r.RooArgList(sig_Dplus, sig_Ds, bkg),r.RooArgList(yield_Dplus, yield_Ds, yield_comb_bkg))	

		bdt_cut = r.RooRealVar("bdt_cut","BDT Working Point", 0.,1.)
		bdt_cut.setVal(BDT_CUT)
		bdt_cut.setConstant(True)

		low_mass_cut = r.RooRealVar("low_mass_cut","Lower mass cut for fit", 1000.,2000.)
		low_mass_cut.setVal(mean_d_down)
		low_mass_cut.setConstant(True)

		up_mass_cut = r.RooRealVar("up_mass_cut","Upper mass cut for fit", 1000.,2000.)
		up_mass_cut.setVal(mean_d_up)
		up_mass_cut.setConstant(True)

		getattr(w, 'import')(bdt_cut)
		getattr(w, 'import')(low_mass_cut)
		getattr(w, 'import')(up_mass_cut)
		getattr(w, 'import')(lambd)	
		getattr(w, 'import')(yield_Dplus)
		getattr(w, 'import')(yield_Ds)

		getattr(w, 'import')(yield_Dplus_misID)
		getattr(w, 'import')(yield_Ds_misID)
		getattr(w, 'import')(yield_comb_bkg)
		getattr(w, 'import')(model)	

		w.Print()
		w.writeToFile(self.workspace_path+self.workspace_filename, r.kTRUE)


		print('Done!')
		return

	def fit_data(self, mother_index=1):

		data_index=1

		print("Fitting data D(s)_phipi_"+l_flv[self.l_index]+l_flv[self.l_index])

		names = return_names(data_index, mother_index)

		"""
		Get workspace from file

		"""

		f = r.TFile(self.workspace_path+self.workspace_filename)
		w = f.Get(self.workspace_filename)

		w.Print()

		model = w.pdf(names['model_name'])
		data_set = w.data(names['data_set_name'])

		m = w.var(names['m_name'])

		print(model, data_set, m)
		"""
		Fit step and save
		
		"""

		fitr = model.fitTo(data_set,r.RooFit.Extended(),r.RooFit.Save())

		xframe = m.frame()
		data_set.plotOn(xframe)
		model.plotOn(xframe, r.RooFit.LineColor(r.kBlack), r.RooFit.LineStyle(2))
		hpull=xframe.pullHist()

		chisq = r.RooRealVar("reduced_chisq","#Chi^2 / n.d.o.f.", 0., 10.)
		n_param = fitr.floatParsFinal().getSize()
		reduced_chi_square = xframe.chiSquare(n_param)
		chisq.setVal(reduced_chi_square)

		model_vars = w.allVars()
		model_pdfs = w.allPdfs()
		getattr(w,'import')(model_vars)
		getattr(w,'import')(chisq)
		getattr(w,'import')(model_pdfs)

		w.Print()
		w.writeToFile(self.workspace_path+self.workspace_filename, r.kTRUE)

		print('Done!')
		return	

	def plot_data(self, mother_index=1, wantThesisStyle=None):

		mother_index=1
		data_index=1

		names = return_names(data_index, mother_index)   
		f = r.TFile(self.workspace_path+self.workspace_filename)
		w = f.Get(self.workspace_filename)
		w.Print()

		model = w.pdf(names['model_name'])
		data_set = w.data(names['data_set_name'])
		m = w.var(names['m_name'])
		bkg = w.pdf("pdf_comb_bkg")
		sig_Dplus = w.pdf("data_pdf_Dplus")
		sig_Ds = w.pdf("data_pdf_Ds")
		bkg_Dplus_misID=w.pdf("data_pdf_Dplus_misID")
		bkg_Ds_misID=w.pdf("data_pdf_Ds_misID")
		lambd = w.var("lambda")
		mean_Dplus = w.var("MC_mean_Dplus")
		mean_Ds = w.var("MC_mean_Ds")
		sigma_Dplus = w.var("MC_sigma_Dplus")
		sigma_Ds = w.var("MC_sigma_Ds")
		meanShift_Dplus = w.var("mean_shift_Dplus")
		scaleSigma_Dplus = w.var("scale_factor_Dplus")
		meanShift_Ds = w.var("mean_shift_Ds")
		scaleSigma_Ds = w.var("scale_factor_Ds")
		chisq=w.var("reduced_chisq")

		mean_Ds.setConstant(False)
		mean_Dplus.setConstant(False)

		sigma_Dplus.setConstant(False)
		sigma_Ds.setConstant(False)
		chisq.setConstant(False)

		xframe = m.frame(r.RooFit.Title("Fit to "+l_flv[self.l_index]+" data trigCat"+str(self.trigCat)))

		if wantThesisStyle==False:
		    
			argset=r.RooArgSet(chisq, mean_Dplus, mean_Ds, sigma_Dplus, sigma_Ds, meanShift_Dplus, scaleSigma_Dplus, meanShift_Ds, scaleSigma_Ds)
			model.paramOn(xframe, r.RooFit.Layout(0.63,0.95,0.90), r.RooFit.Parameters(argset),r.RooFit.Format("NEU", r.RooFit.AutoPrecision(2)))
			xframe.getAttText().SetTextSize(0.038)
			#xframe.getAttFill(names["model_name"]+"_paramBox").SetFillStyle(0)
			xframe.getAttFill(names["model_name"]+"_paramBox").SetFillColor(r.kWhite)
			xframe.getAttFill(names["model_name"]+"_paramBox").SetLineWidth(0)

		third_set = r.RooArgSet(bkg,sig_Dplus,sig_Ds)
		fourth_set = r.RooArgSet(bkg,sig_Dplus)
		fifth_set = r.RooArgSet(bkg)
		data_set.plotOn(xframe)
		model.plotOn(xframe, r.RooFit.LineColor(r.kBlack),
		            r.RooFit.LineStyle(2))
		col1 = r.gROOT.GetColor(92);
		col1.SetRGB(0.992, 0.6823, 0.3804);  #orange for comb
		col2 = r.gROOT.GetColor(93);
		col2.SetRGB(0.6706, 0.8510, 0.9137); #light blue for Ds
		col3 = r.gROOT.GetColor(94);
		col3.SetRGB(0.8431, 0.098, 0.1098);  #red for charm prc
		col4 = r.gROOT.GetColor(95);        #dark blue for Dplus
		col4.SetRGB(0.1725, 0.4824, 0.7137);

		model.plotOn(xframe,r.RooFit.Name("Ds_component"), r.RooFit.Components(third_set),r.RooFit.FillColor(93),
					r.RooFit.DrawOption("f"), r.RooFit.LineColor(93),)
		model.plotOn(xframe,r.RooFit.Name("Dplus_component"), r.RooFit.Components(fourth_set),r.RooFit.FillColor(95),
					r.RooFit.DrawOption("f"), r.RooFit.LineColor(95),)
		model.plotOn(xframe,r.RooFit.Name("Combinatorial"), r.RooFit.Components(fifth_set),r.RooFit.FillColor(92),
					r.RooFit.DrawOption("f"), r.RooFit.LineColor(92),)
		model.plotOn(xframe , r.RooFit.Name("tot"), r.RooFit.LineColor(r.kBlack), r.RooFit.LineStyle(2))
		data_set.plotOn(xframe, r.RooFit.Name("data"))
		hpull=xframe.pullHist()

		xframe2 = m.frame(r.RooFit.Title("Pulls"))
		xframe2.addPlotable(hpull,"P")


		canvTot = r.TCanvas("Fit","Fit",900,600)
		pad1 = r.TPad("pad1", "pad1", 0, 0.35, 1, 1.0)
		pad1.SetBottomMargin(0)
		pad1.Draw()
		canvTot.cd()
		pad2 = r.TPad("pad2", "pad2", 0, 0.05, 1, 0.35)
		pad2.SetTopMargin(0)
		pad2.SetBottomMargin(0.2)
		pad2.SetGridx()
		pad2.SetGridy()
		pad2.Draw()

		pad1.cd()
		xframe.Draw()


		text=r.TPaveText(-0.5, 1.25, 0.92, 0.45, "NDC");
		text.AddText( "#Chi^{2}/NDOF = %.4f" % chisq.getVal())
		text.SetTextSize(0.035);
		text.SetFillColor(0);
		text.SetFillStyle(0);
		text.SetLineColor(0);
		text.SetLineWidth(0);
		text.SetBorderSize(1);
		text.Draw("same")

		leg = r.TLegend(0.78,0.68,0.9,0.9)
		leg.AddEntry("data", "Data", "lep")
		leg.AddEntry("tot", "Total Fit", "l")
		leg.AddEntry("Ds_component", "D_{s}^{#plus} #rightarrow #phi(e^{#plus} e^{#minus}) #pi^{#plus}", "f")
		leg.AddEntry("Dplus_component", "D^{#plus} #rightarrow #phi(e^{#plus} e^{#minus}) #pi^{#plus}", "f")
		leg.AddEntry("Combinatorial", "Comb bkg", "f")
		leg.SetTextSize(xframe.GetYaxis().GetTitleSize());
		leg.SetLineColor(0);
		leg.SetFillColor(0);

		pad2.cd()
		xframe2.Draw()
		xframe2.GetXaxis().SetTitleSize(0.09)
		xframe2.GetXaxis().SetLabelSize(0.1)
		xframe2.GetYaxis().SetLabelSize(0.05)

		if wantThesisStyle==False:


			filename = "Fit_D(s)_phipi_"+l_flv[self.l_index]+l_flv[self.l_index]+".png"
			canvTot.SaveAs(self.PLOT_DIR+filename)

			pad1.DrawClone()
			pad1.SetLogy()

			filename = "Fit_D(s)_phipi_"+l_flv[self.l_index]+l_flv[self.l_index]+"_logy.png"
			canvTot.SaveAs(self.PLOT_DIR+filename)


		elif wantThesisStyle==True:

			pad1.cd()
			leg.Draw("same")
			filename = "Fit_D(s)_phipi_"+l_flv[self.l_index]+l_flv[self.l_index]+"_thesisStyle.png"
			canvTot.SaveAs(self.PLOT_DIR+filename)

			pad1.DrawClone()
			pad1.SetLogy()

			filename = "Fit_D(s)_phipi_"+l_flv[self.l_index]+l_flv[self.l_index]+"_logy_thesisStyle.png"
			canvTot.SaveAs(self.PLOT_DIR+filename)

	def splot(self, mother_index=1, plot=True):

		data_index=1
		names = return_names(data_index, mother_index)   
		f = r.TFile(self.workspace_path+self.workspace_filename)
		w = f.Get(self.workspace_filename)

		w.Print()

		model = w.pdf(names['model_name'])
		data_set = w.data(names['data_set_name'])
		m = w.var(names['m_name'])
		bkg = w.pdf("pdf_comb_bkg")
		sig_Dplus = w.pdf("data_pdf_Dplus")
		sig_Ds = w.pdf("data_pdf_Ds")
		bdt_cut = w.var("bdt_cut")
		low_mass_cut = w.var("low_mass_cut")
		up_mass_cut = w.var("up_mass_cut")
		yield_Dplus=w.var("yield_Dplus")
		yield_Ds=w.var("yield_Ds")
		yield_comb_bkg=w.var("yeld_comb_bkg")


		splot = r.RooStats.SPlot('sData','sData', data_set, model, r.RooArgList(yield_Dplus, yield_Ds, yield_comb_bkg)
		            )

		yeld_comb_bkg_sw = splot.GetYieldFromSWeight("yeld_comb_bkg_sw")
		yield_Dplus_sw= splot.GetYieldFromSWeight("yield_Dplus_sw")
		yield_Ds_sw = splot.GetYieldFromSWeight("yield_Ds_sw")

		yeld_comb_bkg_sw_var = r.RooRealVar("sweighted_comb_bkg_yield","sweighted_comb_bkg_yield", 0., 10000.)
		yield_Dplus_sw_var = r.RooRealVar("sweighted_Dplus_sw_yield","sweighted_Dplus_sw_yield", 0., 10000.)
		yield_Ds_sw_var = r.RooRealVar("sweighted_Ds_sw_yield","sweighted_Ds_sw_yield", 0., 10000.)


		getattr(w, 'import')(yeld_comb_bkg_sw_var)
		getattr(w, 'import')(yield_Dplus_sw_var)
		getattr(w, 'import')(yield_Ds_sw_var)

		w.Print()
		w.writeToFile(self.workspace_path+self.workspace_filename, r.kTRUE)

		print('Done!')

		nentries=int(data_set.sumEntries())
		Dplus_sWeights=np.array([splot.GetSWeight(i, "yield_Dplus_sw") for i in range(nentries)])
		Ds_sWeights=np.array([splot.GetSWeight(i, "yield_Ds_sw") for i in range(nentries)])
		comb_bkg_sWeights=np.array([splot.GetSWeight(i, "yeld_comb_bkg_sw") for i in range(nentries)])

		with open(self.SAVE_BDT_SELECTED_PATH+mother_ID[1]+'_phipi_'+l_flv[self.l_index]+l_flv[self.l_index]+'_BDT_selection_temp.pickle') as f:
			data_dict=pickle.load(f)

		branches=data_dict.keys()
		branches_updated = branches + ["Dplus_sWeights","Ds_sWeights","Comb_bkg_sWeights"]

		data_dict["Dplus_sWeights"]=Dplus_sWeights
		data_dict["Ds_sWeights"]=Ds_sWeights
		data_dict["Comb_bkg_sWeights"]=comb_bkg_sWeights

		assert nentries == data_dict["Ds_ConsD_M"].shape[0], "Mismatching number of events in two load methods"

		print(data_dict["Ds_ConsD_M"].shape[0])

		FILE_PATH = self.SAVE_BDT_SELECTED_PATH+mother_ID[1]+'_phipi_'+l_flv[self.l_index]+l_flv[self.l_index]+'_'+self.year+'_sWeighted.pickle'

		if os.path.exists(FILE_PATH):

			print('Overwriting pickle file at'+FILE_PATH)

		else:
			print('Saving pickle file at'+FILE_PATH)


		with open(FILE_PATH, 'wb') as handle:
			pickle.dump(data_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

		FILE_PATH = self.SAVE_BDT_SELECTED_PATH+mother_ID[1]+'_phipi_'+l_flv[self.l_index]+l_flv[self.l_index]+'_'+self.year+'_sWeighted.root'

		if os.path.exists(FILE_PATH):
			print('Overwriting root file at'+FILE_PATH)
			mode = 'recreate'
		else:
			print('Saving root file at'+FILE_PATH)
			mode = 'create'


		tuple_array= np.array( 
						[
							tuple(data_dict[label][k] for label in branches_updated )
							for k in range(nentries)
						],
							dtype=[(label, np.float32) for label in branches_updated]
					)

		rn.array2root(tuple_array,
		filename=FILE_PATH,
		treename='DecayTree',
		mode=mode,
		)

		Dplus_sw=splot.GetSWeightVars()[0]
		Ds_sw=splot.GetSWeightVars()[1]
		comb_bkg_sw=splot.GetSWeightVars()[2]

		if plot:

			Ds_sw.setRange(-0.5,1.5)

			frame_Ds_sw = Ds_sw.frame(r.RooFit.Title("signal Ds sWeights"))
			data_set.plotOn(frame_Ds_sw, r.RooFit.MarkerSize(0.05))
			Dplus_sw.setRange(-1.,1.5)

			frame_Dplus_sw = Dplus_sw.frame(r.RooFit.Title("signal Dplus sWeights"))
			data_set.plotOn(frame_Dplus_sw, r.RooFit.MarkerSize(0.05))
			   
			comb_bkg_sw.setRange(-.6,1.2)
			frame_bkg_sw = comb_bkg_sw.frame(r.RooFit.Title("background sWeights"))
			data_set.plotOn(frame_bkg_sw, r.RooFit.MarkerSize(0.05))
			   
			sframe = m.frame(r.RooFit.Title("sWeights vs D mass"))
			data_set.plotOnXY(sframe, r.RooFit.YVar(Ds_sw),
			                r.RooFit.MarkerColor(r.kGreen), r.RooFit.Name("Ds_sig"))
			data_set.plotOnXY(sframe, r.RooFit.YVar(Dplus_sw),
			                r.RooFit.MarkerColor(r.kBlue), r.RooFit.Name("Dplus_sig"))
			data_set.plotOnXY(sframe, r.RooFit.YVar(comb_bkg_sw),
			                r.RooFit.MarkerColor(r.kRed), r.RooFit.Name("bkg"))

			legend = r.TLegend(0.89, 0.89, 0.5, 0.7)
			legend.AddEntry(sframe.findObject('Ds_sig'), 'Signal Ds sWeights', 'p')
			legend.AddEntry(sframe.findObject('Dplus_sig'), 'Signal Dplus sWeights', 'p')
			legend.AddEntry(sframe.findObject('bkg'), 'Background sWeights', 'p')

			c = r.TCanvas("sWeights {0}".format(0),"sWeights {0}".format(0),900,900)
			c.Divide(2,2)
			c.cd(1)
			frame_Ds_sw.Draw()
			c.cd(2)
			frame_Dplus_sw.Draw()
			c.cd(3)
			frame_bkg_sw.Draw()
			c.cd(4)
			sframe.Draw()
			legend.Draw()

			filename = "D(s)_phipi_"+l_flv[self.l_index]+l_flv[self.l_index]+"_sWeights.png"
			c.SaveAs(self.PLOT_DIR+filename)




def return_names(data_index , mother_index, mis_id_bkg=False):

	if data_index==0:
		names={}
		names['m_name']=data_type[data_index]+"_"+mother_ID[mother_index]+"_m"

		names['mean_D_name']=data_type[data_index]+"_mean_"+mother_ID[mother_index]
		names['sigma_D_name']=data_type[data_index]+"_sigma_"+mother_ID[mother_index]

		names['alpha_left_name']=data_type[data_index]+"_alpha_"+mother_ID[mother_index]+"_left"
		names['n_left_name' ]=data_type[data_index]+"_n_"+mother_ID[mother_index]+"_left"
		names['sig_pdf_left_name' ]=data_type[data_index]+"_pdf_"+mother_ID[mother_index]+"_left"

		names['alpha_right_name']=data_type[data_index]+"_alpha_"+mother_ID[mother_index]+"_right"
		names['n_right_name'] =data_type[data_index]+"_n_"+mother_ID[mother_index]+"_right"
		names['sig_pdf_right_name'] =data_type[data_index]+"_pdf_"+mother_ID[mother_index]+"_right"

		names['frac_lr_name']=data_type[data_index]+'_'+mother_ID[mother_index]+"_lr_fraction"
		names['model_name']=data_type[data_index]+'_'+mother_ID[mother_index]+"_model"
		names['data_set_name']=data_type[data_index]+'_'+mother_ID[mother_index]+'_ConsD_M'

		if mis_id_bkg==True:

			for key in names:
				names[key]=names[key]+'_misID'

	if data_index==1:
		names={}
		names['m_name']=data_type[data_index]+"_"+mother_ID[mother_index]+"_m"
		names['data_set_name']=data_type[data_index]+'_'+mother_ID[mother_index]+'_ConsD_M'
		names['model_name']='fit_model'

	return names