Newer
Older
FairShipTools / elena_ShipAna.py
@Ubuntu Ubuntu on 2 Mar 2015 31 KB signal and BG reco efficiency
# example for accessing smeared hits and fitted tracks
import os,sys,getopt,subprocess,gc
from operator import mul, add

sys.path.append("../../FairShipTools")
import funcsByBarbara as tools

#os.chdir("../../FairShipRun")
#os.system("bash config.sh")
#os.chdir("../FairShip/macro")
## Tanto non funziona
#subprocess.call("cd ../../FairShipRun", shell=True)
#subprocess.call(". config.sh", shell=True)
#subprocess.call("cd ../FairShip/macro", shell=True)

import ROOT
import rootUtils as ut
import shipunit as u
from ShipGeoConfig import ConfigRegistry

PDG = ROOT.TDatabasePDG.Instance()
inputFile = None
dy = None
nEvents   = 999999#99999
theHNLcode = 9900015
signal_file = False
bg_file = False
cosmics_file = False
file_type = None
oldSW = False

try:
        opts, args = getopt.getopt(sys.argv[1:], "n:t:f:o:A:Y:i", ["nEvents=", "type="])
except getopt.GetoptError:
        # print help information and exit:
        print ' enter file name'
        sys.exit()
for o, a in opts:
        if o in ("-f"):
            inputFile = a
        if o in ("-o"):
            outputFile = a
        if o in ("-Y"): 
            dy = float(a)
        if o in ("-n", "--nEvents="):
            nEvents = int(a)
        if o in ("-t", "--type="):
            file_type = str(a)

if not file_type:
  print " please select file type (sig or bg)"
  sys.exit()

if 'old' in file_type:
  oldSW = True
if 'sig' in file_type:
  signal_file = True
elif 'bg' in file_type:
  bg_file = True
elif 'cosmics' in file_type:
  cosmics_file = True
else:
  print " please select file type (sig or bg)"
  sys.exit()

if 'Cosmics' in inputFile:
  print 
  print '\t found Cosmics file'
  cosmics_file = True

print
print "\tFile type is: "+file_type
print

#pdgdb = TDatabasePDG.Instance()

if not outputFile:
  outputFile = "ShipAna.root"

# If directory of output file does not exist, create it (depth=3)
if not os.path.exists(os.path.dirname(outputFile)):
  if not os.path.exists(os.path.dirname(os.path.dirname(outputFile))):
    if not os.path.exists(os.path.dirname(os.path.dirname(os.path.dirname(outputFile)))):
      os.system("mkdir "+os.path.dirname(os.path.dirname(os.path.dirname(outputFile))))
    os.system("mkdir "+os.path.dirname(os.path.dirname(outputFile)))
  os.system("mkdir "+os.path.dirname(outputFile))

if not dy:
  # try to extract from input file name
  tmp = inputFile.replace(os.path.dirname(inputFile),'')
  tmp = tmp.split('.')
  try:
    dy = float( tmp[1]+'.'+tmp[2] )
  except:
    dy = None
#else:
# inputFile = 'ship.'+str(dy)+'.Pythia8-TGeant4_rec.root'

if not inputFile:
  inputFile = 'ship.'+str(dy)+'.Pythia8-TGeant4_rec.root'
  
f     = ROOT.TFile(inputFile)
sTree = f.cbmsim

# init geometry and mag. field
ShipGeo = ConfigRegistry.loadpy("$FAIRSHIP/geometry/geometry_config.py", Yheight = dy )
# -----Create geometry----------------------------------------------
import shipDet_conf
run = ROOT.FairRunSim()
modules = shipDet_conf.configure(run,ShipGeo)

tgeom = ROOT.TGeoManager("Geometry", "Geane geometry")
geofile = inputFile.replace('ship.','geofile_full.').replace('_rec.','.')
gMan  = tgeom.Import(geofile)
geoMat =  ROOT.genfit.TGeoMaterialInterface()
ROOT.genfit.MaterialEffects.getInstance().init(geoMat)

bfield = ROOT.genfit.BellField(ShipGeo.Bfield.max, ShipGeo.Bfield.z, 2, ShipGeo.Yheight/2.)
fM = ROOT.genfit.FieldManager.getInstance()
fM.init(bfield)

ev     = ut.PyListOfLeaves()
leaves = sTree.GetListOfLeaves()
names  = ut.setAttributes(ev, leaves)

# Correct weights for neutrinos
weightHistFile = ROOT.TFile("../../FairShipTools/histoForWeights.root","read")
weightHist = weightHistFile.Get("h_Gio")

noHcal = False

# Get geometry positions
if oldSW:
    nodes = tools.searchForNodes2_xyz_dict(geofile)
else:
    nodes = tools.searchForNodes3_xyz_dict(geofile)
#scintTankW = [nodes["lidT1lisci_1"]['z']['pos']-nodes["lidT1lisci_1"]['z']['dim'], nodes["lidT1lisci_1"]['z']['pos']+nodes["lidT1lisci_1"]['z']['dim']]
#scintWPos = [scintTankW]
trackStationsPos = [ [nodes["Tr%s_%s"%(i,i)]['z']['pos']-nodes["Tr%s_%s"%(i,i)]['z']['dim'], nodes["Tr%s_%s"%(i,i)]['z']['pos']+nodes["Tr%s_%s"%(i,i)]['z']['dim']] for i in xrange(1,5)]
muonStationsPos = [ [nodes["muondet%s_1"%(i)]['z']['pos']-nodes["muondet%s_1"%(i)]['z']['dim'], nodes["muondet%s_1"%(i)]['z']['pos']+nodes["muondet%s_1"%(i)]['z']['dim']] for i in xrange(0,4)]
strawVetoPos = [ [nodes["Veto_5"]['z']['pos']-nodes["Veto_5"]['z']['dim'], nodes["Veto_5"]['z']['pos']+nodes["Veto_5"]['z']['dim']] ]
# ecal points are before the start of the ecal, systematically.
ecalPos = [ [nodes['Ecal_1']['z']['pos']-2.*nodes['Ecal_1']['z']['dim'], nodes['Ecal_1']['z']['pos']+nodes['Ecal_1']['z']['dim']] ]
try:
    hcalPos = [ [nodes['Hcal_1']['z']['pos']-nodes['Hcal_1']['z']['dim'], nodes['Hcal_1']['z']['pos']+nodes['Hcal_1']['z']['dim']] ]
except:
    print "\t I did not find the Hcal!!! Continuing without it."
    noHcal =  True
if oldSW:
    volume = [nodes["lidT1O_1"]['z']['pos']-nodes["lidT6O_1"]['z']['dim'],nodes["lidT6O_1"]['z']['pos']-nodes["lidT6O_1"]['z']['dim']]
    vetoWall = [ [volume[0], nodes['Tr1_1']['z']['pos']-nodes['Tr1_1']['z']['dim']] ]
else:
    vetoWall = [ [nodes['VetoTimeDet_1']['z']['pos']+nodes['VetoTimeDet_1']['z']['dim']+0.001 , nodes['Tr1_1']['z']['pos']-nodes['Tr1_1']['z']['dim']] ]
print '\t stored some geometry nodes'
#print 'ecal pos: ', ecalPos[0]


h = {}
#ut.bookHist(h,'delPOverP','delP / P',100,0.,50.,100,-0.5,0.5)
#ut.bookHist(h,'delPtOverP','delPt / P',100,0.,50.,100,-0.5,0.5)
#ut.bookHist(h,'delPtOverPt','delPt / Pt',100,0.,50.,100,-0.5,0.5)
#ut.bookHist(h,'delPOverP2','delP / P chi2/nmeas<25',100,0.,50.,100,-0.5,0.5)
#ut.bookHist(h,'delPOverPz','delPz / Pz',100,0.,50.,100,-0.5,0.5)
#ut.bookHist(h,'delPOverP2z','delPz / Pz chi2/nmeas<25',100,0.,50.,100,-0.5,0.5)
#ut.bookHist(h,'chi2','chi2/nmeas after trackfit',100,0.,100.)
#ut.bookHist(h,'IP','Impact Parameter',100,0.,10.)
#ut.bookHist(h,'meas','number of measurements',40,-0.5,39.5)
#ut.bookHist(h,'measVSchi2','number of measurements vs chi2/meas',40,-0.5,39.5,100,0.,100.)
#ut.bookHist(h,'distu','distance to wire',100,0.,1.)
#ut.bookHist(h,'distv','distance to wire',100,0.,1.)
#ut.bookHist(h,'disty','distance to wire',100,0.,1.)
#ut.bookHist(h,'meanhits','mean number of hits / track',50,-0.5,49.5)
#ut.bookHist(h,'Pt','Reconstructed transverse momentum',100,0.,40.)
ut.bookHist(h,'Candidate-DOCA','DOCA between the two tracks',300,0.,50.)
ut.bookHist(h,'Candidate-IP0','Impact Parameter to target',200,0.,300.)
ut.bookHist(h,'Candidate-IP0/mass','Impact Parameter to target vs mass',100,0.,2.,100,0.,100.)
ut.bookHist(h,'Candidate-Mass','reconstructed Mass',100,0.,4.)
ut.bookHist(h,'Candidate-Pt','reconstructed Pt',100,0.,40.)
ut.bookHist(h,'Candidate-vtxx','X position of reconstructed vertex',150,-4000.,4000.)
ut.bookHist(h,'Candidate-vtxy','Y position of reconstructed vertex',150,-4000.,4000.)
ut.bookHist(h,'Candidate-vtxz','Z position of reconstructed vertex',150,-4000.,4000.)
ut.bookHist(h,'CandidateDaughters-Pt','reconstructed Pt of candidate daughters',100,0.,40.)
ut.bookHist(h,'CandidateDaughters-chi2','chi2/nmeas after trackfit',100,0.,100.)

