
#include "modules/MadGraphJetParticleSelector.h"


#include "ExRootAnalysis/ExRootResult.h"
#include "ExRootAnalysis/ExRootClasses.h"

#include "ExRootAnalysis/ExRootFilter.h"
#include "ExRootAnalysis/ExRootClassifier.h"

#include "ExRootAnalysis/ExRootFactory.h"
#include "ExRootAnalysis/ExRootCandidate.h"

#include "TMath.h"
#include "TString.h"
#include "TLorentzVector.h"
#include "TClonesArray.h"

#include <iostream>
#include <set>

using namespace std;


//------------------------------------------------------------------------------

class MadGraphJetParticleClassifier : public ExRootClassifier
{
public:

  MadGraphJetParticleClassifier(TClonesArray *branch);

  Int_t GetCategory(TObject *object);

  void SetEtaMax(Double_t eta);
  void InsertSpecialParticleID(Int_t pid);
  void InsertExcludedAncestorID(Int_t pid);
  void InsertExcludedParticleID(Int_t pid);
  void SetHadronizationInfo(Bool_t info);

private:

  Bool_t hasBadAncestor(ExRootGenParticle *object);

  Double_t fEtaMax;

  TClonesArray *fBranchParticle;

  set< Int_t > fSpecialParticleIDSet;
  set< Int_t > fExcludedAncestorIDSet;
  set< Int_t > fExcludedParticleIDSet;
};

//------------------------------------------------------------------------------

MadGraphJetParticleClassifier::MadGraphJetParticleClassifier(TClonesArray *branch) :
  fBranchParticle(branch)
{
}

//------------------------------------------------------------------------------

void MadGraphJetParticleClassifier::SetEtaMax(Double_t eta)
{
  fEtaMax = eta;
}

//------------------------------------------------------------------------------

void MadGraphJetParticleClassifier::InsertSpecialParticleID(Int_t pid)
{
  fSpecialParticleIDSet.insert(pid);
}

//------------------------------------------------------------------------------

void MadGraphJetParticleClassifier::InsertExcludedAncestorID(Int_t pid)
{
  fExcludedAncestorIDSet.insert(pid);
}

//------------------------------------------------------------------------------

void MadGraphJetParticleClassifier::InsertExcludedParticleID(Int_t pid)
{
  fSpecialParticleIDSet.insert(pid);
}

//------------------------------------------------------------------------------

Bool_t MadGraphJetParticleClassifier::hasBadAncestor(ExRootGenParticle *object)
{
  const int kMaxAncestors = 10;
  Int_t i, pidAbs;
  ExRootGenParticle *particle = object;
  set< Int_t >::const_iterator itAncestorIDSet;
  set< Int_t >::const_iterator itParticleIDSet;

  for(i = 0; i < kMaxAncestors; ++i)
  {
    if(particle->M1 < 0) return kFALSE;

    particle = static_cast<ExRootGenParticle*>(fBranchParticle->At(particle->M1));

    pidAbs = TMath::Abs(particle->PID);

    // skip particles with pid included in list
    itAncestorIDSet = fExcludedAncestorIDSet.find(pidAbs);

    if(itAncestorIDSet != fExcludedAncestorIDSet.end()) return kTRUE;
  }

  return kFALSE;
}

//------------------------------------------------------------------------------

Int_t MadGraphJetParticleClassifier::GetCategory(TObject *object)
{
  ExRootGenParticle *particle = static_cast<ExRootGenParticle*>(object);

  set< Int_t >::const_iterator itParticleIDSet;

  Int_t pidAbs = TMath::Abs(particle->PID);
  Double_t etaAbs = TMath::Abs(particle->Eta);

  // skip beam particles and initial state partons
  if(particle->M1 < 2) return -1;

  if(etaAbs > fEtaMax) return -1;

  if(particle->Status != 1) return -1;

  // skip particles with pid included in list
  itParticleIDSet = fExcludedParticleIDSet.find(pidAbs);

  if(itParticleIDSet != fExcludedParticleIDSet.end()) return -1;

  // check ancestors for particles with pid included in list
  itParticleIDSet = fSpecialParticleIDSet.find(pidAbs);

  if(itParticleIDSet != fSpecialParticleIDSet.end())
  {
    if(hasBadAncestor(particle)) return -1;
  }

  return 0;
}

//------------------------------------------------------------------------------

MadGraphJetParticleSelector::MadGraphJetParticleSelector() :
  fFilter(0), fClassifier(0)
{
}

//------------------------------------------------------------------------------

MadGraphJetParticleSelector::~MadGraphJetParticleSelector()
{
}

//------------------------------------------------------------------------------

void MadGraphJetParticleSelector::Init()
{
  ExRootConfParam param;

  Int_t i, pid, sizeParam;

  // import ROOT tree branch

  fBranchParticle = UseBranch("GenParticle");

  // create classifier and filter

  fClassifier = new MadGraphJetParticleClassifier(fBranchParticle);
  fFilter = new ExRootFilter(fBranchParticle);

  fEtaMax = GetDouble("EtaMax", 5.0);
  fClassifier->SetEtaMax(fEtaMax);

  // read particle IDs from configuration file and setup classifier

  param = GetParam("SpecialParticleIDs");
  sizeParam = param.GetSize();

  for(i = 0; i < sizeParam; ++i)
  {
    pid = param[i].GetInt();
    fClassifier->InsertSpecialParticleID(pid);
  }

  // read ancestor IDs from configuration file and setup classifier

  param = GetParam("ExcludedAncestorIDs");
  sizeParam = param.GetSize();

  for(i = 0; i < sizeParam; ++i)
  {
    pid = param[i].GetInt();
    fClassifier->InsertExcludedAncestorID(pid);
  }

  // read particle IDs from configuration file and setup classifier

  param = GetParam("ExcludedParticleIDs");
  sizeParam = param.GetSize();

  for(i = 0; i < sizeParam; ++i)
  {
    pid = param[i].GetInt();
    fClassifier->InsertExcludedParticleID(pid);
  }

  // create output arrays

  fOutputArray = ExportArray("candidates");

}

//------------------------------------------------------------------------------

void MadGraphJetParticleSelector::Finish()
{
  if(fFilter) delete fFilter;
  if(fClassifier) delete fClassifier;
}

//------------------------------------------------------------------------------

void MadGraphJetParticleSelector::Process()
{
  TObjArray *array = 0;
  ExRootGenParticle *particle = 0;
  ExRootCandidate *candidate = 0;
  ExRootFactory *factory = GetFactory();

  TLorentzVector momentum;

  fFilter->Reset();
  array = fFilter->GetSubArray(fClassifier, 0);

  if(array == 0) return;

  TIter itArray(array);

  while(particle = static_cast<ExRootGenParticle*>(itArray.Next()))
  {
    momentum.SetPxPyPzE(particle->Px, particle->Py, particle->Pz, particle->E);

    candidate = factory->NewCandidate();

    candidate->SetP4(momentum);
    candidate->SetType(particle->PID);

    fOutputArray->Add(candidate);
  }

/*
  cout << "==============================" << endl;
  Int_t indexParticle = -1;
  itArray.Reset();
  while(particle = static_cast<ExRootGenParticle*>(itArray.Next()))
  {
    ++indexParticle;
    cout << "--->\t" << particle->Status << "\t" << particle->PID << "\t";
    cout << particle->M1 << "\t" << particle->M2 << "\t";
    cout << particle->Px << "\t" << particle->Py << "\t" << particle->Pz << endl;
  }
*/
}

//------------------------------------------------------------------------------
