diff --git a/1_to_1_multi_layer.ipynb b/1_to_1_multi_layer.ipynb index ce6cd94..f9c5ab7 100644 --- a/1_to_1_multi_layer.ipynb +++ b/1_to_1_multi_layer.ipynb @@ -23,7 +23,8 @@ "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", - "from tensorflow.python.framework import ops" + "from tensorflow.python.framework import ops\n", + "from sklearn import preprocessing" ] }, { @@ -35,7 +36,8 @@ "#import data as array\n", "# 8 hits with x,y,z\n", "\n", - "testset = pd.read_pickle('matched_8hittracks.pkl')" + "testset = pd.read_pickle('matched_8hittracks.pkl')\n", + "#print(testset)" ] }, { @@ -56,7 +58,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -71,12 +73,23 @@ " for t in range(timesteps):\n", " arr[i,t,0:3] = arr_orig[i,3*t:3*t+3]\n", " \n", + " return arr\n", + "\n", + "def reshapor_inv(array_shaped):\n", + " timesteps = int(array_shaped.shape[1])\n", + " num_examples = int(array_shaped.shape[0])\n", + " arr = np.zeros((num_examples, timesteps*3))\n", + " \n", + " for i in range(num_examples):\n", + " for t in range(timesteps):\n", + " arr[i,3*t:3*t+3] = array_shaped[i,t,:]\n", + " \n", " return arr" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -102,51 +115,21 @@ " train_set[i,:] += dataset[i,:]\n", " else:\n", " test_set[i - train_set_size,:] += dataset[i,:]\n", - " \n", - " \n", - " train_set = reshapor(train_set)\n", - " test_set = reshapor(test_set)\n", - " \n", + " \n", " return train_set, test_set\n", " " ] }, { "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "train_set, test_set = create_random_sets(tset, 0.99)\n", - "#print(test_set.shape, train_set.shape, reshapor(tset).shape)\n", - "#print(test_set[0,:,:])" - ] - }, - { - "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ - "### create target array of shape (num_examples, 4 timesteps, 3 = n_inputs), inputt array of shape (num_examples, 4 timesteps, 12 = n_inputs)###\n", + "train_set, test_set = create_random_sets(tset, 0.99)\n", "\n", - "def target_and_input(data_set):\n", - " \n", - " num_ex = data_set.shape[0]\n", - " inputt = np.zeros((num_ex, 4, 12))\n", - " target = np.zeros((num_ex, 4, 3))\n", - " \n", - " \n", - " for i in range(4):\n", - " target[:,i,:] = data_set[:,4+i,:]\n", - " for f in range(4):\n", - " inputt[:,i,3*f:3*f+3] = data_set[:,i+f,:]\n", - " \n", - " \n", - " \n", - " \n", - " return inputt, target\n", - " " + "#print(test_set.shape, train_set.shape, reshapor(tset).shape)\n", + "#print(test_set[0,:,:])" ] }, { @@ -155,10 +138,38 @@ "metadata": {}, "outputs": [], "source": [ - "inputt_train, target_train = target_and_input(train_set)\n", - "inputt_test, target_test = target_and_input(test_set)\n", - "#print(inputt_train[0,:,:])\n", - "#print(target_train[0,:,:])" + "#Normalize the data advanced version with scikit learn\n", + "\n", + "#set the transormation based on training set\n", + "def set_min_max_scalor(arr, feature_range= (-1,1)):\n", + " min_max_scalor = preprocessing.MinMaxScaler(feature_range=feature_range)\n", + " if len(arr.shape) == 3:\n", + " arr = reshapor(min_max_scalor.fit_transform(reshapor_inv(arr))) \n", + " else:\n", + " arr = min_max_scalor.fit_transform(arr)\n", + " return min_max_scalor\n", + "\n", + "min_max_scalor = set_min_max_scalor(train_set)\n", + "\n", + "\n", + "#transform data\n", + "def min_max_scaler(arr, min_max_scalor= min_max_scalor):\n", + " \n", + " if len(arr.shape) == 3:\n", + " arr = reshapor(min_max_scalor.transform(reshapor_inv(arr))) \n", + " else:\n", + " arr = min_max_scalor.transform(arr)\n", + " \n", + " return arr\n", + " \n", + "#inverse transformation\n", + "def min_max_scaler_inv(arr, min_max_scalor= min_max_scalor):\n", + " if len(arr.shape) == 3:\n", + " arr = reshapor(min_max_scalor.inverse_transform(reshapor_inv(arr)))\n", + " else:\n", + " arr = min_max_scalor.inverse_transform(arr)\n", + " \n", + " return arr" ] }, { @@ -167,6 +178,41 @@ "metadata": {}, "outputs": [], "source": [ + "train_set = reshapor(train_set)\n", + "test_set = reshapor(test_set)\n", + "\n", + "#print(train_set[0,:,:])" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "train_set = min_max_scaler(train_set)\n", + "test_set = min_max_scaler(test_set)\n", + "\n", + "#print(train_set[0,:,:])" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "#train_set = min_max_scaler_inv(train_set)\n", + "\n", + "#print(train_set[0,:,:])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ "###create random mini_batches###\n", "\n", "\n", @@ -208,7 +254,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -223,7 +269,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -259,28 +305,48 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "class RNNPlacePrediction():\n", " \n", " \n", - " def __init__(self, time_steps, future_steps, ninputs, ncells, num_output, cell_type=\"basic_rnn\"):\n", + " def __init__(self, time_steps, future_steps, ninputs, ncells, num_output, cell_type=\"basic_rnn\", activation=\"relu\"):\n", " \n", " self.nsteps = time_steps\n", " self.future_steps = future_steps\n", " self.ninputs = ninputs\n", " self.ncells = ncells\n", " self.num_output = num_output\n", - " self._ = cell_type\n", + " self._ = cell_type #later used to create folder name\n", + " self.__ = activation #later used to create folder name\n", " \n", " #### The input is of shape (num_examples, time_steps, ninputs)\n", " #### ninputs is the dimentionality (number of features) of the time series (here coordinates)\n", " self.X = tf.placeholder(dtype=tf.float32, shape=(None, time_steps, ninputs))\n", " self.Y = tf.placeholder(dtype=tf.float32, shape=(None, time_steps, ninputs))\n", + "\n", + " \n", + " #Check if activation function valid and set activation\n", + " if activation==\"relu\":\n", + " self.activation = tf.nn.relu\n", + " \n", + " elif activation==\"tanh\":\n", + " self.activation = tf.nn.tanh\n", + " \n", + " elif activation==\"leaky_relu\":\n", + " self.activation = tf.nn.leaky_relu\n", + " \n", + " elif activation==\"elu\":\n", + " self.activation = tf.nn.elu\n", + " \n", + " else:\n", + " raise ValueError(\"Wrong rnn avtivation function: {}\".format(activation))\n", " \n", " \n", + " \n", + " #Check if cell type valid and set cell_type\n", " if cell_type==\"basic_rnn\":\n", " self.cell_type = tf.contrib.rnn.BasicRNNCell\n", " \n", @@ -290,12 +356,11 @@ " elif cell_type==\"GRU\":\n", " self.cell_type = tf.contrib.rnn.GRUCell\n", " \n", - " else: # JONAS\n", + " else:\n", " raise ValueError(\"Wrong rnn cell type: {}\".format(cell_type))\n", " \n", " \n", - " #Check Input of ncells\n", - " \n", + " #Check Input of ncells \n", " if (type(self.ncells) == int):\n", " self.ncells = [self.ncells]\n", " \n", @@ -306,8 +371,12 @@ " if type(self.ncells[_]) != int:\n", " raise ValueError(\"Wrong type of Input for ncells\")\n", " \n", + " self.activationlist = []\n", + " for _ in range(len(self.ncells)-1):\n", + " self.activationlist.append(self.activation)\n", + " self.activationlist.append(tf.nn.tanh)\n", " \n", - " self.cell = tf.contrib.rnn.MultiRNNCell([self.cell_type(num_units=self.ncells[layer], activation=tf.nn.relu)\n", + " self.cell = tf.contrib.rnn.MultiRNNCell([self.cell_type(num_units=self.ncells[layer], activation=self.activationlist[layer])\n", " for layer in range(len(self.ncells))])\n", " \n", " \n", @@ -343,12 +412,12 @@ " \n", " \n", " \n", - " def fit(self, minibatches, epochs, print_step, checkpoint = 10, patience = 15):\n", + " def fit(self, minibatches, epochs, print_step, checkpoint = 5, patience = 200):\n", " self.loss_list = []\n", " patience_cnt = 0\n", " epoche_save = 0\n", " \n", - " folder = \"./rnn_model_\" + str(self._) + str(self.ncells).replace(\" \",\"\") + \"c\" \"_checkpoint/rnn_basic\"\n", + " folder = \"./rnn_model_\" + str(self._)+ \"_\" + self.__ + \"_\" + str(self.ncells).replace(\" \",\"\") + \"c\" + \"_checkpoint/rnn_basic\"\n", " \n", " for iep in range(epochs):\n", " loss = 0\n", @@ -364,7 +433,7 @@ " \n", " loss += self.sess.run(self.cost, feed_dict={self.X:train, self.Y:target})\n", " \n", - " #Normalize loss over number of batches\n", + " #Normalize loss over number of batches and scale it back before normaliziation\n", " loss /= batches\n", " self.loss_list.append(loss)\n", " \n", @@ -377,25 +446,26 @@ " epoche_save = iep\n", " \n", " #early stopping with patience\n", - " if iep > 1 and abs(self.loss_list[iep]-self.loss_list[iep-1]) < 0.005:\n", + " if iep > 1 and abs(self.loss_list[iep]-self.loss_list[iep-1]) < 2/1000000:\n", " patience_cnt += 1\n", " #print(\"Patience now at: \", patience_cnt, \" of \", patience)\n", " \n", " if patience_cnt + 1 > patience:\n", - " print(\"Early stopping at epoch \", iep, \", difference: \", abs(self.loss_list[iep]-self.loss_list[iep-1]))\n", + " print(\"\\n\", \"Early stopping at epoch \", iep, \", difference: \", abs(self.loss_list[iep]-self.loss_list[iep-1]))\n", " print(\"Cost: \",loss)\n", " break\n", " \n", + " #Note that the loss here is multiplied with 1000 for easier reading\n", " if iep%print_step==0:\n", " print(\"Epoch number \",iep)\n", - " print(\"Cost: \",loss)\n", + " print(\"Cost: \",loss*1000, \"e-3\")\n", " print(\"Patience: \",patience_cnt, \"/\", patience)\n", " print(\"Last checkpoint at: Epoch \", epoche_save, \"\\n\")\n", " \n", " #Set model back to the last checkpoint if performance was better\n", " if self.loss_list[epoche_save] < self.loss_list[iep]:\n", " self.load(folder)\n", - " print(\"Last checkpoint at epoch \", epoche_save, \" loaded\")\n", + " print(\"\\n\", \"Last checkpoint at epoch \", epoche_save, \" loaded\")\n", " print(\"Performance at last checkpoint is \" ,self.loss_list[iep] - self.loss_list[epoche_save], \" better\" )\n", " \n", " \n", @@ -415,7 +485,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -432,7 +502,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -448,12 +518,12 @@ "source": [ "tf.reset_default_graph()\n", "rnn = RNNPlacePrediction(time_steps=timesteps, future_steps=future_steps, ninputs=ninputs, \n", - " ncells=ncells, num_output=num_output, cell_type=\"lstm\")" + " ncells=ncells, num_output=num_output, cell_type=\"lstm\", activation=\"leaky_relu\")" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -462,7 +532,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 21, "metadata": { "scrolled": true }, @@ -472,27 +542,604 @@ "output_type": "stream", "text": [ "Epoch number 0\n", - "Cost: 3.893127314587857\n", - "Patience: 0 / 5\n", + "Cost: 3770.231458734959 e4\n", + "Patience: 0 / 200\n", "Last checkpoint at: Epoch 0 \n", "\n", - "Early stopping at epoch 10 , difference: 0.002693013941988287\n", - "Cost: 3.91306800537921\n", - "INFO:tensorflow:Restoring parameters from ./rnn_model_lstm[50,40,30,20,10]c_checkpoint/rnn_basic\n", - "Last checkpoint at epoch 0 loaded\n", - "Performance at last checkpoint is 0.019940690791353077 better\n" + "Epoch number 5\n", + "Cost: 1649.7736788810569 e4\n", + "Patience: 0 / 200\n", + "Last checkpoint at: Epoch 5 \n", + "\n", + "Epoch number 10\n", + "Cost: 625.2868418046768 e4\n", + "Patience: 0 / 200\n", + "Last checkpoint at: Epoch 10 \n", + "\n", + "Epoch number 15\n", + "Cost: 294.9610768639027 e4\n", + "Patience: 0 / 200\n", + "Last checkpoint at: Epoch 15 \n", + "\n", + "Epoch number 20\n", + "Cost: 209.0108957379422 e4\n", + "Patience: 0 / 200\n", + "Last checkpoint at: Epoch 20 \n", + "\n", + "Epoch number 25\n", + "Cost: 174.1866168982171 e4\n", + "Patience: 0 / 200\n", + "Last checkpoint at: Epoch 25 \n", + "\n", + "Epoch number 30\n", + "Cost: 149.8719225538538 e4\n", + "Patience: 0 / 200\n", + "Last checkpoint at: Epoch 30 \n", + "\n", + "Epoch number 35\n", + "Cost: 131.33942407179387 e4\n", + "Patience: 0 / 200\n", + "Last checkpoint at: Epoch 35 \n", + "\n", + "Epoch number 40\n", + "Cost: 115.83642023516462 e4\n", + "Patience: 0 / 200\n", + "Last checkpoint at: Epoch 40 \n", + "\n", + "Epoch number 45\n", + "Cost: 107.55172256935151 e4\n", + "Patience: 0 / 200\n", + "Last checkpoint at: Epoch 45 \n", + "\n", + "Epoch number 50\n", + "Cost: 98.54952309359895 e4\n", + "Patience: 0 / 200\n", + "Last checkpoint at: Epoch 50 \n", + "\n", + "Epoch number 55\n", + "Cost: 95.66065657170529 e4\n", + "Patience: 0 / 200\n", + "Last checkpoint at: Epoch 55 \n", + "\n", + "Epoch number 60\n", + "Cost: 90.34742145462239 e4\n", + "Patience: 1 / 200\n", + "Last checkpoint at: Epoch 60 \n", + "\n", + "Epoch number 65\n", + "Cost: 84.77292855844853 e4\n", + "Patience: 2 / 200\n", + "Last checkpoint at: Epoch 65 \n", + "\n", + "Epoch number 70\n", + "Cost: 78.54001398416275 e4\n", + "Patience: 3 / 200\n", + "Last checkpoint at: Epoch 70 \n", + "\n", + "Epoch number 75\n", + "Cost: 75.23123551397882 e4\n", + "Patience: 3 / 200\n", + "Last checkpoint at: Epoch 75 \n", + "\n", + "Epoch number 80\n", + "Cost: 73.33986362085697 e4\n", + "Patience: 4 / 200\n", + "Last checkpoint at: Epoch 80 \n", + "\n", + "Epoch number 85\n", + "Cost: 69.12997319422504 e4\n", + "Patience: 5 / 200\n", + "Last checkpoint at: Epoch 85 \n", + "\n", + "Epoch number 90\n", + "Cost: 65.79162087291479 e4\n", + "Patience: 5 / 200\n", + "Last checkpoint at: Epoch 90 \n", + "\n", + "Epoch number 95\n", + "Cost: 61.82488113483216 e4\n", + "Patience: 6 / 200\n", + "Last checkpoint at: Epoch 95 \n", + "\n", + "Epoch number 100\n", + "Cost: 59.33671109774646 e4\n", + "Patience: 8 / 200\n", + "Last checkpoint at: Epoch 100 \n", + "\n", + "Epoch number 105\n", + "Cost: 57.19678456637453 e4\n", + "Patience: 9 / 200\n", + "Last checkpoint at: Epoch 105 \n", + "\n", + "Epoch number 110\n", + "Cost: 55.66507161773266 e4\n", + "Patience: 10 / 200\n", + "Last checkpoint at: Epoch 110 \n", + "\n", + "Epoch number 115\n", + "Cost: 54.365597526602286 e4\n", + "Patience: 13 / 200\n", + "Last checkpoint at: Epoch 115 \n", + "\n", + "Epoch number 120\n", + "Cost: 52.487826807067755 e4\n", + "Patience: 14 / 200\n", + "Last checkpoint at: Epoch 120 \n", + "\n", + "Epoch number 125\n", + "Cost: 51.60155072015651 e4\n", + "Patience: 17 / 200\n", + "Last checkpoint at: Epoch 125 \n", + "\n", + "Epoch number 130\n", + "Cost: 51.004822227232 e4\n", + "Patience: 20 / 200\n", + "Last checkpoint at: Epoch 130 \n", + "\n", + "Epoch number 135\n", + "Cost: 49.656663347590474 e4\n", + "Patience: 22 / 200\n", + "Last checkpoint at: Epoch 135 \n", + "\n", + "Epoch number 140\n", + "Cost: 49.04315717756114 e4\n", + "Patience: 26 / 200\n", + "Last checkpoint at: Epoch 140 \n", + "\n", + "Epoch number 145\n", + "Cost: 48.333713487583275 e4\n", + "Patience: 29 / 200\n", + "Last checkpoint at: Epoch 145 \n", + "\n", + "Epoch number 150\n", + "Cost: 47.4689517447606 e4\n", + "Patience: 33 / 200\n", + "Last checkpoint at: Epoch 150 \n", + "\n", + "Epoch number 155\n", + "Cost: 46.82262457827938 e4\n", + "Patience: 38 / 200\n", + "Last checkpoint at: Epoch 155 \n", + "\n", + "Epoch number 160\n", + "Cost: 46.189470573308625 e4\n", + "Patience: 43 / 200\n", + "Last checkpoint at: Epoch 160 \n", + "\n", + "Epoch number 165\n", + "Cost: 45.566867759570165 e4\n", + "Patience: 48 / 200\n", + "Last checkpoint at: Epoch 165 \n", + "\n", + "Epoch number 170\n", + "Cost: 45.00874754120695 e4\n", + "Patience: 53 / 200\n", + "Last checkpoint at: Epoch 170 \n", + "\n", + "Epoch number 175\n", + "Cost: 44.46649339367101 e4\n", + "Patience: 58 / 200\n", + "Last checkpoint at: Epoch 175 \n", + "\n", + "Epoch number 180\n", + "Cost: 43.92929008587244 e4\n", + "Patience: 63 / 200\n", + "Last checkpoint at: Epoch 180 \n", + "\n", + "Epoch number 185\n", + "Cost: 43.44754183585656 e4\n", + "Patience: 68 / 200\n", + "Last checkpoint at: Epoch 185 \n", + "\n", + "Epoch number 190\n", + "Cost: 42.95319576371223 e4\n", + "Patience: 73 / 200\n", + "Last checkpoint at: Epoch 190 \n", + "\n", + "Epoch number 195\n", + "Cost: 42.52819289909082 e4\n", + "Patience: 78 / 200\n", + "Last checkpoint at: Epoch 195 \n", + "\n", + "Epoch number 200\n", + "Cost: 41.93341770665126 e4\n", + "Patience: 83 / 200\n", + "Last checkpoint at: Epoch 200 \n", + "\n", + "Epoch number 205\n", + "Cost: 41.554861285902085 e4\n", + "Patience: 88 / 200\n", + "Last checkpoint at: Epoch 205 \n", + "\n", + "Epoch number 210\n", + "Cost: 41.090038733834284 e4\n", + "Patience: 93 / 200\n", + "Last checkpoint at: Epoch 210 \n", + "\n", + "Epoch number 215\n", + "Cost: 40.845294889221165 e4\n", + "Patience: 98 / 200\n", + "Last checkpoint at: Epoch 215 \n", + "\n", + "Epoch number 220\n", + "Cost: 40.25109122170412 e4\n", + "Patience: 103 / 200\n", + "Last checkpoint at: Epoch 220 \n", + "\n", + "Epoch number 225\n", + "Cost: 39.58158948002977 e4\n", + "Patience: 108 / 200\n", + "Last checkpoint at: Epoch 225 \n", + "\n", + "Epoch number 230\n", + "Cost: 38.97598008327979 e4\n", + "Patience: 113 / 200\n", + "Last checkpoint at: Epoch 230 \n", + "\n", + "Epoch number 235\n", + "Cost: 38.51150915502234 e4\n", + "Patience: 118 / 200\n", + "Last checkpoint at: Epoch 235 \n", + "\n", + "Epoch number 240\n", + "Cost: 38.299499218292695 e4\n", + "Patience: 123 / 200\n", + "Last checkpoint at: Epoch 240 \n", + "\n", + "Epoch number 245\n", + "Cost: 37.74655878821269 e4\n", + "Patience: 128 / 200\n", + "Last checkpoint at: Epoch 245 \n", + "\n", + "Epoch number 250\n", + "Cost: 37.40582783567778 e4\n", + "Patience: 133 / 200\n", + "Last checkpoint at: Epoch 250 \n", + "\n", + "Epoch number 255\n", + "Cost: 37.24810196720856 e4\n", + "Patience: 138 / 200\n", + "Last checkpoint at: Epoch 255 \n", + "\n", + "Epoch number 260\n", + "Cost: 37.280498320197175 e4\n", + "Patience: 143 / 200\n", + "Last checkpoint at: Epoch 255 \n", + "\n", + "Epoch number 265\n", + "Cost: 36.25094043487247 e4\n", + "Patience: 147 / 200\n", + "Last checkpoint at: Epoch 265 \n", + "\n", + "Epoch number 270\n", + "Cost: 36.03106825315255 e4\n", + "Patience: 152 / 200\n", + "Last checkpoint at: Epoch 270 \n", + "\n", + "Epoch number 275\n", + "Cost: 35.67509779191398 e4\n", + "Patience: 156 / 200\n", + "Last checkpoint at: Epoch 275 \n", + "\n", + "Epoch number 280\n", + "Cost: 35.42137842506487 e4\n", + "Patience: 161 / 200\n", + "Last checkpoint at: Epoch 280 \n", + "\n", + "Epoch number 285\n", + "Cost: 35.79035718390282 e4\n", + "Patience: 164 / 200\n", + "Last checkpoint at: Epoch 280 \n", + "\n", + "Epoch number 290\n", + "Cost: 33.758991754594 e4\n", + "Patience: 165 / 200\n", + "Last checkpoint at: Epoch 290 \n", + "\n", + "Epoch number 295\n", + "Cost: 34.39420328891658 e4\n", + "Patience: 166 / 200\n", + "Last checkpoint at: Epoch 290 \n", + "\n", + "Epoch number 300\n", + "Cost: 33.66679522862777 e4\n", + "Patience: 166 / 200\n", + "Last checkpoint at: Epoch 300 \n", + "\n", + "Epoch number 305\n", + "Cost: 34.23552023880976 e4\n", + "Patience: 167 / 200\n", + "Last checkpoint at: Epoch 300 \n", + "\n", + "Epoch number 310\n", + "Cost: 33.27848409560132 e4\n", + "Patience: 168 / 200\n", + "Last checkpoint at: Epoch 310 \n", + "\n", + "Epoch number 315\n", + "Cost: 32.72916789741275 e4\n", + "Patience: 171 / 200\n", + "Last checkpoint at: Epoch 315 \n", + "\n", + "Epoch number 320\n", + "Cost: 32.42362023113255 e4\n", + "Patience: 173 / 200\n", + "Last checkpoint at: Epoch 320 \n", + "\n", + "Epoch number 325\n", + "Cost: 33.13556412591579 e4\n", + "Patience: 173 / 200\n", + "Last checkpoint at: Epoch 320 \n", + "\n", + "Epoch number 330\n", + "Cost: 34.35548811041294 e4\n", + "Patience: 173 / 200\n", + "Last checkpoint at: Epoch 320 \n", + "\n", + "Epoch number 335\n", + "Cost: 31.17884152588692 e4\n", + "Patience: 174 / 200\n", + "Last checkpoint at: Epoch 335 \n", + "\n", + "Epoch number 340\n", + "Cost: 33.64366251341206 e4\n", + "Patience: 174 / 200\n", + "Last checkpoint at: Epoch 335 \n", + "\n", + "Epoch number 345\n", + "Cost: 32.388941939682404 e4\n", + "Patience: 175 / 200\n", + "Last checkpoint at: Epoch 335 \n", + "\n", + "Epoch number 350\n", + "Cost: 29.8897856648298 e4\n", + "Patience: 175 / 200\n", + "Last checkpoint at: Epoch 350 \n", + "\n", + "Epoch number 355\n", + "Cost: 30.779531522792706 e4\n", + "Patience: 176 / 200\n", + "Last checkpoint at: Epoch 350 \n", + "\n", + "Epoch number 360\n", + "Cost: 32.77950439641767 e4\n", + "Patience: 177 / 200\n", + "Last checkpoint at: Epoch 350 \n", + "\n", + "Epoch number 365\n", + "Cost: 34.279519781232516 e4\n", + "Patience: 177 / 200\n", + "Last checkpoint at: Epoch 350 \n", + "\n", + "Epoch number 370\n", + "Cost: 29.02430596147129 e4\n", + "Patience: 177 / 200\n", + "Last checkpoint at: Epoch 370 \n", + "\n", + "Epoch number 375\n", + "Cost: 31.375054398828997 e4\n", + "Patience: 178 / 200\n", + "Last checkpoint at: Epoch 370 \n", + "\n", + "Epoch number 380\n", + "Cost: 33.813590144223355 e4\n", + "Patience: 178 / 200\n", + "Last checkpoint at: Epoch 370 \n", + "\n", + "Epoch number 385\n", + "Cost: 28.6719871268786 e4\n", + "Patience: 178 / 200\n", + "Last checkpoint at: Epoch 385 \n", + "\n", + "Epoch number 390\n", + "Cost: 31.848519872081408 e4\n", + "Patience: 179 / 200\n", + "Last checkpoint at: Epoch 385 \n", + "\n", + "Epoch number 395\n", + "Cost: 29.007866582337847 e4\n", + "Patience: 181 / 200\n", + "Last checkpoint at: Epoch 385 \n", + "\n", + "Epoch number 400\n", + "Cost: 33.16965553552863 e4\n", + "Patience: 181 / 200\n", + "Last checkpoint at: Epoch 385 \n", + "\n", + "Epoch number 405\n", + "Cost: 32.650657305295795 e4\n", + "Patience: 181 / 200\n", + "Last checkpoint at: Epoch 385 \n", + "\n", + "Epoch number 410\n", + "Cost: 28.816359365319318 e4\n", + "Patience: 181 / 200\n", + "Last checkpoint at: Epoch 385 \n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch number 415\n", + "Cost: 29.141941761716886 e4\n", + "Patience: 181 / 200\n", + "Last checkpoint at: Epoch 385 \n", + "\n", + "Epoch number 420\n", + "Cost: 30.577135856877614 e4\n", + "Patience: 182 / 200\n", + "Last checkpoint at: Epoch 385 \n", + "\n", + "Epoch number 425\n", + "Cost: 29.400000695456217 e4\n", + "Patience: 183 / 200\n", + "Last checkpoint at: Epoch 385 \n", + "\n", + "Epoch number 430\n", + "Cost: 26.99479599423865 e4\n", + "Patience: 183 / 200\n", + "Last checkpoint at: Epoch 430 \n", + "\n", + "Epoch number 435\n", + "Cost: 30.304402994744958 e4\n", + "Patience: 184 / 200\n", + "Last checkpoint at: Epoch 430 \n", + "\n", + "Epoch number 440\n", + "Cost: 29.647010675770172 e4\n", + "Patience: 184 / 200\n", + "Last checkpoint at: Epoch 430 \n", + "\n", + "Epoch number 445\n", + "Cost: 27.00613232012442 e4\n", + "Patience: 185 / 200\n", + "Last checkpoint at: Epoch 430 \n", + "\n", + "Epoch number 450\n", + "Cost: 27.036350567210864 e4\n", + "Patience: 186 / 200\n", + "Last checkpoint at: Epoch 430 \n", + "\n", + "Epoch number 455\n", + "Cost: 27.08697458729148 e4\n", + "Patience: 187 / 200\n", + "Last checkpoint at: Epoch 430 \n", + "\n", + "Epoch number 460\n", + "Cost: 28.004820329791055 e4\n", + "Patience: 188 / 200\n", + "Last checkpoint at: Epoch 430 \n", + "\n", + "Epoch number 465\n", + "Cost: 26.3666685551722 e4\n", + "Patience: 188 / 200\n", + "Last checkpoint at: Epoch 465 \n", + "\n", + "Epoch number 470\n", + "Cost: 26.36444576560183 e4\n", + "Patience: 188 / 200\n", + "Last checkpoint at: Epoch 470 \n", + "\n", + "Epoch number 475\n", + "Cost: 31.123574119695324 e4\n", + "Patience: 188 / 200\n", + "Last checkpoint at: Epoch 470 \n", + "\n", + "Epoch number 480\n", + "Cost: 27.53822227068087 e4\n", + "Patience: 189 / 200\n", + "Last checkpoint at: Epoch 470 \n", + "\n", + "Epoch number 485\n", + "Cost: 26.472763485334657 e4\n", + "Patience: 189 / 200\n", + "Last checkpoint at: Epoch 470 \n", + "\n", + "Epoch number 490\n", + "Cost: 25.98736776990142 e4\n", + "Patience: 190 / 200\n", + "Last checkpoint at: Epoch 490 \n", + "\n", + "Epoch number 495\n", + "Cost: 25.32091308781441 e4\n", + "Patience: 191 / 200\n", + "Last checkpoint at: Epoch 495 \n", + "\n", + "Epoch number 500\n", + "Cost: 26.51548171614079 e4\n", + "Patience: 191 / 200\n", + "Last checkpoint at: Epoch 495 \n", + "\n", + "Epoch number 505\n", + "Cost: 25.78474184934129 e4\n", + "Patience: 191 / 200\n", + "Last checkpoint at: Epoch 495 \n", + "\n", + "Epoch number 510\n", + "Cost: 26.016250708477294 e4\n", + "Patience: 191 / 200\n", + "Last checkpoint at: Epoch 495 \n", + "\n", + "Epoch number 515\n", + "Cost: 28.13248825754891 e4\n", + "Patience: 191 / 200\n", + "Last checkpoint at: Epoch 495 \n", + "\n", + "Epoch number 520\n", + "Cost: 28.441735156910852 e4\n", + "Patience: 191 / 200\n", + "Last checkpoint at: Epoch 495 \n", + "\n", + "Epoch number 525\n", + "Cost: 25.8854781079324 e4\n", + "Patience: 193 / 200\n", + "Last checkpoint at: Epoch 495 \n", + "\n", + "Epoch number 530\n", + "Cost: 25.448204473929202 e4\n", + "Patience: 193 / 200\n", + "Last checkpoint at: Epoch 495 \n", + "\n", + "Epoch number 535\n", + "Cost: 26.26546668483222 e4\n", + "Patience: 193 / 200\n", + "Last checkpoint at: Epoch 495 \n", + "\n", + "Epoch number 540\n", + "Cost: 24.608338271525312 e4\n", + "Patience: 196 / 200\n", + "Last checkpoint at: Epoch 540 \n", + "\n", + "Epoch number 545\n", + "Cost: 25.521852422822665 e4\n", + "Patience: 196 / 200\n", + "Last checkpoint at: Epoch 540 \n", + "\n", + "Epoch number 550\n", + "Cost: 24.915404786217085 e4\n", + "Patience: 198 / 200\n", + "Last checkpoint at: Epoch 540 \n", + "\n", + "Epoch number 555\n", + "Cost: 25.868487217404105 e4\n", + "Patience: 198 / 200\n", + "Last checkpoint at: Epoch 540 \n", + "\n", + "Epoch number 560\n", + "Cost: 27.24954412576366 e4\n", + "Patience: 199 / 200\n", + "Last checkpoint at: Epoch 540 \n", + "\n", + "\n", + " Early stopping at epoch 565 , difference: 2.3366942843223992e-05\n", + "Cost: 0.002444114783739156\n" ] } ], "source": [ - "rnn.fit(minibatches, epochs=22, print_step=10)" + "rnn.fit(minibatches, epochs = 5000, print_step=5)" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "execution_count": 22, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEKCAYAAADjDHn2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3X2QXXd93/H35z7v6vlhMbZkWTKI1nIgdljkJARIgnFEkrFpArEoaZ3WrYcMnpDSJthDxrROM03IDAEGN8FpNG0oRDyFRkOVGMcGppQYtMbGRgbXsmLsRTKWrWdpH+7Dt3+cs7tHV/fuXT0crXb385q5c8/53XPu/s7qaj/39/udc36KCMzMzKZTmO0KmJnZxc9hYWZmPTkszMysJ4eFmZn15LAwM7OeHBZmZtaTw8LMzHpyWJiZWU8OCzMz66k02xU4X1avXh3r16+f7WqYmc0pDz/88IsRMdBru3kTFuvXr2doaGi2q2FmNqdI+sFMtnM3lJmZ9eSwMDOznhwWZmbWk8PCzMx6cliYmVlPDgszM+vJYWFmZj0t+LA4Mdbgw19+kkeePTTbVTEzu2gt+LAYa7T42IN7eGz4yGxXxczsorXgw6JYEAD1ZmuWa2JmdvFa8GFRLiZh0WzFLNfEzOziteDDYqJl0XBYmJl1teDDolRIfgWNpsPCzKybBR8WxYKQoNnymIWZWTcLPiwASgVRdzeUmVlXDguSrigPcJuZdeewIG1Z+NRZM7OuHBZAqSi3LMzMpuGwAIqFAnWfDWVm1lWuYSFpi6QnJe2RdEeH198t6XFJj0r6uqRNafl6SSNp+aOS/izPepYK8tlQZmbTKOX1xpKKwD3AW4BhYJekHRHxRGazT0fEn6Xb3wh8GNiSvvZ0RFyTV/2ySkX5ojwzs2nk2bLYDOyJiL0RMQ5sB27KbhARRzOri4BZ+YtdKsgX5ZmZTSPPsFgDPJdZH07LTiHpPZKeBj4E/FbmpQ2SHpH0NUlvyLGelIo+ddbMbDp5hoU6lJ32Fzki7omIVwDvB34vLd4PrIuIa4H3AZ+WtPS0HyDdJmlI0tCBAwfOuqI+ddbMbHp5hsUwcHlmfS2wb5rttwNvA4iIsYh4KV1+GHgaeFX7DhFxb0QMRsTgwMDAWVfUp86amU0vz7DYBWyUtEFSBdgK7MhuIGljZvWXgKfS8oF0gBxJVwIbgb15VbRYKPh2H2Zm08jtbKiIaEi6HbgPKALbImK3pLuBoYjYAdwu6XqgDhwCbkl3fyNwt6QG0ATeHREH86qrT501M5tebmEBEBE7gZ1tZXdllt/bZb8vAF/Is25ZPhvKzGx6voIbX2dhZtaLw4LkrrMOCzOz7hwWTHRDeczCzKwbhwU+ddbMrBeHBUk3lC/KMzPrzmFBMg+3WxZmZt05LPDZUGZmvTgs8HUWZma9OCxI7jrrloWZWXcOC9KWhW/3YWbWlcOC5GyopruhzMy6cliQDHDX3bIwM+vKYYFPnTUz68VhAZQLPnXWzGw6DguSyY8icOvCzKwLhwXJmAXgM6LMzLpwWJCcOgv4wjwzsy5yDQtJWyQ9KWmPpDs6vP5uSY9LelTS1yVtyrx2Z7rfk5J+Ic96lorJr8HjFmZmneUWFpKKwD3AW4FNwDuzYZD6dES8OiKuAT4EfDjddxOwFbga2AL81/T9cjHVsnA3lJlZJ3m2LDYDeyJib0SMA9uBm7IbRMTRzOoiYOKr/U3A9ogYi4h/BPak75eLYhoWHuA2M+uslON7rwGey6wPA9e1byTpPcD7gArw85l9H2rbd00+1YTy5AC3w8LMrJM8WxbqUHbaX+OIuCciXgG8H/i9M9lX0m2ShiQNHThw4KwrWiykYxYe4DYz6yjPsBgGLs+srwX2TbP9duBtZ7JvRNwbEYMRMTgwMHDWFS371Fkzs2nlGRa7gI2SNkiqkAxY78huIGljZvWXgKfS5R3AVklVSRuAjcC38qroxJiFu6HMzDrLbcwiIhqSbgfuA4rAtojYLeluYCgidgC3S7oeqAOHgFvSfXdL+izwBNAA3hMRzbzqWnI3lJnZtPIc4CYidgI728ruyiy/d5p9/wD4g/xqN2Xy1Fl3Q5mZdeQruIGiz4YyM5uWwwIop91Qvs7CzKwzhwVTA9x1X8FtZtaRw4KpU2fdsjAz68xhQebUWZ8NZWbWkcMCKPuus2Zm03JYkG1ZeMzCzKwThwXZ6yzcsjAz68RhwdTkRx7gNjPrzGHBVMvCp86amXXmsABKPnXWzGxaDgsyF+U5LMzMOnJYkLndh7uhzMw6cljgGwmamfXisMCnzpqZ9eKwYGryIw9wm5l15rBgqmUx3vCYhZlZJw4LoFAQlWKBcQ9wm5l1lGtYSNoi6UlJeyTd0eH190l6QtJjkh6QdEXmtaakR9PHjjzrCVAtFRit5zbNt5nZnJbbHNySisA9wFuAYWCXpB0R8URms0eAwYg4Kek3gQ8BN6evjUTENXnVr121XGDM3VBmZh3l2bLYDOyJiL0RMQ5sB27KbhARX4mIk+nqQ8DaHOszrWqpyFjdYWFm1kmeYbEGeC6zPpyWdXMr8LeZ9ZqkIUkPSXpbpx0k3ZZuM3TgwIFzqmy1XGC04W4oM7NOcuuGAtShrOO5qZJ+HRgE3pQpXhcR+yRdCTwo6fGIePqUN4u4F7gXYHBw8JzOe3XLwsysuzxbFsPA5Zn1tcC+9o0kXQ98ALgxIsYmyiNiX/q8F/gqcG2OdaVWLjDmloWZWUd5hsUuYKOkDZIqwFbglLOaJF0LfIIkKF7IlK+QVE2XVwOvB7ID4+ddteQBbjOzbnLrhoqIhqTbgfuAIrAtInZLuhsYiogdwB8Di4HPSQJ4NiJuBK4CPiGpRRJof9h2FtV5Vy0VOXxyPM8fYWY2Z+U5ZkFE7AR2tpXdlVm+vst+3wBenWfd2tV86qyZWVe+gjtVLRUdFmZmXTgsUr6C28ysO4dFqlZ2y8LMrBuHRapaKjDmloWZWUcOi1RyBbdbFmZmnTgsUrVSkWYraPg25WZmp3FYpKrl5FfhcQszs9M5LFLVUhHAZ0SZmXXgsEjV3LIwM+vKYZGaaFk4LMzMTuewSFVLya/C3VBmZqdzWKRqZbcszMy6cVikJloWvjDPzOx0DovUxKmzvjDPzOx0DovU5AC3WxZmZqdxWKR86qyZWXcOi5QvyjMz625GYSHpkzMp67DNFklPStoj6Y4Or79P0hOSHpP0gKQrMq/dIump9HHLTOp5Lny7DzOz7mbasrg6uyKpCLx2uh3Sbe4B3gpsAt4paVPbZo8AgxHxGuDzwIfSfVcCHwSuAzYDH5S0YoZ1PSu+KM/MrLtpw0LSnZKOAa+RdDR9HANeAP6mx3tvBvZExN6IGAe2AzdlN4iIr0TEyXT1IWBtuvwLwP0RcTAiDgH3A1vO6MjO0OSpsw13Q5mZtZs2LCLiv0TEEuCPI2Jp+lgSEasi4s4e770GeC6zPpyWdXMr8Ldnue85m7qC2y0LM7N2M+2G+pKkRQCSfl3Sh7PjC12oQ1l03FD6dWAQ+OMz2VfSbZKGJA0dOHCgR3WmJymZLc8tCzOz08w0LP4UOCnpx4HfBX4A/GWPfYaByzPra4F97RtJuh74AHBjRIydyb4RcW9EDEbE4MDAwAwPpbtkalW3LMzM2s00LBoRESRjDh+NiI8CS3rsswvYKGmDpAqwFdiR3UDStcAnSILihcxL9wE3SFqRDmzfkJblqlYuumVhZtZBaYbbHZN0J/AvgDekZzqVp9shIhqSbif5I18EtkXEbkl3A0MRsYOk22kx8DlJAM9GxI0RcVDS75MEDsDdEXHwjI/uDFXLblmYmXUy07C4GfjnwL+OiOclrWNqfKGriNgJ7GwruyuzfP00+24Dts2wfudFtVT0qbNmZh3MqBsqIp4HPgUsk/TLwGhE9BqzmHNq5YKv4DYz62CmV3D/GvAt4B3ArwHflPT2PCs2G9yyMDPrbKbdUB8AXjcxCC1pAPh7kquu5w2fOmtm1tlMz4YqtJ2t9NIZ7Dtn1MpFX5RnZtbBTFsWfyfpPuCv0vWbaRu4ng/csjAz62zasJD0SuCSiPgdSb8C/AzJ1dX/QDLgPa8kYeGWhZlZu15dSR8BjgFExF9HxPsi4t+RtCo+knflLrSkG8otCzOzdr3CYn1EPNZeGBFDwPpcajSL3LIwM+usV1jUpnmt73xW5GJQLRd9BbeZWQe9wmKXpH/bXijpVuDhfKo0e2qlAqONJsltsMzMbEKvs6F+G/iipHcxFQ6DQAX4Z3lWbDZUy0UioN4MKqVOd0k3M1uYpg2LiPgR8NOSfg74sbT4f0fEg7nXbBZkZ8urlObdZSRmZmdtRtdZRMRXgK/kXJdZVy0n83CP1lssmW60xsxsgfHX54ypqVV9+qyZWZbDIqMvbVn4Km4zs1M5LDJqmW4oMzOb4rDImGhZjLgbyszsFLmGhaQtkp6UtEfSHR1ef6Okb0tqtM+PIakp6dH0saN93zzUyh6zMDPrZKZ3nT1j6Tzd9wBvAYZJLvDbERFPZDZ7FvgN4D90eIuRiLgmr/p1MtENNTLusDAzy8otLIDNwJ6I2AsgaTtwEzAZFhHxTPraRTFIMDlm4ftDmZmdIs9uqDXAc5n14bRspmqShiQ9JOlt57dqnfVV0rBwy8LM7BR5tiw63S/jTG66tC4i9km6EnhQ0uMR8fQpP0C6DbgNYN26dWdf01Rt4joLnzprZnaKPFsWw8DlmfW1wL6Z7hwR+9LnvcBXgWs7bHNvRAxGxODAwMC51RaPWZiZdZNnWOwCNkraIKkCbAVmdFaTpBWSqunyauD1ZMY68uLrLMzMOsstLCKiAdwO3Ad8D/hsROyWdLekGwEkvU7SMPAO4BOSdqe7XwUMSfoOyT2p/rDtLKpcFAuiUiy4G8rMrE2eYxZExE6SKVizZXdllneRdE+17/cN4NV51q2bWrngbigzsza+grtNrVz0vaHMzNo4LNr0VYpuWZiZtXFYtKmVih7gNjNr47BoU6sUfSNBM7M2Dos2tVLBNxI0M2vjsGjTVyk6LMzM2jgs2njMwszsdA6LNrVywWMWZmZtHBZt3A1lZnY6h0WbaslhYWbWzmHRJmlZeMzCzCzLYdGmVioy3mzRbJ3J1BtmZvObw6JNXyWdAMldUWZmkxwWbabmtHBYmJlNcFi0mZwtz2FhZjbJYdHGs+WZmZ3OYdGmz/Nwm5mdxmHRZnE1mTzw+FhjlmtiZnbxyDUsJG2R9KSkPZLu6PD6GyV9W1JD0tvbXrtF0lPp45Y865nlsDAzO11uYSGpCNwDvBXYBLxT0qa2zZ4FfgP4dNu+K4EPAtcBm4EPSlqRV12zFteSsDjhsDAzm5Rny2IzsCci9kbEOLAduCm7QUQ8ExGPAe2jyb8A3B8RByPiEHA/sCXHuk5aVE3GLI45LMzMJuUZFmuA5zLrw2nZedtX0m2ShiQNHThw4KwrmrWkWgbg+KjDwsxsQp5hoQ5lM72Hxoz2jYh7I2IwIgYHBgbOqHLd1MoFigW5G8rMLCPPsBgGLs+srwX2XYB9z4kkFlWKHuA2M8vIMyx2ARslbZBUAbYCO2a4733ADZJWpAPbN6RlF8SSWplj7oYyM5uUW1hERAO4neSP/PeAz0bEbkl3S7oRQNLrJA0D7wA+IWl3uu9B4PdJAmcXcHdadkEsqhbdDWVmllHK880jYiews63srszyLpIupk77bgO25Vm/bpbUyhwbq8/GjzYzuyj5Cu4OlveVOXTCYWFmNsFh0cHy/gqHT47PdjXMzC4aDosOVi4qc+ikWxZmZhMcFh0s768wUm96AiQzs5TDooMV/RUADrkryswMcFh0tKI/ueWHB7nNzBIOiw5WLEpaFh7kNjNLOCw6mOiGOuiwMDMDHBYdTXZD+YwoMzPAYdHR8rRlcfiEWxZmZuCw6KhSKrC4WnLLwsws5bDoYnl/2afOmpmlHBZdrOivOCzMzFIOiy5WLKq4G8rMLOWw6GJFf9nXWZiZpRwWXazor3DQZ0OZmQEOi66W9ydTqzaardmuipnZrMs1LCRtkfSkpD2S7ujwelXSZ9LXvylpfVq+XtKIpEfTx5/lWc9OVk7c8mPE4xZmZrlNqyqpCNwDvAUYBnZJ2hERT2Q2uxU4FBGvlLQV+CPg5vS1pyPimrzq18vkhXknx1m9uDpb1TAzuyjk2bLYDOyJiL0RMQ5sB25q2+Ym4H+ky58H3ixJOdZpxiZu+XHQd541M8s1LNYAz2XWh9OyjttERAM4AqxKX9sg6RFJX5P0hhzr2ZHntDAzm5JbNxTQqYUQM9xmP7AuIl6S9Frgf0m6OiKOnrKzdBtwG8C6devOQ5WnvGxJ0vX0o6Oj5/V9zczmojxbFsPA5Zn1tcC+bttIKgHLgIMRMRYRLwFExMPA08Cr2n9ARNwbEYMRMTgwMHBeK796cZVKscAPD42c1/c1M5uL8gyLXcBGSRskVYCtwI62bXYAt6TLbwcejIiQNJAOkCPpSmAjsDfHup6mUBCXLa8xfNhhYWaWWzdURDQk3Q7cBxSBbRGxW9LdwFBE7AD+AvikpD3AQZJAAXgjcLekBtAE3h0RB/OqazdrVvS5ZWFmRr5jFkTETmBnW9ldmeVR4B0d9vsC8IU86zYT61b287fffZ6I4CI5ScvMbFb4Cu5p/NOXL+XwyTrPe5DbzBY4h8U0rr5sKQC7f3i0x5ZmZvObw2Iamy5bSqkgHn720GxXxcxsVjksptFfKXHN5cv5xp4XZ7sqZmazymHRw5teNcBjPzzCcwdPznZVzMxmjcOih1957VoAPjv0XI8tzczmL4dFD2uW9/GmVw3w2aHnGGs0Z7s6ZmazwmExA7f+zAZ+dHSMbV9/ZrarYmY2KxwWM/CGjQNcf9UlfPzBp9h/xFd0m9nC47CYobt+eRPNCN79P7/tubnNbMFxWMzQulX9fGzrtXx//1He/qff4NmXfHaUmS0cDoszcMPVL+dT/+Y6Xjoxzps//FXe//nHOOzJkcxsAVBE+3xEc9Pg4GAMDQ1dkJ/13MGT/Pn/2csnH/oBlWKBG65+Oddevpx3DK5lSa18QepgZnY+SHo4IgZ7buewOHu79x3hM7ue428e3ceRkTpLayWuu3IVb3zVAL/6E2vor+R6U18zs3PmsLiAWq3g8R8eYdv//Ue+9Nh+mq3kd3rl6kW8bv1KNm9YyY+tWcYVq/qplYuzUkczs04cFrMkInjw+y9w/xM/4uEfHOLpA8dJswMJLlvWx7qV/axaXGHVogovW1rjsuU1Ll3Wx2XL+hhYUqVWLnj+DDO7IGYaFu4nOc8k8earLuHNV10CJK2O3fuOsvfF4zzz4kmeeekEzx48ya5nDnJyvMmx0cZp71EpFfgnlyzhZUuqLOsvs7yvwvL+MtVSgXKxwIbVi1i1uEK1VORlS6qUimJRpUSh4IAxs3w4LHJWKIhXr13Gq9cu6/j6yfEG+4+Msv/wKPuOjHDg2BgHT4zz/eeP8vzRUb7//DGOjNQ5PnZ6qGRVigUuXV5jUaVEtVygWiqwqFKir1KkWipOllVLRSqlieXkUUlDqFxMlivpc6kgAigWkjAqFUVBUJAoFQqUiqJUEMX0USgk6wWlZZIDzGyeyDUsJG0BPkoyB/d/i4g/bHu9Cvwl8FrgJeDmiHgmfe1O4FaSObh/KyLuy7Ous6W/UuIVA4t5xcDiabdrNFuMN1uM1ls89aNjHBttcPDEOEdG6pwYbzBSb7L/8CgnxxuMNVqM1pvsPzLKSL3JeKPFWKPJWKPFWKPFeKN1gY4uMRUcTAZINkyKabgoDaLscmHyeWp/peXJdlPbZNez27W/Z/v7d96v888oiFNenwjDbF3V9jyxz2nrE/sVQGRfn1ie+lmCZLvM+rHRBrVygUYruGRpjWYriAgmOpazPcwFQb0ZnBhrUG+2uGLVIvYfGWGk3uTK1YspFODwyTrjjRbL+sq8eHyM5f1lFlfLLO0r8cLRMV46Mcaa5f28dHyMsWaLK1b2A1AuFmhFcHSkwcrFlcnP29JamdF6k6OjDS5bVuOlE+NIMLC4ytHRBuWiaLaCxdUSJ8aak19oKqUCx0YblIsFRupNVi+usO/wKK1IjvPoSJ2+cjGp54kxFlVKjDda1JstVi6qcHikzor+CkdH6iyplSgWxNMHTvDyZTWK0indvBHB/iOjvHxpbfLfcqzRpFwonPJFZ7TenPxi1K2LuNkKBIzUm/RXikRwyns0mi2Ufv5mMlXzxBDBxdIlnVtYSCoC9wBvAYaBXZJ2RMQTmc1uBQ5FxCslbQX+CLhZ0iZgK3A1cBnw95JeFREL9k5+pWKBUrFAfwWuu3LVOb1XqxWMN1tpeDSpN4N6+p9tvNmi3gzG6k0a6Ye/GcHx0QbNCFqR7N9oBc1Wsm2zlTxaMVEetFqRbJ8+N1vQilO3PfU5eb2VXY6g1ZpYzpSldciuN5qt07drte3T9v4RTNbh1P06/7xgaj+bfQUxOR7YTaWYXEo23pz6gtRfKdJMP8NFifFmi4KS0IuAeqtFXzlpgUf6OTg+1iAClvWVJ4M3IigURKM58Rmfer9KsUC5KJb2lSc/QxP/h4oSjVaLpbUyAaf8f2n/P1ErF1i1qMpYo8lovUVE0F8tEQEj4w1WL6lSkNh06VLueddP5PjbzrdlsRnYExF7ASRtB24CsmFxE/Af0+XPAx9XEqM3AdsjYgz4R0l70vf7hxzru2AUCqJWKKZnZvm6kDMVaWBkQwam1putgIlAYip0Isjs12E9fe9smHV6bkWyXalYoN5sEQHHx+pTLbDMN9GJxUYrKKffihutFsdGG1y6LGmNHDg2RpD8saw3ky8NtXIx+fKQtkYXVUusWlThpRPjDCyuMtpocujEOAWJevqNeVGlyMGT45STphLHRhv0lYv0V4qcGE/+2AZwYqzBqkUVjo02WFwtcXhknGqpSCuCkXoTIZbUSun+BQ6erFOUqJQKjNab9KV/7I+PNVhSKzFab1FOu0xbEfSVixwZqbNyUYUDx8dotZJv8bVykVYrODHemOxmHW+0qJQKk/9uBYlyURwbbdCKSFt0oq9cpBnBsdEGpYIopyHUaLYopcsRUC6JSrHAyfFm0pqvtyZbsrVykVJBjKZlY40WxWxrW1PduRNlR06OJy3ISpFaKTmTcqSe/C7LxQLHRuu0Aq5Y1Z/75z7PsFgDZCeBGAau67ZNRDQkHQFWpeUPte27pv0HSLoNuA1g3bp1563iZtOZ7FLi4ugeMLsQ8rzdR6f/Se2Nxm7bzGRfIuLeiBiMiMGBgYGzqKKZmc1EnmExDFyeWV8L7Ou2jaQSsAw4OMN9zczsAskzLHYBGyVtkFQhGbDe0bbNDuCWdPntwIORnAKwA9gqqSppA7AR+FaOdTUzs2nkNmaRjkHcDtxHcurstojYLeluYCgidgB/AXwyHcA+SBIopNt9lmQwvAG8ZyGfCWVmNtt8uw8zswVsprf78HwWZmbWk8PCzMx6cliYmVlP82bMQtIB4Afn8BargRfPU3UuJvPxuObjMYGPa66ZL8d1RUT0vFBt3oTFuZI0NJNBnrlmPh7XfDwm8HHNNfP1uLpxN5SZmfXksDAzs54cFlPune0K5GQ+Htd8PCbwcc018/W4OvKYhZmZ9eSWhZmZ9bTgw0LSFklPStoj6Y7Zrs+ZkLRN0guSvpspWynpfklPpc8r0nJJ+lh6nI9JyndarXMg6XJJX5H0PUm7Jb03LZ/TxyapJulbkr6THtd/Sss3SPpmelyfSW+8SXojzc+kx/VNSetns/7TkVSU9IikL6Xr8+GYnpH0uKRHJQ2lZXP6M3guFnRYZKZ+fSuwCXhnOqXrXPHfgS1tZXcAD0TERuCBdB2SY9yYPm4D/vQC1fFsNIB/HxFXAT8JvCf9d5nrxzYG/HxE/DhwDbBF0k+STCf8J+lxHSKZbhgy0w4Df5Jud7F6L/C9zPp8OCaAn4uIazKnyM71z+DZS6aIXJgP4KeA+zLrdwJ3zna9zvAY1gPfzaw/CVyaLl8KPJkufwJ4Z6ftLvYH8Dckc7nPm2MD+oFvk8we+SJQSssnP5Mkd2z+qXS5lG6n2a57h2NZS/KH8+eBL5FMXjanjymt3zPA6rayefMZPNPHgm5Z0Hnq19Omb51jLomI/QDp88vS8jl5rGk3xbXAN5kHx5Z21zwKvADcDzwNHI6IRrpJtu6nTDsMTEw7fLH5CPC7QCtdX8XcPyZIZuf8sqSH0ymcYR58Bs9WnnNwzwUzmr51nphzxyppMfAF4Lcj4qjUdc7rOXNskczLco2k5cAXgas6bZY+X/THJemXgRci4mFJPztR3GHTOXNMGa+PiH2SXgbcL+n702w7l47rrCz0lsV8nL71R5IuBUifX0jL59SxSiqTBMWnIuKv0+J5cWwAEXEY+CrJmMzydFphOLXu3aYdvpi8HrhR0jPAdpKuqI8wt48JgIjYlz6/QBLsm5lHn8EztdDDYiZTv8412alqbyHp758o/5fpWRs/CRyZaE5fbJQ0If4C+F5EfDjz0pw+NkkDaYsCSX3A9SSDwl8hmVYYTj+uTtMOXzQi4s6IWBsR60n+/zwYEe9iDh8TgKRFkpZMLAM3AN9ljn8Gz8lsD5rM9gP4ReD/kfQdf2C263OGdf8rYD9QJ/lmcytJ/+8DwFPp88p0W5Gc+fU08DgwONv1n+a4foakCf8Y8Gj6+MW5fmzAa4BH0uP6LnBXWn4lyRzze4DPAdW0vJau70lfv3K2j6HH8f0s8KX5cExp/b+TPnZP/G2Y65/Bc3n4Cm4zM+tpoXdDmZnZDDgszMysJ4eFmZn15LAwM7OeHBZmZtaTw8LsDEhqpnchnXictzsVS1qvzB2EzS4mC/12H2ZnaiQirpntSphdaG5ZmJ0H6dwHf5TOV/EtSa9My6+Q9EA6x8EDktal5ZdI+mI6t8V3JP10+lZFSX+eznfx5fRKb7NZ57AwOzN9bd1QN2deOxoRm4GPk9wfiXT5LyPiNcCngI+l5R8DvhbJ3BY/QXKVMCTzIdwTEVcDh4Ffzfl4zGbEV3CbnQFJxyNicYfyZ0gmNtqb3gRoJxbiAAAA4ElEQVTx+YhYJelFknkN6mn5/ohYLekAsDYixjLvsR64P5KJdZD0fqAcEf85/yMzm55bFmbnT3RZ7rZNJ2OZ5SYeV7SLhMPC7Py5OfP8D+nyN0juxgrwLuDr6fIDwG/C5IRISy9UJc3Ohr+1mJ2ZvnSmuwl/FxETp89WJX2T5EvYO9Oy3wK2Sfod4ADwr9Ly9wL3SrqVpAXxmyR3EDa7KHnMwuw8SMcsBiPixdmui1ke3A1lZmY9uWVhZmY9uWVhZmY9OSzMzKwnh4WZmfXksDAzs54cFmZm1pPDwszMevr/9+FuGHi0DGAAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "plt.plot(rnn.loss_list)\n", "plt.xlabel(\"Epoch\")\n", @@ -502,28 +1149,36 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "#save in a folder that describes the model\n", - "folder = \"./rnn_model_\" + str(rnn._) + \"_\" + str(len(rnn.ncells)) + \"l_\" + str(rnn.ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n", + "folder = \"./rnn_model_\" + str(rnn._)+ \"_\" + rnn.__ + \"_\" + str(len(rnn.ncells)) + \"l_\" + str(rnn.ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n", "rnn.save(folder)" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 24, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "INFO:tensorflow:Restoring parameters from ./rnn_model_lstm_leaky_relu_5l_[50,40,30,20,10]c/rnn_basic\n" + ] + } + ], "source": [ - "#folder = \"./trained_models/rnn_model_\" + str(rnn._) + \"_\" + str(len(rnn.ncells)) + \"l_\" + str(rnn.ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n", + "#folder = \"./rnn_model_\" + str(rnn._)+ \"_\" + rnn.__ + \"_\" + str(len(rnn.ncells)) + \"l_\" + str(rnn.ncells).replace(\" \",\"\") + \"c/rnn_basic\"\n", "#rnn.load(folder)" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ @@ -532,7 +1187,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 28, "metadata": {}, "outputs": [], "source": [ @@ -543,31 +1198,36 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 41, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "ValueError", + "evalue": "operands could not be broadcast together with shapes (469,21) (24,) (469,21) ", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[1;31m#Here i subtract a prediction (random particle) from the target to get an idea of the predictions\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 2\u001b[1;33m \u001b[0mmin_max_scaler\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mtest_input\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 3\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 5\u001b[0m \u001b[1;31m#print(min_max_scaler_inv(test_pred)-min_max_scaler_inv(test_target))\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m\u001b[0m in \u001b[0;36mmin_max_scaler\u001b[1;34m(arr, min_max_scalor)\u001b[0m\n\u001b[0;32m 13\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 14\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marr\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m3\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 15\u001b[1;33m \u001b[0marr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mreshapor\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmin_max_scalor\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mreshapor_inv\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 16\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 17\u001b[0m \u001b[0marr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmin_max_scalor\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0marr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32mc:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\sklearn\\preprocessing\\data.py\u001b[0m in \u001b[0;36mtransform\u001b[1;34m(self, X)\u001b[0m\n\u001b[0;32m 367\u001b[0m \u001b[0mX\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcheck_array\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcopy\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdtype\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mFLOAT_DTYPES\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 368\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 369\u001b[1;33m \u001b[0mX\u001b[0m \u001b[1;33m*=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mscale_\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 370\u001b[0m \u001b[0mX\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmin_\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 371\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mX\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mValueError\u001b[0m: operands could not be broadcast together with shapes (469,21) (24,) (469,21) " + ] + } + ], "source": [ "#Here i subtract a prediction (random particle) from the target to get an idea of the predictions\n", + "min_max_scaler(test_input)\n", "\n", - "#print(test_pred[5,:,:]-test_target[5,:,:])" + "\n", + "#print(min_max_scaler_inv(test_pred)-min_max_scaler_inv(test_target))" ] }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "4.3867254" - ] - }, - "execution_count": 25, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "#Here I evaluate my model on the test set based on mean_squared_error\n", "\n",