Note
Click here to download the full example code
Use of Mixed Variable KrigingΒΆ
from smt.applications.mixed_integer import MixedIntegerContext, FLOAT, ORD, ENUM
import numpy as np
from smt.surrogate_models import KRG
import otsmt
from smt.sampling_methods import LHS, Random
Definition of Initial data
xtypes = [ORD, FLOAT, (ENUM, 4)]
xlimits = [[0, 5], [0.0, 4.0], ["blue", "red", "green", "yellow"]]
def ftest(x):
return (x[:, 0] * x[:, 0] + x[:, 1] * x[:, 1]) * (x[:, 2] + 1)
# context to create consistent DOEs and surrogate
mixint = MixedIntegerContext(xtypes, xlimits)
# DOE for training
lhs = mixint.build_sampling_method(LHS, criterion="ese")
num = mixint.get_unfolded_dimension() * 5
print("DOE point nb = {}".format(num))
xt = lhs(num)
yt = ftest(xt)
# DOE for validation
rand = mixint.build_sampling_method(Random)
xv = rand(50)
yv = ftest(xv)
Out:
DOE point nb = 30
Training of smt model for Mixed variables
sm_mixed = mixint.build_surrogate_model(KRG())
sm_mixed.set_training_values(xt, yt)
sm_mixed.train()
Creation of OpenTurns PythonFunction for prediction
otmixed = otsmt.smt2ot(sm_mixed)
otmixedprediction = otmixed.getPredictionFunction()
otmixedvariances = otmixed.getConditionalVarianceFunction()
print('Predicted values by Mixed Integer:',otmixedprediction(xv))
print('Predicted variances values by Mixed Integer:',otmixedvariances(xv))
Out:
___________________________________________________________________________
Evaluation
# eval points. : 50
Predicting ...
Predicting - done. Time (sec): 0.0003211
Prediction time/pt. (sec) : 0.0000064
Predicted values by Mixed Integer: [ y0 ]
0 : [ 28.8947 ]
1 : [ -1.37485 ]
2 : [ 12.4488 ]
3 : [ 15.4657 ]
4 : [ 9.14723 ]
5 : [ 31.8674 ]
6 : [ 10.2323 ]
7 : [ 6.72685 ]
8 : [ 65.8473 ]
9 : [ 18.37 ]
10 : [ 143.44 ]
11 : [ 8.99577 ]
12 : [ 78.4327 ]
13 : [ 76.9253 ]
14 : [ 65.0013 ]
15 : [ 32.2487 ]
16 : [ 38.7218 ]
17 : [ 7.6667 ]
18 : [ 63.503 ]
19 : [ 49.9074 ]
20 : [ 14.3401 ]
21 : [ 35.1745 ]
22 : [ 18.9982 ]
23 : [ 10.5035 ]
24 : [ 8.80002 ]
25 : [ 16.7934 ]
26 : [ 29.4148 ]
27 : [ 51.8151 ]
28 : [ 23.0123 ]
29 : [ 17.0259 ]
30 : [ 43.1987 ]
31 : [ 19.7929 ]
32 : [ 46.9404 ]
33 : [ 50.3424 ]
34 : [ 24.7844 ]
35 : [ 14.6656 ]
36 : [ 4.95574 ]
37 : [ 62.9422 ]
38 : [ 17.4567 ]
39 : [ 41.0852 ]
40 : [ 44.338 ]
41 : [ 19.0933 ]
42 : [ 11.9112 ]
43 : [ 53.518 ]
44 : [ 12.3303 ]
45 : [ 57.5819 ]
46 : [ 47.8387 ]
47 : [ 50.3167 ]
48 : [ 111.27 ]
49 : [ 9.04379 ]
Predicted variances values by Mixed Integer: [ y0 ]
0 : [ 0.00634394 ]
1 : [ 0.435442 ]
2 : [ 0.0292428 ]
3 : [ 0.435312 ]
4 : [ 0.906689 ]
5 : [ 0.0107925 ]
6 : [ 0.00110412 ]
7 : [ 0.19858 ]
8 : [ 0.00519545 ]
9 : [ 0.0063098 ]
10 : [ 0.000413624 ]
11 : [ 0.000462281 ]
12 : [ 0.0262213 ]
13 : [ 0.136762 ]
14 : [ 0.0179971 ]
15 : [ 4.29155 ]
16 : [ 0.047188 ]
17 : [ 0.00814363 ]
18 : [ 0.063756 ]
19 : [ 0.100801 ]
20 : [ 5.2714 ]
21 : [ 9.11241e-08 ]
22 : [ 0.195747 ]
23 : [ 12.177 ]
24 : [ 0.00453485 ]
25 : [ 0.00758543 ]
26 : [ 0.00998605 ]
27 : [ 0.00631828 ]
28 : [ 0.000565737 ]
29 : [ 0.187824 ]
30 : [ 5.55059e-05 ]
31 : [ 0.00115969 ]
32 : [ 0.000936271 ]
33 : [ 0.168667 ]
34 : [ 0.0153236 ]
35 : [ 0.0612026 ]
36 : [ 4.60271 ]
37 : [ 0.16814 ]
38 : [ 0.0010821 ]
39 : [ 1.61933e-08 ]
40 : [ 1.61538e-06 ]
41 : [ 0.862353 ]
42 : [ 0.0290022 ]
43 : [ 0.0688856 ]
44 : [ 0.0393679 ]
45 : [ 0.00117549 ]
46 : [ 6.26995e-07 ]
47 : [ 0.169151 ]
48 : [ 0.0907524 ]
49 : [ 2.0173e-05 ]
Total running time of the script: ( 0 minutes 1.417 seconds)