def myVertex(t1,t2,PosDir):
 # closest distance between two tracks
   V=0
   for i in range(3):   V += PosDir[t1][1](i)*PosDir[t2][1](i)
   S1=0
   for i in range(3):   S1 += (PosDir[t1][0](i)-PosDir[t2][0](i))*PosDir[t1][1](i)
   S2=0
   for i in range(3):   S2 += (PosDir[t1][0](i)-PosDir[t2][0](i))*PosDir[t2][1](i)
   l = (S2-S1*V)/(1-V*V)
   x2 = PosDir[t2][0](0)+l*PosDir[t2][1](0)
   y2 = PosDir[t2][0](1)+l*PosDir[t2][1](1)
   z2 = PosDir[t2][0](2)+l*PosDir[t2][1](2)
   x1 = PosDir[t1][0](0)+l*PosDir[t1][1](0)
   y1 = PosDir[t1][0](1)+l*PosDir[t1][1](1)
   z1 = PosDir[t1][0](2)+l*PosDir[t1][1](2)
   dist = ROOT.TMath.Sqrt((x1-x2)**2+(y1-y2)**2+(z1-z2)**2)
   return (x1+x2)/2.,(y1+y2)/2.,(z1+z2)/2.,dist

def fitSingleGauss(x,ba=None,be=None):
    name    = 'myGauss_'+x 
    myGauss = h[x].GetListOfFunctions().FindObject(name)
    if not myGauss:
       if not ba : ba = h[x].GetBinCenter(1) 
       if not be : be = h[x].GetBinCenter(h[x].GetNbinsX()) 
       bw    = h[x].GetBinWidth(1) 
       mean  = h[x].GetMean()
       sigma = h[x].GetRMS()
       norm  = h[x].GetEntries()*0.3
       myGauss = ROOT.TF1(name,'[0]*'+str(bw)+'/([2]*sqrt(2*pi))*exp(-0.5*((x-[1])/[2])**2)+[3]',4)
       myGauss.SetParameter(0,norm)
       myGauss.SetParameter(1,mean)
       myGauss.SetParameter(2,sigma)
       myGauss.SetParameter(3,1.)
       myGauss.SetParName(0,'Signal')
       myGauss.SetParName(1,'Mean')
       myGauss.SetParName(2,'Sigma')
       myGauss.SetParName(3,'bckgr')
    h[x].Fit(myGauss,'','',ba,be) 


def makePlots():
   ut.bookCanvas(h,key='strawanalysis',title='Distance to wire and mean nr of hits',nx=1200,ny=600,cx=2,cy=1)
   cv = h['strawanalysis'].cd(1)
   h['disty'].Draw()
   h['distu'].Draw('same')
   h['distv'].Draw('same')
   cv = h['strawanalysis'].cd(2)
   h['meanhits'].Draw()
   ut.bookCanvas(h,key='fitresults',title='Fit Results',nx=1600,ny=1200,cx=2,cy=2)
   cv = h['fitresults'].cd(1)
   h['delPOverPz'].Draw('box')
   cv = h['fitresults'].cd(2)
   cv.SetLogy(1)
   h['chi2'].Draw()
   cv = h['fitresults'].cd(3)
   h['delPOverPz_proj'] = h['delPOverPz'].ProjectionY()
   ROOT.gStyle.SetOptFit(11111)
   h['delPOverPz_proj'].Draw()
   h['delPOverPz_proj'].Fit('gaus')
   cv = h['fitresults'].cd(4)
   h['delPOverP2z_proj'] = h['delPOverP2z'].ProjectionY()
   h['delPOverP2z_proj'].Draw()
   fitSingleGauss('delPOverP2z_proj')
   h['fitresults'].Print('fitresults.gif')
   ut.bookCanvas(h,key='fitresults2',title='Fit Results',nx=1600,ny=1200,cx=2,cy=2)
   print 'finished with first canvas'
   cv = h['fitresults2'].cd(1)
   h['Doca'].Draw()
   cv = h['fitresults2'].cd(2)
   h['IP0'].Draw()
   cv = h['fitresults2'].cd(3)
   h['Mass'].Draw()
   fitSingleGauss('Mass')
   cv = h['fitresults2'].cd(4)
   h['IP0/mass'].Draw('box')
   h['fitresults2'].Print('fitresults2.gif')
   print 'finished making plots'

