Newer
Older
FairShipTools / pleaseRunMe / StrawHits.py
@Ubuntu Ubuntu on 22 Mar 2015 19 KB software run on yandex
import ROOT,os,sys,getopt
import rootUtils as ut
import shipunit as u
from pythia8_conf import addHNLtoROOT
from array import array

import RecoSettings
from FitTrackInfo import FitTrackInfo


########################################################################
class StrawHits(object):
  """StrawHit class"""  
  def __init__(self, tree, modules, resolution, debug=0, mhistdict=None, ship_geo=None):
    ## root tree to be read.
    self.__tree    = tree
    ## geometry description modules. 
    self.__modules = modules
    ## debug level [0,3]
    self.__debug   = debug
    ## hit resolition
    self.__resolution = resolution
    ## {MCtrackID : [{'pos':TVector3, 'det':detID, 'dw':distance to wire, 'smdw': smeared dw} where [TVector3] list of each hit position. Created if MCtrackID>0.
    self.__trackHits     = {}
    ##
    self.__oldSmearedHits ={}
    ## {MCtrackID : {X : TVector3}} where x='entry' or 'exit', TVector3 coordinates of last or first hit. 
    ## Created for tracks with more than #RecoSettings .trackMinNofHits hits.
    self.__trackEdgeHits = {} 
    ## {MCtrackID : number of hits at Z<0 (veto tracker)}.
    self.__vetoHits      = {}
    ## {MCtrackID: number of crossed stations (exclude veto tracker)}.
    self.__nStations       = {}
    ## root random engent for hit smearing (see #__hitSmear).
    self.__random        = ROOT.TRandom()
    ROOT.gRandom.SetSeed(13)
    #fitter          = ROOT.genfit.KalmanFitter()
    #fitter          = ROOT.genfit.KalmanFitterRefTrack()
    self.__fitter          = ROOT.genfit.DAF()
    # refitted traks
    self.__reFitTracks = FitTrackInfo(tree=None, debug = self.__debug)
    self.__docaEval    = []
    
    if (mhistdict and ship_geo) :
      fm  = ROOT.genfit.FieldManager.getInstance()
      # copy from python/shipDet_conf.py
      sbf =  ROOT.ShipBellField("wilfried", ship_geo.Bfield.max,ship_geo.Bfield.z,2,ship_geo.Yheight/2.*u.m )
      for i in range (0,300):
	z = 1000. + i*10
	pvec3  = ROOT.TVector3(0,0,z)
	fx = ROOT.Double(0)
	fy = ROOT.Double(0)
	fz = ROOT.Double(0)
	#fvec3f = fm.getField().get(pvec3)
	fm.getField().get(0,0,z,fx,fy,fz)
	fvec3f = ROOT.TVector3(fx,fy,fz)
	
	fvec3s = ROOT.TVector3( sbf.GetBx(pvec3.X(),pvec3.Y(),pvec3.Z()),sbf.GetBy(pvec3.X(),pvec3.Y(),pvec3.Z()),sbf.GetBz(pvec3.X(),pvec3.Y(),pvec3.Z()))
	  
	#print z,    "  ".join("{:10.4f}".format(fvec3f(ii)) for ii in range(0,3)),
	#print "\t", "  ".join("{:10.4f}".format(fvec3s(ii)) for ii in range(0,3))
	mhistdict['magZfit'].Fill(z,fvec3f.Mag())
	mhistdict['magZsim'].Fill(z,fvec3s.Mag())
	
      zdict = {1:2500., 2:2800., 3:3000.}
      for zi in zdict:
	for xi in range (-30, 30):
	  for yi in range (-30,30):
	    x = xi*10.
	    y = yi*10.
	    pvec3 = ROOT.TVector3(x, y, zdict[zi])
	    fx = ROOT.Double(0)
	    fy = ROOT.Double(0)
	    fz = ROOT.Double(0)
	    #fvec3f = fm.getField().get(pvec3)
	    fm.getField().get(x,y,zdict[zi],fx,fy,fz)
	    fvec3f = ROOT.TVector3(fx,fy,fz)

	    fvec3s = ROOT.TVector3( sbf.GetBx(pvec3.X(),pvec3.Y(),pvec3.Z()),sbf.GetBy(pvec3.X(),pvec3.Y(),pvec3.Z()),sbf.GetBz(pvec3.X(),pvec3.Y(),pvec3.Z()))
	    #print x, "   ", y, "   ", zdict[zi],
	    #print "\t", "  ".join("{:10.4f}".format(fvec3f(ii)) for ii in range(0,3)),
	    #print "\t", "  ".join("{:10.4f}".format(fvec3s(ii)) for ii in range(0,3))
	    mhistdict['magXY'+str(zi)+"fit"].Fill(x, y,fvec3f.Mag())
	    mhistdict['magXY'+str(zi)+"sim"].Fill(x, y,fvec3s.Mag())
