Use of Mixed Variable KrigingΒΆ

from smt.applications.mixed_integer import MixedIntegerContext, FLOAT, ORD, ENUM
import numpy as np
from smt.surrogate_models import KRG

import otsmt
from smt.sampling_methods import LHS, Random
Definition of Initial data
xtypes = [ORD, FLOAT, (ENUM, 4)]
xlimits = [[0, 5], [0.0, 4.0], ["blue", "red", "green", "yellow"]]

def ftest(x):
    return (x[:, 0] * x[:, 0] + x[:, 1] * x[:, 1]) * (x[:, 2] + 1)

# context to create consistent DOEs and surrogate
mixint = MixedIntegerContext(xtypes, xlimits)

# DOE for training
lhs = mixint.build_sampling_method(LHS, criterion="ese")

num = mixint.get_unfolded_dimension() * 5
print("DOE point nb = {}".format(num))
xt = lhs(num)
yt = ftest(xt)

# DOE for validation
rand = mixint.build_sampling_method(Random)
xv = rand(50)
yv = ftest(xv)

Out:

DOE point nb = 30
Training of smt model for Mixed variables
sm_mixed = mixint.build_surrogate_model(KRG())
sm_mixed.set_training_values(xt, yt)
sm_mixed.train()
Creation of OpenTurns PythonFunction for prediction
otmixed = otsmt.smt2ot(sm_mixed)
otmixedprediction = otmixed.getPredictionFunction()
otmixedvariances = otmixed.getConditionalVarianceFunction()

print('Predicted values by Mixed Integer:',otmixedprediction(xv))
print('Predicted variances values by Mixed Integer:',otmixedvariances(xv))

Out:

___________________________________________________________________________

 Evaluation

      # eval points. : 50

   Predicting ...
   Predicting - done. Time (sec):  0.0003211

   Prediction time/pt. (sec) :  0.0000064

Predicted values by Mixed Integer:      [ y0        ]
 0 : [  28.8947  ]
 1 : [  -1.37485 ]
 2 : [  12.4488  ]
 3 : [  15.4657  ]
 4 : [   9.14723 ]
 5 : [  31.8674  ]
 6 : [  10.2323  ]
 7 : [   6.72685 ]
 8 : [  65.8473  ]
 9 : [  18.37    ]
10 : [ 143.44    ]
11 : [   8.99577 ]
12 : [  78.4327  ]
13 : [  76.9253  ]
14 : [  65.0013  ]
15 : [  32.2487  ]
16 : [  38.7218  ]
17 : [   7.6667  ]
18 : [  63.503   ]
19 : [  49.9074  ]
20 : [  14.3401  ]
21 : [  35.1745  ]
22 : [  18.9982  ]
23 : [  10.5035  ]
24 : [   8.80002 ]
25 : [  16.7934  ]
26 : [  29.4148  ]
27 : [  51.8151  ]
28 : [  23.0123  ]
29 : [  17.0259  ]
30 : [  43.1987  ]
31 : [  19.7929  ]
32 : [  46.9404  ]
33 : [  50.3424  ]
34 : [  24.7844  ]
35 : [  14.6656  ]
36 : [   4.95574 ]
37 : [  62.9422  ]
38 : [  17.4567  ]
39 : [  41.0852  ]
40 : [  44.338   ]
41 : [  19.0933  ]
42 : [  11.9112  ]
43 : [  53.518   ]
44 : [  12.3303  ]
45 : [  57.5819  ]
46 : [  47.8387  ]
47 : [  50.3167  ]
48 : [ 111.27    ]
49 : [   9.04379 ]
Predicted variances values by Mixed Integer:      [ y0           ]
 0 : [  0.00634394  ]
 1 : [  0.435442    ]
 2 : [  0.0292428   ]
 3 : [  0.435312    ]
 4 : [  0.906689    ]
 5 : [  0.0107925   ]
 6 : [  0.00110412  ]
 7 : [  0.19858     ]
 8 : [  0.00519545  ]
 9 : [  0.0063098   ]
10 : [  0.000413624 ]
11 : [  0.000462281 ]
12 : [  0.0262213   ]
13 : [  0.136762    ]
14 : [  0.0179971   ]
15 : [  4.29155     ]
16 : [  0.047188    ]
17 : [  0.00814363  ]
18 : [  0.063756    ]
19 : [  0.100801    ]
20 : [  5.2714      ]
21 : [  9.11241e-08 ]
22 : [  0.195747    ]
23 : [ 12.177       ]
24 : [  0.00453485  ]
25 : [  0.00758543  ]
26 : [  0.00998605  ]
27 : [  0.00631828  ]
28 : [  0.000565737 ]
29 : [  0.187824    ]
30 : [  5.55059e-05 ]
31 : [  0.00115969  ]
32 : [  0.000936271 ]
33 : [  0.168667    ]
34 : [  0.0153236   ]
35 : [  0.0612026   ]
36 : [  4.60271     ]
37 : [  0.16814     ]
38 : [  0.0010821   ]
39 : [  1.61933e-08 ]
40 : [  1.61538e-06 ]
41 : [  0.862353    ]
42 : [  0.0290022   ]
43 : [  0.0688856   ]
44 : [  0.0393679   ]
45 : [  0.00117549  ]
46 : [  6.26995e-07 ]
47 : [  0.169151    ]
48 : [  0.0907524   ]
49 : [  2.0173e-05  ]

Total running time of the script: ( 0 minutes 1.417 seconds)

Gallery generated by Sphinx-Gallery