{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n", " from ._conv import register_converters as _register_converters\n" ] } ], "source": [ "import pandas as pd\n", "import numpy as np\n", "import matplotlib as mpl\n", "import random\n", "import math\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "from tensorflow.python.framework import ops" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "#import data as array\n", "# 8 hits with x,y,z\n", "\n", "testset = pd.read_pickle('matched_8hittracks.pkl')" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(46896, 24)\n", "[-20.411108 -9.417887 4.7599998]\n", "[-27.813803 -6.944843 4.7599998]\n", "[-66.736946 22.9032 4.3599997]\n", "[-74.0961 35.649506 4.04 ]\n", "[78.324196 26.359665 -3.7200012]\n", "[69.040436 14.306461 -4.04 ]\n", "[26.880571 -9.817033 -4.84 ]\n", "[ 19.68401 -11.173258 -5. ]\n", "[ -2.2485821 23.380732 -6.04 -6.489999 28.598572\n", " -5.6400003 -21.724771 67.052704 -3.2400002 -22.225971\n", " 79.267685 -2.6000004 82.22602 3.0700002 7.24\n", " 70.390724 0.19000006 7.5599995 28.802656 3.9014618\n", " 6.04 21.421392 6.978845 5.64 ]\n" ] } ], "source": [ "#Check testset with arbitrary particle\n", "\n", "tset = np.array(testset)\n", "tset = tset.astype('float32')\n", "print(tset.shape)\n", "for i in range(8):\n", " print(tset[1,3*i:(3*i+3)])\n", "print(tset[0,:])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "### Reshape original array into the shape (particlenumber, timesteps, input = coordinates)###\n", "\n", "def reshapor(arr_orig):\n", " timesteps = int(arr_orig.shape[1]/3)\n", " number_examples = int(arr_orig.shape[0])\n", " arr = np.zeros((number_examples, timesteps, 3))\n", " \n", " for i in range(number_examples):\n", " for t in range(timesteps):\n", " arr[i,t,0:3] = arr_orig[i,3*t:3*t+3]\n", " \n", " return arr" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "### create the training set and the test set###\n", "\n", "def create_random_sets(dataset, train_to_total_ratio):\n", " num_examples = dataset.shape[0]\n", " train_set_size = np.int(num_examples*train_to_total_ratio)\n", " test_set_size = num_examples - train_set_size\n", " random_indices = random.sample(range(num_examples), train_set_size)\n", " train_set = np.zeros((train_set_size, dataset.shape[1]))\n", " test_set = np.zeros((test_set_size, dataset.shape[1]))\n", " \n", " trc=0\n", " tec=0\n", " \n", " for i in range(num_examples):\n", " if i in random_indices:\n", " train_set[trc,:] += tset[i,:]\n", " trc += 1\n", " else:\n", " test_set[tec,:] += tset[i,:]\n", " tec +=1\n", " \n", " train_set = reshapor(train_set)\n", " test_set = reshapor(test_set)\n", " \n", " return train_set, test_set\n", " " ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(469, 8, 3) (46427, 8, 3) (46896, 8, 3)\n", "[[-19.7146244 11.09934807 12.5199995 ]\n", " [-25.76811028 13.2408371 12.35999966]\n", " [-56.47006226 43.07745743 11.88000011]\n", " [-61.01356888 55.37170029 11.55999947]\n", " [ 53.24511719 62.70946121 18.60000038]\n", " [ 49.52627182 50.02124786 18.76000023]\n", " [ 21.59110832 19.66292 18.44000053]\n", " [ 17.10472298 17.40020561 18.27999878]]\n" ] } ], "source": [ "train_set, test_set = create_random_sets(tset, 0.99)\n", "print(test_set.shape, train_set.shape, reshapor(tset).shape)\n", "print(test_set[0,:,:])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "### create target and input arrays input of shape (num_examples, 8 timesteps, n_inputs)###\n", "\n", "def target_and_input(data_set):\n", " \n", " num_ex = data_set.shape[0]\n", " inputt = np.zeros((num_ex, 4, 12))\n", " target = np.zeros((num_ex, 4, 3))\n", " \n", " \n", " for i in range(4):\n", " target[:,i,:] = data_set[:,4+i,:]\n", " for f in range(4):\n", " inputt[:,i,3*f:3*f+3] = data_set[:,i+f,:]\n", " \n", " \n", " \n", " \n", " return inputt, target\n", " " ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ -2.24858212 23.38073158 -6.03999996 -6.48999882 28.59857178\n", " -5.64000034 -21.7247715 67.05270386 -3.24000025 -22.22597122\n", " 79.26768494 -2.60000038]\n", " [ -6.48999882 28.59857178 -5.64000034 -21.7247715 67.05270386\n", " -3.24000025 -22.22597122 79.26768494 -2.60000038 82.22602081\n", " 3.07000017 7.23999977]\n", " [-21.7247715 67.05270386 -3.24000025 -22.22597122 79.26768494\n", " -2.60000038 82.22602081 3.07000017 7.23999977 70.39072418\n", " 0.19000006 7.55999947]\n", " [-22.22597122 79.26768494 -2.60000038 82.22602081 3.07000017\n", " 7.23999977 70.39072418 0.19000006 7.55999947 28.80265617\n", " 3.90146184 6.03999996]]\n", "[[82.22602081 3.07000017 7.23999977]\n", " [70.39072418 0.19000006 7.55999947]\n", " [28.80265617 3.90146184 6.03999996]\n", " [21.42139244 6.97884512 5.63999987]]\n" ] } ], "source": [ "inputt_train, target_train = target_and_input(train_set)\n", "inputt_test, target_test = target_and_input(test_set)\n", "print(inputt_train[0,:,:])\n", "print(target_train[0,:,:])" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "###create random mini_batches###\n", "\n", "\n", "def unison_shuffled_copies(a, b):\n", " assert a.shape[0] == b.shape[0]\n", " p = np.random.permutation(a.shape[0])\n", " return a[p,:,:], b[p,:,:]\n", "\n", "def random_mini_batches(inputt, target, minibatch_size = 500):\n", " \n", " num_examples = inputt.shape[0]\n", " \n", " \n", " #Number of complete batches\n", " \n", " number_of_batches = int(num_examples/minibatch_size)\n", " minibatches = []\n", " \n", " #shuffle particles\n", " _i, _t = unison_shuffled_copies(inputt, target)\n", " print(_t.shape)\n", " \n", " \n", " for i in range(number_of_batches):\n", " \n", " minibatch_train = _i[minibatch_size*i:minibatch_size*(i+1), :, :]\n", " \n", " minibatch_true = _t[minibatch_size*i:minibatch_size*(i+1), :, :]\n", " \n", " minibatches.append((minibatch_train, minibatch_true))\n", " \n", " \n", " minibatches.append((_i[number_of_batches*minibatch_size:, :, :], _t[number_of_batches*minibatch_size:, :, :]))\n", " \n", " \n", " return minibatches\n", " " ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(46427, 4, 3)\n", "93\n", "(46427, 7, 3)\n", "[[-20.34987831 -9.56570816 -19.71999931]\n", " [-25.26750183 -14.78154945 -19.07999992]\n", " [-59.30515289 -38.06190491 -16.68000031]\n", " [-70.66841888 -42.76715851 -15.88000011]\n", " [ 9.37641907 82.20050812 22.52000046]\n", " [ 9.68933487 70.27758789 23.39999962]\n", " [ 0.87000084 28.59857178 26.44000053]] [[-25.26750183 -14.78154945 -19.07999992]\n", " [-59.30515289 -38.06190491 -16.68000031]\n", " [-70.66841888 -42.76715851 -15.88000011]\n", " [ 9.37641907 82.20050812 22.52000046]\n", " [ 9.68933487 70.27758789 23.39999962]\n", " [ 0.87000084 28.59857178 26.44000053]\n", " [ -0.80022049 24.64358711 26.60000038]]\n" ] } ], "source": [ "minibatches = random_mini_batches(inputt_train, target_train)\n", "\n", "\n", "testinputt, testtarget = minibatches[int(inputt_train.shape[0]/500)]\n", "\n", "print(len(minibatches))\n", "\n", "minibatches = random_mini_batches(train_set[:,:-1,:], train_set[:,1:,:])\n", "train, target = minibatches[0]\n", "test_input, test_target = test_set[:,:-1,:], test_set[:,1:,:]\n", "print(train[0,:,:], target[0,:,:])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "class RNNPlacePrediction():\n", " \n", " \n", " def __init__(self, time_steps, future_steps, ninputs, ncells, num_output, cell_type=\"basic_rnn\"):\n", " \n", " self.nsteps = time_steps\n", " self.future_steps = future_steps\n", " self.ninputs = ninputs\n", " self.ncells = ncells\n", " self.num_output = num_output\n", " \n", " #### The input is of shape (nbatches, time_steps, ninputs)\n", " #### ninputs is the dimentionality (number of features) of the time series\n", " self.X = tf.placeholder(dtype=tf.float32, shape=(None, time_steps, ninputs))\n", " self.Y = tf.placeholder(dtype=tf.float32, shape=(None, time_steps, ninputs))\n", " \n", " \n", " if cell_type==\"basic_rnn\":\n", " self.cell = tf.contrib.rnn.BasicRNNCell(num_units=ncells, activation=tf.nn.relu)\n", " \n", " elif cell_type==\"lstm\":\n", " self.cell = tf.contrib.rnn.BasicLSTMCell(num_units=ncells, activation=tf.nn.relu)\n", " \n", " elif cell_type==\"GRU\":\n", " self.cell = tf.contrib.rnn.GRUCell(num_units=ncells, activation=tf.nn.relu)\n", " \n", " else:\n", " print(\"Wrong rnn cell type: \", cell_type)\n", " assert(False)\n", " \n", " \n", " #### I now define the output\n", " self.RNNCell = tf.contrib.rnn.OutputProjectionWrapper(self.cell, output_size= num_output)\n", " \n", " \n", " \n", " \n", " \n", " self.sess = tf.Session()\n", " \n", " def set_cost_and_functions(self, LR=0.001):\n", " #### I define here the function that unrolls the RNN cell\n", " self.output, self.state = tf.nn.dynamic_rnn(self.RNNCell, self.X, dtype=tf.float32)\n", " #### I define the cost function as the square error\n", " self.cost = tf.reduce_mean(tf.losses.mean_squared_error(self.Y, self.output)) \n", " \n", " #### the rest proceed as usual\n", " self.train = tf.train.AdamOptimizer(LR).minimize(self.cost)\n", " #### Variable initializer\n", " self.init = tf.global_variables_initializer()\n", " self.saver = tf.train.Saver()\n", " self.sess.run(self.init)\n", " \n", " \n", " \n", " def fit(self, minibatches, epochs, print_step):\n", " \n", " self.loss_list = []\n", " \n", " for iep in range(epochs):\n", " loss = 0\n", " for batch in range(len(minibatches)):\n", " #### Here I train the RNNcell\n", " #### The x is the time serie, the y is shifted by 1 time step\n", " train, target = minibatches[batch]\n", " self.sess.run(self.train, feed_dict={self.X:train, self.Y:target})\n", " \n", " \n", " loss += self.sess.run(self.cost, feed_dict={self.X:train, self.Y:target})\n", " \n", " if iep%print_step==0:\n", " print(\"Epoch number \",iep)\n", " print(\"Cost: \",loss)\n", " \n", " self.loss_list.append(loss)\n", " \n", " #print(loss)\n", " \n", " \n", " def save(self, filename=\"./rnn_model_GRU_30/rnn_basic\"):\n", " self.saver.save(self.sess, filename)\n", " \n", " \n", " def load(self, filename=\"./rnn_model_GRU_30/rnn_basic\"):\n", " self.saver.restore(self.sess, filename)\n", " \n", " \n", " def predict(self, x):\n", " return self.sess.run(self.output, feed_dict={self.X:x})\n", " \n", " " ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "tf.reset_default_graph()" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "timesteps = 7\n", "future_steps = 1\n", "ninputs = 3\n", "ncells = 30\n", "num_output = 3" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From c:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use the retry module or similar alternatives.\n" ] } ], "source": [ "rnn = RNNPlacePrediction(time_steps=timesteps, future_steps=future_steps, ninputs=ninputs, \n", " ncells=ncells, num_output=num_output, cell_type=\"GRU\")" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "rnn.set_cost_and_functions()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Epoch number 0\n", "Cost: 104797.03875732422\n", "Epoch number 100\n", "Cost: 1482.824694633484\n", "Epoch number 200\n", "Cost: 1166.6929140090942\n", "Epoch number 300\n", "Cost: 1034.4315690994263\n", "Epoch number 400\n", "Cost: 954.4337291717529\n", "Epoch number 500\n", "Cost: 904.032078742981\n", "Epoch number 600\n", "Cost: 860.2765083312988\n", "Epoch number 700\n", "Cost: 823.9659514427185\n", "Epoch number 800\n", "Cost: 797.8692417144775\n", "Epoch number 900\n", "Cost: 774.48921251297\n", "Epoch number 1000\n", "Cost: 752.8398714065552\n", "Epoch number 1100\n", "Cost: 734.5328130722046\n", "Epoch number 1200\n", "Cost: 720.6981329917908\n", "Epoch number 1300\n", "Cost: 709.523627281189\n", "Epoch number 1400\n", "Cost: 701.528938293457\n", "Epoch number 1500\n", "Cost: 695.8975224494934\n", "Epoch number 1600\n", "Cost: 689.2830562591553\n", "Epoch number 1700\n", "Cost: 684.1840767860413\n", "Epoch number 1800\n", "Cost: 679.4332590103149\n", "Epoch number 1900\n", "Cost: 674.5709180831909\n", "Epoch number 2000\n", "Cost: 670.97270154953\n", "Epoch number 2100\n", "Cost: 667.7384643554688\n", "Epoch number 2200\n", "Cost: 665.9748268127441\n", "Epoch number 2300\n", "Cost: 663.046612739563\n", "Epoch number 2400\n", "Cost: 660.604160785675\n", "Epoch number 2500\n", "Cost: 659.0691528320312\n", "Epoch number 2600\n", "Cost: 658.2915244102478\n", "Epoch number 2700\n", "Cost: 654.9598126411438\n", "Epoch number 2800\n", "Cost: 652.0928063392639\n", "Epoch number 2900\n", "Cost: 650.0017580986023\n", "Epoch number 3000\n", "Cost: 650.9711427688599\n", "Epoch number 3100\n", "Cost: 646.0216059684753\n", "Epoch number 3200\n", "Cost: 644.435601234436\n", "Epoch number 3300\n", "Cost: 645.7997555732727\n", "Epoch number 3400\n", "Cost: 641.1044583320618\n", "Epoch number 3500\n", "Cost: 639.9977240562439\n", "Epoch number 3600\n", "Cost: 638.4698357582092\n", "Epoch number 3700\n", "Cost: 637.1783366203308\n", "Epoch number 3800\n", "Cost: 635.7812042236328\n", "Epoch number 3900\n", "Cost: 634.1737952232361\n", "Epoch number 4000\n", "Cost: 633.4426860809326\n", "Epoch number 4100\n", "Cost: 632.3123679161072\n", "Epoch number 4200\n", "Cost: 631.4027585983276\n", "Epoch number 4300\n", "Cost: 630.4044184684753\n", "Epoch number 4400\n", "Cost: 629.121660232544\n", "Epoch number 4500\n", "Cost: 628.0477848052979\n", "Epoch number 4600\n", "Cost: 627.2914171218872\n", "Epoch number 4700\n", "Cost: 626.7988724708557\n", "Epoch number 4800\n", "Cost: 626.3834252357483\n", "Epoch number 4900\n", "Cost: 625.7865376472473\n" ] } ], "source": [ "rnn.fit(minibatches, epochs=5000, print_step=100)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "rnn.save()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "###rnn.load()###" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "###test_input.shape###" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "test_pred = rnn.predict(test_input)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[ -0.5854187 0.02950048 -0.74694061]\n", " [ 2.12476349 5.43320465 -0.11310959]\n", " [ 0.30315399 -0.89936066 -2.11120224]\n", " [-15.19197083 12.00935364 0.21447372]\n", " [ 0.62501144 -0.96383286 1.62054443]\n", " [ 0.52139854 0.09034729 -3.62054443]\n", " [ -0.28830719 -0.77872753 0.46095276]]\n" ] } ], "source": [ "print(test_pred[5,:,:]-test_target[5,:,:])" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6.648698" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "rnn.sess.run(rnn.cost, feed_dict={rnn.X:test_input, rnn.Y:test_target})" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": { "scrolled": true }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }