Robot - Feature engineering on sensor data¶
The purpose of this notebook is to illustrate how we can overcome the feature explosion problem based on an example dataset involving sensor data.
Summary:
- Prediction type: Regression
- Domain: Robotics
- Prediction target: The force vector on the robot's arm
- Population size: 15001
Author: Dr. Patrick Urbanke
The data set¶
To illustrate the problem, we use a data set related to robotics. When robots interact with humans, the most important think is that they don't hurt people. In order to prevent such accidents, the force vector on the robot's arm is measured. However, measuring the force vector is expensive.
Therefore, we want consider an alternative approach. We would like to predict the force vector based on other sensor data that are less costly to measure. To do so, we use machine learning.
However, the data set contains measurements from almost 100 different sensors and we do not know which and how many sensors are relevant for predicting the force vector.
The data set has been generously provided by Erik Berger who originally collected it for his dissertation:
Berger, E. (2018). Behavior-Specific Proprioception Models for Robotic Force Estimation: A Machine Learning Approach. Freiberg, Germany: Technische Universitaet Bergakademie Freiberg.
Analysis¶
1. Loading data¶
We begin by importing the libraries and setting the project.
%pip install -q "getml==1.4.0" "numpy<2.0.0" "matplotlib~=3.9"
import getml
import matplotlib.pyplot as plt
%matplotlib inline
getml.engine.launch()
getml.engine.set_project('robot')
Note: you may need to restart the kernel to use updated packages. Launching ./getML --allow-push-notifications=true --allow-remote-ips=false --home-directory=/home/alex/.local/lib/python3.10/site-packages/getml --in-memory=true --install=false --launch-browser=true --log=false in /home/alex/.local/lib/python3.10/site-packages/getml/.getML/getml-1.4.0-x64-community-edition-linux... Launched the getML engine. The log output will be stored in /home/alex/.getML/logs/20240807164517.log. Connected to project 'robot'
1.1 Download from source¶
data_all = getml.data.DataFrame.from_csv(
"https://static.getml.com/datasets/robotarm/robot-demo.csv",
"data_all"
)
Downloading robot-demo.csv... 100% |██████████| [elapsed: 00:01, remaining: 00:00]
data_all
name | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | f_x | f_y | f_z |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
role | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float | unused_float |
0 | 3.4098 | -0.3274 | 0.9604 | -3.7436 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8.38e-17 | -4.8116 | -1.4033 | -0.1369 | 0.002472 | 0 | 9.803e-16 | -55.642 | -16.312 | -1.2042 | 0.02167 | 0 | 3.4098 | -0.3274 | 0.9605 | -3.7437 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0.1233 | -6.5483 | -2.8045 | -0.8296 | 0.07625 | -0.1906 | 0.1211 | -6.5483 | -2.8157 | -0.8281 | 0.07015 | -0.1983 | 0.7699 | 0.41 | 0.08279 | -1.4094 | 0.786 | -0.3682 | 0 | 0 | 0 | 0 | 0 | 0 | -22.654 | -11.503 | -18.673 | -3.5155 | 5.8354 | -2.05 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | 48.069 | 48.009 | 0.9668 | 47.834 | 47.925 | 47.818 | 47.834 | 47.955 | 47.971 | -11.03 | 6.9 | -7.33 |
1 | 3.4098 | -0.3274 | 0.9604 | -3.7436 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8.38e-17 | -4.8116 | -1.4033 | -0.1369 | 0.002472 | 0 | 9.803e-16 | -55.642 | -16.312 | -1.2042 | 0.02167 | 0 | 3.4098 | -0.3274 | 0.9604 | -3.7437 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0.1188 | -6.5506 | -2.8404 | -0.8281 | 0.06405 | -0.1998 | 0.1211 | -6.5483 | -2.8157 | -0.8281 | 0.07015 | -0.1983 | 0.7699 | 0.41 | 0.0828 | -1.4094 | 0.7859 | -0.3682 | 0 | 0 | 0 | 0 | 0 | 0 | -21.627 | -11.046 | -18.66 | -3.5395 | 5.7577 | -1.9805 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | 48.009 | 48.009 | 0.8594 | 47.834 | 47.925 | 47.818 | 47.834 | 47.955 | 47.971 | -10.848 | 6.7218 | -7.4427 |
2 | 3.4098 | -0.3274 | 0.9604 | -3.7436 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8.38e-17 | -4.8116 | -1.4033 | -0.1369 | 0.002472 | 0 | 9.803e-16 | -55.642 | -16.312 | -1.2042 | 0.02167 | 0 | 3.4098 | -0.3274 | 0.9605 | -3.7437 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0.1099 | -6.5438 | -2.8 | -0.8205 | 0.07473 | -0.183 | 0.1211 | -6.5483 | -2.8157 | -0.8281 | 0.07015 | -0.1922 | 0.7699 | 0.41 | 0.08279 | -1.4094 | 0.7859 | -0.3682 | 0 | 0 | 0 | 0 | 0 | 0 | -23.843 | -12.127 | -18.393 | -3.6453 | 5.978 | -1.9978 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | 48.009 | 48.069 | 0.931 | 47.879 | 47.925 | 47.818 | 47.834 | 47.955 | 47.971 | -10.666 | 6.5436 | -7.5555 |
3 | 3.4098 | -0.3274 | 0.9604 | -3.7436 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8.38e-17 | -4.8116 | -1.4033 | -0.1369 | 0.002472 | 0 | 9.803e-16 | -55.642 | -16.312 | -1.2042 | 0.02167 | 0 | 3.4098 | -0.3273 | 0.9604 | -3.7437 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0.1233 | -6.5483 | -2.8224 | -0.8266 | 0.07168 | -0.1998 | 0.1211 | -6.5483 | -2.8157 | -0.8281 | 0.07015 | -0.1967 | 0.7699 | 0.41 | 0.08275 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | -21.772 | -10.872 | -18.691 | -3.5512 | 5.6648 | -1.9976 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | 48.069 | 48.069 | 0.931 | 47.879 | 47.925 | 47.818 | 47.834 | 47.955 | 47.971 | -10.507 | 6.4533 | -7.65 |
4 | 3.4098 | -0.3274 | 0.9604 | -3.7436 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8.38e-17 | -4.8116 | -1.4033 | -0.1369 | 0.002472 | 0 | 9.803e-16 | -55.642 | -16.312 | -1.2042 | 0.02167 | 0 | 3.4098 | -0.3274 | 0.9604 | -3.7437 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0.1255 | -6.5394 | -2.8 | -0.8327 | 0.07473 | -0.1952 | 0.1211 | -6.5483 | -2.8157 | -0.8327 | 0.07015 | -0.1922 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | -22.823 | -11.645 | -18.524 | -3.5305 | 5.8712 | -2.0096 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | 48.069 | 48.069 | 0.8952 | 47.879 | 47.925 | 47.818 | 47.834 | 47.955 | 47.971 | -10.413 | 6.6267 | -7.69 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | |
14996 | 3.0837 | -0.8836 | 1.4501 | -2.2102 | -1.559 | -5.3265 | -0.03151 | -0.05375 | 0.04732 | 0.1482 | -0.05218 | 0.06706 | 0.2969 | 0.5065 | -0.4459 | -1.3963 | 0.4916 | -0.6319 | -0.3694 | -4.1879 | -1.1847 | -0.09441 | -0.1568 | 0.1898 | 1.1605 | -42.951 | -19.023 | -2.6343 | 0.1551 | -0.1338 | 3.0836 | -0.8836 | 1.4503 | -2.2101 | -1.5591 | -5.3263 | -0.03347 | -0.05585 | 0.04805 | 0.151 | -0.05513 | 0.07114 | -0.3564 | -6.0394 | -2.3001 | -0.2181 | -0.1159 | 0.09608 | -0.3632 | -6.0394 | -2.3023 | -0.212 | -0.125 | 0.1113 | 0.7116 | 0.06957 | 0.06036 | -0.8506 | 2.9515 | -0.03352 | -0.03558 | -0.03029 | 0.002444 | -0.04208 | 0.1458 | -0.1098 | -0.8784 | -0.07291 | -37.584 | 0.0001132 | -2.1031 | 0.03318 | 0.7117 | 0.0697 | 0.06044 | -0.8511 | 2.951 | -0.03356 | -0.03508 | -0.02849 | 0.001571 | -0.03951 | 0.1442 | -0.1036 | 48.069 | 48.009 | 0.8952 | 47.818 | 47.834 | 47.818 | 47.803 | 47.94 | 47.94 | 10.84 | -1.41 | 16.14 |
14997 | 3.0835 | -0.884 | 1.4505 | -2.2091 | -1.5594 | -5.326 | -0.02913 | -0.0497 | 0.04376 | 0.137 | -0.04825 | 0.062 | 0.2969 | 0.5065 | -0.4459 | -1.3963 | 0.4916 | -0.6319 | -0.3677 | -4.1837 | -1.1874 | -0.09682 | -0.1562 | 0.189 | 1.1592 | -42.937 | -19.023 | -2.6331 | 0.1545 | -0.1338 | 3.0833 | -0.8841 | 1.4507 | -2.209 | -1.5596 | -5.3258 | -0.02909 | -0.04989 | 0.04198 | 0.1481 | -0.05465 | 0.06249 | -0.3161 | -6.1179 | -2.253 | -0.3752 | -0.03965 | 0.08693 | -0.3273 | -6.1022 | -2.2597 | -0.366 | -0.05033 | 0.0915 | 0.7114 | 0.06932 | 0.06039 | -0.8497 | 2.953 | -0.03359 | -0.0335 | -0.02723 | 0.001208 | -0.04242 | 0.1428 | -0.0967 | -2.7137 | 0.8552 | -38.514 | -0.6088 | -3.2383 | -0.9666 | 0.7114 | 0.06948 | 0.06045 | -0.8503 | 2.9525 | -0.03359 | -0.03246 | -0.02633 | 0.001469 | -0.03657 | 0.1333 | -0.09571 | 48.009 | 48.009 | 0.8594 | 47.818 | 47.834 | 47.818 | 47.803 | 47.94 | 47.94 | 10.857 | -1.52 | 15.943 |
14998 | 3.0833 | -0.8844 | 1.4508 | -2.208 | -1.5598 | -5.3256 | -0.02676 | -0.04565 | 0.04019 | 0.1258 | -0.04431 | 0.05695 | 0.2969 | 0.5065 | -0.4459 | -1.3963 | 0.4916 | -0.6319 | -0.3659 | -4.1797 | -1.1901 | -0.09922 | -0.1555 | 0.1881 | 1.1579 | -42.924 | -19.023 | -2.6321 | 0.154 | -0.1338 | 3.0831 | -0.8844 | 1.451 | -2.2078 | -1.56 | -5.3253 | -0.02776 | -0.04382 | 0.03652 | 0.1295 | -0.05064 | 0.04818 | -0.343 | -6.2569 | -2.1566 | -0.3035 | 0.00305 | 0.1434 | -0.3385 | -6.2322 | -2.1589 | -0.302 | -0.00915 | 0.1571 | 0.7111 | 0.06912 | 0.06039 | -0.849 | 2.9544 | -0.0337 | -0.02911 | -0.02589 | 0.001292 | -0.04046 | 0.1246 | -0.08058 | 4.2749 | 1.0128 | -36.412 | -1.2811 | -0.4296 | -1.1013 | 0.7112 | 0.06928 | 0.06046 | -0.8495 | 2.9538 | -0.03362 | -0.02984 | -0.02417 | 0.001364 | -0.03362 | 0.1224 | -0.08786 | 48.009 | 48.009 | 0.931 | 47.818 | 47.834 | 47.818 | 47.803 | 47.94 | 47.94 | 10.89 | -1.74 | 15.55 |
14999 | 3.0831 | -0.8847 | 1.4511 | -2.2071 | -1.5602 | -5.3251 | -0.02438 | -0.0416 | 0.03662 | 0.1147 | -0.04038 | 0.0519 | 0.2969 | 0.5065 | -0.4459 | -1.3963 | 0.4916 | -0.6319 | -0.3642 | -4.1758 | -1.1928 | -0.1016 | -0.1548 | 0.1873 | 1.1568 | -42.912 | -19.023 | -2.6311 | 0.1535 | -0.1338 | 3.0829 | -0.8848 | 1.4513 | -2.2068 | -1.5604 | -5.3249 | -0.02149 | -0.04059 | 0.03417 | 0.1202 | -0.0395 | 0.04178 | -0.4237 | -6.2703 | -2.0939 | -0.302 | -0.01372 | 0.1739 | -0.4125 | -6.2569 | -2.0916 | -0.2943 | -0.02898 | 0.1891 | 0.7109 | 0.06894 | 0.06039 | -0.8484 | 2.9557 | -0.03384 | -0.02738 | -0.01982 | 0.001031 | -0.03028 | 0.1157 | -0.06702 | 11.518 | 1.5002 | -39.314 | -1.8671 | -0.3734 | -0.5733 | 0.7109 | 0.06909 | 0.06047 | -0.8488 | 2.955 | -0.03364 | -0.02721 | -0.02201 | 0.001255 | -0.03067 | 0.1115 | -0.08003 | 48.009 | 48.009 | 0.931 | 47.818 | 47.834 | 47.818 | 47.803 | 47.94 | 47.94 | 11.29 | -1.4601 | 15.743 |
15000 | 3.0829 | -0.885 | 1.4514 | -2.2062 | -1.5605 | -5.3247 | -0.02201 | -0.03755 | 0.03305 | 0.1035 | -0.03645 | 0.04684 | 0.2969 | 0.5065 | -0.4459 | -1.3963 | 0.4916 | -0.6319 | -0.3624 | -4.172 | -1.1955 | -0.1041 | -0.1542 | 0.1864 | 1.1558 | -42.901 | -19.023 | -2.6302 | 0.1531 | -0.1338 | 3.0827 | -0.8851 | 1.4516 | -2.2059 | -1.5607 | -5.3246 | -0.02096 | -0.03808 | 0.02958 | 0.1171 | -0.03289 | 0.03883 | -0.417 | -6.2434 | -2.058 | -0.4102 | -0.04728 | 0.1967 | -0.4237 | -6.2367 | -2.0714 | -0.4163 | -0.0671 | 0.2059 | 0.7107 | 0.06878 | 0.06041 | -0.8478 | 2.9567 | -0.03382 | -0.02535 | -0.01854 | 0.001614 | -0.02421 | 0.11 | -0.06304 | 15.099 | 2.936 | -39.068 | -1.9402 | 0.139 | -0.2674 | 0.7107 | 0.06893 | 0.06048 | -0.8482 | 2.9561 | -0.03367 | -0.02458 | -0.01986 | 0.001142 | -0.0277 | 0.1007 | -0.07221 | 48.009 | 48.069 | 0.8952 | 47.818 | 47.834 | 47.818 | 47.803 | 47.94 | 47.955 | 11.69 | -1.1801 | 15.937 |
15001 rows x 96 columns
memory usage: 11.52 MB
name: data_all
type: getml.DataFrame
1.2 Prepare data for getML¶
The force vector consists of three component (f_x, f_y and f_z), meaning that we have three targets.
data_all.set_role(["f_x", "f_y", "f_z"], getml.data.roles.target)
data_all.set_role(data_all.roles.unused, getml.data.roles.numerical)
This is what the data set looks like:
data_all
name | f_x | f_y | f_z | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
role | target | target | target | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical | numerical |
0 | -11.03 | 6.9 | -7.33 | 3.4098 | -0.3274 | 0.9604 | -3.7436 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8.38e-17 | -4.8116 | -1.4033 | -0.1369 | 0.002472 | 0 | 9.803e-16 | -55.642 | -16.312 | -1.2042 | 0.02167 | 0 | 3.4098 | -0.3274 | 0.9605 | -3.7437 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0.1233 | -6.5483 | -2.8045 | -0.8296 | 0.07625 | -0.1906 | 0.1211 | -6.5483 | -2.8157 | -0.8281 | 0.07015 | -0.1983 | 0.7699 | 0.41 | 0.08279 | -1.4094 | 0.786 | -0.3682 | 0 | 0 | 0 | 0 | 0 | 0 | -22.654 | -11.503 | -18.673 | -3.5155 | 5.8354 | -2.05 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | 48.069 | 48.009 | 0.9668 | 47.834 | 47.925 | 47.818 | 47.834 | 47.955 | 47.971 |
1 | -10.848 | 6.7218 | -7.4427 | 3.4098 | -0.3274 | 0.9604 | -3.7436 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8.38e-17 | -4.8116 | -1.4033 | -0.1369 | 0.002472 | 0 | 9.803e-16 | -55.642 | -16.312 | -1.2042 | 0.02167 | 0 | 3.4098 | -0.3274 | 0.9604 | -3.7437 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0.1188 | -6.5506 | -2.8404 | -0.8281 | 0.06405 | -0.1998 | 0.1211 | -6.5483 | -2.8157 | -0.8281 | 0.07015 | -0.1983 | 0.7699 | 0.41 | 0.0828 | -1.4094 | 0.7859 | -0.3682 | 0 | 0 | 0 | 0 | 0 | 0 | -21.627 | -11.046 | -18.66 | -3.5395 | 5.7577 | -1.9805 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | 48.009 | 48.009 | 0.8594 | 47.834 | 47.925 | 47.818 | 47.834 | 47.955 | 47.971 |
2 | -10.666 | 6.5436 | -7.5555 | 3.4098 | -0.3274 | 0.9604 | -3.7436 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8.38e-17 | -4.8116 | -1.4033 | -0.1369 | 0.002472 | 0 | 9.803e-16 | -55.642 | -16.312 | -1.2042 | 0.02167 | 0 | 3.4098 | -0.3274 | 0.9605 | -3.7437 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0.1099 | -6.5438 | -2.8 | -0.8205 | 0.07473 | -0.183 | 0.1211 | -6.5483 | -2.8157 | -0.8281 | 0.07015 | -0.1922 | 0.7699 | 0.41 | 0.08279 | -1.4094 | 0.7859 | -0.3682 | 0 | 0 | 0 | 0 | 0 | 0 | -23.843 | -12.127 | -18.393 | -3.6453 | 5.978 | -1.9978 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | 48.009 | 48.069 | 0.931 | 47.879 | 47.925 | 47.818 | 47.834 | 47.955 | 47.971 |
3 | -10.507 | 6.4533 | -7.65 | 3.4098 | -0.3274 | 0.9604 | -3.7436 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8.38e-17 | -4.8116 | -1.4033 | -0.1369 | 0.002472 | 0 | 9.803e-16 | -55.642 | -16.312 | -1.2042 | 0.02167 | 0 | 3.4098 | -0.3273 | 0.9604 | -3.7437 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0.1233 | -6.5483 | -2.8224 | -0.8266 | 0.07168 | -0.1998 | 0.1211 | -6.5483 | -2.8157 | -0.8281 | 0.07015 | -0.1967 | 0.7699 | 0.41 | 0.08275 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | -21.772 | -10.872 | -18.691 | -3.5512 | 5.6648 | -1.9976 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | 48.069 | 48.069 | 0.931 | 47.879 | 47.925 | 47.818 | 47.834 | 47.955 | 47.971 |
4 | -10.413 | 6.6267 | -7.69 | 3.4098 | -0.3274 | 0.9604 | -3.7436 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 8.38e-17 | -4.8116 | -1.4033 | -0.1369 | 0.002472 | 0 | 9.803e-16 | -55.642 | -16.312 | -1.2042 | 0.02167 | 0 | 3.4098 | -0.3274 | 0.9604 | -3.7437 | -1.0191 | -6.0205 | 0 | 0 | 0 | 0 | 0 | 0 | 0.1255 | -6.5394 | -2.8 | -0.8327 | 0.07473 | -0.1952 | 0.1211 | -6.5483 | -2.8157 | -0.8327 | 0.07015 | -0.1922 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | -22.823 | -11.645 | -18.524 | -3.5305 | 5.8712 | -2.0096 | 0.7699 | 0.41 | 0.08278 | -1.4094 | 0.786 | -0.3681 | 0 | 0 | 0 | 0 | 0 | 0 | 48.069 | 48.069 | 0.8952 | 47.879 | 47.925 | 47.818 | 47.834 | 47.955 | 47.971 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | |
14996 | 10.84 | -1.41 | 16.14 | 3.0837 | -0.8836 | 1.4501 | -2.2102 | -1.559 | -5.3265 | -0.03151 | -0.05375 | 0.04732 | 0.1482 | -0.05218 | 0.06706 | 0.2969 | 0.5065 | -0.4459 | -1.3963 | 0.4916 | -0.6319 | -0.3694 | -4.1879 | -1.1847 | -0.09441 | -0.1568 | 0.1898 | 1.1605 | -42.951 | -19.023 | -2.6343 | 0.1551 | -0.1338 | 3.0836 | -0.8836 | 1.4503 | -2.2101 | -1.5591 | -5.3263 | -0.03347 | -0.05585 | 0.04805 | 0.151 | -0.05513 | 0.07114 | -0.3564 | -6.0394 | -2.3001 | -0.2181 | -0.1159 | 0.09608 | -0.3632 | -6.0394 | -2.3023 | -0.212 | -0.125 | 0.1113 | 0.7116 | 0.06957 | 0.06036 | -0.8506 | 2.9515 | -0.03352 | -0.03558 | -0.03029 | 0.002444 | -0.04208 | 0.1458 | -0.1098 | -0.8784 | -0.07291 | -37.584 | 0.0001132 | -2.1031 | 0.03318 | 0.7117 | 0.0697 | 0.06044 | -0.8511 | 2.951 | -0.03356 | -0.03508 | -0.02849 | 0.001571 | -0.03951 | 0.1442 | -0.1036 | 48.069 | 48.009 | 0.8952 | 47.818 | 47.834 | 47.818 | 47.803 | 47.94 | 47.94 |
14997 | 10.857 | -1.52 | 15.943 | 3.0835 | -0.884 | 1.4505 | -2.2091 | -1.5594 | -5.326 | -0.02913 | -0.0497 | 0.04376 | 0.137 | -0.04825 | 0.062 | 0.2969 | 0.5065 | -0.4459 | -1.3963 | 0.4916 | -0.6319 | -0.3677 | -4.1837 | -1.1874 | -0.09682 | -0.1562 | 0.189 | 1.1592 | -42.937 | -19.023 | -2.6331 | 0.1545 | -0.1338 | 3.0833 | -0.8841 | 1.4507 | -2.209 | -1.5596 | -5.3258 | -0.02909 | -0.04989 | 0.04198 | 0.1481 | -0.05465 | 0.06249 | -0.3161 | -6.1179 | -2.253 | -0.3752 | -0.03965 | 0.08693 | -0.3273 | -6.1022 | -2.2597 | -0.366 | -0.05033 | 0.0915 | 0.7114 | 0.06932 | 0.06039 | -0.8497 | 2.953 | -0.03359 | -0.0335 | -0.02723 | 0.001208 | -0.04242 | 0.1428 | -0.0967 | -2.7137 | 0.8552 | -38.514 | -0.6088 | -3.2383 | -0.9666 | 0.7114 | 0.06948 | 0.06045 | -0.8503 | 2.9525 | -0.03359 | -0.03246 | -0.02633 | 0.001469 | -0.03657 | 0.1333 | -0.09571 | 48.009 | 48.009 | 0.8594 | 47.818 | 47.834 | 47.818 | 47.803 | 47.94 | 47.94 |
14998 | 10.89 | -1.74 | 15.55 | 3.0833 | -0.8844 | 1.4508 | -2.208 | -1.5598 | -5.3256 | -0.02676 | -0.04565 | 0.04019 | 0.1258 | -0.04431 | 0.05695 | 0.2969 | 0.5065 | -0.4459 | -1.3963 | 0.4916 | -0.6319 | -0.3659 | -4.1797 | -1.1901 | -0.09922 | -0.1555 | 0.1881 | 1.1579 | -42.924 | -19.023 | -2.6321 | 0.154 | -0.1338 | 3.0831 | -0.8844 | 1.451 | -2.2078 | -1.56 | -5.3253 | -0.02776 | -0.04382 | 0.03652 | 0.1295 | -0.05064 | 0.04818 | -0.343 | -6.2569 | -2.1566 | -0.3035 | 0.00305 | 0.1434 | -0.3385 | -6.2322 | -2.1589 | -0.302 | -0.00915 | 0.1571 | 0.7111 | 0.06912 | 0.06039 | -0.849 | 2.9544 | -0.0337 | -0.02911 | -0.02589 | 0.001292 | -0.04046 | 0.1246 | -0.08058 | 4.2749 | 1.0128 | -36.412 | -1.2811 | -0.4296 | -1.1013 | 0.7112 | 0.06928 | 0.06046 | -0.8495 | 2.9538 | -0.03362 | -0.02984 | -0.02417 | 0.001364 | -0.03362 | 0.1224 | -0.08786 | 48.009 | 48.009 | 0.931 | 47.818 | 47.834 | 47.818 | 47.803 | 47.94 | 47.94 |
14999 | 11.29 | -1.4601 | 15.743 | 3.0831 | -0.8847 | 1.4511 | -2.2071 | -1.5602 | -5.3251 | -0.02438 | -0.0416 | 0.03662 | 0.1147 | -0.04038 | 0.0519 | 0.2969 | 0.5065 | -0.4459 | -1.3963 | 0.4916 | -0.6319 | -0.3642 | -4.1758 | -1.1928 | -0.1016 | -0.1548 | 0.1873 | 1.1568 | -42.912 | -19.023 | -2.6311 | 0.1535 | -0.1338 | 3.0829 | -0.8848 | 1.4513 | -2.2068 | -1.5604 | -5.3249 | -0.02149 | -0.04059 | 0.03417 | 0.1202 | -0.0395 | 0.04178 | -0.4237 | -6.2703 | -2.0939 | -0.302 | -0.01372 | 0.1739 | -0.4125 | -6.2569 | -2.0916 | -0.2943 | -0.02898 | 0.1891 | 0.7109 | 0.06894 | 0.06039 | -0.8484 | 2.9557 | -0.03384 | -0.02738 | -0.01982 | 0.001031 | -0.03028 | 0.1157 | -0.06702 | 11.518 | 1.5002 | -39.314 | -1.8671 | -0.3734 | -0.5733 | 0.7109 | 0.06909 | 0.06047 | -0.8488 | 2.955 | -0.03364 | -0.02721 | -0.02201 | 0.001255 | -0.03067 | 0.1115 | -0.08003 | 48.009 | 48.009 | 0.931 | 47.818 | 47.834 | 47.818 | 47.803 | 47.94 | 47.94 |
15000 | 11.69 | -1.1801 | 15.937 | 3.0829 | -0.885 | 1.4514 | -2.2062 | -1.5605 | -5.3247 | -0.02201 | -0.03755 | 0.03305 | 0.1035 | -0.03645 | 0.04684 | 0.2969 | 0.5065 | -0.4459 | -1.3963 | 0.4916 | -0.6319 | -0.3624 | -4.172 | -1.1955 | -0.1041 | -0.1542 | 0.1864 | 1.1558 | -42.901 | -19.023 | -2.6302 | 0.1531 | -0.1338 | 3.0827 | -0.8851 | 1.4516 | -2.2059 | -1.5607 | -5.3246 | -0.02096 | -0.03808 | 0.02958 | 0.1171 | -0.03289 | 0.03883 | -0.417 | -6.2434 | -2.058 | -0.4102 | -0.04728 | 0.1967 | -0.4237 | -6.2367 | -2.0714 | -0.4163 | -0.0671 | 0.2059 | 0.7107 | 0.06878 | 0.06041 | -0.8478 | 2.9567 | -0.03382 | -0.02535 | -0.01854 | 0.001614 | -0.02421 | 0.11 | -0.06304 | 15.099 | 2.936 | -39.068 | -1.9402 | 0.139 | -0.2674 | 0.7107 | 0.06893 | 0.06048 | -0.8482 | 2.9561 | -0.03367 | -0.02458 | -0.01986 | 0.001142 | -0.0277 | 0.1007 | -0.07221 | 48.009 | 48.069 | 0.8952 | 47.818 | 47.834 | 47.818 | 47.803 | 47.94 | 47.955 |
15001 rows x 96 columns
memory usage: 11.52 MB
name: data_all
type: getml.DataFrame
1.3 Separate data into a training and testing set¶
We also want to separate the data set into a training and testing set. We do so by using the first 10,500 measurements for training and then using the remainder for testing.
split = getml.data.split.time(data_all, "rowid", test=10500)
split
0 | train |
---|---|
1 | train |
2 | train |
3 | train |
4 | train |
... |
15001 rows
type: StringColumnView
time_series = getml.data.TimeSeries(
population=data_all,
split=split,
time_stamps="rowid",
lagged_targets=False,
memory=30,
)
time_series
data frames | staging table | |
---|---|---|
0 | population | POPULATION__STAGING_TABLE_1 |
1 | data_all | DATA_ALL__STAGING_TABLE_2 |
subset | name | rows | type | |
---|---|---|---|---|
0 | test | data_all | 4501 | View |
1 | train | data_all | 10500 | View |
name | rows | type | |
---|---|---|---|
0 | data_all | 15001 | View |
fast_prop = getml.feature_learning.FastProp(
loss_function=getml.feature_learning.loss_functions.SquareLoss,
num_features=10,
)
xgboost = getml.predictors.XGBoostRegressor()
pipe1 = getml.pipeline.Pipeline(
data_model=time_series.data_model,
feature_learners=[fast_prop],
predictors=xgboost
)
It is always a good idea to check the pipeline for any potential issues.
pipe1.check(time_series.train)
Checking data model... Staging... 100% |██████████| [elapsed: 00:00, remaining: 00:00] Checking... 100% |██████████| [elapsed: 00:00, remaining: 00:00] OK.
2.2 Fitting the pipeline¶
pipe1.fit(time_series.train)
Checking data model... Staging... 100% |██████████| [elapsed: 00:00, remaining: 00:00] OK. Staging... 100% |██████████| [elapsed: 00:00, remaining: 00:00] FastProp: Trying 1130 features... 100% |██████████| [elapsed: 00:02, remaining: 00:00] FastProp: Building features... 100% |██████████| [elapsed: 00:00, remaining: 00:00] XGBoost: Training as predictor... 100% |██████████| [elapsed: 00:04, remaining: 00:00] XGBoost: Training as predictor... 100% |██████████| [elapsed: 00:04, remaining: 00:00] XGBoost: Training as predictor... 100% |██████████| [elapsed: 00:03, remaining: 00:00] Trained pipeline. Time taken: 0h:0m:13.011683
Pipeline(data_model='population', feature_learners=['FastProp'], feature_selectors=[], include_categorical=False, loss_function='SquareLoss', peripheral=['data_all'], predictors=['XGBoostRegressor'], preprocessors=[], share_selected_features=0.5, tags=['container-pcWY0M'])
2.3 Evaluating the pipeline¶
pipe1.score(time_series.test)
Staging... 100% |██████████| [elapsed: 00:00, remaining: 00:00] Preprocessing... 100% |██████████| [elapsed: 00:00, remaining: 00:00] FastProp: Building features... 100% |██████████| [elapsed: 00:00, remaining: 00:00]
date time | set used | target | mae | rmse | rsquared | |
---|---|---|---|---|---|---|
0 | 2024-08-07 16:45:34 | train | f_x | 0.4403 | 0.58 | 0.9962 |
1 | 2024-08-07 16:45:34 | train | f_y | 0.5168 | 0.6813 | 0.9893 |
2 | 2024-08-07 16:45:34 | train | f_z | 0.2918 | 0.385 | 0.9986 |
3 | 2024-08-07 16:45:34 | test | f_x | 0.5605 | 0.7319 | 0.995 |
4 | 2024-08-07 16:45:34 | test | f_y | 0.5653 | 0.7532 | 0.9871 |
5 | 2024-08-07 16:45:34 | test | f_z | 0.3131 | 0.4071 | 0.9984 |
2.4 Feature importances¶
It is always a good idea to study the features the relational learning algorithm has extracted.
The feature importance is calculated by xgboost based on the improvement of the optimizing criterium at each split in the decision tree and is normalized to 100%.
Also note that we have three different target (f_x, f_y and f_z) and that different features are relevant for different targets.
plt.subplots(figsize=(20, 10))
names, importances = pipe1.features.importances(target_num=0)
plt.bar(names[0:30], importances[0:30], color='#6829c2')
plt.title("feature importances for the x-component", size=20)
plt.grid(True)
plt.xlabel("features")
plt.ylabel("importance")
plt.xticks(rotation='vertical')
plt.show()
plt.subplots(figsize=(20, 10))
names, importances = pipe1.features.importances(target_num=1)
plt.bar(names[0:30], importances[0:30], color='#6829c2')
plt.title("feature importances for the y-component", size=20)
plt.grid(True)
plt.xlabel("features")
plt.ylabel("importance")
plt.xticks(rotation='vertical')
plt.show()
plt.subplots(figsize=(20, 10))
names, importances = pipe1.features.importances(target_num=2)
plt.bar(names[0:30], importances[0:30], color='#6829c2')
plt.title("feature importances for the z-component", size=20)
plt.grid(True)
plt.xlabel("features")
plt.ylabel("importance")
plt.xticks(rotation='vertical')
plt.show()
2.5 Column importances¶
Because getML is a tool for relational learning, we can also calculate the importances for the original columns, using similar methods we have used for the feature importances.
plt.subplots(figsize=(20, 10))
names, importances = pipe1.columns.importances(target_num=0)
plt.bar(names[0:30], importances[0:30], color='#6829c2')
plt.title("column importances for the x-component", size=20)
plt.grid(True)
plt.xlabel("column")
plt.ylabel("importance")
plt.xticks(rotation='vertical')
plt.show()
plt.subplots(figsize=(20, 10))
names, importances = pipe1.columns.importances(target_num=1)
plt.bar(names[0:30], importances[0:30], color='#6829c2')
plt.title("column importances for the y-component", size=20)
plt.grid(True)
plt.xlabel("column")
plt.ylabel("importance")
plt.xticks(rotation='vertical')
plt.show()
plt.subplots(figsize=(20, 10))
names, importances = pipe1.columns.importances(target_num=2)
plt.bar(names[0:30], importances[0:30], color='#6829c2')
plt.title("column importances for the z-component", size=20)
plt.grid(True)
plt.xlabel("column")
plt.ylabel("importance")
plt.xticks(rotation='vertical')
plt.show()
2.6 Visualizing the predictions¶
Sometimes a picture says more than a 1000 words. We therefore want to visualize our predictions on the testing set.
f_x = time_series.test.population["f_x"].to_numpy()
f_y = time_series.test.population["f_y"].to_numpy()
f_z = time_series.test.population["f_z"].to_numpy()
predictions = pipe1.predict(time_series.test)
Staging... 100% |██████████| [elapsed: 00:00, remaining: 00:00] Preprocessing... 100% |██████████| [elapsed: 00:00, remaining: 00:00] FastProp: Building features... 100% |██████████| [elapsed: 00:00, remaining: 00:00]
col_data = 'black'
col_getml = 'darkviolet'
col_getml_alt = 'coral'
plt.subplots(figsize=(20, 10))
plt.title("x-component of the force vector", size=20)
plt.plot(f_x, label="ground truth", color=col_data)
plt.plot(predictions[:,0], label="prediction",color=col_getml)
plt.legend(loc="upper right", fontsize=16)
<matplotlib.legend.Legend at 0x7a52f1751e70>
plt.subplots(figsize=(20, 10))
plt.title("y-component of the force vector", size=20)
plt.plot(f_y, label="ground truth", color=col_data)
plt.plot(predictions[:,1], label="prediction",color=col_getml)
plt.legend(loc="upper right", fontsize=16)
<matplotlib.legend.Legend at 0x7a52f17d7e80>
plt.subplots(figsize=(20, 10))
plt.title("z-component of the force vector", size=20)
plt.plot(f_z, label="ground truth", color=col_data)
plt.plot(predictions[:,2], label="prediction",color=col_getml)
plt.legend(loc="upper right", fontsize=16)
<matplotlib.legend.Legend at 0x7a52f16a4fa0>
2.7 Features¶
The most important feature looks as follows:
pipe1.features.to_sql()[pipe1.features.sort(by="importances")[0].name]
DROP TABLE IF EXISTS "FEATURE_1_1";
CREATE TABLE "FEATURE_1_1" AS
SELECT AVG( t2."7" ) AS "feature_1_1",
t1.rowid AS rownum
FROM "POPULATION__STAGING_TABLE_1" t1
INNER JOIN "DATA_ALL__STAGING_TABLE_2" t2
ON 1 = 1
WHERE t2."rowid" <= t1."rowid"
AND ( t2."( rowid + 30.000000 )" > t1."rowid" OR t2."( rowid + 30.000000 )" IS NULL )
GROUP BY t1.rowid;
getml.engine.shutdown()
As we can see, the predictions are very accurate. This suggests that it is very feasible to predict the force vector based on other sensor data.
3. Conclusion¶
The purpose of this notebook has been to illustrate the problem of the curse of dimensionality when engineering features from datasets with many columns.
The most important thing to remember is that this problem exists regardless of whether you engineer your features manually or using algorithms. Whether you like it or not: If you write your features in the traditional way, your search space grows quadratically with the number of columns.