
#include "modules/MadGraphShowerPartonSelector.h"


#include "ExRootAnalysis/ExRootResult.h"
#include "ExRootAnalysis/ExRootClasses.h"

#include "ExRootAnalysis/ExRootFilter.h"
#include "ExRootAnalysis/ExRootClassifier.h"

#include "ExRootAnalysis/ExRootFactory.h"
#include "ExRootAnalysis/ExRootCandidate.h"

#include "TMath.h"
#include "TString.h"
#include "TLorentzVector.h"
#include "TClonesArray.h"

#include <iostream>
#include <set>

using namespace std;


//------------------------------------------------------------------------------

class MadGraphShowerPartonClassifier : public ExRootClassifier
{
public:

  MadGraphShowerPartonClassifier(TClonesArray *branch);

  Int_t GetCategory(TObject *object);

  void SetEtaMax(Double_t eta);
  void InsertParticleID(Int_t pid);
  void InsertExclAncestorID(Int_t pid);
  void InsertInclAncestorID(Int_t pid);
  void SetHadronizationInfo(Bool_t info);

private:

  Bool_t hasBadAncestor(ExRootGenParticle *object);
  Bool_t hasGoodAncestor(ExRootGenParticle *object);

  Double_t fEtaMax;

  TClonesArray *fBranchParticle;

  set< Int_t > fParticleIDSet;
  set< Int_t > fExclAncestorIDSet;
  set< Int_t > fInclAncestorIDSet;
};

//------------------------------------------------------------------------------

MadGraphShowerPartonClassifier::MadGraphShowerPartonClassifier(TClonesArray *branch) :
  fBranchParticle(branch)
{
}

//------------------------------------------------------------------------------

void MadGraphShowerPartonClassifier::SetEtaMax(Double_t eta)
{
  fEtaMax = eta;
}

//------------------------------------------------------------------------------

void MadGraphShowerPartonClassifier::InsertParticleID(Int_t pid)
{
  fParticleIDSet.insert(pid);
}

//------------------------------------------------------------------------------

void MadGraphShowerPartonClassifier::InsertExclAncestorID(Int_t pid)
{
  fExclAncestorIDSet.insert(pid);
}

//------------------------------------------------------------------------------

void MadGraphShowerPartonClassifier::InsertInclAncestorID(Int_t pid)
{
  fInclAncestorIDSet.insert(pid);
}

//------------------------------------------------------------------------------

Bool_t MadGraphShowerPartonClassifier::hasBadAncestor(ExRootGenParticle *object)
{
  const int kMaxAncestors = 10;
  Int_t i, pidAbs;
  ExRootGenParticle *particle = object;
  set< Int_t >::const_iterator itAncestorIDSet;

  for(i = 0; i < kMaxAncestors && particle->Status != 3; ++i)
  {
    if(particle->M1 < 0) return kFALSE;

    particle = static_cast<ExRootGenParticle*>(fBranchParticle->At(particle->M1));
  }

  if(particle->PID == 21) return kFALSE;

  // keep all particles if there is no pid in the list
  if(fExclAncestorIDSet.empty())
  {
    return kFALSE;
  }

  pidAbs = TMath::Abs(particle->PID);

  // skip particles with pid included in list
  itAncestorIDSet = fExclAncestorIDSet.find(pidAbs);

  if(itAncestorIDSet != fExclAncestorIDSet.end()) return kTRUE;
  if(particle->M2 > -1) return kFALSE;

  for(i = 0; i < kMaxAncestors; ++i)
  {
    if(particle->M1 < 0) return kFALSE;

    if(particle->PID == 21) return kFALSE;

    particle = static_cast<ExRootGenParticle*>(fBranchParticle->At(particle->M1));

    pidAbs = TMath::Abs(particle->PID);

    // skip particles with pid included in list
    itAncestorIDSet = fExclAncestorIDSet.find(pidAbs);

    if(itAncestorIDSet != fExclAncestorIDSet.end()) return kTRUE;
    if(particle->M2 > -1) return kFALSE;
  }

  return kFALSE;
}

//------------------------------------------------------------------------------

Bool_t MadGraphShowerPartonClassifier::hasGoodAncestor(ExRootGenParticle *object)
{
  const int kMaxAncestors = 10;
  Int_t i, pidAbs;
  ExRootGenParticle *particle = object;
  set< Int_t >::const_iterator itAncestorIDSet;

  for(i = 0; i < kMaxAncestors && particle->Status != 3; ++i)
  {
    if(particle->M1 < 0) return kFALSE;

    particle = static_cast<ExRootGenParticle*>(fBranchParticle->At(particle->M1));
  }

  if(particle->PID == 21) return kFALSE;

  // keep all particles if there is no pid in the list
  if(fInclAncestorIDSet.empty())
  {
    return kTRUE;
  }

  pidAbs = TMath::Abs(particle->PID);

  // keep particles with pid included in list
  itAncestorIDSet = fInclAncestorIDSet.find(pidAbs);

  if(itAncestorIDSet != fInclAncestorIDSet.end()) return kTRUE;
  if(particle->M2 > -1) return kFALSE;

  for(i = 0; i < kMaxAncestors; ++i)
  {
    if(particle->M1 < 0) return kFALSE;

    if(particle->PID == 21) return kFALSE;

    particle = static_cast<ExRootGenParticle*>(fBranchParticle->At(particle->M1));

    pidAbs = TMath::Abs(particle->PID);

    // keep particles with pid included in list
    itAncestorIDSet = fInclAncestorIDSet.find(pidAbs);

    if(itAncestorIDSet != fInclAncestorIDSet.end()) return kTRUE;
    if(particle->M2 > -1) return kFALSE;
  }

  return kFALSE;
}

