import pstats
from xml.dom.pulldom import PROCESSING_INSTRUCTION
import ase
from ase import Atoms
from ase.neb import NEB
from ase.io import Trajectory
from ase.optimize import BFGS, LBFGS,LBFGSLineSearch,BFGSLineSearch,FIRE
# from sella import Sella, Constraints
from math import sqrt
from pathlib import Path
import pickle
from logging import getLogger,StreamHandler,FileHandler,Formatter,DEBUG,INFO,WARNING,ERROR,CRITICAL
import numpy as np
from tqdm.notebook import tqdm
from timeout_decorator import timeout, TimeoutError
import sys,os,pathlib,logging,math
from .log import reset_logger, add_filehandler
from .neb import ListNEB
[ドキュメント]class ListOptimizer():
    """全てのOptimizerの親クラス
    
    Parameters:
    
    atoms_list: list of Atoms
        Atomsのリスト, NEB計算の時はListNEBオブジェクト
    savename: str
        | 計算実行(run)中に`{savename}.trajと{savename}.pickleが逐次作成される. 
        | trajファイルには計算後の構造,pickleファイルには収束,未収束などの計算結果の情報が記録される.
    logfolder: str or path object
        ASEのlogfile引数に相当するが,ASEとは異なり,複数のlogファイルが作成されるためフォルダ名を指定する
    trajfolder: str or path object
        ASEのtrajectory引数に相当するが,ASEとは異なり,複数のtrajファイルが作成されるためフォルダ名を指定する
    errorlog: str or path object
        | 計算実行(run)中にerrorlogファイルが作成される. 
        | 内容はsavenameの pickleファイルと似ているがテキストファイルで出力される.
    indexes: list of integers
        特定の要素のみを計算したい場合にindex番号のリストで計算対象を指定する.
    max_fmax: float
        初期構造のforceがmax_fmaxを上回った際には計算を実行しない. 特にNEB計算の際には役立つ.
    timeout: integer
        1回のイタレーションにtimeout(秒)以上かかる際に計算を途中で諦める.
        ListSellaを使ったTS計算の際に特に役立つ
    maxstep: float
        | 1回のイタレーションで動く原子の最大値. VASPのPOTIMに相当.
        | MatlanticGrrmでなくASEで用意されている引数.
    """
    def __init__(self,atoms_list:list,savename:str,logfolder:str,trajfolder:str,errorlog:str,indexes:list=None,
                 max_fmax:float=10000,timeout:int=60*20,#ここまでがユーザー定義の引数
                 restart=None,maxstep=None,master=None,force_consistent=None):
        ##user定義##
        self.atoms_list = atoms_list
        self.savename = savename # 計算(run)すると{sefl.savename}.trajと{sefl.savename}.pickleが作成される
        self.logfolder = logfolder
        self.trajfolder = trajfolder
        self.errorlog = errorlog
        self.indexes = indexes
        self.max_fmax = max_fmax
        self.timeout = timeout
        self.state_list = [] # pickleファイルに保存するリスト
        ##ase由来##
        self.restart = restart
        self.maxstep = maxstep
        self.master = master
        self.force_consistent = force_consistent
        
        if type(self.atoms_list)==ListNEB:
            """NEB計算を行なう場合trajに保存すべきは,ListNEN.imagess"""
            self.save_datas = self.atoms_list.imagess
        else:
            """それ以外の計算ではAtomsリストをtrajに保存"""
            self.save_datas = self.atoms_list
            
        if self.indexes is None:
            self.indexes = [i for i in range(len(self.atoms_list))]
            
        if self.logfolder is not None:
            self.logfolder = Path(logfolder)
            if not self.logfolder.exists():
                os.makedirs(self.logfolder)
            self.logfiles = [self.logfolder/f"{i}.log" for i in range(len(self.atoms_list))]
        else:
            self.logfiles = [None for _ in range(len(self.atoms_list))]
                
        if self.trajfolder:
            self.trajfolder = Path(trajfolder)
            if not self.trajfolder.exists():
                os.makedirs(self.trajfolder)
            self.trajectorys = [self.trajfolder/f"{i}.traj" for i in range(len(self.atoms_list))]
            self.trajectorys = [str(path) for path in self.trajectorys]
        else:
            self.trajectorys = [None for _ in range(len(self.atoms_list))]       
            
    def start_number(self):
        """何番目から計算を開始するか
        
        stop.logが存在する時はstpo.logに書かれている番号から計算を開始する
        
        """
        self.logger = reset_logger() #loggerをリセット  
        if Path("stop.log").exists():
            with open("stop.log","r") as f:
                n = int(f.read()) #n番目から計算を始める
            os.remove("stop.log")
            add_filehandler(self.logger,"a",self.errorlog,DEBUG)
        else:
            n = 0
            add_filehandler(self.logger,"w",self.errorlog,DEBUG)
        return n
        
    def write_stop_log(self,i):
        with open("stop.log","w") as f:
            f.write(str(i)) 
            
    def ase_run(self,atoms,i,fmax,steps):
        opt = self.mk_opt_func(atoms,self.logfiles[i],self.trajectorys[i])
        @timeout(self.timeout)
        def _internal(opt):
            opt.step()
        def irun(opt):
            # compute initial structure and log the first step
            opt.atoms.get_forces()
            # yield the first time to inspect before logging
            yield False
            
            if opt.nsteps == 0:
                opt.log()
                opt.call_observers()
            # run the algorithm until converged or max_steps reached
            while not opt.converged() and opt.nsteps < opt.max_steps:
                # compute the next step
                try:
                    _internal(opt)
                except TimeoutError:
                    self.dnf = True
                    break
                opt.nsteps += 1
                # let the user inspect the step and change things before logging
                # and predicting the next step
                yield False
                # log the step
                opt.log()
                opt.call_observers()
            # finally check if algorithm was converged
            yield opt.converged()
        opt.fmax = fmax
        if steps:
            opt.max_steps = steps
        for converged in irun(opt):
            pass
        
        return opt
    
    def save_object(self,atoms,mode):
        if type(self.atoms_list) == ListNEB:
            """NEBの時"""
            if not atoms: # objectが存在しない時
                atoms = [Atoms() for _ in range(self.atoms_list.n_images)]
            n_images = len(atoms)
            for i,atm in enumerate(atoms):
                if mode == "w" and i == 0:
                    Trajectory(f"{self.savename}.traj", mode="w", atoms=atm).write()
                else:
                    Trajectory(f"{self.savename}.traj", mode="a", atoms=atm).write()
        else:
            """それ以外"""
            n_images  = 1
            if not atoms:
                atoms = Atoms()
            Trajectory(f"{self.savename}.traj", mode=mode, atoms=atoms).write()
            
        with open(f"{self.savename}.pickle","wb") as f:
            pickle.dump((self.state_list, n_images), f)
    def write_log(self,i,n,*comment):
        self.state_list.append(n)
        if n == 0:
            self.logger.debug(f"{i}: 完了 force={comment[0]} state {n}")     
        elif n == 1:
            self.logger.info(f"{i}: index未指定 state {n}")
        elif n == 2:
            self.logger.info(f"{i}: objectが存在しない state {n}")
        elif n == 3:
            self.logger.warning(f"{i}: np.linalg.LinAlgError force={comment[0]} state {n}")
        elif n == 4:
            self.logger.warning(f"{i}: 未収束 force={comment[0]} state {n}")
        elif n == 5:
            self.logger.warning(f"{i}: TimeOut force={comment[0]} state {n}")
        elif n == 6:
            self.logger.warning(f"{i}: forceがmax_fmaxを上回っている force={comment[0]} state {n}")
        elif n == 7:
            self.logger.error(f"{i}: force計算不可 state {n}")
        elif n == 8:
            self.logger.critical(f"{i}: {comment[0]} {comment[1]} {comment[2]} state {n}")
            
    def initialize(self,i,atoms):
        """計算する前に構造やオブジェクトの有無を確認する
        
        param:
        
        i: int
            i番目の計算
        atoms: Atoms or NEB
            i番目のAtomsオブジェクト. NEB計算の時はNEBオブジェクト
            
        Return:
            1つでもひっかればFalseを返す
        """
        self.dnf = False # タイムアウトした場合True. Do not finished
        if not i in self.indexes:
            """indexesに与えられてない時"""
            self.write_log(i,1)
            return False
        elif not atoms:
            """Atoms or NEBオブジェクトが存在しない(None)の時"""
            self.write_log(i,2)
            return False
        try:
            force = sqrt((atoms.get_forces()**2).sum(axis=1).max())
            if force > self.max_fmax:
                """force > max_fmaxの時"""
                self.write_log(i,6,force)
                return False  
            elif math.isnan(sqrt((atoms.get_forces()**2).sum(axis=1).max())):
                """force="nan"の時"""
                self.write_log(i,7)
                return False
            else:
                return True
        except:
            self.write_log(i,7)
            return False
    