########################################################################


  ## \brief to be called for each new event (called in #readEvent())
  # cleans all dictionaries (#__trackHits, #__trackEdgeHits, #__vetoHits, #__nStations).
  def __clean(self):
    self.__trackHits.clear()
    self.__oldSmearedHits.clear()
    self.__trackEdgeHits.clear()
    self.__vetoHits.clear()
    self.__nStations.clear()
    self.__docaEval    = []
########################################################################
    
    
  ## \brief returns list of keys #__trackEdgeHits (MCtrackIDs>0 with more than #RecoSettings .trackMinNofHits hits).
  #  \return list of MCtrackIDs of "good" tracks.
  def getTrIDs(self):
    trID = []
    for tid in self.__trackEdgeHits:
      trID.append(tid)
    return trID
########################################################################


  ## \brief returns list of keys #__trackHits (MCtrackIDs>0.
  #  \return list of MCtrackIDs of MC assigned tracks.
  def getTrIDsALL(self):
    trID = []
    for tid in self.__trackEdgeHits:
      trID.append(tid)
    return trID
########################################################################


  ## \brief returns list of keys #__reFitTracks.
  #  \return list of MCtrackIDs of "good" tracks.
  def getReFitTrIDs(self):
    return self.__reFitTracks.getTrIDs()
########################################################################
  
  
  def getReFitChi2Ndf(self,tid):
    return self.__reFitTracks.getChi2Ndf(tid)
########################################################################
  
  
  def getReFitNdf(self,tid):
    return self.__reFitTracks.getNdf(tid)
  
  
  
########################################################################
  ## \brief returns vertex (if number of tracks!=2 will return None!).
  #  \return new vertex (if number of tracks!=2 will return None!).
  def getReFitVertex(self):
    return self.__reFitTracks.getVertex()
########################################################################



########################################################################
  ## \brief returns doca's of each extrapolation steps (size is defined in #RecoSettings .VertexExtrSteps).
  #  \return new vertex (if number of tracks!=2 will return None!).
  def getStepDoca(self, step):
    if ( step>RecoSettings.VertexExtrSteps or (not self.__reFitTracks.Vertex) ) : return None
    return self.__docaEval[step]
########################################################################


  ## \brief returns vertex (if number of tracks!=2 will return None!).
  #  \return new vertex (if number of tracks!=2 will return None!).
  def getReFitPosDirPval(self, tid):
    return self.__reFitTracks.getPosDirPval(tid)
########################################################################
  
  
  ## \brief returns number of hits in proper tracker stations (Z>0) calculated from #__trackHits and #__vetoHits.
  #  \param tid - MCtrackID.
  #  \return number of hits in proper tracker stations (Z>0).
  def getNofPHits(self, tid):
    return len(self.__trackHits[tid]) - self.__vetoHits[tid]
########################################################################


  ## \brief returns TVector3 of a tracker entry hit (Z>0) from #__trackEdgeHits.
  #  \param tid - MCtrackID.
  #  \return position of a tracker entry hit (Z>0) from #__trackEdgeHits.
  def getStartHit(self, tid):
    return self.__trackEdgeHits[tid]['entry']
