
"""
2D wave equation via FFT 

u_tt = u_xx + u_yy 

on [-1, 1]x[-1, 1], t > 0 and Dirichlet BC u=0
"""

import numpy as np
from scipy import interpolate
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm


class WaveEquation:
    
    def __init__(self, N):
        self.N = N
        self.x0 = -1.0
        self.xf = 1.0
        self.y0 = -1.0
        self.yf = 1.0
        self.initialization()
        self.initCond()
        
        
    def initialization(self):
        k = np.arange(self.N + 1)
        self.x = np.cos(k*np.pi/self.N)
        self.y = self.x.copy()
        self.xx, self.yy = np.meshgrid(self.x, self.y)
        
        self.dt = 6/self.N**2
        self.plotgap = round((1/3)/self.dt)
        self.dt = (1/3)/self.plotgap
        
        
    def initCond(self):
        self.vv = np.exp(-40*((self.xx-0.4)**2 + self.yy**2))
        self.vvold = self.vv.copy()
        
        
    def solve_and_animate(self):

        fig = plt.figure()
        
        ax = fig.add_subplot(111, projection='3d')
        
        tc = 0
        nstep = round(3*self.plotgap+1)
        wframe = None
        
        while tc < nstep:
            if wframe:
                ax.collections.remove(wframe)
                
            xxx = np.arange(self.x0, self.xf+1/16, 1/16)
            yyy = np.arange(self.y0, self.yf+1/16, 1/16)
            vvv = interpolate.interp2d(self.x, self.y, self.vv, kind='cubic')
            Z = vvv(xxx, yyy)
            
            xxf, yyf = np.meshgrid(np.arange(self.x0,self.xf+1/16,1/16), np.arange(self.y0,self.yf+1/16,1/16))
                
            wframe = ax.plot_surface(xxf, yyf, Z, cmap=cm.coolwarm, linewidth=0, 
                    antialiased=False)
            
            ax.set_xlim3d(self.x0, self.xf)
            ax.set_ylim3d(self.y0, self.yf)
            ax.set_zlim3d(-0.15, 1)
            
            ax.set_xlabel("x")
            ax.set_ylabel("y")
            ax.set_zlabel("U")
            
            ax.set_xticks([-1.0, -0.5, 0.0, 0.5, 1.0])
            ax.set_yticks([-1.0, -0.5, 0.0, 0.5, 1.0])
            
            fig.suptitle("Time = %1.3f" % (tc/(3*self.plotgap-1)-self.dt))
            #plt.tight_layout()
            ax.view_init(elev=30., azim=-110)
            plt.pause(0.01)
                
            uxx = np.zeros((self.N+1, self.N+1))
            uyy = np.zeros((self.N+1, self.N+1))
            ii = np.arange(1, self.N)
            
            for i in range(1, self.N):
                v = self.vv[i,:]
                V = np.hstack((v, np.flipud(v[ii])))
                U = np.fft.fft(V)
                U = U.real
                
                r1 = np.arange(self.N)
                r2 = 1j*np.hstack((r1, 0, -r1[:0:-1]))*U
                W1 = np.fft.ifft(r2)
                W1 = W1.real
                s1 = np.arange(self.N+1)
                s2 = np.hstack((s1, -s1[self.N-1:0:-1]))
                s3 = -s2**2*U
                W2 = np.fft.ifft(s3)
                W2 = W2.real
                
                uxx[i,ii] = W2[ii]/(1-self.x[ii]**2) - self.x[ii]*W1[ii]/(1-self.x[ii]**2)**(3/2)
                
            for j in range(1, self.N):
                v = self.vv[:,j]
                V = np.hstack((v, np.flipud(v[ii])))
                U = np.fft.fft(V)
                U = U.real
                
                r1 = np.arange(self.N)
                r2 = 1j*np.hstack((r1, 0, -r1[:0:-1]))*U
                W1 = np.fft.ifft(r2)
                W1 = W1.real
                s1 = np.arange(self.N+1)
                s2 = np.hstack((s1, -s1[self.N-1:0:-1]))
                s3 = -s2**2*U
                W2 = np.fft.ifft(s3)
                W2 = W2.real
                
                uyy[ii,j] = W2[ii]/(1-self.y[ii]**2) - self.y[ii]*W1[ii]/(1-self.y[ii]**2)**(3/2)
                
            vvnew = 2*self.vv - self.vvold + self.dt**2*(uxx+uyy)
            self.vvold = self.vv.copy()
            self.vv = vvnew.copy()
            tc += 1
        
    
def main():
    simulator = WaveEquation(30)
    simulator.solve_and_animate()
    plt.show()

    
if __name__ == "__main__":
    main()