American Institute of Mathematical Sciences

October  2021, 8(4): 495-520. doi: 10.3934/jcd.2021018

Classification with Runge-Kutta networks and feature space augmentation

 1 Institut für Mathematik, Humboldt-Universität zu Berlin, Unter den Linden 6, 10099 Berlin, Germany 2 Martin-Luther-Universität Halle-Wittenberg, Theodor-Lieser-Str. 5, 06120 Halle, Germany

* Corresponding author: Axel Kröner

Received  April 2021 Revised  September 2021 Published  October 2021 Early access  November 2021

Fund Project: The second author is supported by DAAD project 57570343

In this paper we combine an approach based on Runge-Kutta Nets considered in [Benning et al., J. Comput. Dynamics, 9, 2019] and a technique on augmenting the input space in [Dupont et al., NeurIPS, 2019] to obtain network architectures which show a better numerical performance for deep neural networks in point and image classification problems. The approach is illustrated with several examples implemented in PyTorch.

Citation: Elisa Giesecke, Axel Kröner. Classification with Runge-Kutta networks and feature space augmentation. Journal of Computational Dynamics, 2021, 8 (4) : 495-520. doi: 10.3934/jcd.2021018
References:

show all references

References:
Butcher tableaus: (from left to right) general form, forward Euler and classic RK4.
Two dimensional datasets for binary point classification with 1500 samples each: donut 1D and donut 2D (top), squares and spiral (bottom).
Classification of donut_2D with RK4Net of width $\hat{d} = 2$ corresponding to the NODE-approach (top) and $\hat{d} = 3$, i. e. with space augmentation characterizing the ANODE-approach (bottom), and of same depth $L = 100$ and $\tanh$ activation. The plots show (from left to right) the trajectories of the features starting at the small dot and terminating at the large dot, their final transformation in the output layer and the resulting prediction with coloured background according to the network's classification.
Classification of squares with RK4Net of width $\hat{d} = 2$ corresponding to the NODE-approach (top) and $\hat{d} = 3$, i. e. with space augmentation characterizing the ANODE-approach (bottom), and of same depth $L = 100$ and $\tanh$ activation. The plots show (from left to right) the trajectories of the features starting at the small dot and terminating at the large dot, their final transformation in the output layer and the resulting prediction with coloured background according to the network's classification.
Feature transformation of spiral with StandardNet (top) and RK4Net (bottom) of width $\hat{d} = 16$, depth $L = 20$ and $\tanh$ activation. (From left to right) features in input layer, hidden layers and output layer.
Prediction of donut_1D with RK4Net of width $\hat{d} = 16$, depth $L = 20$ and $\tanh$ activation.
Accuracy (left) and cost (right) over the course of epochs on donut_1D with RK4Net of width $\hat{d} = 16$ and depth $L = 20$. Solid lines represent metrics on validation and dotted lines on training data.
Donut and squares datasets of different dimensionality and with varying number of classes used for comparing performance of networks between binary and multiclass classification (first column), as well as 2D and 3D input space (second and third column).
Repetitions with random initializations for RK4Net with width $\hat{d} = 16$, depth $L = 100$ and $\tanh$ activation, on donut 2D & 6C. The plots show (upper row) the feature transformation in the output layer reduced by PCA to 3D, and (lower row) the resulting prediction underlaid with a coloured background according to the network's classification.
Classification of donut 2D & 6C with network width $\hat{d} = 16$, depth $L = 5$ for StandardNet and $L = 100$ for EulerNet and RK4Net, and $\tanh$ activation. The plots show (from left to right) the feature transformation in the output layer reduced by PCA to 3D and 2D, and the resulting prediction. Two dimensional plots are underlaid with a coloured background according to the network's classification.
Classification of squares 2D & 4C with network width $\hat{d} = 16$, depth $L = 5$ for StandardNet and $L = 100$ for EulerNet and RK4Net, and $\tanh$ activation. The plots show (from left to right) the feature transformation in the output layer reduced by PCA to 3D and 2D, and the resulting prediction. Two dimensional plots are underlaid with a coloured background according to the network's classification.
Validation accuracy (left) and cost (right) over the course of epochs on donut 3D & 6C with network width $\hat{d} = 16$, depth $L = 5$ for StandardNet and $L = 100$ for EulerNet and RK4Net, and $\tanh$ activation. Solid line represents the mean and shaded area the standard deviation over repetitions.
Validation accuracy (left) and cost (right) over the course of epochs on squares 3D & 4C with network width $\hat{d} = 16$, depth $L = 5$ for StandardNet and $L = 100$ for EulerNet and RK4Net, and $\tanh$ activation. Solid line represents the mean and shaded area the standard deviation over repetitions.
Exemplary images of MNIST with true label and prediction produced by RK4Net with width $\hat{d} = 30^2$, depth $L = 100$ and $\tanh$ activation.
Exemplary images of Fashion-MNIST with true label and prediction produced by RK4Net with width $\hat{d} = 30^2$, depth $L = 100$ and $\tanh$ activation.
Feature transformation in the output layer of StandardNet (left) and RK4Net (right) of Fashion-MNIST images reduced by PCA to 3D. Each color represents one article class.
Accuracy (left) and cost (right) over the course of epochs on MNIST with network width $\hat{d} = 30^2$, depth $L = 5$ for StandardNet and $L = 100$ for EulerNet and RK4Net, and $\tanh$ activation. Solid lines represent metrics on validation and dotted lines on training data.
Accuracy (left) and cost (right) over the course of epochs on Fashion-MNIST with network width $\hat{d} = 30^2$, depth $L = 5$ for StandardNet and $L = 100$ for EulerNet and RK4Net, and $\tanh$ activation. Solid lines represent metrics on validation and dotted lines on training data.
Mean of training (upper row) and validation (lower row) accuracy (%) over four repetitions on spiral with network width $\hat{d} = 16$ and $\tanh$ activation.
 depth L 1 3 5 10 20 40 100 StandardNet 92.7391.88 92.8792.50 98.1298.10 97.5297.45 67.6266.87 51.0848.92 50.6749.33 RK4Net 75.6075.12 91.4290.68 97.9097.33 99.7799.47 99.9399.70 99.7399.50 99.9599.75
 depth L 1 3 5 10 20 40 100 StandardNet 92.7391.88 92.8792.50 98.1298.10 97.5297.45 67.6266.87 51.0848.92 50.6749.33 RK4Net 75.6075.12 91.4290.68 97.9097.33 99.7799.47 99.9399.70 99.7399.50 99.9599.75