########################################################################

    
  ## \brief returns number of hits with Z<0 of #__trackHits.
  #  \param tid - MCtrackID.
  #  \return number of hits with Z<0 of #__trackHits.
  def checkVetoHits(self, tid):
    vh = 0
    if tid in self.__vetoHits:
      vh = self.__vetoHits[tid]
    return vh
########################################################################


  def PrintNewTracks(self):
    print "new Fits: ",
    self.__reFitTracks.Print()
########################################################################      


  def compareFitTracks(self, tid, theFitTracks):
    pos, direct, pval = theFitTracks.getPosDirPval(tid)
    return self.__reFitTracks.compareTracks(tid, pos, direct, pval)
    
########################################################################      
  ## \brief returns a dictionary {xtop, ytop, z, ybot, ybot, z, dist} for a smeared hit.
  #  \param tid - MCtrackID.
  #  \param hid - hit index of #__trackHits
  #  \param new - to generate new smearing (True) or get from SmearedHits (det.position still recalculated!)
  #  \return a dictionary {xtop, ytop, z, ybot, ybot, z, dist} for a smeared hit.
  def __hitSmear(self,tid,hid, new=False):
    top   = ROOT.TVector3()
    bot   = ROOT.TVector3()
    dw    = self.__trackHits[tid][hid]['dw']
    detID = self.__trackHits[tid][hid]['det']
    
    self.__modules["Strawtubes"].StrawEndPoints(detID,bot,top)
    
    if( new ):
      smear = abs(self.__random.Gaus(dw, self.__resolution))
    else:
      smear = self.__trackHits[tid][hid]['smdw']
    smearedHit = {'xtop':top.x(),'ytop':top.y(),'z':top.z(),'xbot':bot.x(),'ybot':bot.y(),'z':bot.z(),'dist':smear}

    if(self.__debug>2):
      print "\tsmear :", "".join("{:8.2f}".format(self.__trackHits[tid][hid]['pos'](ii)) for ii in range(0,3)),  
      print "{:6.2f}".format(dw),
      print "\t(xt,xb, yt, yb, z, dw) : ",
      for x in ['xtop','xbot', 'ytop','ybot','z', 'dist']:
	print "".join("{:8.2f}".format(smearedHit[x])),  
      print ""
    return smearedHit
