
#include "modules/MadGraphPartonSelector.h"


#include "ExRootAnalysis/ExRootResult.h"
#include "ExRootAnalysis/ExRootClasses.h"

#include "ExRootAnalysis/ExRootFilter.h"
#include "ExRootAnalysis/ExRootClassifier.h"

#include "ExRootAnalysis/ExRootFactory.h"
#include "ExRootAnalysis/ExRootCandidate.h"

#include "TMath.h"
#include "TString.h"
#include "TLorentzVector.h"
#include "TClonesArray.h"

#include <iostream>
#include <set>

using namespace std;

//------------------------------------------------------------------------------

class MadGraphPartonClassifier : public ExRootClassifier
{
public:

  MadGraphPartonClassifier(TClonesArray *branch);

  Int_t GetCategory(TObject *object);

  void InsertParticleID(Int_t pid);
  void InsertAncestorID(Int_t pid);

private:

  Bool_t hasBadAncestor(ExRootGenParticle *object);

  TClonesArray *fBranchParticle;

  set< Int_t > fParticleIDSet;
  set< Int_t > fAncestorIDSet;
};

//------------------------------------------------------------------------------

MadGraphPartonClassifier::MadGraphPartonClassifier(TClonesArray *branch) :
  fBranchParticle(branch)
{
}

//------------------------------------------------------------------------------

void MadGraphPartonClassifier::InsertParticleID(Int_t pid)
{
  fParticleIDSet.insert(pid);
}

//------------------------------------------------------------------------------

void MadGraphPartonClassifier::InsertAncestorID(Int_t pid)
{
  fAncestorIDSet.insert(pid);
}

//------------------------------------------------------------------------------

Bool_t MadGraphPartonClassifier::hasBadAncestor(ExRootGenParticle *object)
{
  const int kMaxAncestors = 10;
  Int_t i, pidAbs;
  ExRootGenParticle *particle = object;
  set< Int_t >::const_iterator itAncestorIDSet;
  set< Int_t >::const_iterator itParticleIDSet;

  for(i = 0; i < kMaxAncestors; ++i)
  {

    if(particle->M1 < 0 || particle->M2 > -1) return kFALSE;

//    if(particle->PID == 21) return kFALSE;

    particle = static_cast<ExRootGenParticle*>(fBranchParticle->At(particle->M1));

    if(particle->PID == 21) return kFALSE;

    pidAbs = TMath::Abs(particle->PID);

    // skip particles with pid included in list
    itAncestorIDSet = fAncestorIDSet.find(pidAbs);

    if(itAncestorIDSet != fAncestorIDSet.end()) return kTRUE;
  }

  return kFALSE;
}

//------------------------------------------------------------------------------

Int_t MadGraphPartonClassifier::GetCategory(TObject *object)
{
  ExRootGenParticle *particle = static_cast<ExRootGenParticle*>(object);
  ExRootGenParticle *beam[2], *mother, *daughter;

  Int_t i, beamPid[2], beamPidAbs[2];

  set< Int_t >::const_iterator itParticleIDSet;

  Int_t pidAbs = TMath::Abs(particle->PID);

  // skip beam particles and initial state partons
  if(particle->M1 < 2) return -1;

  // skip particles with status != 3
  if(particle->Status != 3) return -1;

  // skip particles with pid not included in list
  itParticleIDSet = fParticleIDSet.find(pidAbs);

  if(itParticleIDSet == fParticleIDSet.end()) return -1;

  for(i = 0; i < 2; ++i)
  {
    beam[i] = static_cast<ExRootGenParticle*>(fBranchParticle->At(i));
    beamPid[i] = beam[i]->PID;
    beamPidAbs[i] = TMath::Abs(beamPid[i]);
  }

  if(beamPidAbs[0] == 11 && beamPidAbs[1] == 11 && beamPid[0] == -beamPid[1])
  {
    mother = static_cast<ExRootGenParticle*>(fBranchParticle->At(particle->M1));
    if( (mother->PID == 22 || mother->PID == 23)
      && mother->M1 == 0 && mother->M2 == 1) return -1;
  }

  // skip particles if they have daughters with status == 3
  if(particle->D1 > -1)
  {
    daughter = static_cast<ExRootGenParticle*>(fBranchParticle->At(particle->D1));
    if(daughter->Status == 3) return -1;
  }

  if(hasBadAncestor(particle)) return -1;

  return 0;
}

//------------------------------------------------------------------------------

MadGraphPartonSelector::MadGraphPartonSelector() :
  fFilter(0), fClassifier(0)
{
}

//------------------------------------------------------------------------------

MadGraphPartonSelector::~MadGraphPartonSelector()
{
}

//------------------------------------------------------------------------------

void MadGraphPartonSelector::Init()
{
  ExRootConfParam param;

  Int_t i, pid, sizeParam;

  // import ROOT tree branch

  fBranchParticle = UseBranch("GenParticle");

  // create classifier and filter

  fClassifier = new MadGraphPartonClassifier(fBranchParticle);
  fFilter = new ExRootFilter(fBranchParticle);

  // read particle IDs from configuration file and setup classifier

  param = GetParam("PartonIDs");
  sizeParam = param.GetSize();

  for(i = 0; i < sizeParam; ++i)
  {
    pid = param[i].GetInt();
    fClassifier->InsertParticleID(pid);
  }

  // read ancestor IDs from configuration file and setup classifier

  param = GetParam("ExcludedAncestorIDs");
  sizeParam = param.GetSize();

  for(i = 0; i < sizeParam; ++i)
  {
    pid = param[i].GetInt();
    fClassifier->InsertAncestorID(pid);
  }

  // create output arrays

  fOutputArray = ExportArray("candidates");

}

//------------------------------------------------------------------------------

void MadGraphPartonSelector::Finish()
{ 
  if(fFilter) delete fFilter;
  if(fClassifier) delete fClassifier;
}

//------------------------------------------------------------------------------

void MadGraphPartonSelector::Process()
{
  TObjArray *array = 0;
  ExRootGenParticle *particle = 0;
  ExRootCandidate *candidate = 0;
  ExRootFactory *factory = GetFactory();

  TLorentzVector momentum;

  fFilter->Reset();
  array = fFilter->GetSubArray(fClassifier, 0);

  if(array == 0) return;

  TIter itArray(array);

  while(particle = static_cast<ExRootGenParticle*>(itArray.Next()))
  {
    momentum.SetPxPyPzE(particle->Px, particle->Py, particle->Pz, particle->E);

    candidate = factory->NewCandidate();

    candidate->SetP4(momentum);
    candidate->SetType(particle->PID);

    fOutputArray->Add(candidate);
  }

}

//------------------------------------------------------------------------------
