Newer
Older
rnn_bachelor_thesis / 1_to_1_gru_30.ipynb
@Sascha Liechti Sascha Liechti on 12 Apr 2018 21 KB first test
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "c:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\h5py\\__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib as mpl\n",
    "import random\n",
    "import math\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import tensorflow as tf\n",
    "from tensorflow.python.framework import ops"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#import data as array\n",
    "# 8 hits with x,y,z\n",
    " \n",
    "testset = pd.read_pickle('matched_8hittracks.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(46896, 24)\n",
      "[-20.411108   -9.417887    4.7599998]\n",
      "[-27.813803   -6.944843    4.7599998]\n",
      "[-66.736946   22.9032      4.3599997]\n",
      "[-74.0961    35.649506   4.04    ]\n",
      "[78.324196  26.359665  -3.7200012]\n",
      "[69.040436 14.306461 -4.04    ]\n",
      "[26.880571 -9.817033 -4.84    ]\n",
      "[ 19.68401  -11.173258  -5.      ]\n",
      "[ -2.2485821   23.380732    -6.04        -6.489999    28.598572\n",
      "  -5.6400003  -21.724771    67.052704    -3.2400002  -22.225971\n",
      "  79.267685    -2.6000004   82.22602      3.0700002    7.24\n",
      "  70.390724     0.19000006   7.5599995   28.802656     3.9014618\n",
      "   6.04        21.421392     6.978845     5.64      ]\n"
     ]
    }
   ],
   "source": [
    "#Check testset with arbitrary particle\n",
    "\n",
    "tset = np.array(testset)\n",
    "tset = tset.astype('float32')\n",
    "print(tset.shape)\n",
    "for i in range(8):\n",
    "    print(tset[1,3*i:(3*i+3)])\n",
    "print(tset[0,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Reshape original array into the shape (particlenumber, timesteps, input = coordinates)###\n",
    "\n",
    "def reshapor(arr_orig):\n",
    "    timesteps = int(arr_orig.shape[1]/3)\n",
    "    number_examples = int(arr_orig.shape[0])\n",
    "    arr = np.zeros((number_examples, timesteps, 3))\n",
    "    \n",
    "    for i in range(number_examples):\n",
    "        for t in range(timesteps):\n",
    "            arr[i,t,0:3] = arr_orig[i,3*t:3*t+3]\n",
    "        \n",
    "    return arr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "### create the training set and the test set###\n",
    "\n",
    "def create_random_sets(dataset, train_to_total_ratio):\n",
    "    num_examples = dataset.shape[0]\n",
    "    train_set_size = np.int(num_examples*train_to_total_ratio)\n",
    "    test_set_size = num_examples - train_set_size\n",
    "    random_indices = random.sample(range(num_examples), train_set_size)\n",
    "    train_set = np.zeros((train_set_size, dataset.shape[1]))\n",
    "    test_set = np.zeros((test_set_size, dataset.shape[1]))\n",
    "    \n",
    "    trc=0\n",
    "    tec=0\n",
    "   \n",
    "    for i in range(num_examples):\n",
    "        if i in random_indices:\n",
    "            train_set[trc,:] += tset[i,:]\n",
    "            trc += 1\n",
    "        else:\n",
    "            test_set[tec,:]  += tset[i,:]\n",
    "            tec +=1\n",
    "    \n",
    "    train_set = reshapor(train_set)\n",
    "    test_set = reshapor(test_set)\n",
    "    \n",
    "    return train_set, test_set\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(469, 8, 3) (46427, 8, 3) (46896, 8, 3)\n",
      "[[-19.7146244   11.09934807  12.5199995 ]\n",
      " [-25.76811028  13.2408371   12.35999966]\n",
      " [-56.47006226  43.07745743  11.88000011]\n",
      " [-61.01356888  55.37170029  11.55999947]\n",
      " [ 53.24511719  62.70946121  18.60000038]\n",
      " [ 49.52627182  50.02124786  18.76000023]\n",
      " [ 21.59110832  19.66292     18.44000053]\n",
      " [ 17.10472298  17.40020561  18.27999878]]\n"
     ]
    }
   ],
   "source": [
    "train_set, test_set = create_random_sets(tset, 0.99)\n",
    "print(test_set.shape, train_set.shape, reshapor(tset).shape)\n",
    "print(test_set[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "### create target and input arrays input of shape (num_examples, 8 timesteps, n_inputs)###\n",
    "\n",
    "def target_and_input(data_set):\n",
    "    \n",
    "    num_ex = data_set.shape[0]\n",
    "    inputt = np.zeros((num_ex, 4, 12))\n",
    "    target = np.zeros((num_ex, 4, 3))\n",
    "    \n",
    "    \n",
    "    for i in range(4):\n",
    "        target[:,i,:] = data_set[:,4+i,:]\n",
    "        for f in range(4):\n",
    "            inputt[:,i,3*f:3*f+3] = data_set[:,i+f,:]\n",
    "    \n",
    "        \n",
    "    \n",
    "    \n",
    "    return inputt, target\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ -2.24858212  23.38073158  -6.03999996  -6.48999882  28.59857178\n",
      "   -5.64000034 -21.7247715   67.05270386  -3.24000025 -22.22597122\n",
      "   79.26768494  -2.60000038]\n",
      " [ -6.48999882  28.59857178  -5.64000034 -21.7247715   67.05270386\n",
      "   -3.24000025 -22.22597122  79.26768494  -2.60000038  82.22602081\n",
      "    3.07000017   7.23999977]\n",
      " [-21.7247715   67.05270386  -3.24000025 -22.22597122  79.26768494\n",
      "   -2.60000038  82.22602081   3.07000017   7.23999977  70.39072418\n",
      "    0.19000006   7.55999947]\n",
      " [-22.22597122  79.26768494  -2.60000038  82.22602081   3.07000017\n",
      "    7.23999977  70.39072418   0.19000006   7.55999947  28.80265617\n",
      "    3.90146184   6.03999996]]\n",
      "[[82.22602081  3.07000017  7.23999977]\n",
      " [70.39072418  0.19000006  7.55999947]\n",
      " [28.80265617  3.90146184  6.03999996]\n",
      " [21.42139244  6.97884512  5.63999987]]\n"
     ]
    }
   ],
   "source": [
    "inputt_train, target_train = target_and_input(train_set)\n",
    "inputt_test, target_test = target_and_input(test_set)\n",
    "print(inputt_train[0,:,:])\n",
    "print(target_train[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "###create random mini_batches###\n",
    "\n",
    "\n",
    "def unison_shuffled_copies(a, b):\n",
    "    assert a.shape[0] == b.shape[0]\n",
    "    p = np.random.permutation(a.shape[0])\n",
    "    return a[p,:,:], b[p,:,:]\n",
    "\n",
    "def random_mini_batches(inputt, target, minibatch_size = 500):\n",
    "    \n",
    "    num_examples = inputt.shape[0]\n",
    "    \n",
    "    \n",
    "    #Number of complete batches\n",
    "    \n",
    "    number_of_batches = int(num_examples/minibatch_size)\n",
    "    minibatches = []\n",
    "   \n",
    "    #shuffle particles\n",
    "    _i, _t = unison_shuffled_copies(inputt, target)\n",
    "    print(_t.shape)\n",
    "        \n",
    "    \n",
    "    for i in range(number_of_batches):\n",
    "        \n",
    "        minibatch_train = _i[minibatch_size*i:minibatch_size*(i+1), :, :]\n",
    "        \n",
    "        minibatch_true = _t[minibatch_size*i:minibatch_size*(i+1), :, :]\n",
    "        \n",
    "        minibatches.append((minibatch_train, minibatch_true))\n",
    "        \n",
    "        \n",
    "    minibatches.append((_i[number_of_batches*minibatch_size:, :, :], _t[number_of_batches*minibatch_size:, :, :]))\n",
    "    \n",
    "    \n",
    "    return minibatches\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(46427, 4, 3)\n",
      "93\n",
      "(46427, 7, 3)\n",
      "[[-20.34987831  -9.56570816 -19.71999931]\n",
      " [-25.26750183 -14.78154945 -19.07999992]\n",
      " [-59.30515289 -38.06190491 -16.68000031]\n",
      " [-70.66841888 -42.76715851 -15.88000011]\n",
      " [  9.37641907  82.20050812  22.52000046]\n",
      " [  9.68933487  70.27758789  23.39999962]\n",
      " [  0.87000084  28.59857178  26.44000053]] [[-25.26750183 -14.78154945 -19.07999992]\n",
      " [-59.30515289 -38.06190491 -16.68000031]\n",
      " [-70.66841888 -42.76715851 -15.88000011]\n",
      " [  9.37641907  82.20050812  22.52000046]\n",
      " [  9.68933487  70.27758789  23.39999962]\n",
      " [  0.87000084  28.59857178  26.44000053]\n",
      " [ -0.80022049  24.64358711  26.60000038]]\n"
     ]
    }
   ],
   "source": [
    "minibatches = random_mini_batches(inputt_train, target_train)\n",
    "\n",
    "\n",
    "testinputt, testtarget = minibatches[int(inputt_train.shape[0]/500)]\n",
    "\n",
    "print(len(minibatches))\n",
    "\n",
    "minibatches = random_mini_batches(train_set[:,:-1,:], train_set[:,1:,:])\n",
    "train, target = minibatches[0]\n",
    "test_input, test_target = test_set[:,:-1,:], test_set[:,1:,:]\n",
    "print(train[0,:,:], target[0,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNPlacePrediction():\n",
    "    \n",
    "    \n",
    "    def __init__(self, time_steps, future_steps, ninputs, ncells, num_output, cell_type=\"basic_rnn\"):\n",
    "        \n",
    "        self.nsteps = time_steps\n",
    "        self.future_steps = future_steps\n",
    "        self.ninputs = ninputs\n",
    "        self.ncells = ncells\n",
    "        self.num_output = num_output\n",
    "        \n",
    "        #### The input is of shape (nbatches, time_steps, ninputs)\n",
    "        #### ninputs is the dimentionality (number of features) of the time series\n",
    "        self.X = tf.placeholder(dtype=tf.float32, shape=(None, time_steps, ninputs))\n",
    "        self.Y = tf.placeholder(dtype=tf.float32, shape=(None, time_steps, ninputs))\n",
    "        \n",
    "        \n",
    "        if cell_type==\"basic_rnn\":\n",
    "            self.cell = tf.contrib.rnn.BasicRNNCell(num_units=ncells, activation=tf.nn.relu)\n",
    "            \n",
    "        elif cell_type==\"lstm\":\n",
    "            self.cell = tf.contrib.rnn.BasicLSTMCell(num_units=ncells, activation=tf.nn.relu)\n",
    "                    \n",
    "        elif cell_type==\"GRU\":\n",
    "            self.cell = tf.contrib.rnn.GRUCell(num_units=ncells, activation=tf.nn.relu)\n",
    "            \n",
    "        else:\n",
    "            print(\"Wrong rnn cell type:   \", cell_type)\n",
    "            assert(False)\n",
    "            \n",
    "        \n",
    "        #### I now define the output\n",
    "        self.RNNCell = tf.contrib.rnn.OutputProjectionWrapper(self.cell, output_size= num_output)\n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        \n",
    "        self.sess = tf.Session()\n",
    "        \n",
    "    def set_cost_and_functions(self, LR=0.001):\n",
    "        #### I define here the function that unrolls the RNN cell\n",
    "        self.output, self.state = tf.nn.dynamic_rnn(self.RNNCell, self.X, dtype=tf.float32)\n",
    "        #### I define the cost function as the square error\n",
    "        self.cost = tf.reduce_mean(tf.losses.mean_squared_error(self.Y, self.output))   \n",
    "        \n",
    "        #### the rest proceed as usual\n",
    "        self.train = tf.train.AdamOptimizer(LR).minimize(self.cost)\n",
    "        #### Variable initializer\n",
    "        self.init = tf.global_variables_initializer()\n",
    "        self.saver = tf.train.Saver()\n",
    "        self.sess.run(self.init)\n",
    "        \n",
    "    \n",
    "        \n",
    "    def fit(self, minibatches, epochs, print_step):\n",
    "        \n",
    "        self.loss_list = []\n",
    "        \n",
    "        for iep in range(epochs):\n",
    "            loss = 0\n",
    "            for batch in range(len(minibatches)):\n",
    "            #### Here I train the RNNcell\n",
    "            #### The x is the time serie, the y is shifted by 1 time step\n",
    "                train, target = minibatches[batch]\n",
    "                self.sess.run(self.train, feed_dict={self.X:train, self.Y:target})\n",
    "                \n",
    "            \n",
    "                loss += self.sess.run(self.cost, feed_dict={self.X:train, self.Y:target})\n",
    "                    \n",
    "            if iep%print_step==0:\n",
    "                print(\"Epoch number \",iep)\n",
    "                print(\"Cost: \",loss)\n",
    "                \n",
    "            self.loss_list.append(loss)\n",
    "            \n",
    "            #print(loss)\n",
    "                \n",
    "                \n",
    "    def save(self, filename=\"./rnn_model_GRU_30/rnn_basic\"):\n",
    "        self.saver.save(self.sess, filename)\n",
    "            \n",
    "            \n",
    "    def load(self, filename=\"./rnn_model_GRU_30/rnn_basic\"):\n",
    "        self.saver.restore(self.sess, filename)\n",
    "        \n",
    "        \n",
    "    def predict(self, x):\n",
    "        return self.sess.run(self.output, feed_dict={self.X:x})\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "tf.reset_default_graph()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "timesteps = 7\n",
    "future_steps = 1\n",
    "ninputs = 3\n",
    "ncells = 30\n",
    "num_output = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From c:\\users\\sa_li\\anaconda3\\envs\\rnn-tf-ker\\lib\\site-packages\\tensorflow\\contrib\\learn\\python\\learn\\datasets\\base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use the retry module or similar alternatives.\n"
     ]
    }
   ],
   "source": [
    "rnn = RNNPlacePrediction(time_steps=timesteps, future_steps=future_steps, ninputs=ninputs, \n",
    "                        ncells=ncells, num_output=num_output, cell_type=\"GRU\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "rnn.set_cost_and_functions()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch number  0\n",
      "Cost:  104797.03875732422\n",
      "Epoch number  100\n",
      "Cost:  1482.824694633484\n",
      "Epoch number  200\n",
      "Cost:  1166.6929140090942\n",
      "Epoch number  300\n",
      "Cost:  1034.4315690994263\n",
      "Epoch number  400\n",
      "Cost:  954.4337291717529\n",
      "Epoch number  500\n",
      "Cost:  904.032078742981\n",
      "Epoch number  600\n",
      "Cost:  860.2765083312988\n",
      "Epoch number  700\n",
      "Cost:  823.9659514427185\n",
      "Epoch number  800\n",
      "Cost:  797.8692417144775\n",
      "Epoch number  900\n",
      "Cost:  774.48921251297\n",
      "Epoch number  1000\n",
      "Cost:  752.8398714065552\n",
      "Epoch number  1100\n",
      "Cost:  734.5328130722046\n",
      "Epoch number  1200\n",
      "Cost:  720.6981329917908\n",
      "Epoch number  1300\n",
      "Cost:  709.523627281189\n",
      "Epoch number  1400\n",
      "Cost:  701.528938293457\n",
      "Epoch number  1500\n",
      "Cost:  695.8975224494934\n",
      "Epoch number  1600\n",
      "Cost:  689.2830562591553\n",
      "Epoch number  1700\n",
      "Cost:  684.1840767860413\n",
      "Epoch number  1800\n",
      "Cost:  679.4332590103149\n",
      "Epoch number  1900\n",
      "Cost:  674.5709180831909\n",
      "Epoch number  2000\n",
      "Cost:  670.97270154953\n",
      "Epoch number  2100\n",
      "Cost:  667.7384643554688\n",
      "Epoch number  2200\n",
      "Cost:  665.9748268127441\n",
      "Epoch number  2300\n",
      "Cost:  663.046612739563\n",
      "Epoch number  2400\n",
      "Cost:  660.604160785675\n",
      "Epoch number  2500\n",
      "Cost:  659.0691528320312\n",
      "Epoch number  2600\n",
      "Cost:  658.2915244102478\n",
      "Epoch number  2700\n",
      "Cost:  654.9598126411438\n",
      "Epoch number  2800\n",
      "Cost:  652.0928063392639\n",
      "Epoch number  2900\n",
      "Cost:  650.0017580986023\n",
      "Epoch number  3000\n",
      "Cost:  650.9711427688599\n",
      "Epoch number  3100\n",
      "Cost:  646.0216059684753\n",
      "Epoch number  3200\n",
      "Cost:  644.435601234436\n",
      "Epoch number  3300\n",
      "Cost:  645.7997555732727\n",
      "Epoch number  3400\n",
      "Cost:  641.1044583320618\n",
      "Epoch number  3500\n",
      "Cost:  639.9977240562439\n",
      "Epoch number  3600\n",
      "Cost:  638.4698357582092\n",
      "Epoch number  3700\n",
      "Cost:  637.1783366203308\n",
      "Epoch number  3800\n",
      "Cost:  635.7812042236328\n",
      "Epoch number  3900\n",
      "Cost:  634.1737952232361\n",
      "Epoch number  4000\n",
      "Cost:  633.4426860809326\n",
      "Epoch number  4100\n",
      "Cost:  632.3123679161072\n",
      "Epoch number  4200\n",
      "Cost:  631.4027585983276\n",
      "Epoch number  4300\n",
      "Cost:  630.4044184684753\n",
      "Epoch number  4400\n",
      "Cost:  629.121660232544\n",
      "Epoch number  4500\n",
      "Cost:  628.0477848052979\n",
      "Epoch number  4600\n",
      "Cost:  627.2914171218872\n",
      "Epoch number  4700\n",
      "Cost:  626.7988724708557\n",
      "Epoch number  4800\n",
      "Cost:  626.3834252357483\n",
      "Epoch number  4900\n",
      "Cost:  625.7865376472473\n"
     ]
    }
   ],
   "source": [
    "rnn.fit(minibatches, epochs=5000, print_step=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "rnn.save()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "###rnn.load()###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "###test_input.shape###"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_pred = rnn.predict(test_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ -0.5854187    0.02950048  -0.74694061]\n",
      " [  2.12476349   5.43320465  -0.11310959]\n",
      " [  0.30315399  -0.89936066  -2.11120224]\n",
      " [-15.19197083  12.00935364   0.21447372]\n",
      " [  0.62501144  -0.96383286   1.62054443]\n",
      " [  0.52139854   0.09034729  -3.62054443]\n",
      " [ -0.28830719  -0.77872753   0.46095276]]\n"
     ]
    }
   ],
   "source": [
    "print(test_pred[5,:,:]-test_target[5,:,:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "6.648698"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "rnn.sess.run(rnn.cost, feed_dict={rnn.X:test_input, rnn.Y:test_target})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}