########################################################################


  ## \brief to be called per each event. Fills #__trackHits, #__trackEdgeHits, #__vetoHits, #__nStations.
  #  \return number of "good" tracks (size of #__trackEdgeHits)
  def readEvent(self):
    self.__clean()
    toSort      = [] # list of MCtrackID which has unsorted hits (I saw also hits from different tracks assigned to the same MCtrackID)
    stationList = {} # {MCtrackID:[stations]}
    
    # loop over all hits and fill __trackHits[MCtrackID]
    hindx = -1
    for ahit in self.__tree.strawtubesPoint:
      detID   = ahit.GetDetectorID()
      trID    = ahit.GetTrackID()
      
      # get old smearing
      hindx +=1
      origSmHit = self.__tree.SmearedHits.At(hindx)
      if( (abs(ahit.GetZ()-origSmHit[3])>0.8) or (abs(ahit.dist2Wire()-origSmHit[7])>0.2) ):
	print "problem getting smeared his, but do not change anything" 
	print "=>", ahit.GetZ(), origSmHit[3], ahit.dist2Wire(), origSmHit[7]
	# m = array('d',[i,sm['xtop'],sm['ytop'],sm['z'],sm['xbot'],sm['ybot'],sm['z'],sm['dist']])
      
      #=>
      if(trID<0): continue # these are hits not assigned to MC track because low E cut
      
      if (not self.__trackHits.has_key(trID)): 
	self.__trackHits[trID] = []
	stationList[trID]      = []

      hinfo = {}
      hinfo['pos']  = ROOT.TVector3(ahit.GetX(), ahit.GetY(), ahit.GetZ()) 
      hinfo['det']  = ahit.GetDetectorID()
      hinfo['dw']   = ahit.dist2Wire()
      hinfo['smdw'] = origSmHit[7]
      self.__trackHits[trID].append(hinfo)
      
      lastIndx = len(self.__trackHits[trID])-1
      if( self.__trackHits[trID][lastIndx]['pos'].Z() < self.__trackHits[trID][lastIndx-1]['pos'].Z() ):
	if( not trID in toSort):
	  toSort.append(trID)
	  if(self.__debug>0): print "StrawHitsEntry: wrong order of hits for track ", trID
      
      station = int(ahit.GetDetectorID()/10000000)
      if station > 4 : continue
      if ( not station in stationList[trID]) : stationList[trID].append(station)
      
    # sort
    for trID in toSort:
      if(self.__debug>0): print "StrawHitsEntry: will sort hits for track ", trID
      if(self.__debug>2):
	print "\t\thits to be sorted"
	for hinfo in self.__trackHits[trID]:
	  vec3 = hinfo['pos']
	  print "\t\t\t\t", vec3.X(), "\t", vec3.Y(), "\t", vec3.Z(), hinfo['dw']
      self.__trackHits[trID].sort(key=lambda x: x['pos'].Z(), reverse=False)
      if(self.__debug>2):
	print "\t\thits after sorting"
	for hinfo in self.__trackHits[trID]:
	  vec3 = hinfo['pos']
	  print "\t\t\t\t", vec3.X(), "\t", vec3.Y(), "\t", vec3.Z(), hinfo['dw']
    
    # fill self.__nStations
    for trID in self.__trackHits:
      self.__nStations[trID] = len(stationList[trID])
      if(self.__debug>0):
	print "Number of crossed stations (trID:n)", trID, " : ", self.__nStations[trID]
	
    # find entry and exit positions
    for trID in self.__trackHits:
      if(self.__debug>1):
	print "hits for trID ", trID
	for hinfo in self.__trackHits[trID]:
	  vec3 = hinfo['pos']
	  print "\t", vec3.X(), "\t", vec3.Y(), "\t", vec3.Z(), hinfo['dw']
      if(self.__debug>0): print "start/stop position for hits of track ", trID     
      #find number of vetoTracker hits
      firstHit = 0
      nHits    = len(self.__trackHits[trID])
      while( firstHit<nHits and (self.__trackHits[trID][firstHit]['pos'].Z()<0) ): 
	firstHit+=1
	
      # =>
      # the EdgeHits are filled only if nHits(stations1-4)>25
      if( (firstHit<nHits) and ((nHits-firstHit)>RecoSettings.trackMinNofHits) ):
	self.__trackEdgeHits[trID] = {}   
	self.__trackEdgeHits[trID]['entry']    = self.__trackHits[trID][firstHit]['pos']
	self.__trackEdgeHits[trID]['exit']     = self.__trackHits[trID][-1]['pos']
	self.__vetoHits[trID]                  = firstHit
	if(self.__debug>0):
	  for pos in self.__trackEdgeHits[trID]:
	    vec3 = self.__trackEdgeHits[trID][pos]
	    print "\t", pos, vec3.X(), "\t", vec3.Y(), "\t", vec3.Z()
      elif( self.__debug>0): print "not set due to small number of hits"

    return len(self.__trackEdgeHits)