# start event loop
def myEventLoop(N):
    nEvents = min(sTree.GetEntries(),N)
    for n in range(nEvents): 
        rc = sTree.GetEntry(n)
        wg = sTree.MCTrack[0].GetWeight()
        if not wg>0.: wg=1.
        ## make some straw hit analysis
        #hitlist = {}
        #for ahit in sTree.strawtubesPoint:
        #    detID = ahit.GetDetectorID()
        #    top = ROOT.TVector3()
        #    bot = ROOT.TVector3()
        #    modules["Strawtubes"].StrawEndPoints(detID,bot,top)
        #    dw  = ahit.dist2Wire()
        #    if abs(top.y())==abs(bot.y()): h['disty'].Fill(dw)
        #    if abs(top.y())>abs(bot.y()): h['distu'].Fill(dw)
        #    if abs(top.y())<abs(bot.y()): h['distv'].Fill(dw)
        #    trID = ahit.GetTrackID()
        #    if not trID < 0 :
        #        if hitlist.has_key(trID):  hitlist[trID]+=1
        #        else:  hitlist[trID]=1
        #for tr in hitlist:  h['meanhits'].Fill(hitlist[tr])
        key = -1
        fittedTracks = {}
        for atrack in sTree.FitTracks:
            key+=1
            fitStatus   = atrack.getFitStatus()
            nmeas = atrack.getNumPoints()
            h['meas'].Fill(nmeas)
            if not fitStatus.isFitConverged() : continue
            fittedTracks[key] = atrack
            # needs different study why fit has not converged, continue with fitted tracks
            chi2        = fitStatus.getChi2()/nmeas
            fittedState = atrack.getFittedState()
            h['chi2'].Fill(chi2,wg)
            h['measVSchi2'].Fill(atrack.getNumPoints(),chi2)
            P = fittedState.getMomMag()
            Pz = fittedState.getMom().z()
            Pt = ROOT.TMath.Sqrt(P*P - Pz*Pz)
            h['Pt'].Fill(Pt)
            mcPartKey = sTree.fitTrack2MC[key]
            mcPart    = sTree.MCTrack[mcPartKey]
            if not mcPart : continue
            Ptruth    = mcPart.GetP()
            Ptruthz   = mcPart.GetPz()
            Ptrutht   = ROOT.TMath.Sqrt(Ptruth*Ptruth - Ptruthz*Ptruthz)
            delPtOverPt = (Ptrutht - Pt)/Ptrutht
            delPtOverP = (Ptrutht - Pt)/Ptruth
            delPOverP = (Ptruth - P)/Ptruth
            h['delPOverP'].Fill(Ptruth,delPOverP)
            h['delPtOverP'].Fill(Ptruth,delPtOverP)
            h['delPtOverPt'].Fill(Ptrutht,delPtOverPt)
            delPOverPz = (1./Ptruthz - 1./Pz) * Ptruthz
            h['delPOverPz'].Fill(Ptruthz,delPOverPz)
            if chi2>25: continue
            h['delPOverP2'].Fill(Ptruth,delPOverP)
            h['delPOverP2z'].Fill(Ptruth,delPOverPz)
            # try measure impact parameter
            trackDir = fittedState.getDir()
            trackPos = fittedState.getPos()
            vx = ROOT.TVector3()
            mcPart.GetStartVertex(vx)
            t = 0
            for i in range(3):   t += trackDir(i)*(vx(i)-trackPos(i)) 
            dist = 0
            for i in range(3):   dist += (vx(i)-trackPos(i)-t*trackDir(i))**2
            dist = ROOT.TMath.Sqrt(dist)
            h['IP'].Fill(dist) 
        # loop over particles, 2-track combinations
        # From Thomas:
        # An object in sTree.Particles is made out of two fitted tracks.
        # Except of the mass assignment, and fake pattern recognition (only 
        # hits from same MCTrack are used), no other MC truth. No requirement
        # that the two tracks come from the same MCTrack ! GetDaughter allows to
        # get back to the underlying tracks, see ShipReco.py:
        #  particle = ROOT.TParticle(9900015,0,-1,-1,t1,t2,HNL,vx)
        for HNL in sTree.Particles:
            if signal_file and not (HNL.GetPdgCode() == theHNLcode): continue
            #if bg_file and (HNL.GetPdgCode() == theHNLcode): continue
            t1,t2 = HNL.GetDaughter(0),HNL.GetDaughter(1) 
            PosDir = {} 
            for tr in [t1,t2]:
                xx  = sTree.FitTracks[tr].getFittedState()
                PosDir[tr] = [xx.getPos(),xx.getDir()]
                h['CandidateDaughters-Pt'].Fill(xx.getMom().Pt())
            xv,yv,zv,doca = myVertex(t1,t2,PosDir)
            h['Candidate-DOCA'].Fill(doca) 
            h['Candidate-vtxx'].Fill(xv)
            h['Candidate-vtxy'].Fill(yv)
            h['Candidate-vtxz'].Fill(zv)
            #if  doca>5 : continue
            HNLPos = ROOT.TLorentzVector()
            HNL.ProductionVertex(HNLPos)
            HNLMom = ROOT.TLorentzVector()
            HNL.Momentum(HNLMom)
            tr = ROOT.TVector3(0,0,ShipGeo.target.z0)
            t = 0
            for i in range(3):   t += HNLMom(i)/HNLMom.P()*(tr(i)-HNLPos(i)) 
            dist = 0
            for i in range(3):   dist += (tr(i)-HNLPos(i)-t*HNLMom(i)/HNLMom.P())**2
            dist = ROOT.TMath.Sqrt(dist)
            h['Candidate-IP0'].Fill(dist)  
            h['Candidate-IP0/mass'].Fill(HNLMom.M(),dist)
            h['Candidate-Mass'].Fill(HNLMom.M())
            h['Candidate-Pt'].Fill(HNLMom.Pt())

