1 | from inspect import getargspec
|
---|
2 | from ROOT import TLorentzVector,TVector3,TChain,TClass,TDatabasePDG
|
---|
3 | from datetime import datetime
|
---|
4 | from collections import Iterable
|
---|
5 | from types import StringTypes
|
---|
6 | from os import path
|
---|
7 |
|
---|
8 | class AnalysisEvent(TChain):
|
---|
9 | """A class that complements fwlite::Events with analysis facilities.
|
---|
10 | The class provides the following additional functionalities:
|
---|
11 | 1. instrumentation for event weight
|
---|
12 | A set of weight classes can be defined, and the event weight
|
---|
13 | is computed and cached using those.
|
---|
14 | 2. list of event products used in the analysis
|
---|
15 | It makes the iteration faster by only enabling required branches.
|
---|
16 | 3. a list of "producers" of analysis high-level quantities
|
---|
17 | It allows to run "analysis on demand", by automatically running
|
---|
18 | the defined producers to fill the cache, and later use that one.
|
---|
19 | 4. a volatile dictionary
|
---|
20 | It allows to use the event as an heterogenous container for
|
---|
21 | any analysis product. The event is properly reset when iterating
|
---|
22 | to the next event.
|
---|
23 | """
|
---|
24 |
|
---|
25 | def __init__(self, inputFiles = '', maxEvents=0):
|
---|
26 | """Initialize the AnalysisEvent like a standard Event, plus additional features."""
|
---|
27 | # initialization of base functionalities
|
---|
28 | TChain.__init__(self,"Delphes","Delphes")
|
---|
29 | if isinstance(inputFiles,Iterable) and not isinstance(inputFiles,StringTypes):
|
---|
30 | for thefile in inputFiles:
|
---|
31 | if path.isfile(thefile):
|
---|
32 | self.AddFile(thefile)
|
---|
33 | else:
|
---|
34 | print "Warning: ",thefile," do not exist."
|
---|
35 | elif isinstance(inputFiles,StringTypes):
|
---|
36 | thefile = inputFiles
|
---|
37 | if path.isfile(thefile):
|
---|
38 | self.AddFile(thefile)
|
---|
39 | else:
|
---|
40 | print "Warning: ",thefile," do not exist."
|
---|
41 | else:
|
---|
42 | print "Warning: invalid inputFiles"
|
---|
43 | self.BuildIndex("Event[0].Number")
|
---|
44 | self.SetBranchStatus("*",0)
|
---|
45 | self._eventCounts = 0
|
---|
46 | self._maxEvents = maxEvents
|
---|
47 | # additional features:
|
---|
48 | # 1. instrumentation for event weight
|
---|
49 | self._weightCache = {}
|
---|
50 | self._weightEngines = {}
|
---|
51 | # 2. a list of event products used in the analysis
|
---|
52 | self._collections = {}
|
---|
53 | self._branches = dict((b,False) for b in map(lambda b:b.GetName(),self.GetListOfBranches()))
|
---|
54 | # 3. a list of "producers" of analysis high-level quantities
|
---|
55 | self._producers = {}
|
---|
56 | # 4. volatile dictionary. User can add any quantity to the event and it will be
|
---|
57 | # properly erased in the iteration step.
|
---|
58 | self.__dict__["vardict"] = {}
|
---|
59 |
|
---|
60 | def addWeight(self, name, weightClass):
|
---|
61 | """Declare a new class (engine) to compute the weights.
|
---|
62 | weightClass must have a weight() method returning a float."""
|
---|
63 | if name in self._weightEngines:
|
---|
64 | raise KeyError("%s weight engine is already declared" % name)
|
---|
65 | self._weightEngines[name] = weightClass
|
---|
66 | self._weightCache.clear()
|
---|
67 |
|
---|
68 | def delWeight(self, name):
|
---|
69 | """Remove one weight engine from the internal list."""
|
---|
70 | # just to clean the dictionnary
|
---|
71 | del self._weightEngines[name]
|
---|
72 | self._weightCache.clear()
|
---|
73 |
|
---|
74 | def weight(self, weightList=None, **kwargs):
|
---|
75 | """Return the event weight. Arguments:
|
---|
76 | * weightList is the list of engines to use, as a list of strings.
|
---|
77 | Default: all defined engines.
|
---|
78 | * the other named arguments are forwarded to the engines.
|
---|
79 | The output is the product of the selected individual weights."""
|
---|
80 | # first check in the cache if the result is there already
|
---|
81 | if weightList is None:
|
---|
82 | weightList=self._weightEngines.keys()
|
---|
83 | kwargs["weightList"] = weightList
|
---|
84 | # compute the weight or use the cached value
|
---|
85 | myhash = self._dicthash(kwargs)
|
---|
86 | if not myhash in self._weightCache :
|
---|
87 | w = 1.
|
---|
88 | for weightElement in weightList:
|
---|
89 | engine = self._weightEngines[weightElement]
|
---|
90 | engineArgs = getargspec(engine.weight).args
|
---|
91 | subargs = dict((k,v) for k,v in kwargs.iteritems() if k in engineArgs)
|
---|
92 | w *= self._weightCache.setdefault("weightElement:%s # %s" %(weightElement,self._dicthash(subargs)),engine.weight(self,**subargs))
|
---|
93 | self._weightCache[myhash] = w
|
---|
94 | return self._weightCache[myhash]
|
---|
95 |
|
---|
96 | def addCollection(self, name, inputTag):
|
---|
97 | """Register an event collection as used by the analysis.
|
---|
98 | Example: addCollection("myjets","jets")
|
---|
99 | Note that the direct access to the branch is still possible but unsafe."""
|
---|
100 | if name in self._collections:
|
---|
101 | raise KeyError("%r collection is already declared", name)
|
---|
102 | if name in self._producers:
|
---|
103 | raise KeyError("%r is already declared as a producer", name)
|
---|
104 | if hasattr(self,name):
|
---|
105 | raise AttributeError("%r object already has attribute %r" % (type(self).__name__, name))
|
---|
106 | if inputTag not in self._branches:
|
---|
107 | raise AttributeError("%r object has no branch %r" % (type(self).__name__, inputTag))
|
---|
108 | self._collections[name] = inputTag
|
---|
109 | self.SetBranchStatus(inputTag+"*",1)
|
---|
110 | self._branches[inputTag] = True
|
---|
111 |
|
---|
112 | def removeCollection(self,name):
|
---|
113 | """Forget about the named event collection.
|
---|
114 | This method will delete both the product from the cache (if any) and the definition.
|
---|
115 | To simply clear the cache, use "del event.name" instead. """
|
---|
116 | self.SetBranchStatus(self._collections[name]+"*",0)
|
---|
117 | self._branches[self._collections[name]] = False
|
---|
118 | del self._collections[name]
|
---|
119 | if name in self.vardict:
|
---|
120 | delattr(self,name)
|
---|
121 |
|
---|
122 | def getCollection(self,name):
|
---|
123 | """Retrieve the event product or return the cached collection.
|
---|
124 | Note that the prefered way to get the collection is instead to access the "event.name" attribute."""
|
---|
125 | if not name in self._collections:
|
---|
126 | raise AttributeError("%r object has no attribute %r" % (type(self).__name__,name))
|
---|
127 | if not name in self.vardict:
|
---|
128 | self.vardict[name] = TChain.__getattr__(self,self._collections[name])
|
---|
129 | return getattr(self,name)
|
---|
130 |
|
---|
131 | def addProducer(self,name,producer,**kwargs):
|
---|
132 | """Register a producer to create new high-level analysis objects."""
|
---|
133 | # sanity checks
|
---|
134 | if name in self._producers:
|
---|
135 | raise KeyError("%r producer is already declared", name)
|
---|
136 | if name in self._collections:
|
---|
137 | raise KeyError("%r is already declared as a collection", name)
|
---|
138 | if hasattr(self,name):
|
---|
139 | raise AttributeError("%r object already has attribute %r" % (type(self).__name__, attr))
|
---|
140 | # remove name and producer from kwargs
|
---|
141 | if "name" in kwargs: del kwargs["name"]
|
---|
142 | if "producer" in kwargs: del kwargs["producer"]
|
---|
143 | # store
|
---|
144 | self._producers[name] = (producer,kwargs)
|
---|
145 |
|
---|
146 | def removeProducer(self,name):
|
---|
147 | """Forget about the producer.
|
---|
148 | This method will delete both the product from the cache (if any) and the producer.
|
---|
149 | To simply clear the cache, use "del event.name" instead."""
|
---|
150 | del self._producers[name]
|
---|
151 | if name in self.vardict:
|
---|
152 | delattr(self,name)
|
---|
153 |
|
---|
154 | def event(self):
|
---|
155 | """Event number"""
|
---|
156 | if self._branches["Event"]:
|
---|
157 | return self.Event.At(0).Number
|
---|
158 | else:
|
---|
159 | return 0
|
---|
160 |
|
---|
161 | def to(self,event):
|
---|
162 | """Jump to some event"""
|
---|
163 | self.GetEntryWithIndex(event)
|
---|
164 |
|
---|
165 | def __getitem__(self,index):
|
---|
166 | """Jump to some event"""
|
---|
167 | self.GetEntryWithIndex(index)
|
---|
168 | return self
|
---|
169 |
|
---|
170 | def __iter__ (self):
|
---|
171 | """Iterator"""
|
---|
172 | self._eventCounts = 0
|
---|
173 | while self.GetEntry(self._eventCounts):
|
---|
174 | self.vardict.clear()
|
---|
175 | self._weightCache.clear()
|
---|
176 | yield self
|
---|
177 | self._eventCounts += 1
|
---|
178 | if self._maxEvents > 0 and self._eventCounts >= self._maxEvents:
|
---|
179 | break
|
---|
180 |
|
---|
181 | def __getattr__(self, attr):
|
---|
182 | """Overloaded getter to handle properly:
|
---|
183 | - volatile analysis objects
|
---|
184 | - event collections
|
---|
185 | - data producers"""
|
---|
186 | if attr in self.__dict__["vardict"]:
|
---|
187 | return self.vardict[attr]
|
---|
188 | if attr in self._collections:
|
---|
189 | return self.vardict.setdefault(attr, TChain.__getattr__(self,self._collections[attr]))
|
---|
190 | if attr in self._producers:
|
---|
191 | return self.vardict.setdefault(attr, self._producers[attr][0](self, **self._producers[attr][1]))
|
---|
192 | return TChain.__getattr__(self,attr)
|
---|
193 |
|
---|
194 | def __setattr__(self, name, value):
|
---|
195 | """Overloaded setter that puts any new attribute in the volatile dict."""
|
---|
196 | if name in self.__dict__ or not "vardict" in self.__dict__ or name[0]=='_':
|
---|
197 | self.__dict__[name] = value
|
---|
198 | else:
|
---|
199 | if name in self._collections or name in self._producers:
|
---|
200 | raise AttributeError("%r object %r attribute is read-only (event collection)" % (type(self).__name__, name))
|
---|
201 | self.vardict[name] = value
|
---|
202 |
|
---|
203 | def __delattr__(self, name):
|
---|
204 | """Overloaded del method to handle the volatile internal dictionary."""
|
---|
205 | if name=="vardict":
|
---|
206 | raise AttributeError("%r object has no attribute %r" % (type(self).__name__, name))
|
---|
207 | if name in self.__dict__:
|
---|
208 | del self.__dict__[name]
|
---|
209 | elif name in self.vardict:
|
---|
210 | del self.vardict[name]
|
---|
211 | else:
|
---|
212 | raise AttributeError("%r object has no attribute %r" % (type(self).__name__, name))
|
---|
213 |
|
---|
214 | def _dicthash(self,dict):
|
---|
215 | return (lambda d,j='=',s=';': s.join([j.join((str(k),str(v))) for k,v in d.iteritems()]))(dict)
|
---|
216 |
|
---|
217 | def __str__(self):
|
---|
218 | """Event text dump."""
|
---|
219 | dictjoin = lambda d,j=' => ',s='\n': s.join([j.join((str(k),str(v))) for k,v in d.iteritems()])
|
---|
220 | mystring = "=================================================================\n"
|
---|
221 | # general information
|
---|
222 | if self._branches["Event"]:
|
---|
223 | mystring += str(self.Event.At(0))
|
---|
224 | else:
|
---|
225 | mystring += "Event %d\n" % self.GetReadEvent()
|
---|
226 | mystring += "-----------------------------------------------------------------\n"
|
---|
227 | # weights
|
---|
228 | if len(self._weightCache)==0:
|
---|
229 | mystring += "No weight computed so far. Default weight is %f.\n" % self.weight()
|
---|
230 | else:
|
---|
231 | mystring += "Weights:\n"
|
---|
232 | mystring += dictjoin(self._weightCache)
|
---|
233 | mystring += "\n-----------------------------------------------------------------\n"
|
---|
234 | # list the collections
|
---|
235 | mystring += "Collections:\n"
|
---|
236 | for colname in self._collections.keys():
|
---|
237 | collection = self.getCollection(colname)
|
---|
238 | if collection.GetEntries()>0:
|
---|
239 | if collection.At(0).IsA()==TClass.GetClass("HepMCEvent"):
|
---|
240 | pass
|
---|
241 | else:
|
---|
242 | mystring += "*** %s has %d element(s)\n" % (colname,collection.GetEntries())
|
---|
243 | mystring += reduce(lambda a,b: a+b,map(str,collection))
|
---|
244 | mystring += "\n-----------------------------------------------------------------\n"
|
---|
245 | # list the registered producers
|
---|
246 | mystring += "Producers:\n"
|
---|
247 | mystring += dictjoin(self._producers)
|
---|
248 | mystring += "\n-----------------------------------------------------------------\n"
|
---|
249 | # list the content of vardict, excluding collections
|
---|
250 | mystring += "Content of the cache:\n"
|
---|
251 | for k, v in self.vardict.iteritems():
|
---|
252 | if k in self._collections.keys() : continue
|
---|
253 | if isinstance(v,Iterable) and not isinstance(v,StringTypes):
|
---|
254 | try:
|
---|
255 | thisstring = "%s => vector of %d objects(s)\n" % (k,len(v))
|
---|
256 | except:
|
---|
257 | mystring += "%s => %s\n"%(k,str(v))
|
---|
258 | else:
|
---|
259 | try:
|
---|
260 | for it,vec in enumerate(v):
|
---|
261 | thisstring += "%s[%d] = %s\n"%(k,it,str(vec))
|
---|
262 | except:
|
---|
263 | mystring += "%s => %s\n"%(k,str(v))
|
---|
264 | else:
|
---|
265 | mystring += thisstring
|
---|
266 | else:
|
---|
267 | mystring += "%s => %s\n"%(k,str(v))
|
---|
268 | return mystring
|
---|
269 |
|
---|
270 | def decayTree(self, genparticles):
|
---|
271 | db = TDatabasePDG()
|
---|
272 | theString = ""
|
---|
273 | for part in genparticles:
|
---|
274 | if part.M1==-1 and part.M2==-1:
|
---|
275 | theString += part.printDecay(db, genparticles)
|
---|
276 | return theString
|
---|
277 |
|
---|