########################################################################    
    




  def __getIniDir(self,trID):
    v1 = self.__trackEdgeHits[trID]['entry']
    i2 = self.__vetoHits[trID]+1 
    if( len(self.__trackHits[trID])>i2 ):
      v2 = self.__trackHits[trID][i2]['pos']
      dv = v2-v1
    else: 
      dv = ROOT.TVector3(0., 0., 1.) 
      if(self.__debug>0):
	print "trying to get initial direction having just one hit, will set (0,0,1)"
    return dv*(1./dv.Mag())
  
  
  def __prepareIniPosMomCov(self, tid, original=True):
    if ( original ) :
      pos = ROOT.TVector3(0, 0, 0)
      mom = ROOT.TVector3(0,0,3.*u.GeV)
      cov = ROOT.TMatrixDSym(6)
      resolution = self.__resolution
      for  i in range(3):   cov[i][i] = resolution*resolution
      cov[0][0]=resolution*resolution*100.
      nM = self.getNofPHits(tid)
      for  i in range(3,6): cov[i][i] = ROOT.TMath.pow(resolution / nM / ROOT.TMath.sqrt(3), 2)
    else:
      pos = self.__trackEdgeHits[tid]['entry']
      mom = self.__getIniDir(tid)
      cov = ROOT.TMatrixDSym(6)
      resolution = self.__resolution
      for  i in range(3):   cov[i][i] = resolution*resolution
      cov[0][0]=resolution*resolution*100.
      nM = self.getNofPHits(tid)
      for  i in range(3,6): cov[i][i] = ROOT.TMath.pow(resolution / nM / ROOT.TMath.sqrt(3), 2)      
    return pos, mom, cov    
######################################################################## 



  def __prepareWireMeasurements(self, tid, fTrack):
    #WireMeasurement::WireMeasurement(const TVectorD& rawHitCoords, 
    #                                 const TMatrixDSym& rawHitCov, 
    #                                 int detId, 
    #                                 int hitId, 
    #                                 genfit::TrackPoint* trackPoint)
    # per each proper hit  TMP ??? does it make sense to do for tracks with __vetoHits>0???
    #self.__measurements4fit[trID] = []
    for hindx in range (self.__vetoHits[tid], len(self.__trackHits[tid])):
      sm          = self.__hitSmear(tid,hindx)
      mVector     = ROOT.TVectorD(7,array('d',[sm['xtop'],sm['ytop'],sm['z'],sm['xbot'],sm['ybot'],sm['z'],sm['dist']]))
      #self.__measurements4fit[trID].push_back(mVector)

      hitCov       = ROOT.TMatrixDSym(7)
      hitCov[6][6] = self.__resolution*self.__resolution

      tp = ROOT.genfit.TrackPoint(fTrack) # note how the point is told which track it belongs to 
      measurement = ROOT.genfit.WireMeasurement(mVector,hitCov,1,6,tp) # the measurement is told which trackpoint it belongs to
      # print measurement.getMaxDistance()
      measurement.setMaxDistance(0.5*u.cm)
      #measurement.setLeftRightResolution(-1)
      tp.addRawMeasurement(measurement) # package measurement in the TrackPoint                                          
      if(self.__debug>2):
	tp.Print()
      fTrack.insertPoint(tp)  # add point to Track