elenaTree = ROOT.TTree('ShipAna','ShipAna')
elenaTree, DaughtersPt = tools.AddVect(elenaTree, 'DaughtersPt', 'float')
elenaTree, DaughtersChi2 = tools.AddVect(elenaTree, 'DaughtersChi2', 'float')
elenaTree, DaughtersNPoints = tools.AddVect(elenaTree, 'DaughtersNPoints', 'int')
elenaTree, DaughtersTruthProdX = tools.AddVect(elenaTree, 'DaughtersTruthProdX', 'float')
elenaTree, DaughtersTruthProdY = tools.AddVect(elenaTree, 'DaughtersTruthProdY', 'float')
elenaTree, DaughtersTruthProdZ = tools.AddVect(elenaTree, 'DaughtersTruthProdZ', 'float')
elenaTree, DaughtersTruthPDG = tools.AddVect(elenaTree, 'DaughtersTruthPDG', 'int')
elenaTree, DaughtersTruthMotherPDG = tools.AddVect(elenaTree, 'DaughtersTruthMotherPDG', 'int')
elenaTree, DaughtersFitConverged = tools.AddVect(elenaTree, 'DaughtersFitConverged', 'int')
elenaTree, straw_x = tools.AddVect(elenaTree, 'straw_x', 'float')
elenaTree, straw_y = tools.AddVect(elenaTree, 'straw_y', 'float')
elenaTree, straw_z = tools.AddVect(elenaTree, 'straw_z', 'float')
elenaTree, muon_x = tools.AddVect(elenaTree, 'muon_x', 'float')
elenaTree, muon_y = tools.AddVect(elenaTree, 'muon_y', 'float')
elenaTree, muon_z = tools.AddVect(elenaTree, 'muon_z', 'float')
elenaTree, ecal_x = tools.AddVect(elenaTree, 'ecal_x', 'float')
elenaTree, ecal_y = tools.AddVect(elenaTree, 'ecal_y', 'float')
elenaTree, ecal_z = tools.AddVect(elenaTree, 'ecal_z', 'float')
elenaTree, hcal_x = tools.AddVect(elenaTree, 'hcal_x', 'float')
elenaTree, hcal_y = tools.AddVect(elenaTree, 'hcal_y', 'float')
elenaTree, hcal_z = tools.AddVect(elenaTree, 'hcal_z', 'float')
elenaTree, veto5_x = tools.AddVect(elenaTree, 'veto5_x', 'float')
elenaTree, veto5_y = tools.AddVect(elenaTree, 'veto5_y', 'float')
elenaTree, veto5_z = tools.AddVect(elenaTree, 'veto5_z', 'float')
elenaTree, liquidscint_x = tools.AddVect(elenaTree, 'liquidscint_x', 'float')
elenaTree, liquidscint_y = tools.AddVect(elenaTree, 'liquidscint_y', 'float')
elenaTree, liquidscint_z = tools.AddVect(elenaTree, 'liquidscint_z', 'float')
elenaTree, DOCA = tools.AddVar(elenaTree, 'DOCA', 'float')
elenaTree, vtxx = tools.AddVar(elenaTree, 'vtxx', 'float')
elenaTree, vtxy = tools.AddVar(elenaTree, 'vtxy', 'float')
elenaTree, vtxz = tools.AddVar(elenaTree, 'vtxz', 'float')
elenaTree, IP0 = tools.AddVar(elenaTree, 'IP0', 'float')
elenaTree, Mass = tools.AddVar(elenaTree, 'Mass', 'float')
elenaTree, Pt = tools.AddVar(elenaTree, 'Pt', 'float')
elenaTree, P = tools.AddVar(elenaTree, 'P', 'float')
elenaTree, NParticles = tools.AddVar(elenaTree, 'NParticles', 'int')
elenaTree, HNLw = tools.AddVar(elenaTree, 'HNLw', 'float')
elenaTree, NuWeight = tools.AddVar(elenaTree, 'NuWeight', 'float')
elenaTree, EventNumber = tools.AddVar(elenaTree, 'EventNumber', 'int')
elenaTree, DaughterMinPt = tools.AddVar(elenaTree, 'DaughterMinPt', 'float')
elenaTree, DaughterMinP = tools.AddVar(elenaTree, 'DaughterMinP', 'float')
elenaTree, DaughtersAlwaysIn = tools.AddVar(elenaTree, 'DaughtersAlwaysIn', 'int')
elenaTree, BadTruthVtx = tools.AddVar(elenaTree, 'BadTruthVtx', 'int')


def extrapolateFitTrackToPosition(FitTrack, NewPosition):
    """ takes a FitTrack and a TVector3 """
    fittedState = FitTrack.getFittedState()
    pdg = fittedState.getPDG()
    mom = fittedState.getMom()
    pos = fittedState.getPos()
    #print 'current pos: ', pos.X(), pos.Y(), pos.Z()
    rep = ROOT.genfit.RKTrackRep(pdg)
    state = ROOT.genfit.StateOnPlane(rep)
    rep.setPosMom(state, pos, mom)
    origPlane = state.getPlane()
    origState = ROOT.genfit.StateOnPlane(state)
    #rep.extrapolateToPoint(state, NewPosition, False)
    parallelToZ = ROOT.TVector3(0., 0., 1.)
    #extrapPlane = ROOT.genfit.DetPlane( NewPosition, parallelToZ )
    #rep.extrapolateToPlane(state, extrapPlane)
    rep.extrapolateToPlane(state, NewPosition, parallelToZ, False, False)
    newState = ROOT.genfit.StateOnPlane(state)
    endPos = rep.getPos(newState)
    #print 'extr state: ', endPos.X(), endPos.Y(), endPos.Z()
    #print 'req position: ', NewPosition.X(), NewPosition.Y(), NewPosition.Z()
    return endPos

