Newer
Older
FairShipTools / newGen / offlineForBarbara_maybeOld.py
from lookAtGeo import *
import tools
import shipunit as u
from ShipGeoConfig import ConfigRegistry
import shipDet_conf

from operator import mul, add

import sys
sys.path.append('KaterinaLight/')
from StrawHits import StrawHits
## Use it like:
# f = TFile(fileName)
# t = f.Get("cbmsim")
# sh = offline.StrawHits(t, offline.shipDet_conf.configure(offline.__run, r['ShipGeo']), r['ShipGeo'].straw.resol, 0, None, r['ShipGeo'])
# t.GetEntry(58)
# sh.readEvent()
# sh.FitTracks()

#dy = 10.
# init geometry and mag. field
#ShipGeo = ConfigRegistry.loadpy("$FAIRSHIP/geometry/geometry_config.py", Yheight = dy )


def searchForNodes3_xyz_dict(fGeo, verbose=False):
    from tools import findPositionElement, findDimentionBoxElement, findPositionElement2
    d = {}
    #r = loadGeometry(inputFile)
    #fGeo = r['fGeo']
    ## Get the top volume
    #fGeo = ROOT.gGeoManager
    tv = fGeo.GetTopVolume()
    topnodes = tv.GetNodes()
    for (j,topn) in enumerate(topnodes):
        # top volumes
        if verbose:
            print j, topn.GetName()
            print "            x: ", findPositionElement(topn)['x'],findDimentionBoxElement(topn)['x']
            print "            y: ", findPositionElement(topn)['y'],findDimentionBoxElement(topn)['y']
            print "            z: ", findPositionElement(topn)['z'],findDimentionBoxElement(topn)['z']
        d[topn.GetName()] = {'x': {}, 'y':{}, 'z':{}, 'r':{}}
        d[topn.GetName()]['x']['pos'] =     findPositionElement(topn)['x']
        d[topn.GetName()]['x']['dim'] = findDimentionBoxElement(topn)['x']
        d[topn.GetName()]['y']['pos'] =     findPositionElement(topn)['y']
        d[topn.GetName()]['y']['dim'] = findDimentionBoxElement(topn)['y']
        d[topn.GetName()]['z']['pos'] =     findPositionElement(topn)['z']
        d[topn.GetName()]['z']['dim'] = findDimentionBoxElement(topn)['z']
        if topn.GetVolume().GetShape().IsCylType():
            d[topn.GetName()]['r']['pos'] =     findPositionElement(topn)['r']
            d[topn.GetName()]['r']['dim'] = findDimentionBoxElement(topn)['r']
        else:
            d[topn.GetName()]['r']['pos'] = 0.
            d[topn.GetName()]['r']['dim'] = 0.
        # First children
        nodes = topn.GetNodes()
        if nodes:
            topPos = topn.GetMatrix().GetTranslation()
            for (i,n) in enumerate(nodes):
                if verbose:
                    print j, topn.GetName(), i, n.GetName()
                    print "            x: ", findPositionElement2(n,topPos)['x'],findDimentionBoxElement(n)['x']
                    print "            y: ", findPositionElement2(n,topPos)['y'],findDimentionBoxElement(n)['y']
                    print "            z: ", findPositionElement2(n,topPos)['z'],findDimentionBoxElement(n)['z']
                d[n.GetName()] = {'x': {}, 'y':{}, 'z':{}, 'r':{}}
                d[n.GetName()]['x']['pos'] =     findPositionElement2(n,topPos)['x']
                d[n.GetName()]['x']['dim'] = findDimentionBoxElement(n)['x']
                d[n.GetName()]['y']['pos'] =     findPositionElement2(n,topPos)['y']
                d[n.GetName()]['y']['dim'] = findDimentionBoxElement(n)['y']
                d[n.GetName()]['z']['pos'] =     findPositionElement2(n,topPos)['z']
                d[n.GetName()]['z']['dim'] = findDimentionBoxElement(n)['z']
                if n.GetVolume().GetShape().IsCylType():
                    d[n.GetName()]['r']['pos'] =     findPositionElement2(n,topPos)['r']
                    d[n.GetName()]['r']['dim'] = findDimentionBoxElement(n)['r']
                else:
                    d[n.GetName()]['r']['pos'] = 0.
                    d[n.GetName()]['r']['dim'] = 0.
                # Second children
                cnodes = n.GetNodes()
                if cnodes:
                    localpos = n.GetMatrix().GetTranslation()
                    localToGlobal = []
                    for i in xrange(3):
                        localToGlobal.append(localpos[i]+topPos[i])
                    for (k,cn) in enumerate(cnodes):
                        if verbose:
                            print j, topn.GetName(), i, n.GetName(), k, cn.GetName()
                            print "            x: ", findPositionElement2(cn,localToGlobal)['x'],findDimentionBoxElement(cn)['x']
                            print "            y: ", findPositionElement2(cn,localToGlobal)['y'],findDimentionBoxElement(cn)['y']
                            print "            z: ", findPositionElement2(cn,localToGlobal)['z'],findDimentionBoxElement(cn)['z']
                        d[cn.GetName()] = {'x': {}, 'y':{}, 'z':{}, 'r':{}}
                        d[cn.GetName()]['x']['pos'] =     findPositionElement2(cn,localToGlobal)['x']
                        d[cn.GetName()]['x']['dim'] = findDimentionBoxElement(cn)['x']
                        d[cn.GetName()]['y']['pos'] =     findPositionElement2(cn,localToGlobal)['y']
                        d[cn.GetName()]['y']['dim'] = findDimentionBoxElement(cn)['y']
                        d[cn.GetName()]['z']['pos'] =     findPositionElement2(cn,localToGlobal)['z']
                        d[cn.GetName()]['z']['dim'] = findDimentionBoxElement(cn)['z']
                        if cn.GetVolume().GetShape().IsCylType():
                            d[cn.GetName()]['r']['pos'] =     findPositionElement2(cn,localToGlobal)['r']
                            d[cn.GetName()]['r']['dim'] = findDimentionBoxElement(cn)['r']
                        else:
                            d[cn.GetName()]['r']['pos'] = 0.
                            d[cn.GetName()]['r']['dim'] = 0.
    return d