[ドキュメント]    def run(self,fmax=0.05, steps=100000000):
        """計算を実行する
        Parameters:
        
        fmax:
            収束条件となるforce
        steps: integer
            最大ステップ数
        """
        n = self.start_number() #n番目から計算を始める
        for i, atoms in enumerate(tqdm(self.atoms_list[n:]),start=n):
            if Path("stop.log").exists():
                self.write_stop_log(i)
                break
            if self.initialize(i,atoms):
                try:
                    opt = self.ase_run(atoms,i,fmax,steps)
                    force = round(sqrt((atoms.get_forces()**2).sum(axis=1).max()),4)
                    if opt.converged():
                        self.write_log(i,0,force)
                    else:
                        if self.dnf:
                            self.write_log(i,5,force)
                        else:
                            self.write_log(i,4,force)             
                except Exception as e:
                    exc_type, exc_obj, exc_tb = sys.exc_info()
                    fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
                    self.write_log(i,8,exc_type,fname,exc_tb.tb_lineno)
            # trajファイルに保存する
            if i == 0 and n == 0: #始めは上書き保存
                self.save_object(self.save_datas[i],mode="w")
            else:
                self.save_object(self.save_datas[i],mode="a")  
                            
[ドキュメント]class ListBFGS(ListOptimizer):
    def __init__(self, atoms_list, savename, logfolder=None, errorlog="error.log", trajfolder=None,
                 indexes=None,max_fmax=100000,timeout=60*20, 
                 restart=None, maxstep=None, master=None, alpha=None):
        super().__init__(atoms_list,savename,logfolder,trajfolder,errorlog,indexes,max_fmax,timeout,
                         restart,maxstep,master)
        self.alpha = alpha
        
    def mk_opt_func(self,atoms,logfile,trajectory):
        return BFGS(atoms,restart=self.restart,logfile=logfile,trajectory=trajectory,
                     maxstep=self.maxstep,master=self.master,alpha=self.alpha) 
            
