Use of GEKPLSΒΆ

from smt.sampling_methods import LHS
from smt.problems import Sphere
from smt.surrogate_models import GEKPLS
import numpy as np
import otsmt
import openturns as ot
Definition of Initial data
# Construction of the DOE
fun = Sphere(ndim=2)
sampling = LHS(xlimits=fun.xlimits, criterion="m")
xt = sampling(40)
yt = fun(xt)
# Compute the gradient
for i in range(2):
    yd = fun(xt, kx=i)
    yt = np.concatenate((yt, yd), axis=1)

xv = ot.Sample([[0.1,1.],[1.,2.]])
Training of smt model for GEKPLS
n_comp = 2
sm_gekpls = GEKPLS(
                    theta0=[1e-2] * n_comp,
                    xlimits=fun.xlimits,
                    extra_points=1,
                    print_prediction=False,
                    n_comp=n_comp,
                    )
sm_gekpls.set_training_values(xt, yt[:, 0][:,np.newaxis])

for i in range(2):
   sm_gekpls.set_training_derivatives(xt, yt[:, 1 + i].reshape((yt.shape[0], 1)), i)
sm_gekpls.train()

Out:

___________________________________________________________________________

                                  GEKPLS
___________________________________________________________________________

 Problem size

      # training points.        : 40

___________________________________________________________________________

 Training

   Training ...
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
/usr/share/miniconda3/envs/test/lib/python3.9/site-packages/scikit_learn-1.1.1-py3.9-linux-x86_64.egg/sklearn/cross_decomposition/_pls.py:304: UserWarning: Y residual is constant at iteration 1
  warnings.warn(f"Y residual is constant at iteration {k}")
   Training - done. Time (sec):  0.1289790
Creation of OpenTurns PythonFunction for prediction
otgekpls = otsmt.smt2ot(sm_gekpls)
otgekplsprediction = otgekpls.getPredictionFunction()
otgekplsvariances = otgekpls.getConditionalVarianceFunction()
otgekplsgradient = otgekpls.getPredictionDerivativesFunction()

print('Predicted values by GEKPLS:',otgekplsprediction(xv))
print('Predicted variances values by GEKPLS:',otgekplsvariances(xv))
print('Prediction derivatives by GEKPLS:',otgekplsgradient(xv))

Out:

Predicted values by GEKPLS:     [ y0      ]
0 : [ 1.00997 ]
1 : [ 5.00001 ]
Predicted variances values by GEKPLS:     [ y0         ]
0 : [ 1.4479e-09 ]
1 : [ 1.7761e-09 ]
Prediction derivatives by GEKPLS:     [ y0       y1       ]
0 : [ 0.200015 2.00001  ]
1 : [ 2.00003  4.00002  ]

Total running time of the script: ( 0 minutes 0.134 seconds)

Gallery generated by Sphinx-Gallery