
#include "modules/MadGraphMatchingTreeWriter.h"


#include "ExRootAnalysis/ExRootResult.h"
#include "ExRootAnalysis/ExRootClasses.h"
#include "ExRootAnalysis/ExRootTreeBranch.h"

#include "ExRootAnalysis/ExRootCandidate.h"

#include "TClonesArray.h"

#include "TH1.h"
#include "TH2.h"
#include "TString.h"
#include "TCanvas.h"
#include "TLorentzVector.h"

#include <iostream>

using namespace std;

//------------------------------------------------------------------------------

MadGraphMatchingTreeWriter::MadGraphMatchingTreeWriter()
{
}

//------------------------------------------------------------------------------

MadGraphMatchingTreeWriter::~MadGraphMatchingTreeWriter()
{
}

//------------------------------------------------------------------------------

void MadGraphMatchingTreeWriter::Init()
{
  fJetPTMin = GetDouble("JetPTMin", 20.0);
  fJetEtaMax = GetDouble("JetEtaMax", 4.5);

  fClassMap[ExRootGenParticle::Class()] = &MadGraphMatchingTreeWriter::ProcessPartons;

  fClassMap[ExRootMatching::Class()] = &MadGraphMatchingTreeWriter::ProcessMatching;

  fClassMap[ExRootGenJet::Class()] = &MadGraphMatchingTreeWriter::ProcessJets;

  TBranchMap::iterator itBranchMap;
  map< TClass *, TProcessMethod >::iterator itClassMap;

  // read branch configuration and
  // import array with output from filter/classifier/jetfinder modules

  ExRootConfParam param = GetParam("Branch");
  Long_t i, size;
  TString branchName, branchClassName, branchInputArray;
  TClass *branchClass;
  const TObjArray *array;
  ExRootTreeBranch *branch;

  size = param.GetSize();
  for(i = 0; i < size; ++i)
  {
    branchName = param[i][0].GetString();
    branchClassName = param[i][1].GetString();
    branchInputArray = param[i][2].GetString();

    branchClass = gROOT->GetClass(branchClassName);

    if(!branchClass)
    {
      cout << "** ERROR: cannot find class '" << branchClassName << "'" << endl;
      continue;
    }

    itClassMap = fClassMap.find(branchClass);
    if(itClassMap == fClassMap.end())
    {
      cout << "** ERROR: cannot create branch for class '" << branchClassName << "'" << endl;
      continue;
    }

    array = ImportArray(branchInputArray);
    branch = NewBranch(branchName, branchClass);

    fBranchMap.insert(make_pair(branch, make_pair(itClassMap->second, array->MakeIterator())));
  }

}

//------------------------------------------------------------------------------

void MadGraphMatchingTreeWriter::Finish()
{
  TBranchMap::iterator itBranchMap;
  TIterator *iterator;

  for(itBranchMap = fBranchMap.begin(); itBranchMap != fBranchMap.end(); ++itBranchMap)
  {
    iterator = itBranchMap->second.second;
    if(iterator) delete iterator;
 }
}

//------------------------------------------------------------------------------

void MadGraphMatchingTreeWriter::ProcessPartons(ExRootTreeBranch *branch, TIterator *iterator)
{
  ExRootCandidate *candidate = 0;
  ExRootGenParticle *entry = 0;
  Double_t pt, signPz, eta, rapidity;

  // loop over all partons
  iterator->Reset();
  while((candidate = static_cast<ExRootCandidate*>(iterator->Next())))
  {
    const TLorentzVector &momentum = candidate->GetP4();

    entry = static_cast<ExRootGenParticle*>(branch->NewEntry());

    pt = momentum.Pt();
    signPz = (momentum.Pz() >= 0.0) ? 1.0 : -1.0;
    eta = (pt == 0.0 ? signPz*999.9 : momentum.Eta());
    rapidity = (pt == 0.0 ? signPz*999.9 : momentum.Rapidity());

    entry->PID = candidate->GetType()->PdgCode();

    entry->E = momentum.E();
    entry->Px = momentum.Px();
    entry->Py = momentum.Py();
    entry->Pz = momentum.Pz();

    entry->Eta = eta;
    entry->Phi = momentum.Phi();
    entry->PT = pt;

    entry->Rapidity = rapidity;
  }
}

//------------------------------------------------------------------------------

void MadGraphMatchingTreeWriter::ProcessMatching(ExRootTreeBranch *branch, TIterator *iterator)
{
  ExRootMatching *matching = 0, *entry = 0;

  // loop over all matching
  iterator->Reset();
  while((matching = static_cast<ExRootMatching*>(iterator->Next())))
  {
    entry = static_cast<ExRootMatching*>(branch->NewEntry());

    entry->DMerge = matching->DMerge;
    entry->YMerge = matching->YMerge;
  }
}

//------------------------------------------------------------------------------

void MadGraphMatchingTreeWriter::ProcessJets(ExRootTreeBranch *branch, TIterator *iterator)
{
  ExRootCandidate *candidate = 0;
  ExRootGenJet *entry = 0;
  Double_t pt, signPz, eta, rapidity;

  // loop over all jets
  iterator->Reset();
  while((candidate = static_cast<ExRootCandidate*>(iterator->Next())))
  {
    const TLorentzVector &momentum = candidate->GetP4();

    pt = momentum.Pt();
    signPz = (momentum.Pz() >= 0.0) ? 1.0 : -1.0;
    eta = (pt == 0.0 ? signPz*999.9 : momentum.Eta());
    rapidity = (pt == 0.0 ? signPz*999.9 : momentum.Rapidity());

    if(pt < fJetPTMin) continue;
    if(TMath::Abs(eta) > fJetEtaMax) continue;

    entry = static_cast<ExRootGenJet*>(branch->NewEntry());

    entry->E = momentum.E();
    entry->Px = momentum.Px();
    entry->Py = momentum.Py();
    entry->Pz = momentum.Pz();

    entry->Eta = eta;
    entry->Phi = momentum.Phi();
    entry->PT = pt;

    entry->Rapidity = rapidity;

    entry->Mass = momentum.M();
  }
}

//------------------------------------------------------------------------------

void MadGraphMatchingTreeWriter::Process()
{

  TBranchMap::iterator itBranchMap;
  ExRootTreeBranch *branch;
  TProcessMethod method;
  TIterator *iterator;

  for(itBranchMap = fBranchMap.begin(); itBranchMap != fBranchMap.end(); ++itBranchMap)
  {
    branch = itBranchMap->first;
    method = itBranchMap->second.first;
    iterator = itBranchMap->second.second;

    (this->*method)(branch, iterator);
  }

}

//------------------------------------------------------------------------------