def hasStrawStations(event, trackId, listOfWantedStraws):
    ok = [False]*len(listOfWantedStraws)
    positions = [ (nodes[det]['z']['pos'] - nodes[det]['z']['dim'], nodes[det]['z']['pos'] + nodes[det]['z']['dim'] ) for det in listOfWantedStraws ]
    for hit in event.strawtubesPoint:
        if hit.GetTrackID() == trackId:
            for (i,det) in enumerate(listOfWantedStraws):
                if (positions[i][0] < hit.GetZ() < positions[i][1]) and tools.inEllipse(hit.GetX(), hit.GetY(), 250., 500.):
                    ok[i] = True
    return bool(reduce(mul, ok, 1))

def hasGoodStrawStations(event, trackId):
    #ok = [False]*2
    okupstream = [False]*2
    okdownstream = [False]*2
    upstream = [ (nodes[det]['z']['pos'] - nodes[det]['z']['dim'], nodes[det]['z']['pos'] + nodes[det]['z']['dim'] ) for det in ['Tr1_1', 'Tr2_2'] ]
    downstream = [ (nodes[det]['z']['pos'] - nodes[det]['z']['dim'], nodes[det]['z']['pos'] + nodes[det]['z']['dim'] ) for det in ['Tr3_3', 'Tr4_4'] ]
    for hit in event.strawtubesPoint:
        if hit.GetTrackID() == trackId:
            for i in xrange(2):
                if (upstream[i][0] < hit.GetZ() < upstream[i][1]) and tools.inEllipse(hit.GetX(), hit.GetY(), 250., 500.):
                    okupstream[i] = True
                if (downstream[i][0] < hit.GetZ() < downstream[i][1]) and tools.inEllipse(hit.GetX(), hit.GetY(), 250., 500.):
                    okdownstream[i] = True
    ok = [ bool(reduce(mul, l, 1)) for l in [okupstream, okdownstream] ]
    return bool(reduce(add, ok, 0))

def findHNLvertex(event):
    for t in event.MCTrack:
        if t.GetMotherId() == 1:
            return t.GetStartZ()
    return False

def hasMuons(event):
    m1 = 0
    m2 = 0
    m3 = 0
    m4 = 0
    for ahit in event.muonPoint:
        detID = ahit.GetDetectorID()
        if(detID == 476) :
            m1 += 1
        if(detID == 477) :
            m2 += 1
        if(detID == 478) :
            m3 += 1
        if(detID == 479) :
            m4 += 1
    return [bool(m1), bool(m2), bool(m3), bool(m4)]