ff_nu = ROOT.TFile("histoForWeights_nu.root")
h_GioHans_nu = ff_nu.Get("h_Gio")

ff_antinu = ROOT.TFile("histoForWeights_antinu.root")
h_GioHans_antinu = ff_antinu.Get("h_Gio")

def calcWeightNu(NC, E, w, entries, nuName, ON=True):
    # Only for neutrinos and antineutrinos
    if not ON:
        return 1
    if "bar" in nuName:
        reduction = 0.5
        flux = 1.#6.98e+11 * 2.e+20 / 5.e+13
        h_GioHans = h_GioHans_antinu
    else: 
        reduction = 1.
        flux = 1.#1.09e+12 * 2.e+20/ 5.e+13
        h_GioHans = h_GioHans_nu
        
    crossSec = 6.7e-39*E * reduction
    NA = 6.022e+23
    binN = h_GioHans.GetXaxis().FindBin(E)    
    return crossSec * flux * h_GioHans.GetBinContent(binN) * w * NA #/ entries


def findWeight(sampleType, NC, E, MCTrack, entries, nuName, ON):
    if sampleType == 'nuBg':
        return calcWeightNu(NC, E, MCTrack.GetWeight(), entries, nuName, ON)
    elif sampleType == 'sig':
        return MCTrack.GetWeight() # for the acceptance, multiply by normalization
    elif sampleType == 'cosmics':
        return MCTrack.GetWeight() # multiply by 1.e6 / 200.



def hasStrawStations(event, trackId, listOfWantedStraws):
    ok = [False]*len(listOfWantedStraws)
    positions = [ (nodes[det]['z']['pos'] - nodes[det]['z']['dim'], nodes[det]['z']['pos'] + nodes[det]['z']['dim'] ) for det in listOfWantedStraws ]
    for hit in event.strawtubesPoint:
        if hit.GetTrackID() == trackId:
            for (i,det) in enumerate(listOfWantedStraws):
                if (positions[i][0] < hit.GetZ() < positions[i][1]) and tools.inEllipse(hit.GetX(), hit.GetY(), 250., 500.):
                    ok[i] = True
    return bool(reduce(mul, ok, 1))

def hasGoodStrawStations(event, trackId):
    #ok = [False]*2
    okupstream = [False]*2
    okdownstream = [False]*2
    upstream = [ (nodes[det]['z']['pos'] - nodes[det]['z']['dim'], nodes[det]['z']['pos'] + nodes[det]['z']['dim'] ) for det in ['Tr1_1', 'Tr2_2'] ]
    downstream = [ (nodes[det]['z']['pos'] - nodes[det]['z']['dim'], nodes[det]['z']['pos'] + nodes[det]['z']['dim'] ) for det in ['Tr3_3', 'Tr4_4'] ]
    for hit in event.strawtubesPoint:
        if hit.GetTrackID() == trackId:
            for i in xrange(2):
                if (upstream[i][0] < hit.GetZ() < upstream[i][1]) and tools.inEllipse(hit.GetX(), hit.GetY(), 250., 500.):
                    okupstream[i] = True
                if (downstream[i][0] < hit.GetZ() < downstream[i][1]) and tools.inEllipse(hit.GetX(), hit.GetY(), 250., 500.):
                    okdownstream[i] = True
    ok = [ bool(reduce(mul, l, 1)) for l in [okupstream, okdownstream] ]
    return bool(reduce(add, ok, 0))

def findHNLvertex(event):
    for t in event.MCTrack:
        if t.GetMotherId() == 1:
            return t.GetStartZ()
    return False

def has_muon_station(event, trackId, station):
    zIn = nodes['muondet%s_1'%(station-1)]['z']['pos'] - nodes['muondet%s_1'%(station-1)]['z']['dim']
    zOut = nodes['muondet%s_1'%(station-1)]['z']['pos'] + nodes['muondet%s_1'%(station-1)]['z']['dim']
    for hit in event.muonPoint:
        if hit.GetTrackID() == trackId:
            if zIn <= hit.GetZ() <= zOut:
                return True
    return False

