
#include "modules/MadGraphParticleClassifier.h"

#include "ExRootAnalysis/ExRootClasses.h"

#include "TClass.h"

#include <iostream>

using namespace std;

//------------------------------------------------------------------------------

MadGraphParticleClassifier::MadGraphParticleClassifier() :
  fMaxCategories(0), fIsExtendable(kFALSE)
{
}

//------------------------------------------------------------------------------

void MadGraphParticleClassifier::InsertParticleStatus(Int_t status)
{
  fParticleStatusSet.insert(status);
}

//------------------------------------------------------------------------------

void MadGraphParticleClassifier::InsertClassPID(const TString &className, Int_t pid)
{
  Int_t category;
  map< TString, Int_t >::const_iterator itClassNameMap;

  itClassNameMap = fClassNameMap.find(className);
  
  if(itClassNameMap == fClassNameMap.end())
  {
    category = fMaxCategories;
    fClassNameMap[className] = category;
    fClassNameArray.push_back(className);
    ++fMaxCategories;
  }
  else
  {
    category = itClassNameMap->second;
  }
  fParticleIDMap[pid] = category;
}

//------------------------------------------------------------------------------

Int_t MadGraphParticleClassifier::GetCategory(TObject *object)
{
  Int_t pidAbs, pid, status;
  Int_t result = -1;

  if(object->IsA()->InheritsFrom(ExRootLHEFParticle::Class()))
  {
    ExRootLHEFParticle *particle = static_cast<ExRootLHEFParticle*>(object);
    pid = particle->PID;
    status = particle->Status;
  }
  else if(object->IsA()->InheritsFrom(ExRootGenParticle::Class()))
  {
    ExRootGenParticle *particle = static_cast<ExRootGenParticle*>(object);
    pid = particle->PID;
    status = particle->Status;
  }
  else
  {
    return -1;
  }

  map< Int_t, Int_t >::const_iterator itParticleIDMap;

  TString className;

  if(fParticleStatusSet.find(status) == fParticleStatusSet.end()) return -1;

  itParticleIDMap = fParticleIDMap.find(pid);

  if(itParticleIDMap != fParticleIDMap.end())
  {
    result = itParticleIDMap->second;
  }
  else if(fIsExtendable)
  {
    pidAbs = TMath::Abs(pid);
    className = Form("%d", pidAbs);
    result = fMaxCategories;
    InsertClassPID(className, pidAbs);
    InsertClassPID(className, -pidAbs);
  }

  return result;
}

//------------------------------------------------------------------------------

void MadGraphParticleClassifier::SetExtendable(Bool_t extendable)
{
  fIsExtendable = extendable;
}

//------------------------------------------------------------------------------

Int_t MadGraphParticleClassifier::GetMaxCategories() const
{
  return fMaxCategories;
}

//------------------------------------------------------------------------------

TString MadGraphParticleClassifier::GetCategoryClassName(Int_t category) const
{
  return fClassNameArray[category];
}

//------------------------------------------------------------------------------