def elenaEventLoop(N):
    entries = sTree.GetEntries()
    nEvents = min(entries,N)
    num_bad_z = 0
    for n in range(nEvents): 
        rc = sTree.GetEntry(n)
        tools.PutToZero(EventNumber); tools.Push(EventNumber, n)
        key = -1
        # loop over particles, 2-track combinations
        # From Thomas:
        # An object in sTree.Particles is made out of two fitted tracks.
        # Except of the mass assignment, and fake pattern recognition (only 
        # hits from same MCTrack are used), no other MC truth. No requirement
        # that the two tracks come from the same MCTrack ! GetDaughter allows to
        # get back to the underlying tracks, see ShipReco.py:
        #  particle = ROOT.TParticle(9900015,0,-1,-1,t1,t2,HNL,vx)
        nu = sTree.MCTrack[0]
        if len(sTree.MCTrack) == 1 and len(sTree.Particles) > 0:
            print "There is an error somewhere! Particles out of nothing..."
            sys.exit()
        if len(sTree.MCTrack) < 2:
            print 1, n
            continue
        nu_daughter = sTree.MCTrack[1]
        if not isinstance(nu_daughter, ROOT.ShipMCTrack):
            print 2, n
            continue
        tools.PutToZero(BadTruthVtx)
        z_hnl_vtx = findHNLvertex(sTree)
        if not z_hnl_vtx:
            if not (bg_file or cosmics_file):
                print 'ERROR: hnl vertex not found!'
            ii = 0
            for g in sTree.MCTrack:
                ii +=1
            if signal_file and ii < 3:
                continue
            elif signal_file and ii >= 3:
                sys.exit()
            #if bg_file or cosmics_file:
            #    continue
        bad_z = True
        if nodes['Veto_5']['z']['pos']-nodes['Veto_5']['z']['dim']-500. < z_hnl_vtx < nodes['Tr4_4']['z']['pos']+nodes['Tr4_4']['z']['dim']:
            bad_z = False
        else:
            num_bad_z += 1
            if not (bg_file or cosmics_file):
                print z_hnl_vtx
        tools.Push(BadTruthVtx, int(bad_z))
        nu_x = nu_daughter.GetStartX(); nu_y = nu_daughter.GetStartY(); nu_z = nu_daughter.GetStartZ()
        nu_energy = nu.GetEnergy()
        if not cosmics_file:
            if oldSW:
                nu_w = tools.calcWeightOldNtuple(nu_x,nu_y,nu_z, nu_energy, nodes, entries, weightHist, file_type)
            else:
                nu_w = tools.calcWeight(nu_energy, nu.GetWeight(), entries, PDG.GetParticle(nu.GetPdgCode()).GetName(), weightHist)
            tools.PutToZero(NuWeight); tools.Push(NuWeight, nu_w)
        np = 0
        for HNL in sTree.Particles:
            np += 1
        tools.PutToZero(NParticles); tools.Push(NParticles, np)
        for HNL in sTree.Particles:
            if signal_file and not (HNL.GetPdgCode() == theHNLcode):
                print 3, n
                continue
            #if bg_file and (HNL.GetPdgCode() == theHNLcode): continue
            # Fill hit arrays
            tools.PutToZero(straw_x); tools.PutToZero(straw_y); tools.PutToZero(straw_z)
            tools.PutToZero(veto5_x); tools.PutToZero(veto5_y); tools.PutToZero(veto5_z)
            tools.PutToZero(muon_x); tools.PutToZero(muon_y); tools.PutToZero(muon_z)
            tools.PutToZero(ecal_x); tools.PutToZero(ecal_y); tools.PutToZero(ecal_z)
            tools.PutToZero(hcal_x); tools.PutToZero(hcal_y); tools.PutToZero(hcal_z)
            tools.PutToZero(liquidscint_x); tools.PutToZero(liquidscint_y); tools.PutToZero(liquidscint_z)
            hasStraws, nothing   = tools.wasFired(None, sTree.strawtubesPoint, trackStationsPos, pointsVects=[straw_x, straw_y, straw_z], Ethr=0.)
            hasVeto5, nothing    = tools.wasFired(None, sTree.strawtubesPoint, strawVetoPos, pointsVects=[veto5_x, veto5_y, veto5_z], Ethr=0.)
            hasMuon, nothing     = tools.wasFired(None, sTree.muonPoint, muonStationsPos, pointsVects=[muon_x, muon_y, muon_z], Ethr=0.)
            hasEcal, ecalOverThr = tools.wasFired(None, sTree.EcalPoint, ecalPos, pointsVects=[ecal_x, ecal_y, ecal_z], Ethr=0.015)
            if not noHcal:
                hasHcal, hcalOverThr = tools.wasFired(None, sTree.HcalPoint, hcalPos, pointsVects=[hcal_x, hcal_y, hcal_z], Ethr=0.015)
            hasliquidscint, liquidscintOverThr = tools.wasFired(None, sTree.vetoPoint, vetoWall, pointsVects=[liquidscint_x, liquidscint_y, liquidscint_z], Ethr=0.015)
            # get "daughters"
            t1,t2 = HNL.GetDaughter(0),HNL.GetDaughter(1) 
            PosDir = {} 
            tools.PutToZero(DaughtersFitConverged)
            tools.PutToZero(DaughtersChi2)
            tools.PutToZero(DaughtersPt)
            tools.PutToZero(DaughtersNPoints)
            w = sTree.MCTrack[1].GetWeight()
            if cosmics_file:
                w = sTree.MCTrack[0].GetWeight()
            tools.PutToZero(HNLw); tools.Push(HNLw, w)
            tools.PutToZero(DaughtersTruthPDG); tools.PutToZero(DaughtersTruthMotherPDG)
            tools.PutToZero(DaughtersTruthProdX); tools.PutToZero(DaughtersTruthProdY); tools.PutToZero(DaughtersTruthProdZ); 
            p, pt = [], []
            accFlag = True
            for tr in [t1,t2]:
                converged = 0
                x = sTree.FitTracks[tr]
                # Check if the tracks are in acceptance before , at and after the magnet
                #pos1 = extrapolateFitTrackToPosition(x, ROOT.TVector3(0.,0., nodes['Tr1_1']['z']['pos']))
                #print 'z mid ', (nodes['Tr4_4']['z']['pos'] - nodes['Tr1_1']['z']['pos'] )/2.
                #print x.getFittedState().getPos().Z()
                #try:
                #    pos2 = extrapolateFitTrackToPosition(x, ROOT.TVector3(0.,0., (nodes['Tr4_4']['z']['pos'] - nodes['Tr1_1']['z']['pos'] )/2. ))
                #except:
                #    print n, tr
                #    pos2 = extrapolateFitTrackToPosition(x, ROOT.TVector3(0.,0., nodes['Tr1_1']['z']['pos'] + 39000.))
                #pos3 = extrapolateFitTrackToPosition(x, ROOT.TVector3(0.,0., nodes['Tr4_4']['z']['pos']))
                ## REMEMBER: HERE IS SOMETHING HARDCODED!!!!!!!!!!!!!!!!!!!!!!!!!!!!1
                #a = 250.
                #b = 500.
                #if not (tools.pointInEllipse(pos1,a,b) and tools.pointInEllipse(pos2,a,b) and tools.pointInEllipse(pos3,a,b)):
                mctrid = sTree.fitTrack2MC[tr]
                if not hasGoodStrawStations(sTree, mctrid):#, ['Tr1_1', 'Tr4_4']):
                    accFlag = False
                xx  = x.getFittedState()
                if x.getFitStatus().isFitConverged():
                    converged = 1
                tools.Push(DaughtersFitConverged, converged)
                PosDir[tr] = [xx.getPos(),xx.getDir()]
                h['CandidateDaughters-Pt'].Fill(xx.getMom().Pt())
                tools.Push(DaughtersPt, xx.getMom().Pt())
                p.append(xx.getMom().Mag()); pt.append(xx.getMom().Pt())
                h['CandidateDaughters-chi2'].Fill(x.getFitStatus().getChi2() / x.getNumPoints())
                tools.Push(DaughtersChi2, x.getFitStatus().getChi2())
                #print x.getFitStatus().getChi2(), x.getFitStatus().getNdf()
                tools.Push(DaughtersNPoints, int(round(x.getFitStatus().getNdf())))#x.getNumPoints())
                pdg, mumPdg, truthX, truthY, truthZ = tools.retrieveMCParticleInfo(sTree, tr)
                tools.Push(DaughtersTruthPDG, pdg); tools.Push(DaughtersTruthMotherPDG, mumPdg)
                tools.Push(DaughtersTruthProdX, truthX); tools.Push(DaughtersTruthProdY, truthY); tools.Push(DaughtersTruthProdZ, truthZ); 
            tools.Push(DaughterMinP, min(p))
            tools.Push(DaughterMinPt, min(pt))
            tools.PutToZero(DaughtersAlwaysIn); tools.Push(DaughtersAlwaysIn, int(accFlag))
            xv,yv,zv,doca = myVertex(t1,t2,PosDir)
            h['Candidate-DOCA'].Fill(doca) 
            tools.PutToZero(DOCA); tools.Push(DOCA, doca)
            h['Candidate-vtxx'].Fill(xv)
            tools.PutToZero(vtxx); tools.Push(vtxx, xv)
            h['Candidate-vtxy'].Fill(yv)
            tools.PutToZero(vtxy); tools.Push(vtxy, yv)
            h['Candidate-vtxz'].Fill(zv)
            tools.PutToZero(vtxz); tools.Push(vtxz, zv)
            #if  doca>5 : continue
            HNLPos = ROOT.TLorentzVector()
            HNL.ProductionVertex(HNLPos)
            HNLMom = ROOT.TLorentzVector()
            HNL.Momentum(HNLMom)
            tr = ROOT.TVector3(0,0,ShipGeo.target.z0)
            t = 0
            for i in range(3):   t += HNLMom(i)/HNLMom.P()*(tr(i)-HNLPos(i)) 
            dist = 0
            for i in range(3):   dist += (tr(i)-HNLPos(i)-t*HNLMom(i)/HNLMom.P())**2
            dist = ROOT.TMath.Sqrt(dist)
            h['Candidate-IP0'].Fill(dist)  
            tools.PutToZero(IP0); tools.Push(IP0, dist)
            h['Candidate-IP0/mass'].Fill(HNLMom.M(),dist)
            h['Candidate-Mass'].Fill(HNLMom.M())
            tools.PutToZero(Mass); tools.Push(Mass, HNLMom.M())
            h['Candidate-Pt'].Fill(HNLMom.Pt())
            tools.PutToZero(Pt); tools.Push(Pt, HNLMom.Pt())
            tools.PutToZero(P); tools.Push(P, HNLMom.P())
            elenaTree.SetDirectory(0)
            elenaTree.Fill()
            #if n in [117149, 123463]:
            #    print doca, dist, Pt, DaughtersTruthPDG[0], DaughtersTruthPDG[1], DaughtersTruthProdZ[0], DaughtersTruthProdZ[1], zv, xv, yv, DaughtersFitConverged[0], DaughtersFitConverged[1]
    print
    print 'Number of HNLs with a bad vertex: %s out of %s'%(num_bad_z, entries)