def hasEcalDeposit(event, trackId, ELossThreshold):
    ELoss = 0.
    for hit in event.EcalPoint:
        if hit.GetTrackID() == trackId:
            ELoss += hit.GetEnergyLoss()
    if ELoss >= ELossThreshold:
        return True
    return False

def hasMuons(event, trackId):
    m1 = 0
    m2 = 0
    m3 = 0
    m4 = 0
    for ahit in event.muonPoint:
        if ahit.GetTrackID() == trackId:
            detID = ahit.GetDetectorID()
            if(detID == 476) :
                m1 += 1
            if(detID == 477) :
                m2 += 1
            if(detID == 478) :
                m3 += 1
            if(detID == 479) :
                m4 += 1
    return [bool(m1), bool(m2), bool(m3), bool(m4)]

def myVertex(t1,t2,PosDir):
    # closest distance between two tracks
    # d = |pq . u x v|/|u x v|
    a = ROOT.TVector3(PosDir[t1][0](0) ,PosDir[t1][0](1), PosDir[t1][0](2))
    u = ROOT.TVector3(PosDir[t1][1](0),PosDir[t1][1](1),PosDir[t1][1](2))
    c = ROOT.TVector3(PosDir[t2][0](0) ,PosDir[t2][0](1), PosDir[t2][0](2))
    v = ROOT.TVector3(PosDir[t2][1](0),PosDir[t2][1](1),PosDir[t2][1](2))
    pq = a-c
    uCrossv = u.Cross(v)
    dist  = pq.Dot(uCrossv)/(uCrossv.Mag()+1E-8)
    # u.a - u.c + s*|u|**2 - u.v*t    = 0
    # v.a - v.c + s*v.u    - t*|v|**2 = 0
    E = u.Dot(a) - u.Dot(c) 
    F = v.Dot(a) - v.Dot(c) 
    A,B = u.Mag2(), -u.Dot(v) 
    C,D = u.Dot(v), -v.Mag2()
    t = -(C*E-A*F)/(B*C-A*D)
    X = c.x()+v.x()*t
    Y = c.y()+v.y()*t
    Z = c.z()+v.z()*t
    # sT = ROOT.gROOT.FindAnything('cbmsim')
    #print 'test2 ',X,Y,Z,dist
    #print 'truth',sTree.MCTrack[2].GetStartX(),sTree.MCTrack[2].GetStartY(),sTree.MCTrack[2].GetStartZ()
    return X,Y,Z,abs(dist)

def addFullInfoToTree(elenaTree):
    elenaTree, DaughtersPt = tools.AddVect(elenaTree, 'DaughtersPt', 'float')
    elenaTree, DaughtersChi2 = tools.AddVect(elenaTree, 'DaughtersChi2', 'float')
    elenaTree, DaughtersNPoints = tools.AddVect(elenaTree, 'DaughtersNPoints', 'int')
    elenaTree, DaughtersTruthProdX = tools.AddVect(elenaTree, 'DaughtersTruthProdX', 'float')
    elenaTree, DaughtersTruthProdY = tools.AddVect(elenaTree, 'DaughtersTruthProdY', 'float')
    elenaTree, DaughtersTruthProdZ = tools.AddVect(elenaTree, 'DaughtersTruthProdZ', 'float')
    elenaTree, DaughtersTruthPDG = tools.AddVect(elenaTree, 'DaughtersTruthPDG', 'int')
    elenaTree, DaughtersTruthMotherPDG = tools.AddVect(elenaTree, 'DaughtersTruthMotherPDG', 'int')
    elenaTree, DaughtersFitConverged = tools.AddVect(elenaTree, 'DaughtersFitConverged', 'int')
    elenaTree, straw_x = tools.AddVect(elenaTree, 'straw_x', 'float')
    elenaTree, straw_y = tools.AddVect(elenaTree, 'straw_y', 'float')
    elenaTree, straw_z = tools.AddVect(elenaTree, 'straw_z', 'float')
    elenaTree, muon_x = tools.AddVect(elenaTree, 'muon_x', 'float')
    elenaTree, muon_y = tools.AddVect(elenaTree, 'muon_y', 'float')
    elenaTree, muon_z = tools.AddVect(elenaTree, 'muon_z', 'float')
    elenaTree, ecal_x = tools.AddVect(elenaTree, 'ecal_x', 'float')
    elenaTree, ecal_y = tools.AddVect(elenaTree, 'ecal_y', 'float')
    elenaTree, ecal_z = tools.AddVect(elenaTree, 'ecal_z', 'float')
    elenaTree, hcal_x = tools.AddVect(elenaTree, 'hcal_x', 'float')
    elenaTree, hcal_y = tools.AddVect(elenaTree, 'hcal_y', 'float')
    elenaTree, hcal_z = tools.AddVect(elenaTree, 'hcal_z', 'float')
    elenaTree, veto5_x = tools.AddVect(elenaTree, 'veto5_x', 'float')
    elenaTree, veto5_y = tools.AddVect(elenaTree, 'veto5_y', 'float')
    elenaTree, veto5_z = tools.AddVect(elenaTree, 'veto5_z', 'float')
    elenaTree, liquidscint_x = tools.AddVect(elenaTree, 'liquidscint_x', 'float')
    elenaTree, liquidscint_y = tools.AddVect(elenaTree, 'liquidscint_y', 'float')
    elenaTree, liquidscint_z = tools.AddVect(elenaTree, 'liquidscint_z', 'float')
    elenaTree, DOCA = tools.AddVar(elenaTree, 'DOCA', 'float')
    elenaTree, vtxx = tools.AddVar(elenaTree, 'vtxx', 'float')
    elenaTree, vtxy = tools.AddVar(elenaTree, 'vtxy', 'float')
    elenaTree, vtxz = tools.AddVar(elenaTree, 'vtxz', 'float')
    elenaTree, IP0 = tools.AddVar(elenaTree, 'IP0', 'float')
    elenaTree, Mass = tools.AddVar(elenaTree, 'Mass', 'float')
    elenaTree, Pt = tools.AddVar(elenaTree, 'Pt', 'float')
    elenaTree, P = tools.AddVar(elenaTree, 'P', 'float')
    elenaTree, NParticles = tools.AddVar(elenaTree, 'NParticles', 'int')
    elenaTree, HNLw = tools.AddVar(elenaTree, 'HNLw', 'float')
    elenaTree, NuWeight = tools.AddVar(elenaTree, 'NuWeight', 'float')
    elenaTree, EventNumber = tools.AddVar(elenaTree, 'EventNumber', 'int')
    elenaTree, DaughterMinPt = tools.AddVar(elenaTree, 'DaughterMinPt', 'float')
    elenaTree, DaughterMinP = tools.AddVar(elenaTree, 'DaughterMinP', 'float')
    elenaTree, DaughtersAlwaysIn = tools.AddVar(elenaTree, 'DaughtersAlwaysIn', 'int')
    elenaTree, BadTruthVtx = tools.AddVar(elenaTree, 'BadTruthVtx', 'int')

DaughtersFitConverged, DOCA, vtxx, vtxy, vtxz, IP0, HasEcal = None, None, None, None, None, None, None
NoB_DOCA, NoB_vtxx, NoB_vtxy, NoB_vtxz, NoB_IP0 = None, None, None, None, None
DaughtersAlwaysIn, BadTruthVtx, Has1Muon1, Has1Muon2, Has2Muon1, Has2Muon2 = None, None, None, None, None, None
MaxDaughtersRedChi2, MinDaughtersNdf = None, None
NoB_MaxDaughtersRedChi2, NoB_MinDaughtersNdf = None, None
DaughtersMinP, DaughtersMinPt, Mass, P, Pt = None, None, None, None, None
NoB_DaughtersMinP, NoB_DaughtersMinPt, NoB_Mass, NoB_P, NoB_Pt = None, None, None, None, None

def addOfflineToTree(elenaTree):
    global DaughtersFitConverged, DOCA, vtxx, vtxy, vtxz, IP0, HasEcal
    global NoB_DOCA, NoB_vtxx, NoB_vtxy, NoB_vtxz, NoB_IP0
    global DaughtersAlwaysIn, BadTruthVtx, Has1Muon1, Has1Muon2, Has2Muon1, Has2Muon2
    global MaxDaughtersRedChi2, MinDaughtersNdf, HNLw, NuWeight, NoB_MaxDaughtersRedChi2, NoB_MinDaughtersNdf
    global DaughtersMinP, DaughtersMinPt, Mass, P, Pt
    global NoB_DaughtersMinP, NoB_DaughtersMinPt, NoB_Mass, NoB_P, NoB_Pt
    elenaTree, DaughtersFitConverged = tools.AddVect(elenaTree, 'DaughtersFitConverged', 'int') #
    elenaTree, DOCA = tools.AddVect(elenaTree, 'DOCA', 'float') #
    elenaTree, NoB_DOCA = tools.AddVect(elenaTree, 'NoB_DOCA', 'float') #
    elenaTree, vtxx = tools.AddVect(elenaTree, 'vtxxSqr', 'float') #
    elenaTree, vtxy = tools.AddVect(elenaTree, 'vtxySqr', 'float') #
    elenaTree, vtxz = tools.AddVect(elenaTree, 'vtxz', 'float') #
    elenaTree, NoB_vtxx = tools.AddVect(elenaTree, 'NoB_vtxxSqr', 'float') #
    elenaTree, NoB_vtxy = tools.AddVect(elenaTree, 'NoB_vtxySqr', 'float') #
    elenaTree, NoB_vtxz = tools.AddVect(elenaTree, 'NoB_vtxz', 'float') #
    elenaTree, IP0 = tools.AddVect(elenaTree, 'IP0', 'float') #
    elenaTree, NoB_IP0 = tools.AddVect(elenaTree, 'NoB_IP0', 'float') #
    #elenaTree, NParticles = tools.AddVar(elenaTree, 'NParticles', 'int') #
    elenaTree, DaughtersAlwaysIn = tools.AddVect(elenaTree, 'DaughtersAlwaysIn', 'int') #
    elenaTree, BadTruthVtx = tools.AddVect(elenaTree, 'BadTruthVtx', 'int') #
    elenaTree, Has1Muon1 = tools.AddVect(elenaTree, 'Has1Muon1', 'int') #
    elenaTree, Has1Muon2 = tools.AddVect(elenaTree, 'Has1Muon2', 'int') #
    elenaTree, Has2Muon1 = tools.AddVect(elenaTree, 'Has2Muon1', 'int') #
    elenaTree, Has2Muon2 = tools.AddVect(elenaTree, 'Has2Muon2', 'int') #
    elenaTree, HasEcal = tools.AddVect(elenaTree, 'HasEcal', 'int') #
    elenaTree, MaxDaughtersRedChi2 = tools.AddVect(elenaTree, 'MaxDaughtersRedChi2', 'float') #
    elenaTree, MinDaughtersNdf = tools.AddVect(elenaTree, 'MinDaughtersNdf', 'int') #
    elenaTree, NoB_MaxDaughtersRedChi2 = tools.AddVect(elenaTree, 'NoB_MaxDaughtersRedChi2', 'float') #
    elenaTree, NoB_MinDaughtersNdf = tools.AddVect(elenaTree, 'NoB_MinDaughtersNdf', 'int') #
    elenaTree, DaughtersMinP = tools.AddVect(elenaTree, 'DaughtersMinP', 'float')
    elenaTree, DaughtersMinPt = tools.AddVect(elenaTree, 'DaughtersMinPt', 'float')
    elenaTree, P = tools.AddVect(elenaTree, 'P', 'float')
    elenaTree, Pt = tools.AddVect(elenaTree, 'Pt', 'float')
    elenaTree, Mass = tools.AddVect(elenaTree, 'Mass', 'float')
    elenaTree, NoB_DaughtersMinP = tools.AddVect(elenaTree, 'NoB_DaughtersMinP', 'float')
    elenaTree, NoB_DaughtersMinPt = tools.AddVect(elenaTree, 'NoB_DaughtersMinPt', 'float')
    elenaTree, NoB_P = tools.AddVect(elenaTree, 'NoB_P', 'float')
    elenaTree, NoB_Pt = tools.AddVect(elenaTree, 'NoB_Pt', 'float')
    elenaTree, NoB_Mass = tools.AddVect(elenaTree, 'NoB_Mass', 'float')
    # Add liquid scintillator segmentation
    tools.makeLSsegments(nodes, elenaTree)

nodes = None
def loadNodes(fGeo):
    global nodes
    nodes = searchForNodes3_xyz_dict(fGeo)

num_bad_z = 0

def signalNormalisationZ(tree, datatype, verbose):
    # By event
    # Uses MC truth!!
    global BadTruthVtx, num_bad_z
    tools.PutToZero(BadTruthVtx)
    z_hnl_vtx = findHNLvertex(tree)
    bad_z = False
    if not z_hnl_vtx:
        if "sig" in datatype:
            print 'ERROR: hnl vertex not found!'
        ii = 0
        for g in tree.MCTrack:
            ii +=1
        if ("sig" in datatype) and ii < 3:
            pass
        elif ("sig" in datatype) and ii >= 3:
            sys.exit()
    if not (nodes['Veto_5']['z']['pos']-nodes['Veto_5']['z']['dim']-500. < z_hnl_vtx < nodes['Tr4_4']['z']['pos']+nodes['Tr4_4']['z']['dim']):
        bad_z = True
        num_bad_z += 1
        if "sig" in datatype:
            print z_hnl_vtx
    tools.Push(BadTruthVtx, int(bad_z))

def nParticles(tree):
    # By event
    global NParticles
    np = 0
    for HNL in tree.Particles:
        np += 1
    tools.PutToZero(NParticles); tools.Push(NParticles, np)

def hasEcalAndMuons(tree, particle):
    # By particle
    global Has1Muon1, Has1Muon2, Has2Muon1
    global Has2Muon2, HasEcal
    flag2Muon1 = False
    flag2Muon2 = False
    flag1Muon1 = False
    flag1Muon2 = False
    flagEcal = False
    t1,t2 = tree.fitTrack2MC[particle.GetDaughter(0)], tree.fitTrack2MC[particle.GetDaughter(1)]
    # AND or OR?
    if ( has_muon_station(tree, t1, 1) and has_muon_station(tree, t2, 1) ):
        flag2Muon1 = True
    if ( has_muon_station(tree, t1, 2) and has_muon_station(tree, t2, 2) ):
        flag2Muon2 = True
    if ( has_muon_station(tree, t1, 1) or has_muon_station(tree, t2, 1) ):
        flag1Muon1 = True
    if ( has_muon_station(tree, t1, 2) or has_muon_station(tree, t2, 2) ):
        flag1Muon2 = True
    # This also work, but may be slower
    #muons1 = hasMuons(tree, t1)
    #muons2 = hasMuons(tree, t2)
    #if muons1[0] or muons2[0]: flag1Muon1 = True
    #if muons1[1] or muons2[1]: flag1Muon2 = True
    #if muons1[0] and muons2[0]: flag2Muon1 = True
    #if muons1[1] and muons2[1]: flag2Muon2 = True
    if ( hasEcalDeposit(tree, t1, 150.*u.MeV) or hasEcalDeposit(tree, t2, 150.*u.MeV)   ):
        flagEcal = True
    tools.Push(Has2Muon1, int(flag2Muon1))
    tools.Push(Has2Muon2, int(flag2Muon2))
    tools.Push(Has1Muon1, int(flag1Muon1))
    tools.Push(Has1Muon2, int(flag1Muon2))
    tools.Push(HasEcal, int(flagEcal))

def chi2Ndf(tree, particle, ntr, nref):
    # By particle
    global MaxDaughtersRedChi2, MinDaughtersNdf
    t1,t2 = particle.GetDaughter(0),particle.GetDaughter(1)
    if ntr>1 and nref==2:#nf>1
        t1r,t2r = sh.getReFitTrIDs()[0], sh.getReFitTrIDs()[1]
        chi2red_1 = sh.getReFitChi2Ndf(t1r)
        ndf_1 = int(round(sh.getReFitNdf(t1r)))
        chi2red_2 = sh.getReFitChi2Ndf(t2r)
        ndf_2 = int(round(sh.getReFitNdf(t2r)))
        reducedChi2 = [chi2red_1, chi2red_2]
        ndfs = [ndf_1, ndf_2]
    # if the refit didn't work
    if (ntr<2) or (nref!=2) or (not ndf_1) or (not ndf_2) or (not chi2red_1) or (not chi2red_2):
        reducedChi2 = []
        ndfs = []
        for tr in [t1,t2]:
            x = tree.FitTracks[tr]
            ndfs.append( int(round(x.getFitStatus().getNdf())) )
            reducedChi2.append( x.getFitStatus().getChi2()/x.getFitStatus().getNdf() )
    tools.Push(MaxDaughtersRedChi2, max(reducedChi2))
    tools.Push(MinDaughtersNdf, min(ndfs))

    
def NoB_chi2Ndf(tree, particle):
    # By particle
    global NoB_MaxDaughtersRedChi2, NoB_MinDaughtersNdf, DaughtersFitConverged
    t1,t2 = particle.GetDaughter(0),particle.GetDaughter(1)
    reducedChi2 = []
    ndfs = []
    converged = []
    for tr in [t1,t2]:
        x = tree.FitTracks[tr]
        ndfs.append( int(round(x.getFitStatus().getNdf())) )
        reducedChi2.append( x.getFitStatus().getChi2()/x.getFitStatus().getNdf() )
        converged.append( x.getFitStatus().isFitConverged() )
    tools.Push(NoB_MaxDaughtersRedChi2, max(reducedChi2))
    tools.Push(NoB_MinDaughtersNdf, min(ndfs))
    tools.Push( DaughtersFitConverged, int(converged[0]*converged[1]) )

def NoB_kinematics(tree, particle):
    global NoB_DaughtersMinP, NoB_DaughtersMinPt, NoB_P, NoB_Pt, NoB_Mass
    t1,t2 = particle.GetDaughter(0),particle.GetDaughter(1)
    dp, dpt = [], []
    for tr in [t1, t2]:
        x = tree.FitTracks[tr]
        xx  = x.getFittedState()
        dp.append(xx.getMom().Mag()); dpt.append(xx.getMom().Pt())
    tools.Push(NoB_DaughtersMinP, min(dp))
    tools.Push(NoB_DaughtersMinPt, min(dpt))
    HNLMom = ROOT.TLorentzVector()
    particle.Momentum(HNLMom)
    tools.Push(NoB_Mass, HNLMom.M())
    tools.Push(NoB_Pt, HNLMom.Pt())
    tools.Push(NoB_P, HNLMom.P())
    
def goodBehavedTracks(tree, particle):
    # By particle
    # Uses MC truth!!
    global DaughtersAlwaysIn
    t1,t2 = particle.GetDaughter(0),particle.GetDaughter(1) 
    accFlag = True
    for tr in [t1,t2]:
        mctrid = tree.fitTrack2MC[tr]
        if not hasGoodStrawStations(tree, mctrid):
            accFlag = False
    tools.Push(DaughtersAlwaysIn, int(accFlag))

def NoB_vertexInfo(tree, particle):
    # By particle
    global NoB_vtxx, NoB_vtxy, NoB_vtxz
    global NoB_IP0, NoB_DOCA
    HNLPos = ROOT.TLorentzVector()
    particle.ProductionVertex(HNLPos)
    xv, yv, zv, doca = HNLPos.X(),HNLPos.Y(),HNLPos.Z(),HNLPos.T()
    tools.Push(NoB_DOCA, doca)
    tools.Push(NoB_vtxx, xv*xv); tools.Push(NoB_vtxy, yv*yv); tools.Push(NoB_vtxz, zv)
    # impact parameter to target
    HNLMom = ROOT.TLorentzVector()
    particle.Momentum(HNLMom)
    tr = ROOT.TVector3(0,0,ShipGeo.target.z0)
    t = 0
    for i in range(3):   t += HNLMom(i)/HNLMom.P()*(tr(i)-HNLPos(i)) 
    ip = 0
    for i in range(3):   ip += (tr(i)-HNLPos(i)-t*HNLMom(i)/HNLMom.P())**2
    ip = ROOT.TMath.Sqrt(ip)
    tools.Push(NoB_IP0, ip)
    """
    t1,t2 = particle.GetDaughter(0),particle.GetDaughter(1) 
    PosDir = {} 
    for tr in [t1,t2]:
        xx  = tree.FitTracks[tr].getFittedState()
        PosDir[tr] = [xx.getPos(),xx.getDir()]
    xv,yv,zv,doca = myVertex(t1,t2,PosDir)
    tools.Push(NoB_DOCA, doca)
    #tools.Push(NoB_vtxx, xv); tools.Push(NoB_vtxy, yv); tools.Push(NoB_vtxz, zv)
    tools.Push(NoB_vtxx, xv*xv); tools.Push(NoB_vtxy, yv*yv); tools.Push(NoB_vtxz, zv)
    # impact parameter to target
    HNLPos = ROOT.TLorentzVector()
    particle.ProductionVertex(HNLPos)
    HNLMom = ROOT.TLorentzVector()
    particle.Momentum(HNLMom)
    tr = ROOT.TVector3(0,0,ShipGeo.target.z0)
    t = 0
    for i in range(3):   t += HNLMom(i)/HNLMom.P()*(tr(i)-HNLPos(i)) 
    ip = 0
    for i in range(3):   ip += (tr(i)-HNLPos(i)-t*HNLMom(i)/HNLMom.P())**2
    ip = ROOT.TMath.Sqrt(ip)
    tools.Push(NoB_IP0, ip)
    """


