Multilayer Perceptron (MLP) is a type of neural network that can be used for regression and classification.
MLPs consist of several fully connected hidden layers with non-linear activation functions. In the case of classification, the final layer of the neural net has as many nodes as classes, and the output of the neural net can be interpreted as the probability that a given input feature belongs to a specific class.
MLP can be used with or without mini-batching. The advantage of using mini-batching is that it can perform better than the default MADlib optimizer, because it uses more than one training example at a time, typically resulting faster and smoother convergence [3].
mlp_classification( source_table, output_table, independent_varname, dependent_varname, hidden_layer_sizes, optimizer_params, activation, weights, warm_start, verbose, grouping_col )
Arguments
TEXT. Name of the table containing the training data. If you are using mini-batching, this is the name of the output table from the mini-batch preprocessor.
TEXT. Name of the output table containing the model. Details of the output table are shown below.
TEXT. Expression list to evaluate for the independent variables. It should be a numeric array expression. If you are using mini-batching, set this parameter to 'independent_varname' which is the hardcoded name of the column from the mini-batch preprocessor containing the packed independent variables.
TEXT. Name of the dependent variable column. For classification, supported types are: text, varchar, character varying, char, character integer, smallint, bigint, and boolean. If you are using mini-batching, set this parameter to 'dependent_varname' which is the hardcoded name of the column from the mini-batch preprocessor containing the packed dependent variables.
INTEGER[], default: ARRAY[100]. The number of neurons in each hidden layer. The length of this array will determine the number of hidden layers. For example, ARRAY[5,10] means 2 hidden layers, one with 5 neurons and the other with 10 neurons. Use ARRAY[]::INTEGER[] for no hidden layers.
TEXT, default: NULL. Parameters for optimization in a comma-separated string of key-value pairs. See the description below for details.
TEXT, default: 'sigmoid'. Activation function. Currently three functions are supported: 'sigmoid' (default), 'relu', and 'tanh'. The text can be any prefix of the three strings; for e.g., specifying 's' will use sigmoid activation.
BOOLEAN, default: FALSE. Initalize neural network weights with the coefficients from the last call of the training function. If set to true, neural network weights will be initialized from the output_table generated by the previous run. Note that all parameters other than optimizer_params and verbose must remain constant between calls when warm_start is used.
Output tables
The model table produced by MLP contains the following columns:
coeffs | FLOAT8[]. Flat array containing the weights of the neural net. |
---|---|
n_iterations | INTEGER. Number of iterations completed by the stochastic gradient descent algorithm. The algorithm either converged in this number of iterations or hit the maximum number specified in the optimization parameters. |
loss | FLOAT8. The cross entropy loss over the training data. See Technical Background section below for more details. |
grouping columns | If grouping_col is specified during training, a column for each grouping column is created. |
A summary table named <output_table>_summary is also created, which has the following columns:
source_table | The source table. |
---|---|
independent_varname | The independent variables. |
dependent_varname | The dependent variable. |
tolerance | The tolerance as given in optimizer_params. |
learning_rate_init | The initial learning rate as given in optimizer_params. |
learning_rate_policy | The learning rate policy as given in optimizer_params. |
momentum | Momentum value as given in optimizer_params. |
nesterov | Nesterov value as given in optimizer_params. |
n_iterations | The number of iterations run. |
n_tries | The number of tries as given in optimizer_params. |
layer_sizes | The number of units in each layer including the input and output layers. |
activation | The activation function. |
is_classification | True if the model was trained for classification, False if it was trained for regression. |
classes | The classes which were trained against (empty for regression). |
weights | The weight column used during training for giving different weights to different rows. |
grouping_col | NULL if no grouping_col was specified during training, and a comma-separated list of grouping column names if not. |
A standardization table named <output_table>_standardization is also create, that has the following columns:
mean | The mean for all input features (used for normalization). |
---|---|
std | The standard deviation for all input features (used for normalization). |
grouping columns | If grouping_col is specified during training, a column for each grouping column is created. |
mlp_regression( source_table, output_table, independent_varname, dependent_varname, hidden_layer_sizes, optimizer_params, activation, weights, warm_start, verbose, grouping_col )
Arguments
Parameters for regression are largely the same as for classification. In the model table, the loss refers to mean square error instead of cross entropy loss. In the summary table, there is no classes column. The following arguments have specifications which differ from mlp_classification:
'solver = <value>, learning_rate_init = <value>, learning_rate_policy = <value>, gamma = <value>, power = <value>, iterations_per_step = <value>, n_iterations = <value>, n_tries = <value>, lambda = <value>, tolerance = <value>, batch_size = <value>, n_epochs = <value>, momentum = <value>, nesterov = <value>', rho = <value>, beta1 = <value>, beta2 = <value>, eps = <value>'
Optimizer Parameters
Default: sgd. One of 'sgd', 'rmsprop' or 'adam' or any prefix of these (e.g., 'rmsp' means 'rmsprop'). These are defined below:
Default: 0.001. Also known as the learning rate. A small value is usually desirable to ensure convergence, while a large value provides more room for progress during training. Since the best value depends on the condition number of the data, in practice one often tunes this parameter.
Default: constant. One of 'constant', 'exp', 'inv' or 'step' or any prefix of these (e.g., 's' means 'step'). These are defined below, where 'iter' is the current iteration:
Default: 0.1. Decay rate for learning rate when learning_rate_policy is 'exp' or 'step'.
Default: 0.5. Exponent for learning_rate_policy = 'inv'.
Default: 100. Number of iterations to run before decreasing the learning rate by a factor of gamma. Valid for learning rate policy = 'step'.
Default: 100. The maximum number of iterations allowed.
Default: 1. Number of times to retrain the network with randomly initialized neural network weights.
Default: 0. The regularization coefficient for L2 regularization.
Default: 0.001. The criterion to end iterations. The training stops whenever the difference between the training models of two consecutive iterations is smaller than tolerance or the iteration number is larger than n_iterations. If you want to run the full number of iterations specified in n_interations, set tolerance=0.0
Default: min(200, buffer_size) where buffer_size is set in the mini-batch preprocessor. The 'batch_size' is the size of the mini-batch used in the optimizer. This parameter is only used in the case of mini-batching.
Default: 1. Represents the number of times each batch is used by the optimizer. This parameter is only used in the case of mini-batching.
Default: 0.9. Momentum can help accelerate learning and avoid local minima when using gradient descent. Value must be in the range 0 to 1, where 0 means no momentum.
Default: TRUE. Only used when the 'momentum' parameter is > 0. Nesterov momentum can provide better results than using classical momentum alone, due to its look-ahead characteristics. In classical momentum we correct the velocity and then update the model with that velocity, whereas in Nesterov Accelerated Gradient method, we first move the model in the direction of velocity, compute the gradient using this updated model, and then add this gradient back into the model. The main difference being that in classical momentum, we compute the gradient before updating the model whereas in nesterov we first update the model and then compute the gradient from the updated position.
Default: 0.9. Moving average parameter for the RMSprop solver.
Default: 0.9. The exponential decay rate for the first moment estimates.
Default: 0.999. The exponential decay rate for the second moment estimates.
mlp_predict( model_table, data_table, id_col_name, output_table, pred_type )
Arguments
TEXT. Model table produced by the training function.
TEXT. Name of the table containing the data for prediction. This table is expected to contain the same input features that were used during training. The table should also contain id_col_name used for identifying each row.
TEXT. The name of the id column in data_table.
id | Gives the 'id' for each prediction, corresponding to each row from the data_table. |
---|---|
estimated_COL_NAME | (For pred_type='response') The estimated class for classification or value for regression, where COL_NAME is the name of the column to be predicted from training data. |
prob_CLASS | (For pred_type='prob' for classification) The probability of a given class CLASS as given by softmax. There will be one column for each class in the training data. |
DROP TABLE IF EXISTS iris_data; CREATE TABLE iris_data( id serial, attributes numeric[], class_text varchar, class integer, state varchar ); INSERT INTO iris_data(id, attributes, class_text, class, state) VALUES (1,ARRAY[5.0,3.2,1.2,0.2],'Iris_setosa',1,'Alaska'), (2,ARRAY[5.5,3.5,1.3,0.2],'Iris_setosa',1,'Alaska'), (3,ARRAY[4.9,3.1,1.5,0.1],'Iris_setosa',1,'Alaska'), (4,ARRAY[4.4,3.0,1.3,0.2],'Iris_setosa',1,'Alaska'), (5,ARRAY[5.1,3.4,1.5,0.2],'Iris_setosa',1,'Alaska'), (6,ARRAY[5.0,3.5,1.3,0.3],'Iris_setosa',1,'Alaska'), (7,ARRAY[4.5,2.3,1.3,0.3],'Iris_setosa',1,'Alaska'), (8,ARRAY[4.4,3.2,1.3,0.2],'Iris_setosa',1,'Alaska'), (9,ARRAY[5.0,3.5,1.6,0.6],'Iris_setosa',1,'Alaska'), (10,ARRAY[5.1,3.8,1.9,0.4],'Iris_setosa',1,'Alaska'), (11,ARRAY[4.8,3.0,1.4,0.3],'Iris_setosa',1,'Alaska'), (12,ARRAY[5.1,3.8,1.6,0.2],'Iris_setosa',1,'Alaska'), (13,ARRAY[5.7,2.8,4.5,1.3],'Iris_versicolor',2,'Alaska'), (14,ARRAY[6.3,3.3,4.7,1.6],'Iris_versicolor',2,'Alaska'), (15,ARRAY[4.9,2.4,3.3,1.0],'Iris_versicolor',2,'Alaska'), (16,ARRAY[6.6,2.9,4.6,1.3],'Iris_versicolor',2,'Alaska'), (17,ARRAY[5.2,2.7,3.9,1.4],'Iris_versicolor',2,'Alaska'), (18,ARRAY[5.0,2.0,3.5,1.0],'Iris_versicolor',2,'Alaska'), (19,ARRAY[5.9,3.0,4.2,1.5],'Iris_versicolor',2,'Alaska'), (20,ARRAY[6.0,2.2,4.0,1.0],'Iris_versicolor',2,'Alaska'), (21,ARRAY[6.1,2.9,4.7,1.4],'Iris_versicolor',2,'Alaska'), (22,ARRAY[5.6,2.9,3.6,1.3],'Iris_versicolor',2,'Alaska'), (23,ARRAY[6.7,3.1,4.4,1.4],'Iris_versicolor',2,'Alaska'), (24,ARRAY[5.6,3.0,4.5,1.5],'Iris_versicolor',2,'Alaska'), (25,ARRAY[5.8,2.7,4.1,1.0],'Iris_versicolor',2,'Alaska'), (26,ARRAY[6.2,2.2,4.5,1.5],'Iris_versicolor',2,'Alaska'), (27,ARRAY[5.6,2.5,3.9,1.1],'Iris_versicolor',2,'Alaska'), (28,ARRAY[5.0,3.4,1.5,0.2],'Iris_setosa',1,'Tennessee'), (29,ARRAY[4.4,2.9,1.4,0.2],'Iris_setosa',1,'Tennessee'), (30,ARRAY[4.9,3.1,1.5,0.1],'Iris_setosa',1,'Tennessee'), (31,ARRAY[5.4,3.7,1.5,0.2],'Iris_setosa',1,'Tennessee'), (32,ARRAY[4.8,3.4,1.6,0.2],'Iris_setosa',1,'Tennessee'), (33,ARRAY[4.8,3.0,1.4,0.1],'Iris_setosa',1,'Tennessee'), (34,ARRAY[4.3,3.0,1.1,0.1],'Iris_setosa',1,'Tennessee'), (35,ARRAY[5.8,4.0,1.2,0.2],'Iris_setosa',1,'Tennessee'), (36,ARRAY[5.7,4.4,1.5,0.4],'Iris_setosa',1,'Tennessee'), (37,ARRAY[5.4,3.9,1.3,0.4],'Iris_setosa',1,'Tennessee'), (38,ARRAY[6.0,2.9,4.5,1.5],'Iris_versicolor',2,'Tennessee'), (39,ARRAY[5.7,2.6,3.5,1.0],'Iris_versicolor',2,'Tennessee'), (40,ARRAY[5.5,2.4,3.8,1.1],'Iris_versicolor',2,'Tennessee'), (41,ARRAY[5.5,2.4,3.7,1.0],'Iris_versicolor',2,'Tennessee'), (42,ARRAY[5.8,2.7,3.9,1.2],'Iris_versicolor',2,'Tennessee'), (43,ARRAY[6.0,2.7,5.1,1.6],'Iris_versicolor',2,'Tennessee'), (44,ARRAY[5.4,3.0,4.5,1.5],'Iris_versicolor',2,'Tennessee'), (45,ARRAY[6.0,3.4,4.5,1.6],'Iris_versicolor',2,'Tennessee'), (46,ARRAY[6.7,3.1,4.7,1.5],'Iris_versicolor',2,'Tennessee'), (47,ARRAY[6.3,2.3,4.4,1.3],'Iris_versicolor',2,'Tennessee'), (48,ARRAY[5.6,3.0,4.1,1.3],'Iris_versicolor',2,'Tennessee'), (49,ARRAY[5.5,2.5,4.0,1.3],'Iris_versicolor',2,'Tennessee'), (50,ARRAY[5.5,2.6,4.4,1.2],'Iris_versicolor',2,'Tennessee'), (51,ARRAY[6.1,3.0,4.6,1.4],'Iris_versicolor',2,'Tennessee'), (52,ARRAY[5.8,2.6,4.0,1.2],'Iris_versicolor',2,'Tennessee');
DROP TABLE IF EXISTS mlp_model, mlp_model_summary, mlp_model_standardization; -- Set seed so results are reproducible SELECT setseed(0); SELECT madlib.mlp_classification( 'iris_data', -- Source table 'mlp_model', -- Destination table 'attributes', -- Input features 'class_text', -- Label ARRAY[5], -- Number of units per layer 'learning_rate_init=0.003, n_iterations=500, tolerance=0', -- Optimizer params 'tanh', -- Activation function NULL, -- Default weight (1) FALSE, -- No warm start FALSE -- Not verbose );View the model:
\x on SELECT * FROM mlp_model;
-[ RECORD 1 ]--+------------------------------------------------------------------------------------ coeff | {-0.40378996718,0.0157490328855,-0.298904053444,-0.984152185093,-0.657684089715 ... loss | 0.0103518565103 num_iterations | 500View the model summary table:
SELECT * FROM mlp_model_summary;
-[ RECORD 1 ]--------+------------------------------ source_table | iris_data independent_varname | attributes dependent_varname | class_text dependent_vartype | character varying tolerance | 0 learning_rate_init | 0.003 learning_rate_policy | constant momentum | 0.9 nesterov | t n_iterations | 500 n_tries | 1 layer_sizes | {4,5,2} activation | tanh is_classification | t classes | {Iris_setosa,Iris_versicolor} weights | 1 grouping_col | NULLView the model standardization table:
SELECT * FROM mlp_model_standardization;
-[ RECORD 1 ]------------------------------------------------------------------ mean | {5.45961538461539,2.99807692307692,3.025,0.851923076923077} std | {0.598799958694505,0.498262513685689,1.41840579525043,0.550346179381454}
DROP TABLE IF EXISTS mlp_prediction; \x off SELECT madlib.mlp_predict( 'mlp_model', -- Model table 'iris_data', -- Test data table 'id', -- Id column in test table 'mlp_prediction', -- Output table for predictions 'response' -- Output classes, not probabilities ); SELECT * FROM mlp_prediction JOIN iris_data USING (id) ORDER BY id;
id | estimated_class_text | attributes | class_text | class | state ----+----------------------+-------------------+-----------------+-------+----------- 1 | Iris_setosa | {5.0,3.2,1.2,0.2} | Iris_setosa | 1 | Alaska 2 | Iris_setosa | {5.5,3.5,1.3,0.2} | Iris_setosa | 1 | Alaska 3 | Iris_setosa | {4.9,3.1,1.5,0.1} | Iris_setosa | 1 | Alaska 4 | Iris_setosa | {4.4,3.0,1.3,0.2} | Iris_setosa | 1 | Alaska 5 | Iris_setosa | {5.1,3.4,1.5,0.2} | Iris_setosa | 1 | Alaska 6 | Iris_setosa | {5.0,3.5,1.3,0.3} | Iris_setosa | 1 | Alaska 7 | Iris_setosa | {4.5,2.3,1.3,0.3} | Iris_setosa | 1 | Alaska 8 | Iris_setosa | {4.4,3.2,1.3,0.2} | Iris_setosa | 1 | Alaska 9 | Iris_setosa | {5.0,3.5,1.6,0.6} | Iris_setosa | 1 | Alaska 10 | Iris_setosa | {5.1,3.8,1.9,0.4} | Iris_setosa | 1 | Alaska 11 | Iris_setosa | {4.8,3.0,1.4,0.3} | Iris_setosa | 1 | Alaska 12 | Iris_setosa | {5.1,3.8,1.6,0.2} | Iris_setosa | 1 | Alaska 13 | Iris_versicolor | {5.7,2.8,4.5,1.3} | Iris_versicolor | 2 | Alaska 14 | Iris_versicolor | {6.3,3.3,4.7,1.6} | Iris_versicolor | 2 | Alaska 15 | Iris_versicolor | {4.9,2.4,3.3,1.0} | Iris_versicolor | 2 | Alaska 16 | Iris_versicolor | {6.6,2.9,4.6,1.3} | Iris_versicolor | 2 | Alaska 17 | Iris_versicolor | {5.2,2.7,3.9,1.4} | Iris_versicolor | 2 | Alaska 18 | Iris_versicolor | {5.0,2.0,3.5,1.0} | Iris_versicolor | 2 | Alaska 19 | Iris_versicolor | {5.9,3.0,4.2,1.5} | Iris_versicolor | 2 | Alaska 20 | Iris_versicolor | {6.0,2.2,4.0,1.0} | Iris_versicolor | 2 | Alaska 21 | Iris_versicolor | {6.1,2.9,4.7,1.4} | Iris_versicolor | 2 | Alaska 22 | Iris_versicolor | {5.6,2.9,3.6,1.3} | Iris_versicolor | 2 | Alaska 23 | Iris_versicolor | {6.7,3.1,4.4,1.4} | Iris_versicolor | 2 | Alaska 24 | Iris_versicolor | {5.6,3.0,4.5,1.5} | Iris_versicolor | 2 | Alaska 25 | Iris_versicolor | {5.8,2.7,4.1,1.0} | Iris_versicolor | 2 | Alaska 26 | Iris_versicolor | {6.2,2.2,4.5,1.5} | Iris_versicolor | 2 | Alaska 27 | Iris_versicolor | {5.6,2.5,3.9,1.1} | Iris_versicolor | 2 | Alaska 28 | Iris_setosa | {5.0,3.4,1.5,0.2} | Iris_setosa | 1 | Tennessee 29 | Iris_setosa | {4.4,2.9,1.4,0.2} | Iris_setosa | 1 | Tennessee 30 | Iris_setosa | {4.9,3.1,1.5,0.1} | Iris_setosa | 1 | Tennessee 31 | Iris_setosa | {5.4,3.7,1.5,0.2} | Iris_setosa | 1 | Tennessee 32 | Iris_setosa | {4.8,3.4,1.6,0.2} | Iris_setosa | 1 | Tennessee 33 | Iris_setosa | {4.8,3.0,1.4,0.1} | Iris_setosa | 1 | Tennessee 34 | Iris_setosa | {4.3,3.0,1.1,0.1} | Iris_setosa | 1 | Tennessee 35 | Iris_setosa | {5.8,4.0,1.2,0.2} | Iris_setosa | 1 | Tennessee 36 | Iris_setosa | {5.7,4.4,1.5,0.4} | Iris_setosa | 1 | Tennessee 37 | Iris_setosa | {5.4,3.9,1.3,0.4} | Iris_setosa | 1 | Tennessee 38 | Iris_versicolor | {6.0,2.9,4.5,1.5} | Iris_versicolor | 2 | Tennessee 39 | Iris_versicolor | {5.7,2.6,3.5,1.0} | Iris_versicolor | 2 | Tennessee 40 | Iris_versicolor | {5.5,2.4,3.8,1.1} | Iris_versicolor | 2 | Tennessee 41 | Iris_versicolor | {5.5,2.4,3.7,1.0} | Iris_versicolor | 2 | Tennessee 42 | Iris_versicolor | {5.8,2.7,3.9,1.2} | Iris_versicolor | 2 | Tennessee 43 | Iris_versicolor | {6.0,2.7,5.1,1.6} | Iris_versicolor | 2 | Tennessee 44 | Iris_versicolor | {5.4,3.0,4.5,1.5} | Iris_versicolor | 2 | Tennessee 45 | Iris_versicolor | {6.0,3.4,4.5,1.6} | Iris_versicolor | 2 | Tennessee 46 | Iris_versicolor | {6.7,3.1,4.7,1.5} | Iris_versicolor | 2 | Tennessee 47 | Iris_versicolor | {6.3,2.3,4.4,1.3} | Iris_versicolor | 2 | Tennessee 48 | Iris_versicolor | {5.6,3.0,4.1,1.3} | Iris_versicolor | 2 | Tennessee 49 | Iris_versicolor | {5.5,2.5,4.0,1.3} | Iris_versicolor | 2 | Tennessee 50 | Iris_versicolor | {5.5,2.6,4.4,1.2} | Iris_versicolor | 2 | Tennessee 51 | Iris_versicolor | {6.1,3.0,4.6,1.4} | Iris_versicolor | 2 | Tennessee 52 | Iris_versicolor | {5.8,2.6,4.0,1.2} | Iris_versicolor | 2 | Tennessee (52 rows)Count the misclassifications:
SELECT COUNT(*) FROM mlp_prediction JOIN iris_data USING (id) WHERE mlp_prediction.estimated_class_text != iris_data.class_text;
count -------+ 0
DROP TABLE IF EXISTS iris_data_packed, iris_data_packed_summary, iris_data_packed_standardization; SELECT madlib.minibatch_preprocessor('iris_data', -- Source table 'iris_data_packed', -- Output table 'class_text', -- Dependent variable 'attributes' -- Independent variables );
DROP TABLE IF EXISTS mlp_model, mlp_model_summary, mlp_model_standardization; -- Set seed so results are reproducible SELECT setseed(0); SELECT madlib.mlp_classification( 'iris_data_packed', -- Output table from mini-batch preprocessor 'mlp_model', -- Destination table 'independent_varname', -- Hardcode to this, from table iris_data_packed 'dependent_varname', -- Hardcode to this, from table iris_data_packed ARRAY[5], -- Number of units per layer 'learning_rate_init=0.1, n_iterations=500, tolerance=0', -- Optimizer params 'tanh', -- Activation function NULL, -- Default weight (1) FALSE, -- No warm start FALSE -- Not verbose );View the model:
\x on SELECT * FROM mlp_model;
-[ RECORD 1 ]--+------------------------------------------------------------------------------------ coeff | {-0.0780564661828377,-0.0781452670639994,0.3083605989842 ... loss | 0.00563534904146765 num_iterations | 500
DROP TABLE IF EXISTS mlp_prediction; \x off SELECT madlib.mlp_predict( 'mlp_model', -- Model table 'iris_data', -- Test data table 'id', -- Id column in test table 'mlp_prediction', -- Output table for predictions 'response' -- Output classes, not probabilities ); SELECT * FROM mlp_prediction JOIN iris_data USING (id) ORDER BY id;
id | estimated_class_text | attributes | class_text | class | state ----+----------------------+-------------------+-----------------+-------+----------- 1 | Iris_setosa | {5.0,3.2,1.2,0.2} | Iris_setosa | 1 | Alaska 2 | Iris_setosa | {5.5,3.5,1.3,0.2} | Iris_setosa | 1 | Alaska 3 | Iris_setosa | {4.9,3.1,1.5,0.1} | Iris_setosa | 1 | Alaska 4 | Iris_setosa | {4.4,3.0,1.3,0.2} | Iris_setosa | 1 | Alaska 5 | Iris_setosa | {5.1,3.4,1.5,0.2} | Iris_setosa | 1 | Alaska 6 | Iris_setosa | {5.0,3.5,1.3,0.3} | Iris_setosa | 1 | Alaska 7 | Iris_setosa | {4.5,2.3,1.3,0.3} | Iris_setosa | 1 | Alaska 8 | Iris_setosa | {4.4,3.2,1.3,0.2} | Iris_setosa | 1 | Alaska 9 | Iris_setosa | {5.0,3.5,1.6,0.6} | Iris_setosa | 1 | Alaska 10 | Iris_setosa | {5.1,3.8,1.9,0.4} | Iris_setosa | 1 | Alaska 11 | Iris_setosa | {4.8,3.0,1.4,0.3} | Iris_setosa | 1 | Alaska 12 | Iris_setosa | {5.1,3.8,1.6,0.2} | Iris_setosa | 1 | Alaska 13 | Iris_versicolor | {5.7,2.8,4.5,1.3} | Iris_versicolor | 2 | Alaska 14 | Iris_versicolor | {6.3,3.3,4.7,1.6} | Iris_versicolor | 2 | Alaska 15 | Iris_versicolor | {4.9,2.4,3.3,1.0} | Iris_versicolor | 2 | Alaska 16 | Iris_versicolor | {6.6,2.9,4.6,1.3} | Iris_versicolor | 2 | Alaska 17 | Iris_versicolor | {5.2,2.7,3.9,1.4} | Iris_versicolor | 2 | Alaska 18 | Iris_versicolor | {5.0,2.0,3.5,1.0} | Iris_versicolor | 2 | Alaska 19 | Iris_versicolor | {5.9,3.0,4.2,1.5} | Iris_versicolor | 2 | Alaska 20 | Iris_versicolor | {6.0,2.2,4.0,1.0} | Iris_versicolor | 2 | Alaska 21 | Iris_versicolor | {6.1,2.9,4.7,1.4} | Iris_versicolor | 2 | Alaska 22 | Iris_versicolor | {5.6,2.9,3.6,1.3} | Iris_versicolor | 2 | Alaska 23 | Iris_versicolor | {6.7,3.1,4.4,1.4} | Iris_versicolor | 2 | Alaska 24 | Iris_versicolor | {5.6,3.0,4.5,1.5} | Iris_versicolor | 2 | Alaska 25 | Iris_versicolor | {5.8,2.7,4.1,1.0} | Iris_versicolor | 2 | Alaska 26 | Iris_versicolor | {6.2,2.2,4.5,1.5} | Iris_versicolor | 2 | Alaska 27 | Iris_versicolor | {5.6,2.5,3.9,1.1} | Iris_versicolor | 2 | Alaska 28 | Iris_setosa | {5.0,3.4,1.5,0.2} | Iris_setosa | 1 | Tennessee 29 | Iris_setosa | {4.4,2.9,1.4,0.2} | Iris_setosa | 1 | Tennessee 30 | Iris_setosa | {4.9,3.1,1.5,0.1} | Iris_setosa | 1 | Tennessee 31 | Iris_setosa | {5.4,3.7,1.5,0.2} | Iris_setosa | 1 | Tennessee 32 | Iris_setosa | {4.8,3.4,1.6,0.2} | Iris_setosa | 1 | Tennessee 33 | Iris_setosa | {4.8,3.0,1.4,0.1} | Iris_setosa | 1 | Tennessee 34 | Iris_setosa | {4.3,3.0,1.1,0.1} | Iris_setosa | 1 | Tennessee 35 | Iris_setosa | {5.8,4.0,1.2,0.2} | Iris_setosa | 1 | Tennessee 36 | Iris_setosa | {5.7,4.4,1.5,0.4} | Iris_setosa | 1 | Tennessee 37 | Iris_setosa | {5.4,3.9,1.3,0.4} | Iris_setosa | 1 | Tennessee 38 | Iris_versicolor | {6.0,2.9,4.5,1.5} | Iris_versicolor | 2 | Tennessee 39 | Iris_versicolor | {5.7,2.6,3.5,1.0} | Iris_versicolor | 2 | Tennessee 40 | Iris_versicolor | {5.5,2.4,3.8,1.1} | Iris_versicolor | 2 | Tennessee 41 | Iris_versicolor | {5.5,2.4,3.7,1.0} | Iris_versicolor | 2 | Tennessee 42 | Iris_versicolor | {5.8,2.7,3.9,1.2} | Iris_versicolor | 2 | Tennessee 43 | Iris_versicolor | {6.0,2.7,5.1,1.6} | Iris_versicolor | 2 | Tennessee 44 | Iris_versicolor | {5.4,3.0,4.5,1.5} | Iris_versicolor | 2 | Tennessee 45 | Iris_versicolor | {6.0,3.4,4.5,1.6} | Iris_versicolor | 2 | Tennessee 46 | Iris_versicolor | {6.7,3.1,4.7,1.5} | Iris_versicolor | 2 | Tennessee 47 | Iris_versicolor | {6.3,2.3,4.4,1.3} | Iris_versicolor | 2 | Tennessee 48 | Iris_versicolor | {5.6,3.0,4.1,1.3} | Iris_versicolor | 2 | Tennessee 49 | Iris_versicolor | {5.5,2.5,4.0,1.3} | Iris_versicolor | 2 | Tennessee 50 | Iris_versicolor | {5.5,2.6,4.4,1.2} | Iris_versicolor | 2 | Tennessee 51 | Iris_versicolor | {6.1,3.0,4.6,1.4} | Iris_versicolor | 2 | Tennessee 52 | Iris_versicolor | {5.8,2.6,4.0,1.2} | Iris_versicolor | 2 | Tennessee (52 rows)Count the misclassifications:
SELECT COUNT(*) FROM mlp_prediction JOIN iris_data USING (id) WHERE mlp_prediction.estimated_class_text != iris_data.class_text;
count -------+ 0
DROP TABLE IF EXISTS mlp_model, mlp_model_summary, mlp_model_standardization; -- Set seed so results are reproducible SELECT setseed(0); SELECT madlib.mlp_classification( 'iris_data', -- Source table 'mlp_model', -- Destination table 'attributes', -- Input features 'class_text', -- Label ARRAY[5], -- Number of units per layer 'learning_rate_init=0.003, n_iterations=50, tolerance=0, n_tries=3', -- Optimizer params, with n_tries 'tanh', -- Activation function NULL, -- Default weight (1) FALSE, -- No warm start FALSE -- Not verbose );View the model:
\x on SELECT * FROM mlp_model;
-[ RECORD 1 ]--+------------------------------------------------------------------------------------ coeff | {0.000156316559088915,0.131131017223563,-0.293990512682215 ... loss | 0.142238768280717 num_iterations | 50
SELECT madlib.mlp_classification( 'iris_data', -- Source table 'mlp_model', -- Destination table 'attributes', -- Input features 'class_text', -- Label ARRAY[5], -- Number of units per layer 'learning_rate_init=0.003, n_iterations=450, tolerance=0', -- Optimizer params 'tanh', -- Activation function NULL, -- Default weight (1) TRUE, -- Warm start FALSE -- Not verbose );View the model:
\x on SELECT * FROM mlp_model;
-[ RECORD 1 ]--+------------------------------------------------------------------------------------ coeff | {0.0883013960215441,0.235944854050211,-0.398126039487036 ... loss | 0.00818899646775659 num_iterations | 450Notice that the loss is lower compared to the previous example, despite having the same values for every other parameter. This is because the algorithm learned three different models starting with a different set of initial weights for the coefficients, and chose the best model among them as the initial weights for the coefficients when run with warm start.
DROP TABLE IF EXISTS mlp_model, mlp_model_summary, mlp_model_standardization; SELECT madlib.mlp_classification( 'iris_data_packed', -- Output table from mini-batch preprocessor 'mlp_model', -- Destination table 'independent_varname', -- Hardcode to this, from table iris_data_packed 'dependent_varname', -- Hardcode to this, from table iris_data_packed ARRAY[5], -- Number of units per layer 'learning_rate_init=0.1, n_iterations=500, tolerance=0.0001, solver=adam', -- Optimizer params 'tanh', -- Activation function NULL, -- Default weight (1) FALSE, -- No warm start FALSE -- Not verbose );View the model:
\x on SELECT * FROM mlp_model;
-[ RECORD 1 ]--+------------------------------------------------------------------------------------ coeff | {5.39258022025872,0.674679083739714,-2.59002712311116 ... loss | 0.155612432637527 num_iterations | 500
DROP TABLE IF EXISTS mlp_model_group, mlp_model_group_summary, mlp_model_group_standardization; -- Set seed so results are reproducible SELECT setseed(0); SELECT madlib.mlp_classification( 'iris_data', -- Source table 'mlp_model_group', -- Destination table 'attributes', -- Input features 'class_text', -- Label ARRAY[5], -- Number of units per layer 'learning_rate_init=0.003, n_iterations=500, -- Optimizer params tolerance=0', 'tanh', -- Activation function NULL, -- Default weight (1) FALSE, -- No warm start FALSE, -- Not verbose 'state' -- Grouping column );View the model:
\x on SELECT * FROM mlp_model_group ORDER BY state;
-[ RECORD 1 ]--+------------------------------------------------------------------------------------ state | Alaska coeff | {-0.51246602223,-0.78952457411,0.454192045225,0.223214894458,0.188804700547 ... loss | 0.0225081995679 num_iterations | 500 -[ RECORD 2 ]--+------------------------------------------------------------------------------------ state | Tennessee coeff | {-0.215009937565,0.116581594162,-0.397643279185,0.919193295184,-0.0811341736111 ... loss | 0.0182854983946 num_iterations | 500A separate model is learnt for each state, and the result table displays the name of the state (grouping column) associated with the model.
\x off DROP TABLE IF EXISTS mlp_prediction; SELECT madlib.mlp_predict( 'mlp_model_group', -- Model table 'iris_data', -- Test data table 'id', -- Id column in test table 'mlp_prediction', -- Output table for predictions 'response' -- Output classes, not probabilities ); SELECT * FROM mlp_prediction JOIN iris_data USING (state,id) ORDER BY state, id;Result for the classification model:
state | id | estimated_class_text | attributes | class_text | class -----------+----+----------------------+-------------------+-----------------+------- Alaska | 1 | Iris_setosa | {5.0,3.2,1.2,0.2} | Iris_setosa | 1 Alaska | 2 | Iris_setosa | {5.5,3.5,1.3,0.2} | Iris_setosa | 1 Alaska | 3 | Iris_setosa | {4.9,3.1,1.5,0.1} | Iris_setosa | 1 Alaska | 4 | Iris_setosa | {4.4,3.0,1.3,0.2} | Iris_setosa | 1 Alaska | 5 | Iris_setosa | {5.1,3.4,1.5,0.2} | Iris_setosa | 1 Alaska | 6 | Iris_setosa | {5.0,3.5,1.3,0.3} | Iris_setosa | 1 Alaska | 7 | Iris_setosa | {4.5,2.3,1.3,0.3} | Iris_setosa | 1 Alaska | 8 | Iris_setosa | {4.4,3.2,1.3,0.2} | Iris_setosa | 1 Alaska | 9 | Iris_setosa | {5.0,3.5,1.6,0.6} | Iris_setosa | 1 Alaska | 10 | Iris_setosa | {5.1,3.8,1.9,0.4} | Iris_setosa | 1 Alaska | 11 | Iris_setosa | {4.8,3.0,1.4,0.3} | Iris_setosa | 1 Alaska | 12 | Iris_setosa | {5.1,3.8,1.6,0.2} | Iris_setosa | 1 Alaska | 13 | Iris_versicolor | {5.7,2.8,4.5,1.3} | Iris_versicolor | 2 Alaska | 14 | Iris_versicolor | {6.3,3.3,4.7,1.6} | Iris_versicolor | 2 Alaska | 15 | Iris_versicolor | {4.9,2.4,3.3,1.0} | Iris_versicolor | 2 Alaska | 16 | Iris_versicolor | {6.6,2.9,4.6,1.3} | Iris_versicolor | 2 Alaska | 17 | Iris_versicolor | {5.2,2.7,3.9,1.4} | Iris_versicolor | 2 Alaska | 18 | Iris_versicolor | {5.0,2.0,3.5,1.0} | Iris_versicolor | 2 Alaska | 19 | Iris_versicolor | {5.9,3.0,4.2,1.5} | Iris_versicolor | 2 Alaska | 20 | Iris_versicolor | {6.0,2.2,4.0,1.0} | Iris_versicolor | 2 Alaska | 21 | Iris_versicolor | {6.1,2.9,4.7,1.4} | Iris_versicolor | 2 Alaska | 22 | Iris_versicolor | {5.6,2.9,3.6,1.3} | Iris_versicolor | 2 Alaska | 23 | Iris_versicolor | {6.7,3.1,4.4,1.4} | Iris_versicolor | 2 Alaska | 24 | Iris_versicolor | {5.6,3.0,4.5,1.5} | Iris_versicolor | 2 Alaska | 25 | Iris_versicolor | {5.8,2.7,4.1,1.0} | Iris_versicolor | 2 Alaska | 26 | Iris_versicolor | {6.2,2.2,4.5,1.5} | Iris_versicolor | 2 Alaska | 27 | Iris_versicolor | {5.6,2.5,3.9,1.1} | Iris_versicolor | 2 Tennessee | 28 | Iris_setosa | {5.0,3.4,1.5,0.2} | Iris_setosa | 1 Tennessee | 29 | Iris_setosa | {4.4,2.9,1.4,0.2} | Iris_setosa | 1 Tennessee | 30 | Iris_setosa | {4.9,3.1,1.5,0.1} | Iris_setosa | 1 Tennessee | 31 | Iris_setosa | {5.4,3.7,1.5,0.2} | Iris_setosa | 1 Tennessee | 32 | Iris_setosa | {4.8,3.4,1.6,0.2} | Iris_setosa | 1 Tennessee | 33 | Iris_setosa | {4.8,3.0,1.4,0.1} | Iris_setosa | 1 Tennessee | 34 | Iris_setosa | {4.3,3.0,1.1,0.1} | Iris_setosa | 1 Tennessee | 35 | Iris_setosa | {5.8,4.0,1.2,0.2} | Iris_setosa | 1 Tennessee | 36 | Iris_setosa | {5.7,4.4,1.5,0.4} | Iris_setosa | 1 Tennessee | 37 | Iris_setosa | {5.4,3.9,1.3,0.4} | Iris_setosa | 1 Tennessee | 38 | Iris_versicolor | {6.0,2.9,4.5,1.5} | Iris_versicolor | 2 Tennessee | 39 | Iris_versicolor | {5.7,2.6,3.5,1.0} | Iris_versicolor | 2 Tennessee | 40 | Iris_versicolor | {5.5,2.4,3.8,1.1} | Iris_versicolor | 2 Tennessee | 41 | Iris_versicolor | {5.5,2.4,3.7,1.0} | Iris_versicolor | 2 Tennessee | 42 | Iris_versicolor | {5.8,2.7,3.9,1.2} | Iris_versicolor | 2 Tennessee | 43 | Iris_versicolor | {6.0,2.7,5.1,1.6} | Iris_versicolor | 2 Tennessee | 44 | Iris_versicolor | {5.4,3.0,4.5,1.5} | Iris_versicolor | 2 Tennessee | 45 | Iris_versicolor | {6.0,3.4,4.5,1.6} | Iris_versicolor | 2 Tennessee | 46 | Iris_versicolor | {6.7,3.1,4.7,1.5} | Iris_versicolor | 2 Tennessee | 47 | Iris_versicolor | {6.3,2.3,4.4,1.3} | Iris_versicolor | 2 Tennessee | 48 | Iris_versicolor | {5.6,3.0,4.1,1.3} | Iris_versicolor | 2 Tennessee | 49 | Iris_versicolor | {5.5,2.5,4.0,1.3} | Iris_versicolor | 2 Tennessee | 50 | Iris_versicolor | {5.5,2.6,4.4,1.2} | Iris_versicolor | 2 Tennessee | 51 | Iris_versicolor | {6.1,3.0,4.6,1.4} | Iris_versicolor | 2 Tennessee | 52 | Iris_versicolor | {5.8,2.6,4.0,1.2} | Iris_versicolor | 2 (52 rows)
DROP TABLE IF EXISTS lin_housing; CREATE TABLE lin_housing (id serial, x numeric[], zipcode int, y float8); INSERT INTO lin_housing(id, x, zipcode, y) VALUES (1,ARRAY[1,0.00632,18.00,2.310,0,0.5380,6.5750,65.20,4.0900,1,296.0,15.30,396.90,4.98],94016,24.00), (2,ARRAY[1,0.02731,0.00,7.070,0,0.4690,6.4210,78.90,4.9671,2,242.0,17.80,396.90,9.14],94016,21.60), (3,ARRAY[1,0.02729,0.00,7.070,0,0.4690,7.1850,61.10,4.9671,2,242.0,17.80,392.83,4.03],94016,34.70), (4,ARRAY[1,0.03237,0.00,2.180,0,0.4580,6.9980,45.80,6.0622,3,222.0,18.70,394.63,2.94],94016,33.40), (5,ARRAY[1,0.06905,0.00,2.180,0,0.4580,7.1470,54.20,6.0622,3,222.0,18.70,396.90,5.33],94016,36.20), (6,ARRAY[1,0.02985,0.00,2.180,0,0.4580,6.4300,58.70,6.0622,3,222.0,18.70,394.12,5.21],94016,28.70), (7,ARRAY[1,0.08829,12.50,7.870,0,0.5240,6.0120,66.60,5.5605,5,311.0,15.20,395.60,12.43],94016,22.90), (8,ARRAY[1,0.14455,12.50,7.870,0,0.5240,6.1720,96.10,5.9505,5,311.0,15.20,396.90,19.15],94016,27.10), (9,ARRAY[1,0.21124,12.50,7.870,0,0.5240,5.6310,100.00,6.0821,5,311.0,15.20,386.63,29.93],94016,16.50), (10,ARRAY[1,0.17004,12.50,7.870,0,0.5240,6.0040,85.90,6.5921,5,311.0,15.20,386.71,17.10],94016,18.90), (11,ARRAY[1,0.22489,12.50,7.870,0,0.5240,6.3770,94.30,6.3467,5,311.0,15.20,392.52,20.45],94016,15.00), (12,ARRAY[1,0.11747,12.50,7.870,0,0.5240,6.0090,82.90,6.2267,5,311.0,15.20,396.90,13.27],20001,18.90), (13,ARRAY[1,0.09378,12.50,7.870,0,0.5240,5.8890,39.00,5.4509,5,311.0,15.20,390.50,15.71],20001,21.70), (14,ARRAY[1,0.62976,0.00,8.140,0,0.5380,5.9490,61.80,4.7075,4,307.0,21.00,396.90,8.26],20001,20.40), (15,ARRAY[1,0.63796,0.00,8.140,0,0.5380,6.0960,84.50,4.4619,4,307.0,21.00,380.02,10.26],20001,18.20), (16,ARRAY[1,0.62739,0.00,8.140,0,0.5380,5.8340,56.50,4.4986,4,307.0,21.00,395.62,8.47],20001,19.90), (17,ARRAY[1,1.05393,0.00,8.140,0,0.5380,5.9350,29.30,4.4986,4,307.0,21.00,386.85,6.58],20001, 23.10), (18,ARRAY[1,0.78420,0.00,8.140,0,0.5380,5.9900,81.70,4.2579,4,307.0,21.00,386.75,14.67],20001,17.50), (19,ARRAY[1,0.80271,0.00,8.140,0,0.5380,5.4560,36.60,3.7965,4,307.0,21.00,288.99,11.69],20001,20.20), (20,ARRAY[1,0.72580,0.00,8.140,0,0.5380,5.7270,69.50,3.7965,4,307.0,21.00,390.95,11.28],20001,18.20);
DROP TABLE IF EXISTS mlp_regress, mlp_regress_summary, mlp_regress_standardization; SELECT setseed(0); SELECT madlib.mlp_regression( 'lin_housing', -- Source table 'mlp_regress', -- Desination table 'x', -- Input features 'y', -- Dependent variable ARRAY[25,25], -- Number of units per layer 'learning_rate_init=0.001, n_iterations=500, lambda=0.001, tolerance=0', -- Optimizer params 'relu', -- Activation function NULL, -- Default weight (1) FALSE, -- No warm start FALSE -- Not verbose );View the model:
\x on SELECT * FROM mlp_regress;
[ RECORD 1 ]--+------------------------------------------------------------------------------------- coeff | {-0.250057620174,0.0630805938982,-0.290635490112,-0.382966162592,-0.212206338909... loss | 1.07042781236 num_iterations | 500
DROP TABLE IF EXISTS mlp_regress_prediction; SELECT madlib.mlp_predict( 'mlp_regress', -- Model table 'lin_housing', -- Test data table 'id', -- Id column in test table 'mlp_regress_prediction', -- Output table for predictions 'response' -- Output values, not probabilities );View results:
\x off SELECT * FROM lin_housing JOIN mlp_regress_prediction USING (id) ORDER BY id;
id | x | zipcode | y | estimated_y ----+----------------------------------------------------------------------------------+---------+------+------------------ 1 | {1,0.00632,18.00,2.310,0,0.5380,6.5750,65.20,4.0900,1,296.0,15.30,396.90,4.98} | 94016 | 24 | 23.9989087488259 2 | {1,0.02731,0.00,7.070,0,0.4690,6.4210,78.90,4.9671,2,242.0,17.80,396.90,9.14} | 94016 | 21.6 | 21.5983177932005 3 | {1,0.02729,0.00,7.070,0,0.4690,7.1850,61.10,4.9671,2,242.0,17.80,392.83,4.03} | 94016 | 34.7 | 34.7102398021623 4 | {1,0.03237,0.00,2.180,0,0.4580,6.9980,45.80,6.0622,3,222.0,18.70,394.63,2.94} | 94016 | 33.4 | 33.4221257351015 5 | {1,0.06905,0.00,2.180,0,0.4580,7.1470,54.20,6.0622,3,222.0,18.70,396.90,5.33} | 94016 | 36.2 | 36.1523886001663 6 | {1,0.02985,0.00,2.180,0,0.4580,6.4300,58.70,6.0622,3,222.0,18.70,394.12,5.21} | 94016 | 28.7 | 28.723894783928 7 | {1,0.08829,12.50,7.870,0,0.5240,6.0120,66.60,5.5605,5,311.0,15.20,395.60,12.43} | 94016 | 22.9 | 22.6515242795835 8 | {1,0.14455,12.50,7.870,0,0.5240,6.1720,96.10,5.9505,5,311.0,15.20,396.90,19.15} | 94016 | 27.1 | 25.7615314879354 9 | {1,0.21124,12.50,7.870,0,0.5240,5.6310,100.00,6.0821,5,311.0,15.20,386.63,29.93} | 94016 | 16.5 | 15.7368298351732 10 | {1,0.17004,12.50,7.870,0,0.5240,6.0040,85.90,6.5921,5,311.0,15.20,386.71,17.10} | 94016 | 18.9 | 16.8850496141437 11 | {1,0.22489,12.50,7.870,0,0.5240,6.3770,94.30,6.3467,5,311.0,15.20,392.52,20.45} | 94016 | 15 | 14.9150416339458 12 | {1,0.11747,12.50,7.870,0,0.5240,6.0090,82.90,6.2267,5,311.0,15.20,396.90,13.27} | 20001 | 18.9 | 19.4541629864106 13 | {1,0.09378,12.50,7.870,0,0.5240,5.8890,39.00,5.4509,5,311.0,15.20,390.50,15.71} | 20001 | 21.7 | 21.715554997762 14 | {1,0.62976,0.00,8.140,0,0.5380,5.9490,61.80,4.7075,4,307.0,21.00,396.90,8.26} | 20001 | 20.4 | 20.3181247234996 15 | {1,0.63796,0.00,8.140,0,0.5380,6.0960,84.50,4.4619,4,307.0,21.00,380.02,10.26} | 20001 | 18.2 | 18.5026399122209 16 | {1,0.62739,0.00,8.140,0,0.5380,5.8340,56.50,4.4986,4,307.0,21.00,395.62,8.47} | 20001 | 19.9 | 19.9131696333521 17 | {1,1.05393,0.00,8.140,0,0.5380,5.9350,29.30,4.4986,4,307.0,21.00,386.85,6.58} | 20001 | 23.1 | 23.1757650468106 18 | {1,0.78420,0.00,8.140,0,0.5380,5.9900,81.70,4.2579,4,307.0,21.00,386.75,14.67} | 20001 | 17.5 | 17.2671872543377 19 | {1,0.80271,0.00,8.140,0,0.5380,5.4560,36.60,3.7965,4,307.0,21.00,288.99,11.69} | 20001 | 20.2 | 20.1073474558796 20 | {1,0.72580,0.00,8.140,0,0.5380,5.7270,69.50,3.7965,4,307.0,21.00,390.95,11.28} | 20001 | 18.2 | 18.2143446340975 (20 rows)RMS error:
SELECT SQRT(AVG((y-estimated_y)*(y-estimated_y))) as rms_error FROM lin_housing JOIN mlp_regress_prediction USING (id);
rms_error ------------------+ 0.544960829104004
DROP TABLE IF EXISTS lin_housing_packed, lin_housing_packed_summary, lin_housing_packed_standardization; SELECT madlib.minibatch_preprocessor('lin_housing', -- Source table 'lin_housing_packed', -- Output table 'y', -- Dependent variable 'x' -- Independent variables );
DROP TABLE IF EXISTS mlp_regress, mlp_regress_summary, mlp_regress_standardization; SELECT setseed(0); SELECT madlib.mlp_regression( 'lin_housing_packed', -- Source table 'mlp_regress', -- Desination table 'independent_varname', -- Hardcode to this, from table lin_housing_packed 'dependent_varname', -- Hardcode to this, from table lin_housing_packed ARRAY[25,25], -- Number of units per layer 'learning_rate_init=0.01, n_iterations=500, lambda=0.001, tolerance=0', -- Optimizer params 'tanh', -- Activation function NULL, -- Default weight (1) FALSE, -- No warm start FALSE -- Not verbose );View model:
\x on SELECT * FROM mlp_regress;
-[ RECORD 1 ]--+------------------------------------------------------------- coeff | {0.0395865908810001,-0.164860448878703,-0.132787863194324... loss | 0.0442383714892138 num_iterations | 500
DROP TABLE IF EXISTS mlp_regress_prediction; SELECT madlib.mlp_predict( 'mlp_regress', -- Model table 'lin_housing', -- Test data table 'id', -- Id column in test table 'mlp_regress_prediction', -- Output table for predictions 'response' -- Output values, not probabilities ); \x off SELECT *, ABS(y-estimated_y) as abs_diff FROM lin_housing JOIN mlp_regress_prediction USING (id) ORDER BY id;
id | x | zipcode | y | zipcode | estimated_y | abs_diff ----+----------------------------------------------------------------------------------+---------+------+---------+------------------+-------------------- 1 | {1,0.00632,18.00,2.310,0,0.5380,6.5750,65.20,4.0900,1,296.0,15.30,396.90,4.98} | 94016 | 24 | 94016 | 23.9714991250013 | 0.0285008749987092 2 | {1,0.02731,0.00,7.070,0,0.4690,6.4210,78.90,4.9671,2,242.0,17.80,396.90,9.14} | 94016 | 21.6 | 94016 | 22.3655180133895 | 0.765518013389535 3 | {1,0.02729,0.00,7.070,0,0.4690,7.1850,61.10,4.9671,2,242.0,17.80,392.83,4.03} | 94016 | 34.7 | 94016 | 33.8620767428645 | 0.837923257135465 4 | {1,0.03237,0.00,2.180,0,0.4580,6.9980,45.80,6.0622,3,222.0,18.70,394.63,2.94} | 94016 | 33.4 | 94016 | 35.3094157686524 | 1.90941576865244 5 | {1,0.06905,0.00,2.180,0,0.4580,7.1470,54.20,6.0622,3,222.0,18.70,396.90,5.33} | 94016 | 36.2 | 94016 | 35.0379122731818 | 1.16208772681817 6 | {1,0.02985,0.00,2.180,0,0.4580,6.4300,58.70,6.0622,3,222.0,18.70,394.12,5.21} | 94016 | 28.7 | 94016 | 27.5207943492151 | 1.17920565078487 7 | {1,0.08829,12.50,7.870,0,0.5240,6.0120,66.60,5.5605,5,311.0,15.20,395.60,12.43} | 94016 | 22.9 | 94016 | 24.9841422781166 | 2.0841422781166 8 | {1,0.14455,12.50,7.870,0,0.5240,6.1720,96.10,5.9505,5,311.0,15.20,396.90,19.15} | 94016 | 27.1 | 94016 | 24.5403994064793 | 2.55960059352067 9 | {1,0.21124,12.50,7.870,0,0.5240,5.6310,100.00,6.0821,5,311.0,15.20,386.63,29.93} | 94016 | 16.5 | 94016 | 17.2588278443879 | 0.75882784438787 10 | {1,0.17004,12.50,7.870,0,0.5240,6.0040,85.90,6.5921,5,311.0,15.20,386.71,17.10} | 94016 | 18.9 | 94016 | 17.0600407532569 | 1.8399592467431 11 | {1,0.22489,12.50,7.870,0,0.5240,6.3770,94.30,6.3467,5,311.0,15.20,392.52,20.45} | 94016 | 15 | 94016 | 15.2284207930287 | 0.228420793028732 12 | {1,0.11747,12.50,7.870,0,0.5240,6.0090,82.90,6.2267,5,311.0,15.20,396.90,13.27} | 20001 | 18.9 | 20001 | 19.2272848285357 | 0.327284828535671 13 | {1,0.09378,12.50,7.870,0,0.5240,5.8890,39.00,5.4509,5,311.0,15.20,390.50,15.71} | 20001 | 21.7 | 20001 | 21.3979318641202 | 0.302068135879811 14 | {1,0.62976,0.00,8.140,0,0.5380,5.9490,61.80,4.7075,4,307.0,21.00,396.90,8.26} | 20001 | 20.4 | 20001 | 19.7743403979155 | 0.625659602084532 15 | {1,0.63796,0.00,8.140,0,0.5380,6.0960,84.50,4.4619,4,307.0,21.00,380.02,10.26} | 20001 | 18.2 | 20001 | 18.7400800902121 | 0.540080090212125 16 | {1,0.62739,0.00,8.140,0,0.5380,5.8340,56.50,4.4986,4,307.0,21.00,395.62,8.47} | 20001 | 19.9 | 20001 | 19.6187933144569 | 0.281206685543061 17 | {1,1.05393,0.00,8.140,0,0.5380,5.9350,29.30,4.4986,4,307.0,21.00,386.85,6.58} | 20001 | 23.1 | 20001 | 23.3492239648177 | 0.249223964817737 18 | {1,0.78420,0.00,8.140,0,0.5380,5.9900,81.70,4.2579,4,307.0,21.00,386.75,14.67} | 20001 | 17.5 | 20001 | 17.0806608347814 | 0.419339165218577 19 | {1,0.80271,0.00,8.140,0,0.5380,5.4560,36.60,3.7965,4,307.0,21.00,288.99,11.69} | 20001 | 20.2 | 20001 | 20.1559086626409 | 0.044091337359113 20 | {1,0.72580,0.00,8.140,0,0.5380,5.7270,69.50,3.7965,4,307.0,21.00,390.95,11.28} | 20001 | 18.2 | 20001 | 18.6980897920022 | 0.498089792002183 (20 rows)RMS error:
SELECT SQRT(AVG((y-estimated_y)*(y-estimated_y))) as rms_error FROM lin_housing JOIN mlp_regress_prediction USING (id);
rms_error -------------------+ 0.912158035902468 (1 row)
Note that the results you get for all examples may vary with the database you are using.
To train a neural net, the loss function is minimized using stochastic gradient descent. In the case of classification, the loss function is cross entropy. For regression, mean square error is used. Weights in the neural net are updated via the backpropogation process, which uses dynamic programming to compute the partial derivative of each weight with respect to the overall loss. This partial derivative incorporates the activation function used, which requires that the activation function be differentiable.
For an overview of multilayer perceptrons, see [1].
For details on backpropogation, see [2].
On the effect of database cluster size: as the database cluster size increases, the per iteration loss will be higher since the model only sees 1/n of the data, where n is the number of segments. However, each iteration runs faster than single node because it is only traversing 1/n of the data. For large data sets, all else being equal, a bigger cluster will achieve a given accuracy faster than a single node although it may take more iterations to achieve that accuracy.
[1] https://en.wikipedia.org/wiki/Multilayer_perceptron
[2] Yu Hen Hu. "Lecture 11. MLP (III): Back-Propagation." University of Wisconsin Madison: Computer-Aided Engineering. Web. 12 July 2017, http://homepages.cae.wisc.edu/~ece539/videocourse/notes/pdf/lec%2011%20MLP%20(3)%20BP.pdf
[3] "Neural Networks for Machine Learning", Lectures 6a and 6b on mini-batch gradient descent, Geoffrey Hinton with Nitish Srivastava and Kevin Swersky, http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf
[4] Kingma, D. P., & Ba, J. L. (2015), "Adam: a Method for Stochastic Optimization," International Conference on Learning Representations, 1–13.
File mlp.sql_in documenting the training function