{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import itertools\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import math \n",
    "\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "import utils\n",
    "import wiki_utils\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus = wiki_utils.Texts('./wikitext/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "sequence_length = 30\n",
    "grad_clip = 0.1\n",
    "lr = 4.\n",
    "best_val_loss = None\n",
    "log_interval = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_batch_size = 128\n",
    "train_loader = wiki_utils.TextLoader(corpus.train, batch_size=batch_size)\n",
    "val_loader = wiki_utils.TextLoader(corpus.valid, batch_size=eval_batch_size)\n",
    "test_loader = wiki_utils.TextLoader(corpus.test, batch_size=eval_batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNModel(nn.Module):\n",
    "\n",
    "    def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5):\n",
    "        super(RNNModel, self).__init__()\n",
    "        self.drop = nn.Dropout(dropout)\n",
    "        self.encoder = nn.Embedding(ntoken, ninp)\n",
    "        if rnn_type == 'LSTM':\n",
    "            self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)\n",
    "        elif rnn_type == 'GRU':\n",
    "            self.rnn = nn.GRU(ninp, nhid, nlayers, dropout=dropout)\n",
    "        self.decoder = nn.Linear(nhid, ntoken)\n",
    "\n",
    "        self.init_weights()\n",
    "\n",
    "        self.rnn_type = rnn_type\n",
    "        self.nhid = nhid\n",
    "        self.nlayers = nlayers\n",
    "\n",
    "    def init_weights(self):\n",
    "        initrange = 0.1\n",
    "        self.encoder.weight.data.uniform_(-initrange, initrange)\n",
    "        self.decoder.bias.data.fill_(0)\n",
    "        self.decoder.weight.data.uniform_(-initrange, initrange)\n",
    "\n",
    "    def forward(self, x, hidden=None):\n",
    "        emb = self.drop(self.encoder(x))\n",
    "        output, hidden = self.rnn(emb, hidden)\n",
    "        output = self.drop(output)\n",
    "        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))\n",
    "        return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden\n",
    "\n",
    "    def init_hidden(self, bsz):\n",
    "        weight = next(self.parameters()).data\n",
    "        if self.rnn_type == 'LSTM':\n",
    "            return (weight.new(self.nlayers, bsz, self.nhid).zero_(),\n",
    "                    weight.new(self.nlayers, bsz, self.nhid).zero_())\n",
    "        else:\n",
    "            return weight.new(self.nlayers, bsz, self.nhid).zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(data_loader):\n",
    "    model.eval()\n",
    "    total_loss = 0\n",
    "    ntokens = len(corpus.dictionary)\n",
    "    hidden = model.init_hidden(eval_batch_size)\n",
    "    for i, (data, targets) in enumerate(data_loader):\n",
    "        output, hidden = model(data)\n",
    "        output_flat = output.view(-1, ntokens)\n",
    "        total_loss += len(data) * criterion(output_flat, targets).item()\n",
    "    return total_loss / len(data_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    ntokens = len(corpus.dictionary)\n",
    "    for batch, (data, targets) in enumerate(train_loader):\n",
    "        model.zero_grad()\n",
    "        output, hidden = model(data)\n",
    "        loss = criterion(output.view(-1, ntokens), targets)\n",
    "        loss.backward()\n",
    "\n",
    "        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)\n",
    "        for p in model.parameters():\n",
    "            p.data.add_(-lr, p.grad.data)\n",
    "\n",
    "        total_loss += loss.item()\n",
    "\n",
    "        if batch % log_interval == 0 and batch > 0:\n",
    "            cur_loss = total_loss / log_interval\n",
    "            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | loss {:5.2f} | ppl {:8.2f}'.format(\n",
    "                epoch, batch, len(train_loader) // sequence_length, lr, cur_loss, math.exp(cur_loss)))\n",
    "            total_loss = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "ntokens = len(corpus.dictionary)\n",
    "model = RNNModel('LSTM', ntokens, 128, 128, 2, 0.3)\n",
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate(n=50, temp=1.):\n",
    "    model.eval()\n",
    "    x = torch.rand(1, 1).mul(ntokens).long()\n",
    "    hidden = None\n",
    "    out = []\n",
    "    for i in range(n):\n",
    "        output, hidden = model(x, hidden)\n",
    "        s_weights = output.squeeze().data.div(temp).exp()\n",
    "        s_idx = torch.multinomial(s_weights, 1)[0]\n",
    "        x.data.fill_(s_idx)\n",
    "        s = corpus.dictionary.idx2symbol[s_idx]\n",
    "        out.append(s)\n",
    "    return ''.join(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sample:\n",
      " s'W¡კgcử’]3ิÆ>GŁDų’ÜòãIยิLầg:±μ機0ä×€‑tx&ო<?С?3&ア9T \n",
      "\n",
      "| epoch   1 |   100/ 2807 batches | lr 4.00 | loss  3.61 | ppl    36.79\n",
      "| epoch   1 |   200/ 2807 batches | lr 4.00 | loss  3.29 | ppl    26.74\n",
      "| epoch   1 |   300/ 2807 batches | lr 4.00 | loss  3.25 | ppl    25.81\n",
      "| epoch   1 |   400/ 2807 batches | lr 4.00 | loss  3.23 | ppl    25.16\n",
      "| epoch   1 |   500/ 2807 batches | lr 4.00 | loss  3.21 | ppl    24.89\n",
      "| epoch   1 |   600/ 2807 batches | lr 4.00 | loss  3.09 | ppl    22.04\n",
      "| epoch   1 |   700/ 2807 batches | lr 4.00 | loss  2.99 | ppl    19.79\n",
      "| epoch   1 |   800/ 2807 batches | lr 4.00 | loss  2.89 | ppl    18.04\n",
      "| epoch   1 |   900/ 2807 batches | lr 4.00 | loss  2.80 | ppl    16.51\n",
      "| epoch   1 |  1000/ 2807 batches | lr 4.00 | loss  2.71 | ppl    15.08\n",
      "| epoch   1 |  1100/ 2807 batches | lr 4.00 | loss  2.61 | ppl    13.59\n",
      "| epoch   1 |  1200/ 2807 batches | lr 4.00 | loss  2.55 | ppl    12.76\n",
      "| epoch   1 |  1300/ 2807 batches | lr 4.00 | loss  2.50 | ppl    12.14\n",
      "| epoch   1 |  1400/ 2807 batches | lr 4.00 | loss  2.45 | ppl    11.56\n",
      "| epoch   1 |  1500/ 2807 batches | lr 4.00 | loss  2.41 | ppl    11.17\n",
      "| epoch   1 |  1600/ 2807 batches | lr 4.00 | loss  2.38 | ppl    10.80\n",
      "| epoch   1 |  1700/ 2807 batches | lr 4.00 | loss  2.35 | ppl    10.45\n",
      "| epoch   1 |  1800/ 2807 batches | lr 4.00 | loss  2.32 | ppl    10.14\n",
      "| epoch   1 |  1900/ 2807 batches | lr 4.00 | loss  2.29 | ppl     9.92\n",
      "| epoch   1 |  2000/ 2807 batches | lr 4.00 | loss  2.26 | ppl     9.63\n",
      "| epoch   1 |  2100/ 2807 batches | lr 4.00 | loss  2.25 | ppl     9.45\n",
      "| epoch   1 |  2200/ 2807 batches | lr 4.00 | loss  2.22 | ppl     9.24\n",
      "| epoch   1 |  2300/ 2807 batches | lr 4.00 | loss  2.22 | ppl     9.18\n",
      "| epoch   1 |  2400/ 2807 batches | lr 4.00 | loss  2.19 | ppl     8.96\n",
      "| epoch   1 |  2500/ 2807 batches | lr 4.00 | loss  2.18 | ppl     8.85\n",
      "| epoch   1 |  2600/ 2807 batches | lr 4.00 | loss  2.17 | ppl     8.71\n",
      "| epoch   1 |  2700/ 2807 batches | lr 4.00 | loss  2.15 | ppl     8.59\n",
      "| epoch   1 |  2800/ 2807 batches | lr 4.00 | loss  2.13 | ppl     8.38\n",
      "-----------------------------------------------------------------------------------------\n",
      "| end of epoch   1 | valid loss  1.95 | valid ppl     7.04\n",
      "-----------------------------------------------------------------------------------------\n",
      "sample:\n",
      "  Porded 2041 janurre of spopent , . Fhlun . \n",
      " × fl \n",
      "\n",
      "| epoch   2 |   100/ 2807 batches | lr 4.00 | loss  2.14 | ppl     8.48\n",
      "| epoch   2 |   200/ 2807 batches | lr 4.00 | loss  2.10 | ppl     8.16\n",
      "| epoch   2 |   300/ 2807 batches | lr 4.00 | loss  2.09 | ppl     8.08\n",
      "| epoch   2 |   400/ 2807 batches | lr 4.00 | loss  2.08 | ppl     8.00\n",
      "| epoch   2 |   500/ 2807 batches | lr 4.00 | loss  2.07 | ppl     7.92\n",
      "| epoch   2 |   600/ 2807 batches | lr 4.00 | loss  2.06 | ppl     7.82\n",
      "| epoch   2 |   700/ 2807 batches | lr 4.00 | loss  2.05 | ppl     7.78\n",
      "| epoch   2 |   800/ 2807 batches | lr 4.00 | loss  2.04 | ppl     7.67\n",
      "| epoch   2 |   900/ 2807 batches | lr 4.00 | loss  2.03 | ppl     7.64\n",
      "| epoch   2 |  1000/ 2807 batches | lr 4.00 | loss  2.03 | ppl     7.59\n",
      "| epoch   2 |  1100/ 2807 batches | lr 4.00 | loss  2.01 | ppl     7.47\n",
      "| epoch   2 |  1200/ 2807 batches | lr 4.00 | loss  2.00 | ppl     7.42\n",
      "| epoch   2 |  1300/ 2807 batches | lr 4.00 | loss  2.00 | ppl     7.37\n",
      "| epoch   2 |  1400/ 2807 batches | lr 4.00 | loss  1.98 | ppl     7.25\n",
      "| epoch   2 |  1500/ 2807 batches | lr 4.00 | loss  1.98 | ppl     7.24\n",
      "| epoch   2 |  1600/ 2807 batches | lr 4.00 | loss  1.98 | ppl     7.21\n",
      "| epoch   2 |  1700/ 2807 batches | lr 4.00 | loss  1.96 | ppl     7.12\n",
      "| epoch   2 |  1800/ 2807 batches | lr 4.00 | loss  1.96 | ppl     7.09\n",
      "| epoch   2 |  1900/ 2807 batches | lr 4.00 | loss  1.96 | ppl     7.10\n",
      "| epoch   2 |  2000/ 2807 batches | lr 4.00 | loss  1.94 | ppl     6.98\n",
      "| epoch   2 |  2100/ 2807 batches | lr 4.00 | loss  1.95 | ppl     7.00\n",
      "| epoch   2 |  2200/ 2807 batches | lr 4.00 | loss  1.94 | ppl     6.94\n",
      "| epoch   2 |  2300/ 2807 batches | lr 4.00 | loss  1.94 | ppl     6.95\n",
      "| epoch   2 |  2400/ 2807 batches | lr 4.00 | loss  1.92 | ppl     6.84\n",
      "| epoch   2 |  2500/ 2807 batches | lr 4.00 | loss  1.92 | ppl     6.81\n",
      "| epoch   2 |  2600/ 2807 batches | lr 4.00 | loss  1.92 | ppl     6.82\n",
      "| epoch   2 |  2700/ 2807 batches | lr 4.00 | loss  1.91 | ppl     6.78\n",
      "| epoch   2 |  2800/ 2807 batches | lr 4.00 | loss  1.90 | ppl     6.68\n",
      "-----------------------------------------------------------------------------------------\n",
      "| end of epoch   2 | valid loss  1.70 | valid ppl     5.48\n",
      "-----------------------------------------------------------------------------------------\n",
      "sample:\n",
      "  and the funres evilion of a riid and 1961 ovely f \n",
      "\n",
      "| epoch   3 |   100/ 2807 batches | lr 4.00 | loss  1.92 | ppl     6.79\n",
      "| epoch   3 |   200/ 2807 batches | lr 4.00 | loss  1.89 | ppl     6.61\n",
      "| epoch   3 |   300/ 2807 batches | lr 4.00 | loss  1.89 | ppl     6.60\n",
      "| epoch   3 |   400/ 2807 batches | lr 4.00 | loss  1.88 | ppl     6.58\n",
      "| epoch   3 |   500/ 2807 batches | lr 4.00 | loss  1.88 | ppl     6.54\n",
      "| epoch   3 |   600/ 2807 batches | lr 4.00 | loss  1.87 | ppl     6.50\n",
      "| epoch   3 |   700/ 2807 batches | lr 4.00 | loss  1.87 | ppl     6.51\n",
      "| epoch   3 |   800/ 2807 batches | lr 4.00 | loss  1.87 | ppl     6.46\n",
      "| epoch   3 |   900/ 2807 batches | lr 4.00 | loss  1.87 | ppl     6.47\n",
      "| epoch   3 |  1000/ 2807 batches | lr 4.00 | loss  1.87 | ppl     6.46\n",
      "| epoch   3 |  1100/ 2807 batches | lr 4.00 | loss  1.85 | ppl     6.39\n",
      "| epoch   3 |  1200/ 2807 batches | lr 4.00 | loss  1.86 | ppl     6.40\n",
      "| epoch   3 |  1300/ 2807 batches | lr 4.00 | loss  1.85 | ppl     6.35\n",
      "| epoch   3 |  1400/ 2807 batches | lr 4.00 | loss  1.84 | ppl     6.29\n",
      "| epoch   3 |  1500/ 2807 batches | lr 4.00 | loss  1.84 | ppl     6.30\n",
      "| epoch   3 |  1600/ 2807 batches | lr 4.00 | loss  1.84 | ppl     6.29\n",
      "| epoch   3 |  1700/ 2807 batches | lr 4.00 | loss  1.83 | ppl     6.26\n",
      "| epoch   3 |  1800/ 2807 batches | lr 4.00 | loss  1.83 | ppl     6.23\n",
      "| epoch   3 |  1900/ 2807 batches | lr 4.00 | loss  1.84 | ppl     6.29\n",
      "| epoch   3 |  2000/ 2807 batches | lr 4.00 | loss  1.83 | ppl     6.21\n",
      "| epoch   3 |  2100/ 2807 batches | lr 4.00 | loss  1.83 | ppl     6.24\n",
      "| epoch   3 |  2200/ 2807 batches | lr 4.00 | loss  1.83 | ppl     6.21\n",
      "| epoch   3 |  2300/ 2807 batches | lr 4.00 | loss  1.83 | ppl     6.24\n",
      "| epoch   3 |  2400/ 2807 batches | lr 4.00 | loss  1.82 | ppl     6.15\n",
      "| epoch   3 |  2500/ 2807 batches | lr 4.00 | loss  1.82 | ppl     6.15\n",
      "| epoch   3 |  2600/ 2807 batches | lr 4.00 | loss  1.82 | ppl     6.18\n",
      "| epoch   3 |  2700/ 2807 batches | lr 4.00 | loss  1.82 | ppl     6.15\n",
      "| epoch   3 |  2800/ 2807 batches | lr 4.00 | loss  1.80 | ppl     6.08\n",
      "-----------------------------------------------------------------------------------------\n",
      "| end of epoch   3 | valid loss  1.60 | valid ppl     4.93\n",
      "-----------------------------------------------------------------------------------------\n",
      "sample:\n",
      "  when initusable . The playar in Mili <unk> and as \n",
      "\n",
      "| epoch   4 |   100/ 2807 batches | lr 4.00 | loss  1.82 | ppl     6.20\n",
      "| epoch   4 |   200/ 2807 batches | lr 4.00 | loss  1.80 | ppl     6.04\n",
      "| epoch   4 |   300/ 2807 batches | lr 4.00 | loss  1.80 | ppl     6.06\n",
      "| epoch   4 |   400/ 2807 batches | lr 4.00 | loss  1.80 | ppl     6.04\n",
      "| epoch   4 |   500/ 2807 batches | lr 4.00 | loss  1.80 | ppl     6.03\n",
      "| epoch   4 |   600/ 2807 batches | lr 4.00 | loss  1.79 | ppl     5.99\n",
      "| epoch   4 |   700/ 2807 batches | lr 4.00 | loss  1.79 | ppl     6.01\n",
      "| epoch   4 |   800/ 2807 batches | lr 4.00 | loss  1.79 | ppl     5.99\n",
      "| epoch   4 |   900/ 2807 batches | lr 4.00 | loss  1.79 | ppl     5.99\n",
      "| epoch   4 |  1000/ 2807 batches | lr 4.00 | loss  1.79 | ppl     6.00\n",
      "| epoch   4 |  1100/ 2807 batches | lr 4.00 | loss  1.78 | ppl     5.93\n",
      "| epoch   4 |  1200/ 2807 batches | lr 4.00 | loss  1.78 | ppl     5.96\n",
      "| epoch   4 |  1300/ 2807 batches | lr 4.00 | loss  1.78 | ppl     5.92\n",
      "| epoch   4 |  1400/ 2807 batches | lr 4.00 | loss  1.77 | ppl     5.86\n",
      "| epoch   4 |  1500/ 2807 batches | lr 4.00 | loss  1.77 | ppl     5.88\n",
      "| epoch   4 |  1600/ 2807 batches | lr 4.00 | loss  1.77 | ppl     5.90\n",
      "| epoch   4 |  1700/ 2807 batches | lr 4.00 | loss  1.77 | ppl     5.87\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "| epoch   4 |  1800/ 2807 batches | lr 4.00 | loss  1.77 | ppl     5.86\n",
      "| epoch   4 |  1900/ 2807 batches | lr 4.00 | loss  1.78 | ppl     5.91\n",
      "| epoch   4 |  2000/ 2807 batches | lr 4.00 | loss  1.77 | ppl     5.85\n",
      "| epoch   4 |  2100/ 2807 batches | lr 4.00 | loss  1.77 | ppl     5.89\n",
      "| epoch   4 |  2200/ 2807 batches | lr 4.00 | loss  1.77 | ppl     5.86\n",
      "| epoch   4 |  2300/ 2807 batches | lr 4.00 | loss  1.77 | ppl     5.89\n",
      "| epoch   4 |  2400/ 2807 batches | lr 4.00 | loss  1.76 | ppl     5.81\n",
      "| epoch   4 |  2500/ 2807 batches | lr 4.00 | loss  1.76 | ppl     5.81\n",
      "| epoch   4 |  2600/ 2807 batches | lr 4.00 | loss  1.77 | ppl     5.86\n",
      "| epoch   4 |  2700/ 2807 batches | lr 4.00 | loss  1.76 | ppl     5.84\n",
      "| epoch   4 |  2800/ 2807 batches | lr 4.00 | loss  1.75 | ppl     5.76\n",
      "-----------------------------------------------------------------------------------------\n",
      "| end of epoch   4 | valid loss  1.54 | valid ppl     4.67\n",
      "-----------------------------------------------------------------------------------------\n",
      "sample:\n",
      " a \" <unk> Trulational . Selliazed to the made . An \n",
      "\n",
      "| epoch   5 |   100/ 2807 batches | lr 4.00 | loss  1.77 | ppl     5.89\n",
      "| epoch   5 |   200/ 2807 batches | lr 4.00 | loss  1.75 | ppl     5.74\n",
      "| epoch   5 |   300/ 2807 batches | lr 4.00 | loss  1.75 | ppl     5.77\n",
      "| epoch   5 |   400/ 2807 batches | lr 4.00 | loss  1.75 | ppl     5.76\n",
      "| epoch   5 |   500/ 2807 batches | lr 4.00 | loss  1.75 | ppl     5.75\n",
      "| epoch   5 |   600/ 2807 batches | lr 4.00 | loss  1.74 | ppl     5.72\n",
      "| epoch   5 |   700/ 2807 batches | lr 4.00 | loss  1.75 | ppl     5.74\n",
      "| epoch   5 |   800/ 2807 batches | lr 4.00 | loss  1.74 | ppl     5.72\n",
      "| epoch   5 |   900/ 2807 batches | lr 4.00 | loss  1.75 | ppl     5.73\n",
      "| epoch   5 |  1000/ 2807 batches | lr 4.00 | loss  1.75 | ppl     5.75\n",
      "| epoch   5 |  1100/ 2807 batches | lr 4.00 | loss  1.74 | ppl     5.70\n",
      "| epoch   5 |  1200/ 2807 batches | lr 4.00 | loss  1.74 | ppl     5.72\n",
      "| epoch   5 |  1300/ 2807 batches | lr 4.00 | loss  1.74 | ppl     5.68\n",
      "| epoch   5 |  1400/ 2807 batches | lr 4.00 | loss  1.73 | ppl     5.62\n",
      "| epoch   5 |  1500/ 2807 batches | lr 4.00 | loss  1.73 | ppl     5.66\n",
      "| epoch   5 |  1600/ 2807 batches | lr 4.00 | loss  1.74 | ppl     5.67\n",
      "| epoch   5 |  1700/ 2807 batches | lr 4.00 | loss  1.73 | ppl     5.64\n",
      "| epoch   5 |  1800/ 2807 batches | lr 4.00 | loss  1.73 | ppl     5.64\n",
      "| epoch   5 |  1900/ 2807 batches | lr 4.00 | loss  1.74 | ppl     5.70\n",
      "| epoch   5 |  2000/ 2807 batches | lr 4.00 | loss  1.73 | ppl     5.64\n",
      "| epoch   5 |  2100/ 2807 batches | lr 4.00 | loss  1.74 | ppl     5.69\n",
      "| epoch   5 |  2200/ 2807 batches | lr 4.00 | loss  1.73 | ppl     5.66\n",
      "| epoch   5 |  2300/ 2807 batches | lr 4.00 | loss  1.74 | ppl     5.69\n",
      "| epoch   5 |  2400/ 2807 batches | lr 4.00 | loss  1.73 | ppl     5.61\n",
      "| epoch   5 |  2500/ 2807 batches | lr 4.00 | loss  1.72 | ppl     5.61\n",
      "| epoch   5 |  2600/ 2807 batches | lr 4.00 | loss  1.73 | ppl     5.66\n",
      "| epoch   5 |  2700/ 2807 batches | lr 4.00 | loss  1.73 | ppl     5.65\n",
      "| epoch   5 |  2800/ 2807 batches | lr 4.00 | loss  1.72 | ppl     5.58\n",
      "-----------------------------------------------------------------------------------------\n",
      "| end of epoch   5 | valid loss  1.51 | valid ppl     4.52\n",
      "-----------------------------------------------------------------------------------------\n",
      "sample:\n",
      " ero Bivert palluy and Common 's Mingull have milit \n",
      "\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    print('sample:\\n', generate(50), '\\n')\n",
    "\n",
    "for epoch in range(1, 6):\n",
    "    train()\n",
    "    val_loss = evaluate(val_loader)\n",
    "    print('-' * 89)\n",
    "    print('| end of epoch {:3d} | valid loss {:5.2f} | valid ppl {:8.2f}'.format(\n",
    "        epoch, val_loss, math.exp(val_loss)))\n",
    "    print('-' * 89)\n",
    "    if not best_val_loss or val_loss < best_val_loss:\n",
    "        best_val_loss = val_loss\n",
    "    else:\n",
    "        # Anneal the learning rate if no improvement has been seen in the validation dataset.\n",
    "        lr /= 4.0\n",
    "    with torch.no_grad():\n",
    "        print('sample:\\n', generate(50), '\\n')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "t1 = generate(10000, 1.)\n",
    "t15 = generate(10000, 1.5)\n",
    "t075 = generate(10000, 0.75)\n",
    "with open('./generated075.txt', 'w') as outf:\n",
    "    outf.write(t075)\n",
    "with open('./generated1.txt', 'w') as outf:\n",
    "    outf.write(t1)\n",
    "with open('./generated15.txt', 'w') as outf:\n",
    "    outf.write(t15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
