{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", " from ._conv import register_converters as _register_converters\n" ] } ], "source": [ "import pandas as pd\n", "import numpy as np\n", "import matplotlib as mpl\n", "import random\n", "import math\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "from tensorflow.python.framework import ops" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#import data as array\n", "# 8 hits with x,y,z\n", "\n", "testset = pd.read_pickle('matched_8hittracks.pkl')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(46896, 24)\n", "[-20.411108 -9.417887 4.7599998]\n", "[-27.813803 -6.944843 4.7599998]\n", "[-66.736946 22.9032 4.3599997]\n", "[-74.0961 35.649506 4.04 ]\n", "[78.324196 26.359665 -3.7200012]\n", "[69.040436 14.306461 -4.04 ]\n", "[26.880571 -9.817033 -4.84 ]\n", "[ 19.68401 -11.173258 -5. ]\n", "[ -2.2485821 23.380732 -6.04 -6.489999 28.598572\n", " -5.6400003 -21.724771 67.052704 -3.2400002 -22.225971\n", " 79.267685 -2.6000004 82.22602 3.0700002 7.24\n", " 70.390724 0.19000006 7.5599995 28.802656 3.9014618\n", " 6.04 21.421392 6.978845 5.64 ]\n" ] } ], "source": [ "#Check testset with arbitrary particle\n", "\n", "tset = np.array(testset)\n", "tset = tset.astype('float32')\n", "print(tset.shape)\n", "for i in range(8):\n", " print(tset[1,3*i:(3*i+3)])\n", "print(tset[0,:])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "### Reshape original array into the shape (particlenumber, timesteps, input = coordinates)###\n", "\n", "def reshapor(arr_orig):\n", " timesteps = int(arr_orig.shape[1]/3)\n", " number_examples = int(arr_orig.shape[0])\n", " arr = np.zeros((number_examples, timesteps, 3))\n", " \n", " for i in range(number_examples):\n", " for t in range(timesteps):\n", " arr[i,t,0:3] = arr_orig[i,3*t:3*t+3]\n", " \n", " return arr" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "### create the training set and the test set###\n", "\n", "def create_random_sets(dataset, train_to_total_ratio):\n", " num_examples = dataset.shape[0]\n", " train_set_size = np.int(num_examples*train_to_total_ratio)\n", " test_set_size = num_examples - train_set_size\n", " random_indices = random.sample(range(num_examples), train_set_size)\n", " train_set = np.zeros((train_set_size, dataset.shape[1]))\n", " test_set = np.zeros((test_set_size, dataset.shape[1]))\n", " \n", " trc=0\n", " tec=0\n", " \n", " for i in range(num_examples):\n", " if i in random_indices:\n", " train_set[trc,:] += tset[i,:]\n", " trc += 1\n", " else:\n", " test_set[tec,:] += tset[i,:]\n", " tec +=1\n", " \n", " train_set = reshapor(train_set)\n", " test_set = reshapor(test_set)\n", " \n", " return train_set, test_set\n", " " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(469, 8, 3) (46427, 8, 3) (46896, 8, 3)\n", "[[ 18.9492569 -12.94710732 -6.44000053]\n", " [ 21.15423965 -19.98032379 -6.60000038]\n", " [ 27.98397064 -65.37555695 -7.88000011]\n", " [ 27.91954994 -77.96816254 -8.03999996]\n", " [-48.85122299 66.21347809 -21.55999947]\n", " [-38.2697525 59.18515396 -21.56000137]\n", " [ -7.20999908 28.59857178 -21.96000099]\n", " [ -3.13550663 23.01335526 -22.12000084]]\n" ] } ], "source": [ "train_set, test_set = create_random_sets(tset, 0.99)\n", "print(test_set.shape, train_set.shape, reshapor(tset).shape)\n", "print(test_set[0,:,:])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "### create target and input arrays input of shape (num_examples, 8 timesteps, n_inputs)###\n", "\n", "def target_and_input(data_set):\n", " \n", " num_ex = data_set.shape[0]\n", " inputt = np.zeros((num_ex, 4, 12))\n", " target = np.zeros((num_ex, 4, 3))\n", " \n", " \n", " for i in range(4):\n", " target[:,i,:] = data_set[:,4+i,:]\n", " for f in range(4):\n", " inputt[:,i,3*f:3*f+3] = data_set[:,i+f,:]\n", " \n", " \n", " \n", " \n", " return inputt, target\n", " " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ -2.24858212 23.38073158 -6.03999996 -6.48999882 28.59857178\n", " -5.64000034 -21.7247715 67.05270386 -3.24000025 -22.22597122\n", " 79.26768494 -2.60000038]\n", " [ -6.48999882 28.59857178 -5.64000034 -21.7247715 67.05270386\n", " -3.24000025 -22.22597122 79.26768494 -2.60000038 82.22602081\n", " 3.07000017 7.23999977]\n", " [-21.7247715 67.05270386 -3.24000025 -22.22597122 79.26768494\n", " -2.60000038 82.22602081 3.07000017 7.23999977 70.39072418\n", " 0.19000006 7.55999947]\n", " [-22.22597122 79.26768494 -2.60000038 82.22602081 3.07000017\n", " 7.23999977 70.39072418 0.19000006 7.55999947 28.80265617\n", " 3.90146184 6.03999996]]\n", "[[82.22602081 3.07000017 7.23999977]\n", " [70.39072418 0.19000006 7.55999947]\n", " [28.80265617 3.90146184 6.03999996]\n", " [21.42139244 6.97884512 5.63999987]]\n" ] } ], "source": [ "inputt_train, target_train = target_and_input(train_set)\n", "inputt_test, target_test = target_and_input(test_set)\n", "print(inputt_train[0,:,:])\n", "print(target_train[0,:,:])" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "###create random mini_batches###\n", "\n", "\n", "def unison_shuffled_copies(a, b):\n", " assert a.shape[0] == b.shape[0]\n", " p = np.random.permutation(a.shape[0])\n", " return a[p,:,:], b[p,:,:]\n", "\n", "def random_mini_batches(inputt, target, minibatch_size = 500):\n", " \n", " num_examples = inputt.shape[0]\n", " \n", " \n", " #Number of complete batches\n", " \n", " number_of_batches = int(num_examples/minibatch_size)\n", " minibatches = []\n", " \n", " #shuffle particles\n", " _i, _t = unison_shuffled_copies(inputt, target)\n", " print(_t.shape)\n", " \n", " \n", " for i in range(number_of_batches):\n", " \n", " minibatch_train = _i[minibatch_size*i:minibatch_size*(i+1), :, :]\n", " \n", " minibatch_true = _t[minibatch_size*i:minibatch_size*(i+1), :, :]\n", " \n", " minibatches.append((minibatch_train, minibatch_true))\n", " \n", " \n", " minibatches.append((_i[number_of_batches*minibatch_size:, :, :], _t[number_of_batches*minibatch_size:, :, :]))\n", " \n", " \n", " return minibatches\n", " " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(46427, 4, 3)\n", "93\n", "(46427, 7, 3)\n", "[[-14.14815044 18.45176888 20.84000015]\n", " [-17.72400856 22.47253418 20.76000023]\n", " [-25.43392563 66.05883789 17.79999924]\n", " [-18.5602417 80.10436249 16.52000046]\n", " [ 65.26576996 50.039608 10.35999966]\n", " [ 61.54515457 34.18211365 9.63999939]\n", " [ 27.2204895 8.77087116 10.35999966]] [[-17.72400856 22.47253418 20.76000023]\n", " [-25.43392563 66.05883789 17.79999924]\n", " [-18.5602417 80.10436249 16.52000046]\n", " [ 65.26576996 50.039608 10.35999966]\n", " [ 61.54515457 34.18211365 9.63999939]\n", " [ 27.2204895 8.77087116 10.35999966]\n", " [ 20.99278641 8.01359081 10.43999958]]\n" ] } ], "source": [ "minibatches = random_mini_batches(inputt_train, target_train)\n", "\n", "\n", "testinputt, testtarget = minibatches[int(inputt_train.shape[0]/500)]\n", "\n", "print(len(minibatches))\n", "\n", "minibatches = random_mini_batches(train_set[:,:-1,:], train_set[:,1:,:])\n", "train, target = minibatches[0]\n", "test_input, test_target = test_set[:,:-1,:], test_set[:,1:,:]\n", "print(train[0,:,:], target[0,:,:])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "class RNNPlacePrediction():\n", " \n", " \n", " def __init__(self, time_steps, future_steps, ninputs, ncells, num_output, cell_type=\"basic_rnn\"):\n", " \n", " self.nsteps = time_steps\n", " self.future_steps = future_steps\n", " self.ninputs = ninputs\n", " self.ncells = ncells\n", " self.num_output = num_output\n", " \n", " #### The input is of shape (nbatches, time_steps, ninputs)\n", " #### ninputs is the dimentionality (number of features) of the time series\n", " self.X = tf.placeholder(dtype=tf.float32, shape=(None, time_steps, ninputs))\n", " self.Y = tf.placeholder(dtype=tf.float32, shape=(None, time_steps, ninputs))\n", " \n", " \n", " if cell_type==\"basic_rnn\":\n", " self.cell = tf.contrib.rnn.BasicRNNCell(num_units=ncells, activation=tf.nn.relu)\n", " \n", " elif cell_type==\"lstm\":\n", " self.cell = tf.contrib.rnn.BasicLSTMCell(num_units=ncells, activation=tf.nn.relu)\n", " \n", " elif cell_type==\"GRU\":\n", " self.cell = tf.contrib.rnn.GRUCell(num_units=ncells, activation=tf.nn.relu)\n", " \n", " else:\n", " print(\"Wrong rnn cell type: \", cell_type)\n", " assert(False)\n", " \n", " \n", " #### I now define the output\n", " self.RNNCell = tf.contrib.rnn.OutputProjectionWrapper(self.cell, output_size= num_output)\n", " \n", " \n", " \n", " \n", " \n", " self.sess = tf.Session()\n", " \n", " def set_cost_and_functions(self, LR=0.001):\n", " #### I define here the function that unrolls the RNN cell\n", " self.output, self.state = tf.nn.dynamic_rnn(self.RNNCell, self.X, dtype=tf.float32)\n", " #### I define the cost function as the square error\n", " self.cost = tf.reduce_mean(tf.losses.mean_squared_error(self.Y, self.output)) \n", " \n", " #### the rest proceed as usual\n", " self.train = tf.train.AdamOptimizer(LR).minimize(self.cost)\n", " #### Variable initializer\n", " self.init = tf.global_variables_initializer()\n", " self.saver = tf.train.Saver()\n", " self.sess.run(self.init)\n", " \n", " \n", " \n", " def fit(self, minibatches, epochs, print_step):\n", " self.loss_list = []\n", " for iep in range(epochs):\n", " loss = 0\n", " for batch in range(len(minibatches)):\n", " #### Here I train the RNNcell\n", " #### The x is the time serie, the y is shifted by 1 time step\n", " train, target = minibatches[batch]\n", " self.sess.run(self.train, feed_dict={self.X:train, self.Y:target})\n", " \n", " \n", " loss += self.sess.run(self.cost, feed_dict={self.X:train, self.Y:target})\n", " \n", " self.loss_list.append(loss)\n", " \n", " print(loss)\n", " \n", " \n", " #early stopping\n", " if iep > 100 and self.loss_list(iep)-self.loss_list(iep-100) < 0.8:\n", " break\n", " \n", " if iep%print_step==0:\n", " print(\"Epoch number \",iep)\n", " print(\"Cost: \",loss)\n", " \n", " \n", " \n", " def save(self, filename=\"./rnn_model/rnn_basic\"):\n", " self.saver.save(self.sess, filename)\n", " \n", " \n", " def load(self, filename=\"./rnn_model/rnn_basic\"):\n", " self.saver.restore(self.sess, filename)\n", " \n", " \n", " def predict(self, x):\n", " return self.sess.run(self.output, feed_dict={self.X:x})\n", " \n", " " ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "tf.reset_default_graph()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "timesteps = 7\n", "future_steps = 1\n", "ninputs = 3\n", "ncells = 30\n", "num_output = 3" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From c:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use the retry module or similar alternatives.\n" ] } ], "source": [ "rnn = RNNPlacePrediction(time_steps=timesteps, future_steps=future_steps, ninputs=ninputs, \n", " ncells=ncells, num_output=num_output, cell_type=\"lstm\")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "rnn.set_cost_and_functions()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "108164.484375\n", "Epoch number 0\n", "Cost: 108164.484375\n", "57559.07028198242\n", "23417.070388793945\n", "14891.905502319336\n", "12499.229606628418\n", "11122.449005126953\n", "10109.593124389648\n", "9254.187873840332\n", "8533.7841796875\n", "7911.04532623291\n", "7397.553016662598\n", "6970.429878234863\n", "6615.027526855469\n", "6295.746952056885\n", "5980.429901123047\n", "5706.560279846191\n", "5445.698154449463\n", "5201.659439086914\n", "4964.256591796875\n", "4748.1755447387695\n", "4547.136264801025\n", "4359.255859375\n", "4191.035224914551\n", "4038.443386077881\n", "3900.314914703369\n", "3775.984146118164\n", "3661.9904403686523\n", "3555.3094215393066\n", "3457.7933769226074\n", "3364.8507347106934\n", "3273.5197563171387\n", "3184.8794593811035\n", "3099.1916484832764\n", "3017.266839981079\n", "2937.607749938965\n", "2861.451858520508\n", "2788.084276199341\n", "2718.1153564453125\n", "2651.056453704834\n", "2588.111207962036\n", "2528.6895790100098\n", "2472.8057384490967\n", "2419.3036403656006\n", "2368.177661895752\n", "2320.108238220215\n", "2275.061798095703\n", "2232.7341709136963\n", "2192.880346298218\n", "2155.6921787261963\n", "2120.9273471832275\n", "2088.330379486084\n", "2057.153877258301\n", "2027.961145401001\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[1;32m<ipython-input-16-bd64c2feca52>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m()\u001b[0m\n\u001b[1;32m----> 1\u001b[1;33m \u001b[0mrnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mminibatches\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mepochs\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m5000\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mprint_step\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;36m500\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", "\u001b[1;32m<ipython-input-11-30a22fa05c01>\u001b[0m in \u001b[0;36mfit\u001b[1;34m(self, minibatches, epochs, print_step)\u001b[0m\n\u001b[0;32m 62\u001b[0m \u001b[1;31m#### The x is the time serie, the y is shifted by 1 time step\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 63\u001b[0m \u001b[0mtrain\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtarget\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mminibatches\u001b[0m\u001b[1;33m[\u001b[0m\u001b[0mbatch\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 64\u001b[1;33m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msess\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m{\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mX\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mtrain\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mY\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mtarget\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 65\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 66\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36mrun\u001b[1;34m(self, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m 903\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 904\u001b[0m result = self._run(None, fetches, feed_dict, options_ptr,\n\u001b[1;32m--> 905\u001b[1;33m run_metadata_ptr)\n\u001b[0m\u001b[0;32m 906\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 907\u001b[0m \u001b[0mproto_data\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTF_GetBuffer\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mrun_metadata_ptr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run\u001b[1;34m(self, handle, fetches, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m 1138\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mfinal_fetches\u001b[0m \u001b[1;32mor\u001b[0m \u001b[0mfinal_targets\u001b[0m \u001b[1;32mor\u001b[0m \u001b[1;33m(\u001b[0m\u001b[0mhandle\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mfeed_dict_tensor\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1139\u001b[0m results = self._do_run(handle, final_targets, final_fetches,\n\u001b[1;32m-> 1140\u001b[1;33m feed_dict_tensor, options, run_metadata)\n\u001b[0m\u001b[0;32m 1141\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1142\u001b[0m \u001b[0mresults\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_run\u001b[1;34m(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)\u001b[0m\n\u001b[0;32m 1319\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mhandle\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1320\u001b[0m return self._do_call(_run_fn, feeds, fetches, targets, options,\n\u001b[1;32m-> 1321\u001b[1;33m run_metadata)\n\u001b[0m\u001b[0;32m 1322\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1323\u001b[0m \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_do_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0m_prun_fn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeeds\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetches\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_do_call\u001b[1;34m(self, fn, *args)\u001b[0m\n\u001b[0;32m 1325\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_do_call\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1326\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1327\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 1328\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0merrors\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mOpError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1329\u001b[0m \u001b[0mmessage\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mcompat\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mas_text\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0me\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mmessage\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_run_fn\u001b[1;34m(feed_dict, fetch_list, target_list, options, run_metadata)\u001b[0m\n\u001b[0;32m 1310\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_extend_graph\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1311\u001b[0m return self._call_tf_sessionrun(\n\u001b[1;32m-> 1312\u001b[1;33m options, feed_dict, fetch_list, target_list, run_metadata)\n\u001b[0m\u001b[0;32m 1313\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1314\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_prun_fn\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;32mc:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\tensorflow\\python\\client\\session.py\u001b[0m in \u001b[0;36m_call_tf_sessionrun\u001b[1;34m(self, options, feed_dict, fetch_list, target_list, run_metadata)\u001b[0m\n\u001b[0;32m 1418\u001b[0m return tf_session.TF_Run(\n\u001b[0;32m 1419\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_session\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0moptions\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtarget_list\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1420\u001b[1;33m status, run_metadata)\n\u001b[0m\u001b[0;32m 1421\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 1422\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0m_call_tf_sessionprun\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mhandle\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfeed_dict\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfetch_list\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", "\u001b[1;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "rnn.fit(minibatches, epochs=5000, print_step=500)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "rnn.save()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "###rnn.load()###" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "###test_input.shape###" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#test_pred = rnn.predict(test_input)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#print(test_pred[5,:,:]-test_target[5,:,:])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#rnn.sess.run(rnn.cost, feed_dict={rnn.X:test_input, rnn.Y:test_target})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }