#!/usr/bin/env python
# coding: utf-8

# First do 'make' in all subdirectories of 'gen_xxx_standalone_cpp'. These directories
# contain the standalone_cpp output of MadGraph for the u d > t~ ta+ and t > ta+ u~ d~
# matrix elements. Select the process to check with the 'decay' switch.
# The function 'set' below can be used to set the relevant a,b,c,d,a',b',c',d' parameters to
# desired values. The function 'check' computes the matrix element at 10 random phase-space
# points and compares these to the analytical computation presented in [1107.3805],
# PRD85(2012)016006. Note that the published version has the correct definition of O^(s)
# operators (with beta and gamma colour indices interchanged compared to the arxiv version.
# Note also that in both versions, the differential partonic cross section has a typo on the
# first line of Eq.(6): C should be replaced by -C (on this first line only).
# Each line of output from the 'check' function displays the parameters (A,B,C) deriving
# from a,b,c, etc., the matrix element computed analytically and the relative difference
# between the MadGraph computation and the analytical one.


from subprocess import run, STDOUT, PIPE
import re
import numpy as np

decay = True

if not decay:
    fname = './gen_ud2tta_standalone_cpp/Cards/param_card.dat'
else:
    fname = './gen_t2taud_standalone_cpp/Cards/param_card.dat'

def check():
    if not decay:
        out = run(['cd gen_ud2tta_standalone_cpp/SubProcesses/P1_Sigma_bnv_mediator_ufo_ud_txtap && ./check'], shell=True, stdout=PIPE).stdout.decode()
    else:
        out = run(['cd gen_t2taud_standalone_cpp/SubProcesses/P1_Sigma_bnv_mediator_ufo_t_tapuxdx && ./check'], shell=True, stdout=PIPE).stdout.decode()
    
    aaa = float(re.findall('mdl_aaa3x1 = ([0-9e+-.]+)', out)[0])
    bbb = float(re.findall('mdl_bbb3x1 = ([0-9e+-.]+)', out)[0])
    ccc = float(re.findall('mdl_ccc1x3 = ([0-9e+-.]+)', out)[0])
    ddd = float(re.findall('mdl_ddd1x3 = ([0-9e+-.]+)', out)[0])
    aaaprime = float(re.findall('mdl_aaaprime3x3 = ([0-9e+-.]+)', out)[0])
    bbbprime = float(re.findall('mdl_bbbprime3x3 = ([0-9e+-.]+)', out)[0])
    cccprime = float(re.findall('mdl_cccprime1x1 = ([0-9e+-.]+)', out)[0])
    dddprime = float(re.findall('mdl_dddprime1x1 = ([0-9e+-.]+)', out)[0])


    mt = float(re.findall('mdl_MT = ([0-9e+-.]+)', out)[0])
    Lambda = float(re.findall('mdl_Lambda = ([0-9e+-.]+)', out)[0])

    A = (aaa**2+bbb**2)*(ccc**2+ddd**2)
    B = (aaaprime**2+bbbprime**2)*(cccprime**2+dddprime**2)
    C = (aaa*ccc*aaaprime*cccprime+bbb*ddd*bbbprime*dddprime)

    results = re.findall('Matrix element[0-9]+ = ([0-9e+-.]+)',out)
    
    for dummy in range(len(results)):
    
        P = re.findall('Momenta{:}:.+Matrix element{:}'.format(dummy,dummy),out,flags=re.DOTALL)[0]

        p = []
        for line in P.split('\n')[1:-2]:
            tmp = re.findall('[^\s]+',line)
            tmp = list(map(float,tmp))[1:]
            if tmp!=[]:
                p.append( tmp )


        sig = [1,-1,-1,-1]
        if not decay:
            # ordering is ud>tta
            s = np.sum([sig[i]*(p[0][i]+p[1][i])**2 for i in range(4)])
            t = np.sum([sig[i]*(p[0][i]-p[3][i])**2 for i in range(4)])
        else:
            # ordering is t>ta u d
            s = np.sum([sig[i]*(p[2][i]+p[3][i])**2 for i in range(4)])
            t = np.sum([sig[i]*(p[1][i]+p[2][i])**2 for i in range(4)])
        mt2 = mt**2
        s2  = s**2
        
        if not decay:
            xsec = (A*t*(t-mt2) + B*(s-mt2)*s - 2*C*t*s)/(6.*Lambda**4)
        else:
            xsec = (A*t*(t-mt2) + B*(s-mt2)*s - 2*C*t*s)/(-1*Lambda**4)

        print('({:},{:},{:})\t{:}\t{:}'.format(A,B,C,xsec,xsec/float(results[dummy])-1))

    
def set(a,b,c,d,ap,bp,cp,dp):
    with open(fname) as f:
        txt = f.read()
    txt = re.sub('[0-9e+-.]+ # MTA','0. # MTA', txt)# set the tau mass to zero
    txt = re.sub('[0-9e+-.]+ # aaa3x1','{:g} # aaa3x1'.format(a), txt)
    txt = re.sub('[0-9e+-.]+ # bbb3x1','{:g} # bbb3x1'.format(b), txt)
    txt = re.sub('[0-9e+-.]+ # ccc1x3','{:g} # ccc1x3'.format(c), txt)
    txt = re.sub('[0-9e+-.]+ # ddd1x3','{:g} # ddd1x3'.format(d), txt)
    txt = re.sub('[0-9e+-.]+ # aaaprime3x3','{:g} # aaaprime3x3'.format(ap), txt)
    txt = re.sub('[0-9e+-.]+ # bbbprime3x3','{:g} # bbbprime3x3'.format(bp), txt)
    txt = re.sub('[0-9e+-.]+ # cccprime1x1','{:g} # cccprime1x1'.format(cp), txt)
    txt = re.sub('[0-9e+-.]+ # dddprime1x1','{:g} # dddprime1x1'.format(dp), txt)
    
    with open(fname,'w') as f:
        f.write(txt)

set(0,1,0,1, 1,1,1,1)
check()

