Note
Click here to download the full example code
Use of GEKPLS¶
from smt.sampling_methods import LHS
from smt.problems import Sphere
from smt.surrogate_models import GEKPLS
import numpy as np
import otsmt
import openturns as ot
Definition of Initial data
Training of smt model for GEKPLS
n_comp = 2
sm_gekpls = GEKPLS(
theta0=[1e-2] * n_comp,
xlimits=fun.xlimits,
extra_points=1,
print_prediction=False,
n_comp=n_comp,
)
sm_gekpls.set_training_values(xt, yt[:, 0][:,np.newaxis])
for i in range(2):
sm_gekpls.set_training_derivatives(xt, yt[:, 1 + i].reshape((yt.shape[0], 1)), i)
sm_gekpls.train()
Out:
___________________________________________________________________________
GEKPLS
___________________________________________________________________________
Problem size
# training points. : 40
___________________________________________________________________________
Training
Training ...
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
warnings.warn(f"Y residual is constant at iteration {k}")
Training - done. Time (sec): 0.1289790
Creation of OpenTurns PythonFunction for prediction
otgekpls = otsmt.smt2ot(sm_gekpls)
otgekplsprediction = otgekpls.getPredictionFunction()
otgekplsvariances = otgekpls.getConditionalVarianceFunction()
otgekplsgradient = otgekpls.getPredictionDerivativesFunction()
print('Predicted values by GEKPLS:',otgekplsprediction(xv))
print('Predicted variances values by GEKPLS:',otgekplsvariances(xv))
print('Prediction derivatives by GEKPLS:',otgekplsgradient(xv))
Out:
Predicted values by GEKPLS: [ y0 ]
0 : [ 1.00997 ]
1 : [ 5.00001 ]
Predicted variances values by GEKPLS: [ y0 ]
0 : [ 1.4479e-09 ]
1 : [ 1.7761e-09 ]
Prediction derivatives by GEKPLS: [ y0 y1 ]
0 : [ 0.200015 2.00001 ]
1 : [ 2.00003 4.00002 ]
Total running time of the script: ( 0 minutes 0.134 seconds)