[ドキュメント]class ListLBFGS(ListOptimizer):
    """ase.optimize.LBFGSをase.Atomsのリストを引数に使用できるようにしたもの.        
    """
    def __init__(self, atoms_list, savename, logfolder=None, errorlog="error.log", trajfolder=None,
                 indexes=None,max_fmax=100000,timeout=60*20,
                 restart=None, maxstep=None, memory=100, damping=1.0,
                 alpha=70.0,use_line_search=False, master=None,force_consistent=None):
        super().__init__(atoms_list,savename,logfolder,trajfolder,errorlog,indexes,max_fmax,timeout,
                         restart,maxstep,master)
        self.memory = memory
        self.damping = damping
        self.alpha = alpha
        self.use_line_search = use_line_search
        self.force_consistent = force_consistent
        
    def mk_opt_func(self,atoms,logfile,trajectory):
        return LBFGS(atoms,restart=self.restart,logfile=logfile,trajectory=trajectory,
                     maxstep=self.maxstep,memory=self.memory,damping=self.damping,
                     alpha=self.alpha,use_line_search=self.use_line_search,master=self.master,
                     force_consistent=self.force_consistent) 
    
[ドキュメント]class ListBFGSLineSearch(ListLBFGS):
    def __init__(self, atoms_list, savename, logfolder=None, errorlog="error.log", trajfolder=None,
                 indexes=None, max_fmax=100000,timeout=60*20,
                 restart=None, maxstep=None, c1=0.23, c2=0.46, alpha=10.0, stpmax=50.0,
                 master=None, force_consistent=None):
        super().__init__(atoms_list,savename,logfolder,trajfolder,errorlog,indexes,max_fmax,timeout,
                         restart,maxstep,master,force_consistent=force_consistent)
        self.c1 = c1
        self.c2 = c2
        self.alpha = alpha
        self.stpmax = stpmax
        
    def mk_opt_func(self,atoms,logfile,trajectory):
        return BFGSLineSearch(atoms, restart=self.restart, logfile=logfile, maxstep=self.maxstep,
                 trajectory=trajectory, c1=self.c1, c2=self.c2, alpha=self.alpha, stpmax=self.stpmax,
                 master=self.master, force_consistent=self.force_consistent) 
