The Finite Difference Method: 1D steady state heat transfer#
These examples are based on code originally written by Krzysztof Fidkowski and adapted by Venkat Viswanathan.
import jax
import jax.numpy as jnp
from jax import grad, jit
from jax.scipy.linalg import solve
from scipy.optimize import minimize
jax.config.update("jax_enable_x64", True) # Enable 64-bit precision for better numerical stability
import matplotlib.pyplot as plt
import niceplots
plt.style.use(niceplots.get_style())
# Force the jupyter notebook to use vector graphics
import matplotlib_inline
matplotlib_inline.backend_inline.set_matplotlib_formats("pdf", "svg")
This code implements the example case from section 3.1.1 of the course notes.
We will solve the steady state heat transfer equation in 1D:
where
# Define the symbolic function q(x)
def q(x, L):
return jnp.sin(jnp.pi * x / L)
The 1D domain spans
Using the central difference approximation for the second derivative, we can write the equation at each node as:
We will enforce this equation at all nodes except the boundary nodes, which have known temperatures. This gives
Writing this all up in matrix-form we get:
Alternatively, we can keep the boundary nodes in the matrix equation, and enforce the boundary conditions using the first and last rows of the matrix:
The code below implements the second approach, and solves the matrix equation using a direct sparse linear solver.
def heat_conduction(T_left=1.0, T_right=4.0, L=2.0, kappa=0.5, Nx=10):
"""Setup and solve the heat conduction problem using the finite difference method
Parameters
----------
T_left : float, optional
Left boundary temperature, by default 1.0
T_right : float, optional
Right boundary temperature, by default 4.0
L : float, optional
Length of domain, by default 2.0
kappa : float, optional
Thermal conductivity, by default 0.5
Nx : int, optional
Number of intervals, by default 10
Returns
-------
array
Nodal temperature values
array
Nodal x coordinates
"""
# Define the parameters
dx = L / Nx # Grid spacing
# Create the matrix A (tridiagonal with 2 on the diagonal and -1 on the off-diagonals)
diagonal = 2.0 * jnp.ones(Nx - 1)
off_diagonal = -jnp.ones(Nx - 2)
A = jnp.diag(diagonal) + jnp.diag(off_diagonal, k=1) + jnp.diag(off_diagonal, k=-1)
# Add the Boundary conditions
A = jnp.vstack((A, jnp.zeros(Nx - 1)))
A = jnp.vstack((jnp.zeros(Nx - 1), A))
A = jnp.column_stack([jnp.zeros(A.shape[0]), A])
A = jnp.column_stack([A, jnp.zeros(A.shape[0])])
A = A.at[(0, 0)].set(1.0)
A = A.at[(-1, -1)].set(1.0) # Set the bottom-right diagonal element to 1
A = A.at[(1, 0)].set(-1.0)
A = A.at[(-2, -1)].set(-1.0) # Set the bottom-right diagonal element to 1
# Create the vector representing the heat source (modify q(x) as needed)
x_values = jnp.linspace(0, L, Nx + 1)
b = jnp.zeros(Nx + 1)
b = b.at[1:Nx].set(q(x_values[1:Nx], L) / kappa * dx**2)
# Define boundary conditions (e.g., fixed temperature at both ends)
T_left = 1.0
T_right = 4.0
b = b.at[0].set(T_left)
b = b.at[-1].set(T_right)
T = solve(A, b)
return T, x_values
Now let’s solve the system for
# Define the true solution
def true_solution(x, L, kappa, T_left, T_right):
return L**2 / (jnp.pi**2 * kappa) * jnp.sin(jnp.pi * x / L) + T_left + (T_right - T_left) * x / L
# Define the parameters
L = 2.0 # Length of domain
kappa = 0.5 # Thermal conductivity
Nx = 3 # Number of intervals
T0 = 1.0 # Left boundary condition
TN = 4.0 # Right boundary condition
# Solve the finite difference problem
T_soln, x_vals = heat_conduction(T0, TN, L, kappa, Nx)
# Plot the results against the true solution
fig, ax = plt.subplots()
xTrue = jnp.linspace(0, 2.0, 100)
ax.plot(xTrue, true_solution(xTrue, L, kappa, T0, TN), "-", clip_on=False, label="True solution")
ax.plot(x_vals, T_soln, "-o", clip_on=False, label="FD solution")
ax.set_xlabel("$x$")
ax.set_xticks([0, L])
ax.set_ylabel("$T$")
ax.set_yticks([T0, TN])
ax.legend(labelcolor="linecolor", loc="lower right")
niceplots.adjust_spines(ax)
plt.show()
Even with only 3 intervals, the solution is already quite close to the exact solution.
Convergence study#
Since we are using a second-order accurate approximation of
The error is computed using the
This is not equivalent to the typical
Nsweep = 2 ** jnp.arange(3, 11)
errors = []
for Nx in Nsweep:
T_soln, x_vals = heat_conduction(T0, TN, L, kappa, Nx)
error = jnp.sqrt(1 / (Nx + 1) * jnp.sum(jnp.square(T_soln - true_solution(x_vals, L, kappa, T0, TN))))
errors.append(error)
fig, ax = plt.subplots()
ax.set_xlabel("$\Delta x$")
ax.set_ylabel(r"$||T_{FD} - T_{true}||_2$")
ax.set_xscale("log")
ax.set_yscale("log")
ax.plot(L / Nsweep, errors, "-o", clip_on=False)
# Compute the convergence rate by fitting a line to the log-log plot
rate = jnp.polyfit(jnp.log(L / Nsweep), jnp.log(jnp.array(errors)), 1)[0]
ax.annotate(f"Convergence rate: {float(rate):.2f}", xy=(L / Nsweep[4], errors[5]), ha="left", va="top")
niceplots.adjust_spines(ax)
<>:10: SyntaxWarning: invalid escape sequence '\D'
<>:10: SyntaxWarning: invalid escape sequence '\D'
/tmp/ipykernel_2487/2426008397.py:10: SyntaxWarning: invalid escape sequence '\D'
ax.set_xlabel("$\Delta x$")
Learning thermal conductivity from measurement data#
Now we will leverage the power of JAX to solve an inverse problem: given some measured temperatures, we will try to find the value of the thermal conductivity that best matches the data.
In this case, we will generate our “experimental data” by running the FD code with a given value of
Nx = 20
measured_temps, measurement_locations = heat_conduction(T0, TN, L, 0.5, Nx)
Next we define our “loss” or “objective” function, this is the function that we will try to minimize. In this case, we will use the
def obj_function(kappa):
predicted_temps = heat_conduction(T0, TN, L, kappa, Nx)[0]
error = jnp.sum(jnp.square(predicted_temps - measured_temps))
# Print the current kappa and error, the try/except block is so that the printing is skipped when JAX AD's this function
try:
print(f"kappa = {float(kappa[0]): .7e}, error = {float(error): .7e}")
except TypeError:
pass
return error
obj_grad = grad(obj_function)
Now we will use an algorithm called an optimizer to find the value of
A popular optimization algorithm used is gradient-descent. There are a few different versions of gradient descent, adam being one of the most popular in machine learning, but here we will use the simplest one. At each iteration we compute the gradient of the objective function and then simply take a step in the downhill direction, scaled by a value
def gradient_descent(f, grad_f, x0, step_size=1e-2, max_iter=500, tol=1e-6):
x = x0
converged = False
dfdx = 1
x_hist=[]
for ii in range(max_iter):
func_val = f(x)
dfdx = grad_f(x)
if jnp.abs(func_val) < tol:
converged = True
break
x = x - step_size * dfdx
x_hist.append(x)
if converged:
print(f"Converged after {ii} iterations")
else:
print(f"Did not converge after {max_iter} iterations")
return x, func_val, x_hist
kappa_grad_descent, error_grad_descent, kappa_hist = gradient_descent(obj_function, obj_grad, jnp.array([2.34]))
kappa = 2.3400000e+00, error = 4.0791601e+00
kappa = 2.3305259e+00, error = 4.0701527e+00
kappa = 2.3209852e+00, error = 4.0610179e+00
kappa = 2.3113767e+00, error = 4.0517525e+00
kappa = 2.3016992e+00, error = 4.0423531e+00
kappa = 2.2919515e+00, error = 4.0328165e+00
kappa = 2.2821322e+00, error = 4.0231392e+00
kappa = 2.2722402e+00, error = 4.0133174e+00
kappa = 2.2622741e+00, error = 4.0033475e+00
kappa = 2.2522325e+00, error = 3.9932255e+00
kappa = 2.2421139e+00, error = 3.9829473e+00
kappa = 2.2319170e+00, error = 3.9725089e+00
kappa = 2.2216401e+00, error = 3.9619058e+00
kappa = 2.2112818e+00, error = 3.9511334e+00
kappa = 2.2008405e+00, error = 3.9401872e+00
kappa = 2.1903145e+00, error = 3.9290620e+00
kappa = 2.1797020e+00, error = 3.9177530e+00
kappa = 2.1690014e+00, error = 3.9062547e+00
kappa = 2.1582109e+00, error = 3.8945616e+00
kappa = 2.1473285e+00, error = 3.8826680e+00
kappa = 2.1363523e+00, error = 3.8705679e+00
kappa = 2.1252803e+00, error = 3.8582550e+00
kappa = 2.1141105e+00, error = 3.8457227e+00
kappa = 2.1028407e+00, error = 3.8329644e+00
kappa = 2.0914686e+00, error = 3.8199728e+00
kappa = 2.0799921e+00, error = 3.8067405e+00
kappa = 2.0684087e+00, error = 3.7932598e+00
kappa = 2.0567159e+00, error = 3.7795224e+00
kappa = 2.0449113e+00, error = 3.7655200e+00
kappa = 2.0329921e+00, error = 3.7512435e+00
kappa = 2.0209556e+00, error = 3.7366837e+00
kappa = 2.0087989e+00, error = 3.7218307e+00
kappa = 1.9965192e+00, error = 3.7066743e+00
kappa = 1.9841133e+00, error = 3.6912035e+00
kappa = 1.9715780e+00, error = 3.6754072e+00
kappa = 1.9589100e+00, error = 3.6592733e+00
kappa = 1.9461058e+00, error = 3.6427894e+00
kappa = 1.9331619e+00, error = 3.6259421e+00
kappa = 1.9200743e+00, error = 3.6087175e+00
kappa = 1.9068393e+00, error = 3.5911010e+00
kappa = 1.8934528e+00, error = 3.5730770e+00
kappa = 1.8799104e+00, error = 3.5546291e+00
kappa = 1.8662077e+00, error = 3.5357399e+00
kappa = 1.8523400e+00, error = 3.5163913e+00
kappa = 1.8383025e+00, error = 3.4965636e+00
kappa = 1.8240900e+00, error = 3.4762364e+00
kappa = 1.8096972e+00, error = 3.4553877e+00
kappa = 1.7951184e+00, error = 3.4339943e+00
kappa = 1.7803479e+00, error = 3.4120316e+00
kappa = 1.7653793e+00, error = 3.3894732e+00
kappa = 1.7502062e+00, error = 3.3662912e+00
kappa = 1.7348218e+00, error = 3.3424557e+00
kappa = 1.7192189e+00, error = 3.3179347e+00
kappa = 1.7033898e+00, error = 3.2926941e+00
kappa = 1.6873267e+00, error = 3.2666975e+00
kappa = 1.6710210e+00, error = 3.2399056e+00
kappa = 1.6544638e+00, error = 3.2122765e+00
kappa = 1.6376458e+00, error = 3.1837649e+00
kappa = 1.6205569e+00, error = 3.1543223e+00
kappa = 1.6031866e+00, error = 3.1238963e+00
kappa = 1.5855236e+00, error = 3.0924304e+00
kappa = 1.5675561e+00, error = 3.0598634e+00
kappa = 1.5492714e+00, error = 3.0261292e+00
kappa = 1.5306560e+00, error = 2.9911561e+00
kappa = 1.5116956e+00, error = 2.9548663e+00
kappa = 1.4923748e+00, error = 2.9171749e+00
kappa = 1.4726774e+00, error = 2.8779897e+00
kappa = 1.4525859e+00, error = 2.8372099e+00
kappa = 1.4320816e+00, error = 2.7947255e+00
kappa = 1.4111445e+00, error = 2.7504157e+00
kappa = 1.3897530e+00, error = 2.7041483e+00
kappa = 1.3678843e+00, error = 2.6557779e+00
kappa = 1.3455136e+00, error = 2.6051445e+00
kappa = 1.3226142e+00, error = 2.5520716e+00
kappa = 1.2991577e+00, error = 2.4963648e+00
kappa = 1.2751133e+00, error = 2.4378088e+00
kappa = 1.2504480e+00, error = 2.3761660e+00
kappa = 1.2251265e+00, error = 2.3111732e+00
kappa = 1.1991107e+00, error = 2.2425397e+00
kappa = 1.1723600e+00, error = 2.1699441e+00
kappa = 1.1448313e+00, error = 2.0930327e+00
kappa = 1.1164790e+00, error = 2.0114176e+00
kappa = 1.0872554e+00, error = 1.9246762e+00
kappa = 1.0571115e+00, error = 1.8323543e+00
kappa = 1.0259982e+00, error = 1.7339721e+00
kappa = 9.9386811e-01, error = 1.6290382e+00
kappa = 9.6067932e-01, error = 1.5170761e+00
kappa = 9.2640015e-01, error = 1.3976688e+00
kappa = 8.9101767e-01, error = 1.2705359e+00
kappa = 8.5455032e-01, error = 1.1356573e+00
kappa = 8.1706754e-01, error = 9.9346933e-01
kappa = 7.7871936e-01, error = 8.4516041e-01
kappa = 7.3977980e-01, error = 6.9308591e-01
kappa = 7.0070723e-01, error = 5.4127709e-01
kappa = 6.6221970e-01, error = 3.9588569e-01
kappa = 6.2536744e-01, error = 2.6513423e-01
kappa = 5.9154961e-01, error = 1.5801468e-01
kappa = 5.6237199e-01, error = 8.1151897e-02
kappa = 5.3923615e-01, error = 3.4928627e-02
kappa = 5.2272732e-01, error = 1.2471334e-02
kappa = 5.1222973e-01, error = 3.7607226e-03
kappa = 5.0622644e-01, error = 9.9806100e-04
kappa = 5.0305999e-01, error = 2.4410046e-04
kappa = 5.0147427e-01, error = 5.7019420e-05
kappa = 5.0070301e-01, error = 1.3005740e-05
kappa = 5.0033353e-01, error = 2.9317722e-06
kappa = 5.0015785e-01, error = 6.5713316e-07
Converged after 106 iterations
plt.plot(kappa_hist)
plt.show()