1 # 2D Wave Equation - Finite Difference之实现
2
3 import os
4 import shutil
5 import numpy
6 from matplotlib import pyplot as plt
7 from PIL import Image
8
9
10 # 波动方程求解
11 class WaveEq(object):
12
13 def __init__(self, nx, ny, nt, D=0.1, xc=1, yc=1):
14 self.__nx = nx # x轴网格数
15 self.__ny = ny # y轴网格数
16 self.__nt = nt # t轴网格数
17 self.__D = D
18 self.__xc = xc
19 self.__yc = yc
20
21 self.__init_grid() # 网格初始化
22
23
24 def __init_grid(self):
25 xMin, xMax = 0, 2
26 yMin, yMax = 0, 2
27 tMin, tMax = 0, 6
28 self.__hx = (xMax - xMin) / self.__nx
29 self.__hy = (yMax - yMin) / self.__ny
30 self.__ht = (tMax - tMin) / self.__nt
31 self.__X = numpy.linspace(xMin, xMax, self.__nx + 1)
32 self.__Y = numpy.linspace(yMin, yMax, self.__ny + 1)
33 self.__T = numpy.linspace(tMin, tMax, self.__nt + 1)
34 self.__X, self.__Y = numpy.meshgrid(self.__X, self.__Y)
35
36
37 def get_solu(self):
38 U0 = self.__calc_U0()
39 yield self.__X, self.__Y, U0
40 U1 = self.__calc_U1(U0)
41 yield self.__X, self.__Y, U1
42 Uk_1, Uk_2 = U1, U0
43 for i in range(2, self.__nt+1):
44 Uk = self.__calc_Uk(Uk_1, Uk_2)
45 yield self.__X, self.__Y, Uk
46 Uk_1, Uk_2 = Uk, Uk_1
47
48
49 def __calc_Uk(self, Uk_1, Uk_2):
50 '''
51 计算第k层时间步数值
52 '''
53 Uk = numpy.zeros(Uk_1.shape)
54 for i in range(1, self.__nx):
55 for j in range(1, self.__ny):
56 term1 = (Uk_1[j, i+1] + Uk_1[j, i-1] - 2 * Uk_1[j, i]) / self.__hx ** 2
57 term2 = (Uk_1[j+1, i] + Uk_1[j-1, i] - 2 * Uk_1[j, i]) / self.__hy ** 2
58 Uk[j, i] = (term1 + term2) * self.__D * self.__ht ** 2 + 2 * Uk_1[j, i] - Uk_2[j, i]
59 return Uk
60
61
62 def __calc_U1(self, U0):
63 '''
64 计算第1层时间步数值
65 '''
66 U1 = numpy.zeros(U0.shape)
67 for i in range(1, self.__nx):
68 for j in range(1, self.__ny):
69 term1 = (U0[j, i+1] + U0[j, i-1] - 2 * U0[j, i]) / self.__hx ** 2
70 term2 = (U0[j+1, i] + U0[j-1, i] - 2 * U0[j, i]) / self.__hy ** 2
71 U1[j, i] = (term1 + term2) * self.__D * self.__ht ** 2 / 2 + U0[j, i]
72 return U1
73
74
75 def __calc_U0(self):
76 '''
77 计算第0层时间步数值
78 '''
79 U0 = self.__calc_u0(self.__X, self.__Y)
80 self.__fill_boundary(U0) # 填充边界条件
81 return U0
82
83
84 def __fill_boundary(self, mat):
85 mat[0, :] = 0
86 mat[-1, :] = 0
87 mat[:, 0] = 0
88 mat[:, -1] = 0
89
90
91 def __calc_u0(self, x, y):
92 '''
93 计算初始profile
94 '''
95 u0 = 0.1 * numpy.exp(-((x - self.__xc) ** 2 + (y - self.__yc) ** 2) / 0.001)
96 return u0
97
98
99 def __calc_v0(self, x, y):
100 '''
101 计算初始velocity
102 '''
103 v0 = 0 * x
104 return v0
105
106
107 # 动态图绘制
108 class WavePlot(object):
109
110 def __init__(self, waveObj):
111 self.__waveObj = waveObj
112
113
114 def ani_plot(self, aniPath):
115 if not os.path.exists(aniPath):
116 os.mkdir(aniPath)
117
118 imgs = list()
119 for idx, solu in enumerate(self.__waveObj.get_solu()):
120 print(idx)
121 if (idx % 5 == 0):
122 X, Y, Uk = solu
123 fig = plt.figure(figsize=(8, 8))
124 ax1 = plt.subplot()
125 ax1.pcolor(X, Y, Uk[:-1, :-1], cmap="jet", vmin=0)
126 filename = "{}/{}.png".format(aniPath, idx)
127 fig.savefig("{}".format(filename), dpi=80)
128 plt.close()
129 img = Image.open(filename)
130 imgs.append(img)
131 img.save("wave_plot.gif", save_all=True, append_images=imgs, duration=5)
132 shutil.rmtree(aniPath)
133
134
135
136 if __name__ == "__main__":
137 waveObj = WaveEq(300, 300, 1000, 0.1)
138 wpltObj = WavePlot(waveObj)
139 wpltObj.ani_plot("./wave_plot")