//------------------------------------------------------------------------------

Int_t MadGraphShowerPartonClassifier::GetCategory(TObject *object)
{
  ExRootGenParticle *particle = static_cast<ExRootGenParticle*>(object);
  ExRootGenParticle *daughter;

  set< Int_t >::const_iterator itParticleIDSet;

  Int_t pidAbs = TMath::Abs(particle->PID);
  Double_t etaAbs = TMath::Abs(particle->Eta);

  // skip beam particles and initial state partons
  if(particle->M1 < 2) return -1;

  // skip particles with pid not included in list
  itParticleIDSet = fParticleIDSet.find(pidAbs);

  if(itParticleIDSet == fParticleIDSet.end() || etaAbs > fEtaMax) return -1;

  // with hadronization
  if(particle->Status == 2)
  {
    // skip particles if they do not form a string
    if(particle->D1 > -1)
    {
      daughter = static_cast<ExRootGenParticle*>(fBranchParticle->At(particle->D1));
      if(daughter->PID != 92) return -1;
    }
  }
  // without hadronization
  else if(particle->Status != 1) return -1;

  if(hasBadAncestor(particle)) return -1;

  if(!hasGoodAncestor(particle)) return -1;

  return 0;
}

//------------------------------------------------------------------------------

MadGraphShowerPartonSelector::MadGraphShowerPartonSelector() :
  fFilter(0), fClassifier(0)
{
}

//------------------------------------------------------------------------------

MadGraphShowerPartonSelector::~MadGraphShowerPartonSelector()
{
}

//------------------------------------------------------------------------------

void MadGraphShowerPartonSelector::Init()
{
  ExRootConfParam param;

  Int_t i, pid, sizeParam;

  // import ROOT tree branch

  fBranchParticle = UseBranch("GenParticle");

  // create classifier and filter

  fClassifier = new MadGraphShowerPartonClassifier(fBranchParticle);
  fFilter = new ExRootFilter(fBranchParticle);

  fEtaMax = GetDouble("EtaMax", 5.0);
  fClassifier->SetEtaMax(fEtaMax);

  // read particle IDs from configuration file and setup classifier

  param = GetParam("PartonIDs");
  sizeParam = param.GetSize();

  for(i = 0; i < sizeParam; ++i)
  {
    pid = param[i].GetInt();
    fClassifier->InsertParticleID(pid);
  }

  // read ancestor IDs from configuration file and setup classifier

  param = GetParam("ExcludedAncestorIDs");
  sizeParam = param.GetSize();

  for(i = 0; i < sizeParam; ++i)
  {
    pid = param[i].GetInt();
    fClassifier->InsertExclAncestorID(pid);
  }

  // read ancestor IDs from configuration file and setup classifier

  param = GetParam("IncludedAncestorIDs");
  sizeParam = param.GetSize();

  for(i = 0; i < sizeParam; ++i)
  {
    pid = param[i].GetInt();
    fClassifier->InsertInclAncestorID(pid);
  }

  // create output arrays

  fOutputArray = ExportArray("candidates");

}

//------------------------------------------------------------------------------

void MadGraphShowerPartonSelector::Finish()
{
  if(fFilter) delete fFilter;
  if(fClassifier) delete fClassifier;
}

//------------------------------------------------------------------------------

void MadGraphShowerPartonSelector::Process()
{
  TObjArray *array = 0;
  ExRootGenParticle *particle = 0;
  ExRootCandidate *candidate = 0;
  ExRootFactory *factory = GetFactory();

  TLorentzVector momentum;

  fFilter->Reset();
  array = fFilter->GetSubArray(fClassifier, 0);

  if(array == 0) return;

  TIter itArray(array);

  while(particle = static_cast<ExRootGenParticle*>(itArray.Next()))
  {
    momentum.SetPxPyPzE(particle->Px, particle->Py, particle->Pz, particle->E);

    candidate = factory->NewCandidate();

    candidate->SetP4(momentum);
    candidate->SetType(particle->PID);

    fOutputArray->Add(candidate);
  }

/*
  cout << "==============================" << endl;
  Int_t indexParticle = -1;
  itArray.Reset();
  while(particle = static_cast<ExRootGenParticle*>(itArray.Next()))
  {
    ++indexParticle;
    cout << "--->\t" << particle->Status << "\t" << particle->PID << "\t";
    cout << particle->M1 << "\t" << particle->M2 << "\t";
    cout << particle->Px << "\t" << particle->Py << "\t" << particle->Pz << endl;
  }
*/
}

//------------------------------------------------------------------------------