[ドキュメント]class ListLBFGSLineSearch(ListLBFGS):
    def __init__(self, atoms_list, savename, logfolder=None, errorlog="error.log", trajfolder=None, 
                 indexes=None, max_fmax=100000, timeout=60*20,
                 restart=None,maxstep=None, memory=100, damping=1.0, alpha=70.0,
                 use_line_search=False, master=None,
                 force_consistent=None):
        super().__init__(atoms_list,savename,logfolder,trajfolder,errorlog,indexes,max_fmax,timeout,
                         restart,maxstep,master,force_consistent=force_consistent)
        self.memory = memory
        self.damping = damping
        self.alpha = alpha
        self.use_line_search = use_line_search
        
    def mk_opt_func(self,atoms,logfile,trajectory):
        return LBFGSLineSearch(atoms=atoms, restart=self.restart, logfile=logfile, trajectory=trajectory,
                               maxstep=self.maxstep,memory=self.memory, damping=self.damping,
                               alpha=self.alpha,use_line_search=self.use_line_search, 
                               master=self.master, force_consistent=self.force_consistent)  
        
[ドキュメント]class ListFIRE(ListOptimizer):
    def __init__(self, atoms_list, savename, logfolder=None, errorlog="error.log", trajfolder=None,
                 indexes=None, max_fmax=100000, timeout=60*20,
                 restart=None, dt=0.1, maxstep=None, maxmove=None, dtmax=1.0, Nmin=5,
                 finc=1.1, fdec=0.5,
                 astart=0.1, fa=0.99, a=0.1, master=None, downhill_check=False,
                 position_reset_callback=None, force_consistent=None):
        super().__init__(atoms_list,savename,logfolder,trajfolder,errorlog,indexes,max_fmax,timeout,
                         restart,maxstep,master,force_consistent=force_consistent)
        self.dt = dt
        self.maxmove = maxmove
        self.dtmax = dtmax
        self.Nmin = Nmin
        self.finc = finc
        self.fdec = fdec
        self.astart = astart
        self.fa = fa
        self.a = a
        self.downhill_check = downhill_check
        self.position_reset_callback = position_reset_callback
        
    def mk_opt_func(self,atoms,logfile,trajectory):
        return FIRE(atoms=atoms,restart=self.restart,logfile=logfile,trajectory=trajectory,
                    dt=self.dt,maxstep=self.maxstep,maxmove=self.maxmove,dtmax=self.dtmax,
                    Nmin=self.Nmin,finc=self.finc,fdec=self.fdec,astart=self.astart,fa=self.fa,
                    a=self.a,master=self.master,downhill_check=self.downhill_check,
                    position_reset_callback=self.position_reset_callback,
                    force_consistent=self.force_consistent) 
        
