add project2
This commit is contained in:
160
project2/problem1.py
Normal file
160
project2/problem1.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from iterative_solvers import *
|
||||
|
||||
# (a): Exact Solution
|
||||
def exact_solution(t: np.ndarray) -> np.ndarray:
|
||||
"""Computes the exact solution u(t) = t/6 * (1 - t^2)."""
|
||||
return t / 6.0 * (1 - t**2)
|
||||
|
||||
def plot_exact_solution():
|
||||
"""Plots the exact solution of the BVP."""
|
||||
t_fine = np.linspace(0, 1, 500)
|
||||
u_exact_fine = exact_solution(t_fine)
|
||||
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.plot(t_fine, u_exact_fine, 'k-', label='Exact Solution $u(t) = t/6 \cdot (1 - t^2)$')
|
||||
plt.title('Exact Solution of $-u\'\'=t$ with $u(0)=u(1)=0$')
|
||||
plt.xlabel('$t$')
|
||||
plt.ylabel('$u(t)$')
|
||||
plt.grid(True)
|
||||
plt.legend()
|
||||
plt.savefig('problem1_exact_solution.png')
|
||||
# plt.show()
|
||||
print("Exact solution plot saved to 'problem1_exact_solution.png'")
|
||||
|
||||
|
||||
# (b): Helper functions for iterative methods
|
||||
def mat_vec_prod_p1(x: np.ndarray, n: int) -> np.ndarray:
|
||||
"""Computes the matrix-vector product Ax for Problem 1."""
|
||||
y = np.zeros(n)
|
||||
# A is tridiagonal with [ -1, 2, -1 ]
|
||||
# y_i = -x_{i-1} + 2x_i - x_{i+1}
|
||||
for i in range(n):
|
||||
term_prev = -x[i-1] if i > 0 else 0
|
||||
term_next = -x[i+1] if i < n - 1 else 0
|
||||
y[i] = term_prev + 2 * x[i] + term_next
|
||||
return y
|
||||
|
||||
def get_diag_inv_p1(i: int, n: int) -> float:
|
||||
"""Returns 1/A_ii for Problem 1. Diagonal elements are all 2, so inverse is 1/2."""
|
||||
return 1.0 / 2.0
|
||||
|
||||
def get_off_diag_sum_p1(x: np.ndarray, i: int, n: int) -> float:
|
||||
"""Computes sum_{j!=i} A_ij * x_j. For Problem 1, A_ij = -1 for j=i-1, i+1, else 0."""
|
||||
sum_val = 0.0
|
||||
if i > 0:
|
||||
sum_val += -x[i-1]
|
||||
if i < n - 1:
|
||||
sum_val += -x[i+1]
|
||||
return sum_val
|
||||
|
||||
def solve_and_report(n: int):
|
||||
"""
|
||||
Solves the linear system for a given n and reports results.
|
||||
"""
|
||||
print(f"\n--- Solving for n = {n} ---")
|
||||
h = 1.0 / (n + 1)
|
||||
|
||||
# Setup the right-hand side vector b
|
||||
t = np.linspace(h, 1.0 - h, n)
|
||||
b = (h**2) * t
|
||||
|
||||
# Calculate SOR parameter omega
|
||||
omega = 2.0 / (1.0 + np.sin(np.pi * h))
|
||||
|
||||
# Store results
|
||||
iteration_counts = []
|
||||
all_residuals = {}
|
||||
all_solutions = {}
|
||||
# Define methods and run solvers
|
||||
methods_to_run = {
|
||||
'Jacobi': ('jacobi', {}),
|
||||
'Gauss-Seidel': ('gauss_seidel', {}),
|
||||
'SOR': ('sor', {'omega': omega}),
|
||||
'Steepest Descent': (steepest_descent, {}),
|
||||
'Conjugate Gradient': (conjugate_gradient, {})
|
||||
}
|
||||
|
||||
print("\nStarting iterative solvers...")
|
||||
for name, (method_func, params) in methods_to_run.items():
|
||||
print(f"Running {name}...")
|
||||
common_args = {'b': b, 'n': n}
|
||||
if isinstance(method_func, str):
|
||||
solver_args = { 'method': method_func, 'get_diag_inv': get_diag_inv_p1,
|
||||
'get_off_diag_sum': get_off_diag_sum_p1, 'mat_vec_prod': mat_vec_prod_p1,
|
||||
**common_args, **params }
|
||||
solution, iters, res_hist = stationary_method(**solver_args)
|
||||
else:
|
||||
solver_args = {'mat_vec_prod': mat_vec_prod_p1, **common_args}
|
||||
solution, iters, res_hist = method_func(**solver_args)
|
||||
iteration_counts.append((name, iters))
|
||||
all_residuals[name] = res_hist
|
||||
all_solutions[name] = solution if iters < MAX_ITER else None
|
||||
# (c): Report iteration counts
|
||||
print("\n--- Iteration Counts ---")
|
||||
for name, iters in iteration_counts:
|
||||
status = "converged" if iters < MAX_ITER else "did NOT converge"
|
||||
print(f"{name:<20}: {iters} iterations ({status}, last residual = {all_residuals[name][-1]:.2e})")
|
||||
# (d): Plot residue history
|
||||
plt.figure(figsize=(12, 8))
|
||||
for name, res_hist in all_residuals.items():
|
||||
res_log = np.log10(np.array(res_hist) + 1e-20)
|
||||
plt.plot(res_log, label=name, linestyle='-')
|
||||
|
||||
plt.title(f'Convergence History for n = {n}')
|
||||
plt.xlabel('Iteration Step (m)')
|
||||
plt.ylabel('log10(||r^(m)||_2)')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
|
||||
# Set the y-axis limit to make the plot readable
|
||||
current_ylim = plt.ylim()
|
||||
plt.ylim(bottom=max(current_ylim[0], -13), top=current_ylim[1])
|
||||
plt.savefig(f'problem1_convergence_n{n}.png')
|
||||
# plt.show()
|
||||
print(f"\nConvergence plot saved to 'problem1_convergence_n{n}.png'")
|
||||
|
||||
if all_solutions:
|
||||
plt.figure(figsize=(12, 8))
|
||||
|
||||
# Plot exact solution on a fine grid
|
||||
t_fine = np.linspace(0, 1, 500)
|
||||
u_exact_fine = exact_solution(t_fine)
|
||||
plt.plot(t_fine, u_exact_fine, 'k-', label='Exact Solution', linewidth=1, zorder=10) # Black, thick, on top
|
||||
|
||||
# Define some styles for the different numerical methods
|
||||
styles = {
|
||||
'Jacobi': {'color': 'red', 'linestyle': '--'},
|
||||
'Gauss-Seidel': {'color': 'blue', 'linestyle': '-.'},
|
||||
'SOR': {'color': 'green', 'linestyle': ':'},
|
||||
'Steepest Descent':{'color': 'purple','linestyle': '--'},
|
||||
'Conjugate Gradient':{'color': 'orange','linestyle': '-.'},
|
||||
}
|
||||
|
||||
# Plot each numerical solution
|
||||
for name, solution in all_solutions.items():
|
||||
# Add boundary points for a complete plot
|
||||
t_numerical_with_bounds = np.concatenate(([0], t, [1]))
|
||||
u_numerical_with_bounds = np.concatenate(([0], solution, [0]))
|
||||
|
||||
style = styles.get(name, {'color': 'gray', 'linestyle': '-'})
|
||||
plt.plot(t_numerical_with_bounds, u_numerical_with_bounds,
|
||||
label=f'Numerical ({name})', **style, linewidth=2)
|
||||
|
||||
plt.title(f'Comparison of Exact and Numerical Solutions for n = {n}')
|
||||
plt.xlabel('$t$')
|
||||
plt.ylabel('$u(t)$')
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.savefig(f'problem1_all_solutions_comparison_n{n}.png')
|
||||
# plt.show()
|
||||
print(f"All solutions comparison plot saved to 'problem1_all_solutions_comparison_n{n}.png'")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Part (a)
|
||||
plot_exact_solution()
|
||||
|
||||
# Part (b), (c), (d)
|
||||
solve_and_report(n=20)
|
||||
solve_and_report(n=40)
|
||||
Reference in New Issue
Block a user