Mean of training (upper row) and validation (lower row) cost ($\times 10^{-1}$) over four repetitions on spiral with network width $\hat{d} = 16$ and $\tanh$ activation.
 depth L 1 3 5 10 20 40 100 StandardNet 2.232.33 1.381.53 0.660.67 0.770.77 6.096.13 6.936.94 6.936.93 RK4Net 4.324.39 2.682.69 0.981.06 0.160.28 0.040.13 0.100.12 0.010.12
 depth L 1 3 5 10 20 40 100 StandardNet 2.232.33 1.381.53 0.660.67 0.770.77 6.096.13 6.936.94 6.936.93 RK4Net 4.324.39 2.682.69 0.981.06 0.160.28 0.040.13 0.100.12 0.010.12
Variability of accuracy (%) and cost ($\times 10^{-1}$) over four repetitions for RK4Net with width $\hat{d} = 16$, depth $L = 100$ and $\tanh$ activation, on donut 2D & 6C.
 training accuracy validation accuracy training cost validation cost mean 77.13 74.92 5.13 5.59 standard deviation 0.76 0.89 0.08 0.16
 training accuracy validation accuracy training cost validation cost mean 77.13 74.92 5.13 5.59 standard deviation 0.76 0.89 0.08 0.16
Mean of validation accuracy (%, upper row) and cost ($\times 10^{-1}$, lower row) over four repetitions with network width $\hat{d} = 16$, depth $L = 5$ for StandardNet and $L = 100$ for EulerNet and RK4Net, and $\tanh$ activation.
 donut3D & 2C donut3D & 3C donut2D & 6C donut3D & 6C squares2D & 4C squares3D & 4C StandardNet 92.371.71 87.752.85 75.125.60 73.005.86 94.121.57 89.683.03 EulerNet 91.881.84 88.302.75 74.875.56 74.635.67 93.351.66 89.482.71 RK4Net 92.731.72 87.132.95 74.925.59 74.885.73 93.201.64 89.372.81
 donut3D & 2C donut3D & 3C donut2D & 6C donut3D & 6C squares2D & 4C squares3D & 4C StandardNet 92.371.71 87.752.85 75.125.60 73.005.86 94.121.57 89.683.03 EulerNet 91.881.84 88.302.75 74.875.56 74.635.67 93.351.66 89.482.71 RK4Net 92.731.72 87.132.95 74.925.59 74.885.73 93.201.64 89.372.81