def HNLKinematics():
 ut.bookHist(h,'HNLmomNoW','momentum unweighted',100,0.,300.)
 ut.bookHist(h,'HNLmom','momentum',100,0.,300.)
 ut.bookHist(h,'HNLmom_recTracks','HNL momentum from reco tracks',100,0.,300.)
 ut.bookHist(h,'HNLmomNoW_recTracks','HNL momentum unweighted from reco tracks',100,0.,300.)
 for n in range(sTree.GetEntries()): 
  rc = sTree.GetEntry(n)
  wg = sTree.MCTrack[1].GetWeight()
  if not wg>0.: wg=1.
  P = sTree.MCTrack[1].GetP()
  h['HNLmom'].Fill(P,wg) 
  h['HNLmomNoW'].Fill(P) 
  for HNL in sTree.Particles:
     t1,t2 = HNL.GetDaughter(0),HNL.GetDaughter(1) 
     for tr in [t1,t2]:
      xx  = sTree.FitTracks[tr].getFittedState()
      Prec = xx.getMom().Mag()
      h['HNLmom_recTracks'].Fill(Prec,wg) 
      h['HNLmomNoW_recTracks'].Fill(Prec) 
#
def access2SmearedHits():
 key = 0
 for ahit in ev.SmearedHits.GetObject():
   print ahit[0],ahit[1],ahit[2],ahit[3],ahit[4],ahit[5],ahit[6]
   # follow link to true MCHit
   mchit   = TrackingHits[key]
   mctrack =  MCTracks[mchit.GetTrackID()]
   print mchit.GetZ(),mctrack.GetP(),mctrack.GetPdgCode()
   key+=1

#myEventLoop(nEvents)
elenaEventLoop(nEvents)
#makePlots()
#HNLKinematics()
# output histograms
ut.writeHists(h,outputFile)
ofile = ROOT.TFile(outputFile, "update")
elenaTree.Write()
ofile.Close()
weightHistFile.Close()
print "\tOutput saved to ", outputFile