Radial basis functions¶
The radial basis function (RBF) surrogate model represents the interpolating function as a linear combination of basis functions, one for each training point. RBFs are named as such because the basis functions depend only on the distance from the prediction point to the training point for the basis function. The coefficients of the basis functions are computed during the training stage. RBFs are frequently augmented to global polynomials to capture the general trends.
The prediction equation for RBFs is
where \(\mathbf{x} \in \mathbb{R}^{nx}\) is the prediction input vector, \(y \in \mathbb{R}\) is the prediction output, \(\mathbf{xt}_i \in \mathbb{R}^{nx}\) is the input vector for the \(i\) th training point, \(\mathbf{p}(\mathbf{x}) \in \mathbb{R}^{np}\) is the vector mapping the polynomial coefficients to the prediction output, \(\phi(\mathbf{x}, \mathbf{xt}_i) \in \mathbb{R}^{nt}\) is the vector mapping the radial basis function coefficients to the prediction output, \(\mathbf{w_p} \in \mathbb{R}^{np}\) is the vector of polynomial coefficients, and \(\mathbf{w_r} \in \mathbb{R}^{nt}\) is the vector of radial basis function coefficients.
The coefficients, \(\mathbf{w_p}\) and \(\mathbf{w_r}\), are computed by solving the follow linear system:
Only Gaussian basis functions are currently implemented. These are given by:
Usage¶
import numpy as np
import matplotlib.pyplot as plt
from smt.surrogate_models import RBF
xt = np.array([0.0, 1.0, 2.0, 3.0, 4.0])
yt = np.array([0.0, 1.0, 1.5, 0.9, 1.0])
sm = RBF(d0=5)
sm.set_training_values(xt, yt)
sm.train()
num = 100
x = np.linspace(0.0, 4.0, num)
y = sm.predict_values(x)
plt.plot(xt, yt, "o")
plt.plot(x, y)
plt.xlabel("x")
plt.ylabel("y")
plt.legend(["Training data", "Prediction"])
plt.show()
___________________________________________________________________________
RBF
___________________________________________________________________________
Problem size
# training points. : 5
___________________________________________________________________________
Training
Training ...
Initializing linear solver ...
Performing LU fact. (5 x 5 mtx) ...
Performing LU fact. (5 x 5 mtx) - done. Time (sec): 0.0000172
Initializing linear solver - done. Time (sec): 0.0000250
Solving linear system (col. 0) ...
Back solving (5 x 5 mtx) ...
Back solving (5 x 5 mtx) - done. Time (sec): 0.0000079
Solving linear system (col. 0) - done. Time (sec): 0.0000150
Training - done. Time (sec): 0.0001552
___________________________________________________________________________
Evaluation
# eval points. : 100
Predicting ...
Predicting - done. Time (sec): 0.0000069
Prediction time/pt. (sec) : 0.0000001
Options¶
Option |
Default |
Acceptable values |
Acceptable types |
Description |
---|---|---|---|---|
print_global |
True |
None |
[‘bool’] |
Global print toggle. If False, all printing is suppressed |
print_training |
True |
None |
[‘bool’] |
Whether to print training information |
print_prediction |
True |
None |
[‘bool’] |
Whether to print prediction information |
print_problem |
True |
None |
[‘bool’] |
Whether to print problem information |
print_solver |
True |
None |
[‘bool’] |
Whether to print solver information |
d0 |
1.0 |
None |
[‘int’, ‘float’, ‘list’, ‘ndarray’] |
basis function scaling parameter in exp(-d^2 / d0^2) |
poly_degree |
-1 |
[-1, 0, 1] |
[‘int’] |
-1 means no global polynomial, 0 means constant, 1 means linear trend |
data_dir |
None |
None |
[‘str’] |
Directory for loading / saving cached data; None means do not save or load |
reg |
1e-10 |
None |
[‘int’, ‘float’] |
Regularization coeff. |
max_print_depth |
5 |
None |
[‘int’] |
Maximum depth (level of nesting) to print operation descriptions and times |