Mean and standard deviation of accuracy (%) and cost ($\times 10^{-1}$) over four repetitions for non-augmented (upper row) and augmented (lower row) RK4Net with depth $L = 100$ and $\tanh$ activation on MNIST.
 width $\hat{d}$ training accuracy validation accuracy training cost validation cost $28^2$ $97.70 \pm 2.80$ $87.27 \pm 2.93$ $0.78 \pm 0.95$ $7.71 \pm 1.62$ $30^2$ $99.77 \pm 0.40$ $90.40 \pm 1.08$ $0.10 \pm 0.17$ $5.36 \pm 0.46$
 width $\hat{d}$ training accuracy validation accuracy training cost validation cost $28^2$ $97.70 \pm 2.80$ $87.27 \pm 2.93$ $0.78 \pm 0.95$ $7.71 \pm 1.62$ $30^2$ $99.77 \pm 0.40$ $90.40 \pm 1.08$ $0.10 \pm 0.17$ $5.36 \pm 0.46$
Mean and standard deviation of validation accuracy (%, upper row) and cost ($\times 10^{-1}$, lower row) over four repetitions with network width $\hat{d} = 30^2$, depth $L = 5$ for StandardNet and $L = 100$ for EulerNet and RK4Net, and $\tanh$ activation.
 MNIST Fashion-MNIST StandardNet $85.67 \pm 0.78\\8.95 \pm 0.92$ $61.23 \pm 6.00\\11.41 \pm 1.37$ EulerNet $90.98 \pm 0.48\\5.71 \pm 0.36$ $77.62 \pm 2.57\\9.52 \pm 1.87$ RK4Net $90.40 \pm 1.08\\5.36 \pm 0.46$ $79.13 \pm 1.57\\8.24 \pm 0.60$
 MNIST Fashion-MNIST StandardNet $85.67 \pm 0.78\\8.95 \pm 0.92$ $61.23 \pm 6.00\\11.41 \pm 1.37$ EulerNet $90.98 \pm 0.48\\5.71 \pm 0.36$ $77.62 \pm 2.57\\9.52 \pm 1.87$ RK4Net $90.40 \pm 1.08\\5.36 \pm 0.46$ $79.13 \pm 1.57\\8.24 \pm 0.60$
 [1] Da Xu. Numerical solutions of viscoelastic bending wave equations with two term time kernels by Runge-Kutta convolution quadrature. Discrete & Continuous Dynamical Systems - B, 2017, 22 (6) : 2389-2416. doi: 10.3934/dcdsb.2017122 [2] Sihong Shao, Huazhong Tang. Higher-order accurate Runge-Kutta discontinuous Galerkin methods for a nonlinear Dirac model. Discrete & Continuous Dynamical Systems - B, 2006, 6 (3) : 623-640. doi: 10.3934/dcdsb.2006.6.623 [3] Yuantian Xia, Juxiang Zhou, Tianwei Xu, Wei Gao. An improved deep convolutional neural network model with kernel loss function in image classification. Mathematical Foundations of Computing, 2020, 3 (1) : 51-64. doi: 10.3934/mfc.2020005 [4] Seonho Park, Maciej Rysz, Kaitlin L. Fair, Panos M. Pardalos. Synthetic-Aperture Radar image based positioning in GPS-denied environments using Deep Cosine Similarity Neural Networks. Inverse Problems & Imaging, 2021, 15 (4) : 763-785. doi: 10.3934/ipi.2021013 [5] Lars Grüne. Computing Lyapunov functions using deep neural networks. Journal of Computational Dynamics, 2021, 8 (2) : 131-152. doi: 10.3934/jcd.2021006 [6] Miria Feng, Wenying Feng. Evaluation of parallel and sequential deep learning models for music subgenre classification. Mathematical Foundations of Computing, 2021, 4 (2) : 131-143. doi: 10.3934/mfc.2021008 [7] Zbigniew Gomolka, Boguslaw Twarog, Jacek Bartman. Improvement of image processing by using homogeneous neural networks with fractional derivatives theorem. Conference Publications, 2011, 2011 (Special) : 505-514. doi: 10.3934/proc.2011.2011.505 [8] Antonella Zanna. Symplectic P-stable additive Runge—Kutta methods. Journal of Computational Dynamics, 2021  doi: 10.3934/jcd.2021030 [9] Ruhua Wang, Senjian An, Wanquan Liu, Ling Li. Fixed-point algorithms for inverse of residual rectifier neural networks. Mathematical Foundations of Computing, 2021, 4 (1) : 31-44. doi: 10.3934/mfc.2020024 [10] H. N. Mhaskar, T. Poggio. Function approximation by deep networks. Communications on Pure & Applied Analysis, 2020, 19 (8) : 4085-4095. doi: 10.3934/cpaa.2020181 [11] Antonia Katzouraki, Tania Stathaki. Intelligent traffic control on internet-like topologies - integration of graph principles to the classic Runge--Kutta method. Conference Publications, 2009, 2009 (Special) : 404-415. doi: 10.3934/proc.2009.2009.404 [12] Wenjuan Zhai, Bingzhen Chen. A fourth order implicit symmetric and symplectic exponentially fitted Runge-Kutta-Nyström method for solving oscillatory problems. Numerical Algebra, Control & Optimization, 2019, 9 (1) : 71-84. doi: 10.3934/naco.2019006 [13] Christopher Oballe, David Boothe, Piotr J. Franaszczuk, Vasileios Maroulas. ToFU: Topology functional units for deep learning. Foundations of Data Science, 2021  doi: 10.3934/fods.2021021 [14] Ziju Shen, Yufei Wang, Dufan Wu, Xu Yang, Bin Dong. Learning to scan: A deep reinforcement learning approach for personalized scanning in CT imaging. Inverse Problems & Imaging, 2022, 16 (1) : 179-195. doi: 10.3934/ipi.2021045 [15] Ying Sue Huang. Resynchronization of delayed neural networks. Discrete & Continuous Dynamical Systems, 2001, 7 (2) : 397-401. doi: 10.3934/dcds.2001.7.397 [16] Hyeontae Jo, Hwijae Son, Hyung Ju Hwang, Eun Heui Kim. Deep neural network approach to forward-inverse problems. Networks & Heterogeneous Media, 2020, 15 (2) : 247-259. doi: 10.3934/nhm.2020011 [17] Zheng Chen, Liu Liu, Lin Mu. Solving the linear transport equation by a deep neural network approach. Discrete & Continuous Dynamical Systems - S, 2021  doi: 10.3934/dcdss.2021070 [18] Martin Benning, Elena Celledoni, Matthias J. Ehrhardt, Brynjulf Owren, Carola-Bibiane Schönlieb. Deep learning as optimal control problems: Models and numerical methods. Journal of Computational Dynamics, 2019, 6 (2) : 171-198. doi: 10.3934/jcd.2019009 [19] Nicholas Geneva, Nicholas Zabaras. Multi-fidelity generative deep learning turbulent flows. Foundations of Data Science, 2020, 2 (4) : 391-428. doi: 10.3934/fods.2020019 [20] Govinda Anantha Padmanabha, Nicholas Zabaras. A Bayesian multiscale deep learning framework for flows in random media. Foundations of Data Science, 2021, 3 (2) : 251-303. doi: 10.3934/fods.2021016

Impact Factor: