import ROOT,os,sys,getopt import rootUtils as ut import shipunit as u from pythia8_conf import addHNLtoROOT from array import array import RecoSettings from FitTrackInfo import FitTrackInfo ######################################################################## class StrawHits(object): """StrawHit class""" def __init__(self, tree, modules, resolution, debug=0, mhistdict=None, ship_geo=None): ## root tree to be read. self.__tree = tree ## geometry description modules. self.__modules = modules ## debug level [0,3] self.__debug = debug ## hit resolition self.__resolution = resolution ## {MCtrackID : [{'pos':TVector3, 'det':detID, 'dw':distance to wire, 'smdw': smeared dw} where [TVector3] list of each hit position. Created if MCtrackID>0. self.__trackHits = {} ## self.__oldSmearedHits ={} ## {MCtrackID : {X : TVector3}} where x='entry' or 'exit', TVector3 coordinates of last or first hit. ## Created for tracks with more than #RecoSettings .trackMinNofHits hits. self.__trackEdgeHits = {} ## {MCtrackID : number of hits at Z<0 (veto tracker)}. self.__vetoHits = {} ## {MCtrackID: number of crossed stations (exclude veto tracker)}. self.__nStations = {} ## root random engent for hit smearing (see #__hitSmear). self.__random = ROOT.TRandom() ROOT.gRandom.SetSeed(13) #fitter = ROOT.genfit.KalmanFitter() #fitter = ROOT.genfit.KalmanFitterRefTrack() self.__fitter = ROOT.genfit.DAF() # refitted traks self.__reFitTracks = FitTrackInfo(tree=None, debug = self.__debug) self.__docaEval = [] if (mhistdict and ship_geo) : fm = ROOT.genfit.FieldManager.getInstance() # copy from python/shipDet_conf.py sbf = ROOT.ShipBellField("wilfried", ship_geo.Bfield.max,ship_geo.Bfield.z,2,ship_geo.Yheight/2.*u.m ) for i in range (0,300): z = 1000. + i*10 pvec3 = ROOT.TVector3(0,0,z) fx = ROOT.Double(0) fy = ROOT.Double(0) fz = ROOT.Double(0) #fvec3f = fm.getField().get(pvec3) fm.getField().get(0,0,z,fx,fy,fz) fvec3f = ROOT.TVector3(fx,fy,fz) fvec3s = ROOT.TVector3( sbf.GetBx(pvec3.X(),pvec3.Y(),pvec3.Z()),sbf.GetBy(pvec3.X(),pvec3.Y(),pvec3.Z()),sbf.GetBz(pvec3.X(),pvec3.Y(),pvec3.Z())) #print z, " ".join("{:10.4f}".format(fvec3f(ii)) for ii in range(0,3)), #print "\t", " ".join("{:10.4f}".format(fvec3s(ii)) for ii in range(0,3)) mhistdict['magZfit'].Fill(z,fvec3f.Mag()) mhistdict['magZsim'].Fill(z,fvec3s.Mag()) zdict = {1:2500., 2:2800., 3:3000.} for zi in zdict: for xi in range (-30, 30): for yi in range (-30,30): x = xi*10. y = yi*10. pvec3 = ROOT.TVector3(x, y, zdict[zi]) fx = ROOT.Double(0) fy = ROOT.Double(0) fz = ROOT.Double(0) #fvec3f = fm.getField().get(pvec3) fm.getField().get(x,y,zdict[zi],fx,fy,fz) fvec3f = ROOT.TVector3(fx,fy,fz) fvec3s = ROOT.TVector3( sbf.GetBx(pvec3.X(),pvec3.Y(),pvec3.Z()),sbf.GetBy(pvec3.X(),pvec3.Y(),pvec3.Z()),sbf.GetBz(pvec3.X(),pvec3.Y(),pvec3.Z())) #print x, " ", y, " ", zdict[zi], #print "\t", " ".join("{:10.4f}".format(fvec3f(ii)) for ii in range(0,3)), #print "\t", " ".join("{:10.4f}".format(fvec3s(ii)) for ii in range(0,3)) mhistdict['magXY'+str(zi)+"fit"].Fill(x, y,fvec3f.Mag()) mhistdict['magXY'+str(zi)+"sim"].Fill(x, y,fvec3s.Mag()) ######################################################################## ## \brief to be called for each new event (called in #readEvent()) # cleans all dictionaries (#__trackHits, #__trackEdgeHits, #__vetoHits, #__nStations). def __clean(self): self.__trackHits.clear() self.__oldSmearedHits.clear() self.__trackEdgeHits.clear() self.__vetoHits.clear() self.__nStations.clear() self.__docaEval = [] ######################################################################## ## \brief returns list of keys #__trackEdgeHits (MCtrackIDs>0 with more than #RecoSettings .trackMinNofHits hits). # \return list of MCtrackIDs of "good" tracks. def getTrIDs(self): trID = [] for tid in self.__trackEdgeHits: trID.append(tid) return trID ######################################################################## ## \brief returns list of keys #__trackHits (MCtrackIDs>0. # \return list of MCtrackIDs of MC assigned tracks. def getTrIDsALL(self): trID = [] for tid in self.__trackEdgeHits: trID.append(tid) return trID ######################################################################## ## \brief returns list of keys #__reFitTracks. # \return list of MCtrackIDs of "good" tracks. def getReFitTrIDs(self): return self.__reFitTracks.getTrIDs() ######################################################################## def getReFitChi2Ndf(self,tid): return self.__reFitTracks.getChi2Ndf(tid) ######################################################################## def getReFitNdf(self,tid): return self.__reFitTracks.getNdf(tid) ######################################################################## ## \brief returns vertex (if number of tracks!=2 will return None!). # \return new vertex (if number of tracks!=2 will return None!). def getReFitVertex(self): return self.__reFitTracks.getVertex() ######################################################################## ######################################################################## ## \brief returns doca's of each extrapolation steps (size is defined in #RecoSettings .VertexExtrSteps). # \return new vertex (if number of tracks!=2 will return None!). def getStepDoca(self, step): if ( step>RecoSettings.VertexExtrSteps or (not self.__reFitTracks.Vertex) ) : return None return self.__docaEval[step] ######################################################################## ## \brief returns vertex (if number of tracks!=2 will return None!). # \return new vertex (if number of tracks!=2 will return None!). def getReFitPosDirPval(self, tid): return self.__reFitTracks.getPosDirPval(tid) ######################################################################## ## \brief returns number of hits in proper tracker stations (Z>0) calculated from #__trackHits and #__vetoHits. # \param tid - MCtrackID. # \return number of hits in proper tracker stations (Z>0). def getNofPHits(self, tid): return len(self.__trackHits[tid]) - self.__vetoHits[tid] ######################################################################## ## \brief returns TVector3 of a tracker entry hit (Z>0) from #__trackEdgeHits. # \param tid - MCtrackID. # \return position of a tracker entry hit (Z>0) from #__trackEdgeHits. def getStartHit(self, tid): return self.__trackEdgeHits[tid]['entry'] ######################################################################## ## \brief returns number of hits with Z<0 of #__trackHits. # \param tid - MCtrackID. # \return number of hits with Z<0 of #__trackHits. def checkVetoHits(self, tid): vh = 0 if tid in self.__vetoHits: vh = self.__vetoHits[tid] return vh ######################################################################## def PrintNewTracks(self): print "new Fits: ", self.__reFitTracks.Print() ######################################################################## def compareFitTracks(self, tid, theFitTracks): pos, direct, pval = theFitTracks.getPosDirPval(tid) return self.__reFitTracks.compareTracks(tid, pos, direct, pval) ######################################################################## ## \brief returns a dictionary {xtop, ytop, z, ybot, ybot, z, dist} for a smeared hit. # \param tid - MCtrackID. # \param hid - hit index of #__trackHits # \param new - to generate new smearing (True) or get from SmearedHits (det.position still recalculated!) # \return a dictionary {xtop, ytop, z, ybot, ybot, z, dist} for a smeared hit. def __hitSmear(self,tid,hid, new=False): top = ROOT.TVector3() bot = ROOT.TVector3() dw = self.__trackHits[tid][hid]['dw'] detID = self.__trackHits[tid][hid]['det'] self.__modules["Strawtubes"].StrawEndPoints(detID,bot,top) if( new ): smear = abs(self.__random.Gaus(dw, self.__resolution)) else: smear = self.__trackHits[tid][hid]['smdw'] smearedHit = {'xtop':top.x(),'ytop':top.y(),'z':top.z(),'xbot':bot.x(),'ybot':bot.y(),'z':bot.z(),'dist':smear} if(self.__debug>2): print "\tsmear :", "".join("{:8.2f}".format(self.__trackHits[tid][hid]['pos'](ii)) for ii in range(0,3)), print "{:6.2f}".format(dw), print "\t(xt,xb, yt, yb, z, dw) : ", for x in ['xtop','xbot', 'ytop','ybot','z', 'dist']: print "".join("{:8.2f}".format(smearedHit[x])), print "" return smearedHit ######################################################################## ## \brief to be called per each event. Fills #__trackHits, #__trackEdgeHits, #__vetoHits, #__nStations. # \return number of "good" tracks (size of #__trackEdgeHits) def readEvent(self): self.__clean() toSort = [] # list of MCtrackID which has unsorted hits (I saw also hits from different tracks assigned to the same MCtrackID) stationList = {} # {MCtrackID:[stations]} # loop over all hits and fill __trackHits[MCtrackID] hindx = -1 for ahit in self.__tree.strawtubesPoint: detID = ahit.GetDetectorID() trID = ahit.GetTrackID() # get old smearing hindx +=1 origSmHit = self.__tree.SmearedHits.At(hindx) if( (abs(ahit.GetZ()-origSmHit[3])>0.8) or (abs(ahit.dist2Wire()-origSmHit[7])>0.2) ): print "problem getting smeared his, but do not change anything" print "=>", ahit.GetZ(), origSmHit[3], ahit.dist2Wire(), origSmHit[7] # m = array('d',[i,sm['xtop'],sm['ytop'],sm['z'],sm['xbot'],sm['ybot'],sm['z'],sm['dist']]) #=> if(trID<0): continue # these are hits not assigned to MC track because low E cut if (not self.__trackHits.has_key(trID)): self.__trackHits[trID] = [] stationList[trID] = [] hinfo = {} hinfo['pos'] = ROOT.TVector3(ahit.GetX(), ahit.GetY(), ahit.GetZ()) hinfo['det'] = ahit.GetDetectorID() hinfo['dw'] = ahit.dist2Wire() hinfo['smdw'] = origSmHit[7] self.__trackHits[trID].append(hinfo) lastIndx = len(self.__trackHits[trID])-1 if( self.__trackHits[trID][lastIndx]['pos'].Z() < self.__trackHits[trID][lastIndx-1]['pos'].Z() ): if( not trID in toSort): toSort.append(trID) if(self.__debug>0): print "StrawHitsEntry: wrong order of hits for track ", trID station = int(ahit.GetDetectorID()/10000000) if station > 4 : continue if ( not station in stationList[trID]) : stationList[trID].append(station) # sort for trID in toSort: if(self.__debug>0): print "StrawHitsEntry: will sort hits for track ", trID if(self.__debug>2): print "\t\thits to be sorted" for hinfo in self.__trackHits[trID]: vec3 = hinfo['pos'] print "\t\t\t\t", vec3.X(), "\t", vec3.Y(), "\t", vec3.Z(), hinfo['dw'] self.__trackHits[trID].sort(key=lambda x: x['pos'].Z(), reverse=False) if(self.__debug>2): print "\t\thits after sorting" for hinfo in self.__trackHits[trID]: vec3 = hinfo['pos'] print "\t\t\t\t", vec3.X(), "\t", vec3.Y(), "\t", vec3.Z(), hinfo['dw'] # fill self.__nStations for trID in self.__trackHits: self.__nStations[trID] = len(stationList[trID]) if(self.__debug>0): print "Number of crossed stations (trID:n)", trID, " : ", self.__nStations[trID] # find entry and exit positions for trID in self.__trackHits: if(self.__debug>1): print "hits for trID ", trID for hinfo in self.__trackHits[trID]: vec3 = hinfo['pos'] print "\t", vec3.X(), "\t", vec3.Y(), "\t", vec3.Z(), hinfo['dw'] if(self.__debug>0): print "start/stop position for hits of track ", trID #find number of vetoTracker hits firstHit = 0 nHits = len(self.__trackHits[trID]) while( firstHit<nHits and (self.__trackHits[trID][firstHit]['pos'].Z()<0) ): firstHit+=1 # => # the EdgeHits are filled only if nHits(stations1-4)>25 if( (firstHit<nHits) and ((nHits-firstHit)>RecoSettings.trackMinNofHits) ): self.__trackEdgeHits[trID] = {} self.__trackEdgeHits[trID]['entry'] = self.__trackHits[trID][firstHit]['pos'] self.__trackEdgeHits[trID]['exit'] = self.__trackHits[trID][-1]['pos'] self.__vetoHits[trID] = firstHit if(self.__debug>0): for pos in self.__trackEdgeHits[trID]: vec3 = self.__trackEdgeHits[trID][pos] print "\t", pos, vec3.X(), "\t", vec3.Y(), "\t", vec3.Z() elif( self.__debug>0): print "not set due to small number of hits" return len(self.__trackEdgeHits) ######################################################################## def __getIniDir(self,trID): v1 = self.__trackEdgeHits[trID]['entry'] i2 = self.__vetoHits[trID]+1 if( len(self.__trackHits[trID])>i2 ): v2 = self.__trackHits[trID][i2]['pos'] dv = v2-v1 else: dv = ROOT.TVector3(0., 0., 1.) if(self.__debug>0): print "trying to get initial direction having just one hit, will set (0,0,1)" return dv*(1./dv.Mag()) def __prepareIniPosMomCov(self, tid, original=True): if ( original ) : pos = ROOT.TVector3(0, 0, 0) mom = ROOT.TVector3(0,0,3.*u.GeV) cov = ROOT.TMatrixDSym(6) resolution = self.__resolution for i in range(3): cov[i][i] = resolution*resolution cov[0][0]=resolution*resolution*100. nM = self.getNofPHits(tid) for i in range(3,6): cov[i][i] = ROOT.TMath.pow(resolution / nM / ROOT.TMath.sqrt(3), 2) else: pos = self.__trackEdgeHits[tid]['entry'] mom = self.__getIniDir(tid) cov = ROOT.TMatrixDSym(6) resolution = self.__resolution for i in range(3): cov[i][i] = resolution*resolution cov[0][0]=resolution*resolution*100. nM = self.getNofPHits(tid) for i in range(3,6): cov[i][i] = ROOT.TMath.pow(resolution / nM / ROOT.TMath.sqrt(3), 2) return pos, mom, cov ######################################################################## def __prepareWireMeasurements(self, tid, fTrack): #WireMeasurement::WireMeasurement(const TVectorD& rawHitCoords, # const TMatrixDSym& rawHitCov, # int detId, # int hitId, # genfit::TrackPoint* trackPoint) # per each proper hit TMP ??? does it make sense to do for tracks with __vetoHits>0??? #self.__measurements4fit[trID] = [] for hindx in range (self.__vetoHits[tid], len(self.__trackHits[tid])): sm = self.__hitSmear(tid,hindx) mVector = ROOT.TVectorD(7,array('d',[sm['xtop'],sm['ytop'],sm['z'],sm['xbot'],sm['ybot'],sm['z'],sm['dist']])) #self.__measurements4fit[trID].push_back(mVector) hitCov = ROOT.TMatrixDSym(7) hitCov[6][6] = self.__resolution*self.__resolution tp = ROOT.genfit.TrackPoint(fTrack) # note how the point is told which track it belongs to measurement = ROOT.genfit.WireMeasurement(mVector,hitCov,1,6,tp) # the measurement is told which trackpoint it belongs to # print measurement.getMaxDistance() measurement.setMaxDistance(0.5*u.cm) #measurement.setLeftRightResolution(-1) tp.addRawMeasurement(measurement) # package measurement in the TrackPoint if(self.__debug>2): tp.Print() fTrack.insertPoint(tp) # add point to Track ######################################################################## def FitTracks(self, old=True): self.__reFitTracks.clean() fitTrack = {} #self.__measurements4fit = {} nTrack = -1 for trID in self.__trackEdgeHits : # these are already tracks with large number of hits #print "track entry", self.__trackEdgeHits[tid]['entry']. #print "mfield: ", ROOT if(self.__debug>0): print "ELENA Number of crossed stations (trID:n)", trID, " : ", self.__nStations[trID] print self.__trackEdgeHits[trID]['entry'].Z() # minimal requirements on number of crossed stations if ( self.__nStations<RecoSettings.trackMinNofStations): continue pdg = self.__tree.MCTrack[trID].GetPdgCode() # remove unknown or neutral particles charge = RecoSettings.chargePDG(pdg) if( (not charge) or (charge==0) ): print "StrawHits.FitTracks for TrID ", trID, "finds charge of track of ", charge, " and does nothing." continue posM, momM, covM = self.__prepareIniPosMomCov(trID,old) rep = ROOT.genfit.RKTrackRep(pdg) stateSmeared = ROOT.genfit.MeasuredStateOnPlane(rep) rep.setPosMomCov(stateSmeared, posM, momM, covM) print "ELENA:" posM.Print() momM.Print() covM.Print() seedState = ROOT.TVectorD(6) seedCov = ROOT.TMatrixDSym(6) rep.get6DStateCov(stateSmeared, seedState, seedCov) rep.Print() fitTrack[trID] = ROOT.genfit.Track(rep, seedState, seedCov) print print fitTrack[trID].Print() self.__fitter.Print() ROOT.SetOwnership(fitTrack[trID], False) if(self.__debug>2): print "preparing measurements for track ID", trID self.__prepareWireMeasurements(trID, fitTrack[trID]) if not fitTrack[trID].checkConsistency(): print 'Problem with track before fit, not consistent',self.fitTrack[atrack] continue try: self.__fitter.processTrack(fitTrack[trID]) # processTrackWithRep(fitTrack[atrack],rep,True) except: print "genfit failed to fit track" continue if not fitTrack[trID].checkConsistency(): print 'Problem with track after fit, not consistent',self.fitTrack[atrack] continue stat = fitTrack[trID].getFitStatus() if not stat.isFitConverged() : continue f = fitTrack[trID].getFittedState() #if(self.__debug>0): #print "for track ", trID, #print " pos:", " ".join("{:10.4f}".format(f.getPos()(ii)) for ii in range(0,3)), #print " mom:", " ".join("{:10.4f}".format(f.getMom()(ii)) for ii in range(0,3)) self.__reFitTracks.addNewTrack(trID, f.getPos(), f.getDir(), f.getMomMag(), stat.getNdf(), stat.getChi2()) newFitTrIDs = self.__reFitTracks.getTrIDs() twoTracks = ( len(newFitTrIDs)==2 ) theStep = 0 self.__docaEval = [] if (twoTracks) : self.__reFitTracks.createVertex(newFitTrIDs[0], newFitTrIDs[1], flag=0) # original iniDoca = self.__reFitTracks.Doca iniY = self.__reFitTracks.Vertex.Y() while ( theStep<RecoSettings.VertexExtrSteps and twoTracks): flag = 1 newFitTrIDs = self.__reFitTracks.getTrIDs() #if(self.__debug>1): print "==>vertex ", theStep, " ", self.__reFitTracks.Doca self.__reFitTracks.Vertex.Print() self.__docaEval.append(self.__reFitTracks.Doca) for tid in fitTrack : try: state = fitTrack[tid].getFittedState() except: print "can't get fittedState" flag = -1 vPosEx = ROOT.TVector3(0,0,0) vMomEx = ROOT.TVector3(0,0,0) try : state.extrapolateToPoint(self.__reFitTracks.Vertex) except : flag = -1 print "track exctrapolation failed!tid: ", tid if (flag > 0 ) : # status = fitTrack[tid].getFitStatus() #print "extr track ", tid, #print " pos:", " ".join("{:10.4f}".format(state.getPos()(ii)) for ii in range(0,3)), #print " mom:", " ".join("{:10.4f}".format(state.getMom()(ii)) for ii in range(0,3)) self.__reFitTracks.addNewTrack(trID, state.getPos(), state.getDir(), state.getMomMag(), status.getNdf(), status.getChi2(), verb=False) # FIX temporary self.__reFitTracks.createVertex(newFitTrIDs[0], newFitTrIDs[1], flag) self.__reFitTracks.Vertex.SetY(iniY) theStep+=1 twoTacks = ( len(self.__reFitTracks.getTrIDs())==2 ) return len(newFitTrIDs)