def kinematics(tree, particle, ntr, nref):
    global DaughtersMinP, DaughtersMinPt, P, Pt, Mass
    t1,t2 = particle.GetDaughter(0),particle.GetDaughter(1)
    dminpt, dminp = 0., 0.

    if ntr>1 and nref==2:
        t1r,t2r = sh.getReFitTrIDs()[0], sh.getReFitTrIDs()[1]
        Pos1, Dir1, Mom1= sh.getReFitPosDirPval(t1r)
        Pos2, Dir2, Mom2= sh.getReFitPosDirPval(t2r)
        mass1 = pdg.GetParticle(tree.FitTracks[t1].getFittedState().getPDG()).Mass()
        mass2 = pdg.GetParticle(tree.FitTracks[t2].getFittedState().getPDG()).Mass()
        LV1 = ROOT.TLorentzVector(Mom1*Dir1, ROOT.TMath.Sqrt( mass1*mass1 + Mom1*Mom1 ))
        LV2 = ROOT.TLorentzVector(Mom2*Dir2, ROOT.TMath.Sqrt( mass2*mass2 + Mom2*Mom2 ))
        HNLMom = LV1+LV2
        if LV1 and LV2:
            dminpt = min([LV1.Pt(), LV2.Pt()])
            dminp = min([LV1.P(), LV2.P()])

    if (ntr<2) or (nref!=2) or (not dminp) or (not dminpt) or (not HNLMom):
        dp, dpt = [], []
        for tr in [t1, t2]:
            x = tree.FitTracks[tr]
            xx  = x.getFittedState()
            dp.append(xx.getMom().Mag()); dpt.append(xx.getMom().Pt())
        dminpt = min(dpt)
        dminp = min(dp)
        HNLMom = ROOT.TLorentzVector()
        particle.Momentum(HNLMom)
    tools.Push(DaughtersMinP, dminp)
    tools.Push(DaughtersMinPt, dminpt)
    tools.Push(Mass, HNLMom.M())
    tools.Push(Pt, HNLMom.Pt())
    tools.Push(P, HNLMom.P())
    
def vertexInfo(tree, particle, ntr, nref):
    # By particle
    global vtxx, vtxy, vtxz
    global IP0, DOCA
    t1,t2 = particle.GetDaughter(0),particle.GetDaughter(1) 

    if ntr>1 and nref==2:#nf>1
        assert( len(sh.getReFitTrIDs())==2 )
        t1r,t2r = sh.getReFitTrIDs()[0], sh.getReFitTrIDs()[1]
        #print tree.fitTrack2MC[t1], t1r, tree.fitTrack2MC[t2], t2r
        #print ntr, nref, len(sh._StrawHits__docaEval)
        doca = sh.getDoca()#sh._StrawHits__docaEval[-1]#getDoca()
        v = sh.getReFitVertex()
        if v and doca:
            xv = v.X(); yv = v.Y(); zv = v.Z()
            Pos1, Dir1, Mom1= sh.getReFitPosDirPval(t1r)
            Pos2, Dir2, Mom2= sh.getReFitPosDirPval(t2r)
            mass1 = pdg.GetParticle(tree.FitTracks[t1].getFittedState().getPDG()).Mass()
            mass2 = pdg.GetParticle(tree.FitTracks[t2].getFittedState().getPDG()).Mass()
            LV1 = ROOT.TLorentzVector(Mom1*Dir1, ROOT.TMath.Sqrt( mass1*mass1 + Mom1*Mom1 ))
            LV2 = ROOT.TLorentzVector(Mom2*Dir2, ROOT.TMath.Sqrt( mass2*mass2 + Mom2*Mom2 ))
            HNLMom = LV1+LV2

    # If something went wrong, take the standard values
    if (ntr<2) or (nref!=2) or (not v) or (not doca) or (not HNLMom):#(nf<2)
        PosDir = {} 
        for tr in [t1,t2]:
            xx  = tree.FitTracks[tr].getFittedState()
            PosDir[tr] = [xx.getPos(),xx.getDir()]
        xv,yv,zv,doca = myVertex(t1,t2,PosDir)
        HNLMom = ROOT.TLorentzVector()
        particle.Momentum(HNLMom)

    tools.Push(DOCA, doca)
    #tools.Push(vtxx, xv); tools.Push(vtxy, yv); tools.Push(vtxz, zv)
    tools.Push(vtxx, xv*xv); tools.Push(vtxy, yv*yv); tools.Push(vtxz, zv)

    # impact parameter to target
    #HNLPos = ROOT.TLorentzVector()
    #particle.ProductionVertex(HNLPos)
    HNLPos = ROOT.TVector3(xv, yv, zv)
    tr = ROOT.TVector3(0,0,ShipGeo.target.z0)
    t = 0
    for i in range(3):   t += HNLMom(i)/HNLMom.P()*(tr(i)-HNLPos(i)) 
    ip = 0
    for i in range(3):   ip += (tr(i)-HNLPos(i)-t*HNLMom(i)/HNLMom.P())**2
    ip = ROOT.TMath.Sqrt(ip)
    tools.Push(IP0, ip)


