Note
Click here to download the full example code
Use of GEKPLSΒΆ
from smt.sampling_methods import LHS
from smt.problems import Sphere
from smt.surrogate_models import GEKPLS
import numpy as np
import otsmt
import openturns as ot
Definition of Initial data
Training of smt model for GEKPLS
n_comp = 2
sm_gekpls = GEKPLS(
                    theta0=[1e-2] * n_comp,
                    xlimits=fun.xlimits,
                    extra_points=1,
                    print_prediction=False,
                    n_comp=n_comp,
                    )
sm_gekpls.set_training_values(xt, yt[:, 0][:,np.newaxis])
for i in range(2):
   sm_gekpls.set_training_derivatives(xt, yt[:, 1 + i].reshape((yt.shape[0], 1)), i)
sm_gekpls.train()
Out:
___________________________________________________________________________
                                  GEKPLS
___________________________________________________________________________
 Problem size
      # training points.        : 40
___________________________________________________________________________
 Training
   Training ...
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
   Training - done. Time (sec):  0.1289790
Creation of OpenTurns PythonFunction for prediction
otgekpls = otsmt.smt2ot(sm_gekpls)
otgekplsprediction = otgekpls.getPredictionFunction()
otgekplsvariances = otgekpls.getConditionalVarianceFunction()
otgekplsgradient = otgekpls.getPredictionDerivativesFunction()
print('Predicted values by GEKPLS:',otgekplsprediction(xv))
print('Predicted variances values by GEKPLS:',otgekplsvariances(xv))
print('Prediction derivatives by GEKPLS:',otgekplsgradient(xv))
Out:
Predicted values by GEKPLS:     [ y0      ]
0 : [ 1.00997 ]
1 : [ 5.00001 ]
Predicted variances values by GEKPLS:     [ y0         ]
0 : [ 1.4479e-09 ]
1 : [ 1.7761e-09 ]
Prediction derivatives by GEKPLS:     [ y0       y1       ]
0 : [ 0.200015 2.00001  ]
1 : [ 2.00003  4.00002  ]
Total running time of the script: ( 0 minutes 0.134 seconds)
      otsmt