[ドキュメント]class ListSella(ListOptimizer):
    """ :sella:`sella.Sella <optimize/optimize.py>` の第一引数にAtomsリストを使用できるように拡張したクラス
    Parameters:
    
    retry: integer
        | Sellaの計算ではnumpy.linalg.LinAlgErrorで計算が終了する場合が多いが, 多くの場合,再計算を行なうことで問題が解決する.
        | retry(回)まではLinAlgErrorが発生しても再計算を行なうようにする.
            
    その他の引数は :class:`ListBFGS` と :sella:`sella.Sella <optimize/optimize.py>` を参照
    
    """
    def __init__(self,atoms_list, savename, logfolder=None, errorlog="error.log", trajfolder=None,
                 indexes=None, max_fmax=100000, timeout=60*20, retry=0,
                 restart=None, master=None, force_consistent=False, delta0=None, sigma_inc=None,
                 sigma_dec=None, rho_dec=None, rho_inc=None,order=1, eig=None, eta=1e-4, 
                 method=None, gamma=0.1, threepoint=False, constraints=None, constraints_tol=1e-5,
                 v0=None, internal=False, append_trajectory=False, rs=None, nsteps_per_diag=3):
        
        super().__init__(atoms_list,savename,logfolder,trajfolder,errorlog,indexes,max_fmax,timeout,
                         restart,master,force_consistent=force_consistent)
        self.retry = retry #最大のエラーの回数
        self.error_count = 0 #エラーをカウント
        self.delta0 = delta0
        self.sigma_inc = sigma_inc
        self.sigma_dec = sigma_dec
        self.rho_dec = rho_dec
        self.rho_inc = rho_inc
        self.order = order
        self.eig = eig
        self.eta = eta
        self.method = method
        self.gamma = gamma
        self.threepoint = threepoint
        self.constraints = constraints
        self.constraints_tol = constraints_tol
        self.v0 = v0
        self.internal = internal
        self.append_trajectory = append_trajectory
        self.rs = rs
        self.nsteps_per_diag = nsteps_per_diag
    def mk_opt_func(self,atoms,logfile,trajectory):
        return Sella(atoms=atoms,restart=self.restart,logfile=logfile,trajectory=trajectory,
                    master=self.master,force_consistent=self.force_consistent,delta0=self.delta0,
                    sigma_inc=self.sigma_inc,sigma_dec=self.sigma_dec,rho_dec=self.rho_dec,
                    rho_inc=self.rho_inc,order=self.order,eig=self.eig,eta=self.eta,method=self.method,
                    gamma=self.gamma,threepoint=self.threepoint,constraints=self.constraints,
                    constraints_tol=self.constraints_tol,v0=self.v0,internal=self.internal,
                    append_trajectory=self.append_trajectory,rs=self.rs,nsteps_per_diag=self.nsteps_per_diag)
    
[ドキュメント]    def run(self,fmax=0.05, steps=100000000):
        """計算を実行する
        Parameters:
        
        fmax:
            収束条件となるforce
        steps: integer
            最大ステップ数
        """
        n = self.start_number() # n番目から計算を始める
        i = n
        bar = tqdm(total=len(self.atoms_list))
        while i < len(self.atoms_list):
            if Path("stop.log").exists():
                self.write_stop_log(i)
                break
            atoms = self.atoms_list[i]
            if self.initialize(i,atoms):
                try:
                    opt = self.ase_run(atoms,i,fmax,steps)
                    force = round(sqrt((atoms.get_forces()**2).sum(axis=1).max()),4)
                    if opt.converged():
                        self.write_log(i,0,force)
                    else:
                        if self.dnf: #timeoutの場合
                            if self.error_count < self.retry and opt.nsteps > 1:
                                self.error_count += 1
                                self.logger.info(f"{i}: TimeOut {self.error_count}回目の再計算を開始する")
                                continue
                            else:
                                self.write_log(i,5,force)
                        else:
                            self.write_log(i,4,force)
                except np.linalg.LinAlgError as e:
                    if self.error_count < self.retry:
                        self.error_count += 1
                        self.logger.info(f"{i}: numpyエラー {self.error_count}回目の再計算を開始する")
                        continue
                    else:
                        force = round(sqrt((atoms.get_forces()**2).sum(axis=1).max()),4)
                        self.write_log(self,i,3,force)
                except Exception as e:
                    exc_type, exc_obj, exc_tb = sys.exc_info()
                    fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
                    self.write_log(i,8,exc_type,fname,exc_tb.tb_lineno)
            # trajファイルに保存する
            if i == 0 and n == 0: #始めは上書き保存
                self.save_object(self.save_datas[i],mode="w")
            else:
                self.save_object(self.save_datas[i],mode="a")
            i += 1
            self.error_count = 0  
            bar.update(1)