def prepareFillingsByParticle():
    # By event
    global DaughtersAlwaysIn, DaughtersFitConverged, MinDaughtersNdf, MaxDaughtersRedChi2
    global NoB_MinDaughtersNdf, NoB_MaxDaughtersRedChi2
    global Has1Muon1, Has1Muon2, Has2Muon1, Has2Muon2, HasEcal
    global vtxx, vtxy, vtxz, IP0, DOCA
    global NoB_vtxx, NoB_vtxy, NoB_vtxz, NoB_IP0, NoB_DOCA
    global DaughtersMinP, DaughtersMinPt, Mass, P, Pt
    global NoB_DaughtersMinP, NoB_DaughtersMinPt, NoB_Mass, NoB_P, NoB_Pt
    tools.PutToZero(DaughtersAlwaysIn)
    tools.PutToZero(Has2Muon1); tools.PutToZero(Has2Muon2); tools.PutToZero(HasEcal)
    tools.PutToZero(Has1Muon1); tools.PutToZero(Has1Muon2)
    tools.PutToZero(DOCA)
    tools.PutToZero(vtxx); tools.PutToZero(vtxy); tools.PutToZero(vtxz)
    tools.PutToZero(IP0)
    tools.PutToZero(NoB_DOCA)
    tools.PutToZero(NoB_vtxx); tools.PutToZero(NoB_vtxy); tools.PutToZero(NoB_vtxz)
    tools.PutToZero(NoB_IP0)
    tools.PutToZero(MinDaughtersNdf); tools.PutToZero(MaxDaughtersRedChi2)
    tools.PutToZero(NoB_MinDaughtersNdf); tools.PutToZero(NoB_MaxDaughtersRedChi2)
    tools.PutToZero(DaughtersFitConverged)
    tools.PutToZero(DaughtersMinP); tools.PutToZero(DaughtersMinPt)
    tools.PutToZero(P); tools.PutToZero(Pt); tools.PutToZero(Mass)
    tools.PutToZero(NoB_DaughtersMinP); tools.PutToZero(NoB_DaughtersMinPt)
    tools.PutToZero(NoB_P); tools.PutToZero(NoB_Pt); tools.PutToZero(NoB_Mass)
    ntr = sh.readEvent()
    nref = 0
    if ntr>1:
        nref = sh.FitTracks()
        #print ntr, nref
    return ntr, nref 


def pushOfflineByEvent(tree, vetoPoints, datatype, verbose, threshold):
    # True HNL decay vertex (only for signal normalisation)
    signalNormalisationZ(tree, datatype, verbose)
    ## Number of particles
    #nParticles(tree)
    # Empties arrays filled by particle
    ntr, nref = prepareFillingsByParticle()
    # Liquid scintillator segments
    global nodes
    tools.hitSegments(vetoPoints, nodes, threshold)
    return ntr, nref

def pushOfflineByParticle(tree, particle, ntr, nref):
    hasEcalAndMuons(tree, particle)
    goodBehavedTracks(tree, particle)
    NoB_chi2Ndf(tree, particle)
    chi2Ndf(tree, particle, ntr, nref)
    NoB_vertexInfo(tree, particle)
    vertexInfo(tree, particle, ntr, nref)
    NoB_kinematics(tree, particle)
    kinematics(tree, particle, ntr, nref)

fM, tgeom, gMan, geoMat, matEff, modules, run = None, None, None, None, None, None, None

def initBField(fileNameGeo):
    global fM, tgeom, gMan, geoMat, matEff, modules, run, sh
    run     = ROOT.FairRunSim()
    modules = shipDet_conf.configure(run,ShipGeo)
    tgeom = ROOT.TGeoManager("Geometry", "Geane geometry")
    gMan  = tgeom.Import(fileNameGeo)
    geoMat =  ROOT.genfit.TGeoMaterialInterface()
    matEff = ROOT.genfit.MaterialEffects.getInstance()
    matEff.init(geoMat)
    bfield = ROOT.genfit.BellField(ShipGeo.Bfield.max, ShipGeo.Bfield.z, 2, ShipGeo.Yheight/2.)
    fM = ROOT.genfit.FieldManager.getInstance()
    fM.init(bfield)

pdg, sh = None, None