Fork me on GitHub

source: svn/trunk/python/DelphesAnalysis/AnalysisEvent.py

Last change on this file was 1112, checked in by Pavel Demin, 12 years ago

add DelphesAnalysis

File size: 11.5 KB
Line 
1from inspect import getargspec
2from ROOT import TLorentzVector,TVector3,TChain,TClass,TDatabasePDG
3from datetime import datetime
4from collections import Iterable
5from types import StringTypes
6from os import path
7
8class AnalysisEvent(TChain):
9 """A class that complements fwlite::Events with analysis facilities.
10 The class provides the following additional functionalities:
11 1. instrumentation for event weight
12 A set of weight classes can be defined, and the event weight
13 is computed and cached using those.
14 2. list of event products used in the analysis
15 It makes the iteration faster by only enabling required branches.
16 3. a list of "producers" of analysis high-level quantities
17 It allows to run "analysis on demand", by automatically running
18 the defined producers to fill the cache, and later use that one.
19 4. a volatile dictionary
20 It allows to use the event as an heterogenous container for
21 any analysis product. The event is properly reset when iterating
22 to the next event.
23 """
24
25 def __init__(self, inputFiles = '', maxEvents=0):
26 """Initialize the AnalysisEvent like a standard Event, plus additional features."""
27 # initialization of base functionalities
28 TChain.__init__(self,"Delphes","Delphes")
29 if isinstance(inputFiles,Iterable) and not isinstance(inputFiles,StringTypes):
30 for thefile in inputFiles:
31 if path.isfile(thefile):
32 self.AddFile(thefile)
33 else:
34 print "Warning: ",thefile," do not exist."
35 elif isinstance(inputFiles,StringTypes):
36 thefile = inputFiles
37 if path.isfile(thefile):
38 self.AddFile(thefile)
39 else:
40 print "Warning: ",thefile," do not exist."
41 else:
42 print "Warning: invalid inputFiles"
43 self.BuildIndex("Event[0].Number")
44 self.SetBranchStatus("*",0)
45 self._eventCounts = 0
46 self._maxEvents = maxEvents
47 # additional features:
48 # 1. instrumentation for event weight
49 self._weightCache = {}
50 self._weightEngines = {}
51 # 2. a list of event products used in the analysis
52 self._collections = {}
53 self._branches = dict((b,False) for b in map(lambda b:b.GetName(),self.GetListOfBranches()))
54 # 3. a list of "producers" of analysis high-level quantities
55 self._producers = {}
56 # 4. volatile dictionary. User can add any quantity to the event and it will be
57 # properly erased in the iteration step.
58 self.__dict__["vardict"] = {}
59
60 def addWeight(self, name, weightClass):
61 """Declare a new class (engine) to compute the weights.
62 weightClass must have a weight() method returning a float."""
63 if name in self._weightEngines:
64 raise KeyError("%s weight engine is already declared" % name)
65 self._weightEngines[name] = weightClass
66 self._weightCache.clear()
67
68 def delWeight(self, name):
69 """Remove one weight engine from the internal list."""
70 # just to clean the dictionnary
71 del self._weightEngines[name]
72 self._weightCache.clear()
73
74 def weight(self, weightList=None, **kwargs):
75 """Return the event weight. Arguments:
76 * weightList is the list of engines to use, as a list of strings.
77 Default: all defined engines.
78 * the other named arguments are forwarded to the engines.
79 The output is the product of the selected individual weights."""
80 # first check in the cache if the result is there already
81 if weightList is None:
82 weightList=self._weightEngines.keys()
83 kwargs["weightList"] = weightList
84 # compute the weight or use the cached value
85 myhash = self._dicthash(kwargs)
86 if not myhash in self._weightCache :
87 w = 1.
88 for weightElement in weightList:
89 engine = self._weightEngines[weightElement]
90 engineArgs = getargspec(engine.weight).args
91 subargs = dict((k,v) for k,v in kwargs.iteritems() if k in engineArgs)
92 w *= self._weightCache.setdefault("weightElement:%s # %s" %(weightElement,self._dicthash(subargs)),engine.weight(self,**subargs))
93 self._weightCache[myhash] = w
94 return self._weightCache[myhash]
95
96 def addCollection(self, name, inputTag):
97 """Register an event collection as used by the analysis.
98 Example: addCollection("myjets","jets")
99 Note that the direct access to the branch is still possible but unsafe."""
100 if name in self._collections:
101 raise KeyError("%r collection is already declared", name)
102 if name in self._producers:
103 raise KeyError("%r is already declared as a producer", name)
104 if hasattr(self,name):
105 raise AttributeError("%r object already has attribute %r" % (type(self).__name__, name))
106 if inputTag not in self._branches:
107 raise AttributeError("%r object has no branch %r" % (type(self).__name__, inputTag))
108 self._collections[name] = inputTag
109 self.SetBranchStatus(inputTag+"*",1)
110 self._branches[inputTag] = True
111
112 def removeCollection(self,name):
113 """Forget about the named event collection.
114 This method will delete both the product from the cache (if any) and the definition.
115 To simply clear the cache, use "del event.name" instead. """
116 self.SetBranchStatus(self._collections[name]+"*",0)
117 self._branches[self._collections[name]] = False
118 del self._collections[name]
119 if name in self.vardict:
120 delattr(self,name)
121
122 def getCollection(self,name):
123 """Retrieve the event product or return the cached collection.
124 Note that the prefered way to get the collection is instead to access the "event.name" attribute."""
125 if not name in self._collections:
126 raise AttributeError("%r object has no attribute %r" % (type(self).__name__,name))
127 if not name in self.vardict:
128 self.vardict[name] = TChain.__getattr__(self,self._collections[name])
129 return getattr(self,name)
130
131 def addProducer(self,name,producer,**kwargs):
132 """Register a producer to create new high-level analysis objects."""
133 # sanity checks
134 if name in self._producers:
135 raise KeyError("%r producer is already declared", name)
136 if name in self._collections:
137 raise KeyError("%r is already declared as a collection", name)
138 if hasattr(self,name):
139 raise AttributeError("%r object already has attribute %r" % (type(self).__name__, attr))
140 # remove name and producer from kwargs
141 if "name" in kwargs: del kwargs["name"]
142 if "producer" in kwargs: del kwargs["producer"]
143 # store
144 self._producers[name] = (producer,kwargs)
145
146 def removeProducer(self,name):
147 """Forget about the producer.
148 This method will delete both the product from the cache (if any) and the producer.
149 To simply clear the cache, use "del event.name" instead."""
150 del self._producers[name]
151 if name in self.vardict:
152 delattr(self,name)
153
154 def event(self):
155 """Event number"""
156 if self._branches["Event"]:
157 return self.Event.At(0).Number
158 else:
159 return 0
160
161 def to(self,event):
162 """Jump to some event"""
163 self.GetEntryWithIndex(event)
164
165 def __getitem__(self,index):
166 """Jump to some event"""
167 self.GetEntryWithIndex(index)
168 return self
169
170 def __iter__ (self):
171 """Iterator"""
172 self._eventCounts = 0
173 while self.GetEntry(self._eventCounts):
174 self.vardict.clear()
175 self._weightCache.clear()
176 yield self
177 self._eventCounts += 1
178 if self._maxEvents > 0 and self._eventCounts >= self._maxEvents:
179 break
180
181 def __getattr__(self, attr):
182 """Overloaded getter to handle properly:
183 - volatile analysis objects
184 - event collections
185 - data producers"""
186 if attr in self.__dict__["vardict"]:
187 return self.vardict[attr]
188 if attr in self._collections:
189 return self.vardict.setdefault(attr, TChain.__getattr__(self,self._collections[attr]))
190 if attr in self._producers:
191 return self.vardict.setdefault(attr, self._producers[attr][0](self, **self._producers[attr][1]))
192 return TChain.__getattr__(self,attr)
193
194 def __setattr__(self, name, value):
195 """Overloaded setter that puts any new attribute in the volatile dict."""
196 if name in self.__dict__ or not "vardict" in self.__dict__ or name[0]=='_':
197 self.__dict__[name] = value
198 else:
199 if name in self._collections or name in self._producers:
200 raise AttributeError("%r object %r attribute is read-only (event collection)" % (type(self).__name__, name))
201 self.vardict[name] = value
202
203 def __delattr__(self, name):
204 """Overloaded del method to handle the volatile internal dictionary."""
205 if name=="vardict":
206 raise AttributeError("%r object has no attribute %r" % (type(self).__name__, name))
207 if name in self.__dict__:
208 del self.__dict__[name]
209 elif name in self.vardict:
210 del self.vardict[name]
211 else:
212 raise AttributeError("%r object has no attribute %r" % (type(self).__name__, name))
213
214 def _dicthash(self,dict):
215 return (lambda d,j='=',s=';': s.join([j.join((str(k),str(v))) for k,v in d.iteritems()]))(dict)
216
217 def __str__(self):
218 """Event text dump."""
219 dictjoin = lambda d,j=' => ',s='\n': s.join([j.join((str(k),str(v))) for k,v in d.iteritems()])
220 mystring = "=================================================================\n"
221 # general information
222 if self._branches["Event"]:
223 mystring += str(self.Event.At(0))
224 else:
225 mystring += "Event %d\n" % self.GetReadEvent()
226 mystring += "-----------------------------------------------------------------\n"
227 # weights
228 if len(self._weightCache)==0:
229 mystring += "No weight computed so far. Default weight is %f.\n" % self.weight()
230 else:
231 mystring += "Weights:\n"
232 mystring += dictjoin(self._weightCache)
233 mystring += "\n-----------------------------------------------------------------\n"
234 # list the collections
235 mystring += "Collections:\n"
236 for colname in self._collections.keys():
237 collection = self.getCollection(colname)
238 if collection.GetEntries()>0:
239 if collection.At(0).IsA()==TClass.GetClass("HepMCEvent"):
240 pass
241 else:
242 mystring += "*** %s has %d element(s)\n" % (colname,collection.GetEntries())
243 mystring += reduce(lambda a,b: a+b,map(str,collection))
244 mystring += "\n-----------------------------------------------------------------\n"
245 # list the registered producers
246 mystring += "Producers:\n"
247 mystring += dictjoin(self._producers)
248 mystring += "\n-----------------------------------------------------------------\n"
249 # list the content of vardict, excluding collections
250 mystring += "Content of the cache:\n"
251 for k, v in self.vardict.iteritems():
252 if k in self._collections.keys() : continue
253 if isinstance(v,Iterable) and not isinstance(v,StringTypes):
254 try:
255 thisstring = "%s => vector of %d objects(s)\n" % (k,len(v))
256 except:
257 mystring += "%s => %s\n"%(k,str(v))
258 else:
259 try:
260 for it,vec in enumerate(v):
261 thisstring += "%s[%d] = %s\n"%(k,it,str(vec))
262 except:
263 mystring += "%s => %s\n"%(k,str(v))
264 else:
265 mystring += thisstring
266 else:
267 mystring += "%s => %s\n"%(k,str(v))
268 return mystring
269
270 def decayTree(self, genparticles):
271 db = TDatabasePDG()
272 theString = ""
273 for part in genparticles:
274 if part.M1==-1 and part.M2==-1:
275 theString += part.printDecay(db, genparticles)
276 return theString
277
Note: See TracBrowser for help on using the repository browser.