
# inital imports, setup AIMnet minimization
import sys
import os
import tarfile
import glob
import numpy as np
import pandas as pd
import torch
import multiprocessing as mp
torch.set_num_threads(1)

from aimnet import load_AIMNetSMD, load_AIMNetSMD_ens, AIMNetCalculator
from dftd4 import D4_model

if torch.cuda.is_available():
    model = load_AIMNetSMD_ens().cuda()
else:
    model = load_AIMNetSMD_ens()


try:
    calculator = D4_model(xc='wB97x', calc=AIMNetCalculator(model))
except:
    print('error making D4_model, using vanila AIMNet model')


# import stuf other than Aimnet / Torch
import h5py
from collections import defaultdict

import ase
from ase import Atoms
from ase.io import write, read, iread
from ase.units import kcal,mol
from ase.optimize import BFGS
from ase.constraints import FixInternals


# need to extract the inputs.tgz File
TAR = glob.glob('*gz')[0]
with tarfile.open(TAR, 'r:gz') as tar:
    for item in tar:
        tar.extract(item)

# get the commandline argument
arguments = {}
with open('./flags.txt') as flag_file:
    for line in flag_file:
        if line[0] == "#": continue
        k, v = line.split()
        arguments[k] = v
if arguments['verbose'] == 'True': print(arguments.items())




# from this directory
#import utils
#import utils_build
#import params
#import setup_minimization


def _par_minimize(
    XYZs,
    bonds,
    atoms,
    endpoint,
):
    '''
    This will parallel-ly minimze the XYZs

    this returns a default dict that looks like:
    {
        'xyz'      : [np.arrays],
        'energies' : [kcal/mol],
    }
    '''

    # create the defualt dictioary we will return
    results = defaultdict(list)


    for i, xyz in enumerate(XYZs):
        macrocycle = Atoms(
                positions=xyz,
                numbers=atoms,
        )
        macrocycle.set_constraint(
            FixInternals(
                bonds=[[macrocycle.get_distance(*bond), bond] for bond in bonds],
                epsilon=0.03,
            )
        )
        macrocycle.set_calculator(calculator)

        def save(atoms=macrocycle, results=results):
            results['energies'].append(atoms.get_total_energy()[0] * (1/(kcal/mol)))
            results['xyz'].append(atoms.get_positions())

        # controls if we save ONLY the endpoint of min, or snapshots of the trajectory
        try:
            opt = BFGS(macrocycle, maxstep=0.5, logfile='/dev/null')
            if not endpoint: opt.attach(save, interval=1)
            opt.run(steps=100, fmax=0.005)
        except:
            pass# conitnue here i htink means we skip adding the coordiantes to the results dict

        # save the last coordinates only!!!!!!
        if endpoint:
            save()

        #min_XYZs[i, :, :] = macrocycle.get_positions()
        #min_Es[i] = np.nan_to_num(macrocycle.get_total_energy())[0] * (1/ (kcal/mol))

    #return {'xyzs' : min_XYZs, 'energies' : min_Es}
    return results


def minimize_macrocycle(
    XYZs,
    atoms,
    bonds,
    endpoint,
):
    '''
    minimizing all confomers,
    and original scoring

    returns
        Msnapshots x Natoms x 3 np array of XYZs after/during minimzation
        Msnapshots, np array of energies (in kcal/mol) after min
    '''

    if endpoint == 'True':
        end = True
    elif enpoint == 'False':
        end = False
    else:
        print('invalid input for enpoint, defualting to False')
        end = False


    # might need to convert bonds form whatever they are in the HDF5 to
    # whateer ase needs

    # send the (slixed XYZs, atoms) tuple to pool
    with mp.Pool(processes=1) as pool:
        results = pool.starmap(
            _par_minimize,
            [(XYZs, bonds, atoms, end)],
        )

    # pool returns
    #{
    #    'points'   : [int],
    #    'xyz'      : [np.arrays],
    #    'energies' : [kcal/mol],
    #}

    # need to first combind all of the defualt dicts
    list_of_xyz = []
    list_of_energies = []

    for dict in results:
        for energy, xyz in zip(dict['energies'], dict['xyz']):
            list_of_xyz.append(xyz[np.newaxis,:,:])
            list_of_energies.append(energy)


    min_XYZs = np.concatenate(
        list_of_xyz,
        axis=0,
    )

    min_Es = np.array(list_of_energies)

    del results
    del list_of_xyz
    del list_of_energies

    relative_min_energies = min_Es - np.min(min_Es)

    return min_XYZs, relative_min_energies, np.min(min_Es)

def save_HDF5(
    XYZs,
    atoms,
    deltaEs,
    pdb_dir,
    macrocycle_name,
):
    '''
    makes a HDF5 file that contains all of the confoemrs, and relavent info
    '''
    with h5py.File(f'{pdb_dir}{macrocycle_name}.hdf5', 'w') as f:
        XYZs_dataset = f.create_dataset('XYZs', data=XYZs)
        deltaEs_dataset = f.create_dataset('deltaEs', data=deltaEs)
        atoms_dataset = f.create_dataset('atoms', data=atoms)

if __name__ == '__main__':
    # MACRO 1, load data out of the input HDF5 file

    with h5py.File(arguments['HDF5'], 'r') as f:
        macrocycle_xyzs  = f['XYZS'][:]
        macrocycle_atoms = f['ATOMS'][:]
        macrocycle_bonds = list(map(tuple, f['BONDS'][:]))

    # testing
    if arguments['verbose'] == 'True':
        print(f'shape of XYZs           : {macrocycle_xyzs.shape}')
        print(f'shape of ATOMs          : {macrocycle_atoms.shape}')
        print(f'BONDs                   : {macrocycle_bonds}')


    # for log file to figure out what is taking a very long time
    print(f'{arguments["macrocycle_name"]} XYZs : {macrocycle_xyzs.shape}')

    # MACRO 5
    # start minimizing the conformers with AIMNet
    min_xyzs, min_Es, minimum_energy = minimize_macrocycle(
        macrocycle_xyzs,
        macrocycle_atoms,
        macrocycle_bonds,
        arguments['endpoint_only'],
    )

    if arguments['endpoint_only'] == 'True':
        assert min_xyzs.shape == macrocycle_xyzs.shape
        assert min_xyzs.shape[0] == min_Es.shape[0]

    if arguments['verbose'] == 'True':
        print(f'shape of min XYZs      : {min_xyzs.shape}')
        print(f'shape of min energies  : {min_Es.shape}')
        print(f'energies               : {min_Es[:5]}')

    # MACRO 8
    save_HDF5(
        min_xyzs,
        macrocycle_atoms,
        min_Es + minimum_energy,
        arguments['out_dir'],
        arguments['macrocycle_name'],
    )