######################################################################## 
  def FitTracks(self, old=True):
    
    self.__reFitTracks.clean()
    
    fitTrack = {}
    #self.__measurements4fit = {}
    nTrack = -1


    for trID in self.__trackEdgeHits : # these are already tracks with large number of hits
      #print "track entry", self.__trackEdgeHits[tid]['entry']. 
      #print "mfield: ", ROOT
      if(self.__debug>0):
          print "ELENA Number of crossed stations (trID:n)", trID, " : ", self.__nStations[trID]
          print self.__trackEdgeHits[trID]['entry'].Z()
      
      # minimal requirements on number of crossed stations
      if ( self.__nStations<RecoSettings.trackMinNofStations): continue
      
      pdg = self.__tree.MCTrack[trID].GetPdgCode()
      
      # remove unknown or neutral particles 
      charge = RecoSettings.chargePDG(pdg)
      if( (not charge) or (charge==0) ): 
	print "StrawHits.FitTracks for TrID ", trID, "finds charge of track of ", charge, " and does nothing."
	continue

      posM, momM, covM = self.__prepareIniPosMomCov(trID,old)
      rep          = ROOT.genfit.RKTrackRep(pdg)
      stateSmeared = ROOT.genfit.MeasuredStateOnPlane(rep)
      rep.setPosMomCov(stateSmeared, posM, momM, covM)

      print "ELENA:"
      posM.Print()
      momM.Print()
      covM.Print()

      seedState = ROOT.TVectorD(6)
      seedCov   = ROOT.TMatrixDSym(6)
      rep.get6DStateCov(stateSmeared, seedState, seedCov)
      
      rep.Print()

      fitTrack[trID] = ROOT.genfit.Track(rep, seedState, seedCov)
      print
      print
      fitTrack[trID].Print()

      self.__fitter.Print()

      ROOT.SetOwnership(fitTrack[trID], False)
      
      if(self.__debug>2): print "preparing measurements for track ID", trID
      self.__prepareWireMeasurements(trID, fitTrack[trID])
      if not fitTrack[trID].checkConsistency():
	      print 'Problem with track before fit, not consistent',self.fitTrack[atrack]
	      continue
      try:  self.__fitter.processTrack(fitTrack[trID]) # processTrackWithRep(fitTrack[atrack],rep,True)
      except: 
	      print "genfit failed to fit track"
	      continue
      if not fitTrack[trID].checkConsistency():
	      print 'Problem with track after fit, not consistent',self.fitTrack[atrack]
	      continue

      stat = fitTrack[trID].getFitStatus()
      if not stat.isFitConverged() : continue
      f = fitTrack[trID].getFittedState()
      
      #if(self.__debug>0):
      #print "for track ", trID,
      #print "  pos:", "  ".join("{:10.4f}".format(f.getPos()(ii)) for ii in range(0,3)),  
      #print "  mom:", "  ".join("{:10.4f}".format(f.getMom()(ii)) for ii in range(0,3))
      self.__reFitTracks.addNewTrack(trID, f.getPos(), f.getDir(), f.getMomMag(), 
				     stat.getNdf(), stat.getChi2())
      
    
    
    
    newFitTrIDs = self.__reFitTracks.getTrIDs()
    twoTracks   = ( len(newFitTrIDs)==2 )
    theStep     = 0
    self.__docaEval = []
    if (twoTracks) : 
      self.__reFitTracks.createVertex(newFitTrIDs[0], newFitTrIDs[1], flag=0) # original
      iniDoca = self.__reFitTracks.Doca
      iniY    = self.__reFitTracks.Vertex.Y()
    while ( theStep<RecoSettings.VertexExtrSteps and  twoTracks):
      flag        = 1
      newFitTrIDs = self.__reFitTracks.getTrIDs()
      #if(self.__debug>1):
      print "==>vertex ", theStep, "  ", self.__reFitTracks.Doca
      self.__reFitTracks.Vertex.Print()
      self.__docaEval.append(self.__reFitTracks.Doca)
      for tid in fitTrack :
	try:
	  state = fitTrack[tid].getFittedState()
	except:
	  print "can't get fittedState"
	  flag = -1
	vPosEx = ROOT.TVector3(0,0,0)
	vMomEx = ROOT.TVector3(0,0,0)
	try :
	  state.extrapolateToPoint(self.__reFitTracks.Vertex)
	except :
	  flag = -1
	  print "track exctrapolation failed!tid: ", tid
	if (flag > 0 ) : # 
	  status = fitTrack[tid].getFitStatus()
	  #print "extr track ", tid,
	  #print "  pos:", "  ".join("{:10.4f}".format(state.getPos()(ii)) for ii in range(0,3)),  
	  #print "  mom:", "  ".join("{:10.4f}".format(state.getMom()(ii)) for ii in range(0,3))
	  self.__reFitTracks.addNewTrack(trID, state.getPos(), state.getDir(), state.getMomMag(), 
					status.getNdf(), status.getChi2(), verb=False)
      # FIX temporary
      self.__reFitTracks.createVertex(newFitTrIDs[0], newFitTrIDs[1], flag)
      self.__reFitTracks.Vertex.SetY(iniY)
      theStep+=1
      twoTacks = ( len(self.__reFitTracks.getTrIDs())==2 )
    return len(newFitTrIDs)