{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "import itertools\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "import math \n",
    "\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "import utils\n",
    "import wiki_utils\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "corpus = wiki_utils.Texts('./wikitext/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "sequence_length = 30\n",
    "grad_clip = 0.1\n",
    "lr = 4.\n",
    "best_val_loss = None\n",
    "log_interval = 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_batch_size = 128\n",
    "train_loader = wiki_utils.TextLoader(corpus.train, batch_size=batch_size)\n",
    "val_loader = wiki_utils.TextLoader(corpus.valid, batch_size=eval_batch_size)\n",
    "test_loader = wiki_utils.TextLoader(corpus.test, batch_size=eval_batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class RNNModel(nn.Module):\n",
    "\n",
    "    def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5):\n",
    "        super(RNNModel, self).__init__()\n",
    "        self.drop = nn.Dropout(dropout)\n",
    "        self.encoder = nn.Embedding(ntoken, ninp)\n",
    "        if rnn_type == 'LSTM':\n",
    "            self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)\n",
    "        elif rnn_type == 'GRU':\n",
    "            self.rnn = nn.GRU(ninp, nhid, nlayers, dropout=dropout)\n",
    "        self.decoder = nn.Linear(nhid, ntoken)\n",
    "\n",
    "        self.init_weights()\n",
    "\n",
    "        self.rnn_type = rnn_type\n",
    "        self.nhid = nhid\n",
    "        self.nlayers = nlayers\n",
    "\n",
    "    def init_weights(self):\n",
    "        initrange = 0.1\n",
    "        self.encoder.weight.data.uniform_(-initrange, initrange)\n",
    "        self.decoder.bias.data.fill_(0)\n",
    "        self.decoder.weight.data.uniform_(-initrange, initrange)\n",
    "\n",
    "    def forward(self, x, hidden=None):\n",
    "        emb = self.drop(self.encoder(x))\n",
    "        output, hidden = self.rnn(emb, hidden)\n",
    "        output = self.drop(output)\n",
    "        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))\n",
    "        return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden\n",
    "\n",
    "    def init_hidden(self, bsz):\n",
    "        weight = next(self.parameters()).data\n",
    "        if self.rnn_type == 'LSTM':\n",
    "            return (weight.new(self.nlayers, bsz, self.nhid).zero_(),\n",
    "                    weight.new(self.nlayers, bsz, self.nhid).zero_())\n",
    "        else:\n",
    "            return weight.new(self.nlayers, bsz, self.nhid).zero_()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(data_loader):\n",
    "    model.eval()\n",
    "    total_loss = 0\n",
    "    ntokens = len(corpus.dictionary)\n",
    "    hidden = model.init_hidden(eval_batch_size)\n",
    "    for i, (data, targets) in enumerate(data_loader):\n",
    "        output, hidden = model(data)\n",
    "        output_flat = output.view(-1, ntokens)\n",
    "        total_loss += len(data) * criterion(output_flat, targets).item()\n",
    "    return total_loss / len(data_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train():\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    ntokens = len(corpus.dictionary)\n",
    "    for batch, (data, targets) in enumerate(train_loader):\n",
    "        model.zero_grad()\n",
    "        output, hidden = model(data)\n",
    "        loss = criterion(output.view(-1, ntokens), targets)\n",
    "        loss.backward()\n",
    "\n",
    "        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)\n",
    "        for p in model.parameters():\n",
    "            p.data.add_(-lr, p.grad.data)\n",
    "\n",
    "        total_loss += loss.item()\n",
    "\n",
    "        if batch % log_interval == 0 and batch > 0:\n",
    "            cur_loss = total_loss / log_interval\n",
    "            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | loss {:5.2f} | ppl {:8.2f}'.format(\n",
    "                epoch, batch, len(train_loader) // sequence_length, lr, cur_loss, math.exp(cur_loss)))\n",
    "            total_loss = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "ntokens = len(corpus.dictionary)\n",
    "model = RNNModel('LSTM', ntokens, 128, 128, 2, 0.3)\n",
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate(n=50, temp=1.):\n",
    "    model.eval()\n",
    "    x = torch.rand(1, 1).mul(ntokens).long()\n",
    "    hidden = None\n",
    "    out = []\n",
    "    for i in range(n):\n",
    "        output, hidden = model(x, hidden)\n",
    "        s_weights = output.squeeze().data.div(temp).exp()\n",
    "        s_idx = torch.multinomial(s_weights, 1)[0]\n",
    "        x.data.fill_(s_idx)\n",
    "        s = corpus.dictionary.idx2symbol[s_idx]\n",
    "        out.append(s)\n",
    "    return ''.join(out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sample:\n",
      " 4JμュaµსʻzCв\"ėMFôćÍプ﻿ัвง;ŻśK&3=ã戦ュิ(*mュłžยtí.ńFgĐ£L \n",
      "\n",
      "| epoch   1 |   100/ 2807 batches | lr 4.00 | loss  3.60 | ppl    36.43\n",
      "| epoch   1 |   200/ 2807 batches | lr 4.00 | loss  3.28 | ppl    26.63\n",
      "| epoch   1 |   300/ 2807 batches | lr 4.00 | loss  3.25 | ppl    25.76\n",
      "| epoch   1 |   400/ 2807 batches | lr 4.00 | loss  3.22 | ppl    25.12\n",
      "| epoch   1 |   500/ 2807 batches | lr 4.00 | loss  3.21 | ppl    24.73\n",
      "| epoch   1 |   600/ 2807 batches | lr 4.00 | loss  3.08 | ppl    21.78\n",
      "| epoch   1 |   700/ 2807 batches | lr 4.00 | loss  2.98 | ppl    19.69\n",
      "| epoch   1 |   800/ 2807 batches | lr 4.00 | loss  2.89 | ppl    18.05\n",
      "| epoch   1 |   900/ 2807 batches | lr 4.00 | loss  2.78 | ppl    16.09\n",
      "| epoch   1 |  1000/ 2807 batches | lr 4.00 | loss  2.68 | ppl    14.56\n",
      "| epoch   1 |  1100/ 2807 batches | lr 4.00 | loss  2.58 | ppl    13.24\n",
      "| epoch   1 |  1200/ 2807 batches | lr 4.00 | loss  2.53 | ppl    12.54\n",
      "| epoch   1 |  1300/ 2807 batches | lr 4.00 | loss  2.48 | ppl    12.00\n",
      "| epoch   1 |  1400/ 2807 batches | lr 4.00 | loss  2.44 | ppl    11.44\n",
      "| epoch   1 |  1500/ 2807 batches | lr 4.00 | loss  2.40 | ppl    11.06\n",
      "| epoch   1 |  1600/ 2807 batches | lr 4.00 | loss  2.37 | ppl    10.69\n",
      "| epoch   1 |  1700/ 2807 batches | lr 4.00 | loss  2.33 | ppl    10.33\n",
      "| epoch   1 |  1800/ 2807 batches | lr 4.00 | loss  2.30 | ppl    10.01\n",
      "| epoch   1 |  1900/ 2807 batches | lr 4.00 | loss  2.28 | ppl     9.79\n",
      "| epoch   1 |  2000/ 2807 batches | lr 4.00 | loss  2.25 | ppl     9.49\n",
      "| epoch   1 |  2100/ 2807 batches | lr 4.00 | loss  2.23 | ppl     9.32\n",
      "| epoch   1 |  2200/ 2807 batches | lr 4.00 | loss  2.21 | ppl     9.10\n",
      "| epoch   1 |  2300/ 2807 batches | lr 4.00 | loss  2.20 | ppl     9.04\n",
      "| epoch   1 |  2400/ 2807 batches | lr 4.00 | loss  2.17 | ppl     8.80\n",
      "| epoch   1 |  2500/ 2807 batches | lr 4.00 | loss  2.16 | ppl     8.70\n",
      "| epoch   1 |  2600/ 2807 batches | lr 4.00 | loss  2.15 | ppl     8.55\n",
      "| epoch   1 |  2700/ 2807 batches | lr 4.00 | loss  2.13 | ppl     8.42\n",
      "| epoch   1 |  2800/ 2807 batches | lr 4.00 | loss  2.10 | ppl     8.20\n",
      "-----------------------------------------------------------------------------------------\n",
      "| end of epoch   1 | valid loss  1.93 | valid ppl     6.90\n",
      "-----------------------------------------------------------------------------------------\n",
      "sample:\n",
      " ond , Por hiirs as \" in jeritiviciaatt Dhois \" sur \n",
      "\n",
      "| epoch   2 |   100/ 2807 batches | lr 4.00 | loss  2.12 | ppl     8.30\n",
      "| epoch   2 |   200/ 2807 batches | lr 4.00 | loss  2.08 | ppl     7.99\n",
      "| epoch   2 |   300/ 2807 batches | lr 4.00 | loss  2.07 | ppl     7.89\n",
      "| epoch   2 |   400/ 2807 batches | lr 4.00 | loss  2.06 | ppl     7.83\n",
      "| epoch   2 |   500/ 2807 batches | lr 4.00 | loss  2.04 | ppl     7.72\n",
      "| epoch   2 |   600/ 2807 batches | lr 4.00 | loss  2.03 | ppl     7.63\n",
      "| epoch   2 |   700/ 2807 batches | lr 4.00 | loss  2.02 | ppl     7.57\n",
      "| epoch   2 |   800/ 2807 batches | lr 4.00 | loss  2.01 | ppl     7.49\n",
      "| epoch   2 |   900/ 2807 batches | lr 4.00 | loss  2.01 | ppl     7.43\n",
      "| epoch   2 |  1000/ 2807 batches | lr 4.00 | loss  2.00 | ppl     7.39\n",
      "| epoch   2 |  1100/ 2807 batches | lr 4.00 | loss  1.98 | ppl     7.27\n",
      "| epoch   2 |  1200/ 2807 batches | lr 4.00 | loss  1.98 | ppl     7.23\n",
      "| epoch   2 |  1300/ 2807 batches | lr 4.00 | loss  1.97 | ppl     7.16\n",
      "| epoch   2 |  1400/ 2807 batches | lr 4.00 | loss  1.95 | ppl     7.06\n",
      "| epoch   2 |  1500/ 2807 batches | lr 4.00 | loss  1.95 | ppl     7.04\n",
      "| epoch   2 |  1600/ 2807 batches | lr 4.00 | loss  1.95 | ppl     7.00\n",
      "| epoch   2 |  1700/ 2807 batches | lr 4.00 | loss  1.94 | ppl     6.94\n",
      "| epoch   2 |  1800/ 2807 batches | lr 4.00 | loss  1.93 | ppl     6.88\n",
      "| epoch   2 |  1900/ 2807 batches | lr 4.00 | loss  1.93 | ppl     6.92\n",
      "| epoch   2 |  2000/ 2807 batches | lr 4.00 | loss  1.92 | ppl     6.80\n",
      "| epoch   2 |  2100/ 2807 batches | lr 4.00 | loss  1.92 | ppl     6.81\n",
      "| epoch   2 |  2200/ 2807 batches | lr 4.00 | loss  1.91 | ppl     6.75\n",
      "| epoch   2 |  2300/ 2807 batches | lr 4.00 | loss  1.91 | ppl     6.77\n",
      "| epoch   2 |  2400/ 2807 batches | lr 4.00 | loss  1.89 | ppl     6.65\n",
      "| epoch   2 |  2500/ 2807 batches | lr 4.00 | loss  1.89 | ppl     6.63\n",
      "| epoch   2 |  2600/ 2807 batches | lr 4.00 | loss  1.89 | ppl     6.64\n",
      "| epoch   2 |  2700/ 2807 batches | lr 4.00 | loss  1.89 | ppl     6.61\n",
      "| epoch   2 |  2800/ 2807 batches | lr 4.00 | loss  1.87 | ppl     6.49\n",
      "-----------------------------------------------------------------------------------------\n",
      "| end of epoch   2 | valid loss  1.67 | valid ppl     5.30\n",
      "-----------------------------------------------------------------------------------------\n",
      "sample:\n",
      "  the 1650 . It in the <unk> scited in be subfire n \n",
      "\n",
      "| epoch   3 |   100/ 2807 batches | lr 4.00 | loss  1.89 | ppl     6.61\n",
      "| epoch   3 |   200/ 2807 batches | lr 4.00 | loss  1.86 | ppl     6.42\n",
      "| epoch   3 |   300/ 2807 batches | lr 4.00 | loss  1.86 | ppl     6.42\n",
      "| epoch   3 |   400/ 2807 batches | lr 4.00 | loss  1.86 | ppl     6.40\n",
      "| epoch   3 |   500/ 2807 batches | lr 4.00 | loss  1.85 | ppl     6.38\n",
      "| epoch   3 |   600/ 2807 batches | lr 4.00 | loss  1.85 | ppl     6.33\n",
      "| epoch   3 |   700/ 2807 batches | lr 4.00 | loss  1.84 | ppl     6.32\n",
      "| epoch   3 |   800/ 2807 batches | lr 4.00 | loss  1.84 | ppl     6.29\n",
      "| epoch   3 |   900/ 2807 batches | lr 4.00 | loss  1.84 | ppl     6.28\n",
      "| epoch   3 |  1000/ 2807 batches | lr 4.00 | loss  1.84 | ppl     6.28\n",
      "| epoch   3 |  1100/ 2807 batches | lr 4.00 | loss  1.82 | ppl     6.20\n",
      "| epoch   3 |  1200/ 2807 batches | lr 4.00 | loss  1.83 | ppl     6.22\n",
      "| epoch   3 |  1300/ 2807 batches | lr 4.00 | loss  1.82 | ppl     6.18\n",
      "| epoch   3 |  1400/ 2807 batches | lr 4.00 | loss  1.81 | ppl     6.10\n",
      "| epoch   3 |  1500/ 2807 batches | lr 4.00 | loss  1.81 | ppl     6.12\n",
      "| epoch   3 |  1600/ 2807 batches | lr 4.00 | loss  1.81 | ppl     6.13\n",
      "| epoch   3 |  1700/ 2807 batches | lr 4.00 | loss  1.81 | ppl     6.09\n",
      "| epoch   3 |  1800/ 2807 batches | lr 4.00 | loss  1.80 | ppl     6.06\n",
      "| epoch   3 |  1900/ 2807 batches | lr 4.00 | loss  1.81 | ppl     6.13\n",
      "| epoch   3 |  2000/ 2807 batches | lr 4.00 | loss  1.80 | ppl     6.06\n",
      "| epoch   3 |  2100/ 2807 batches | lr 4.00 | loss  1.80 | ppl     6.08\n",
      "| epoch   3 |  2200/ 2807 batches | lr 4.00 | loss  1.80 | ppl     6.05\n",
      "| epoch   3 |  2300/ 2807 batches | lr 4.00 | loss  1.80 | ppl     6.07\n",
      "| epoch   3 |  2400/ 2807 batches | lr 4.00 | loss  1.79 | ppl     5.98\n",
      "| epoch   3 |  2500/ 2807 batches | lr 4.00 | loss  1.79 | ppl     5.98\n",
      "| epoch   3 |  2600/ 2807 batches | lr 4.00 | loss  1.79 | ppl     6.01\n",
      "| epoch   3 |  2700/ 2807 batches | lr 4.00 | loss  1.79 | ppl     5.99\n",
      "| epoch   3 |  2800/ 2807 batches | lr 4.00 | loss  1.78 | ppl     5.90\n",
      "-----------------------------------------------------------------------------------------\n",
      "| end of epoch   3 | valid loss  1.57 | valid ppl     4.79\n",
      "-----------------------------------------------------------------------------------------\n",
      "sample:\n",
      " I4 : He for Peyhaling <unk> that the storyon <unk> \n",
      "\n",
      "| epoch   4 |   100/ 2807 batches | lr 4.00 | loss  1.80 | ppl     6.03\n",
      "| epoch   4 |   200/ 2807 batches | lr 4.00 | loss  1.77 | ppl     5.88\n",
      "| epoch   4 |   300/ 2807 batches | lr 4.00 | loss  1.77 | ppl     5.90\n",
      "| epoch   4 |   400/ 2807 batches | lr 4.00 | loss  1.77 | ppl     5.89\n",
      "| epoch   4 |   500/ 2807 batches | lr 4.00 | loss  1.77 | ppl     5.88\n",
      "| epoch   4 |   600/ 2807 batches | lr 4.00 | loss  1.77 | ppl     5.84\n",
      "| epoch   4 |   700/ 2807 batches | lr 4.00 | loss  1.77 | ppl     5.85\n",
      "| epoch   4 |   800/ 2807 batches | lr 4.00 | loss  1.76 | ppl     5.83\n",
      "| epoch   4 |   900/ 2807 batches | lr 4.00 | loss  1.76 | ppl     5.83\n",
      "| epoch   4 |  1000/ 2807 batches | lr 4.00 | loss  1.76 | ppl     5.84\n",
      "| epoch   4 |  1100/ 2807 batches | lr 4.00 | loss  1.75 | ppl     5.78\n",
      "| epoch   4 |  1200/ 2807 batches | lr 4.00 | loss  1.76 | ppl     5.81\n",
      "| epoch   4 |  1300/ 2807 batches | lr 4.00 | loss  1.76 | ppl     5.78\n",
      "| epoch   4 |  1400/ 2807 batches | lr 4.00 | loss  1.74 | ppl     5.71\n",
      "| epoch   4 |  1500/ 2807 batches | lr 4.00 | loss  1.75 | ppl     5.74\n",
      "| epoch   4 |  1600/ 2807 batches | lr 4.00 | loss  1.75 | ppl     5.75\n",
      "| epoch   4 |  1700/ 2807 batches | lr 4.00 | loss  1.74 | ppl     5.72\n",
      "| epoch   4 |  1800/ 2807 batches | lr 4.00 | loss  1.74 | ppl     5.71\n",
      "| epoch   4 |  1900/ 2807 batches | lr 4.00 | loss  1.76 | ppl     5.78\n",
      "| epoch   4 |  2000/ 2807 batches | lr 4.00 | loss  1.74 | ppl     5.71\n",
      "| epoch   4 |  2100/ 2807 batches | lr 4.00 | loss  1.75 | ppl     5.75\n",
      "| epoch   4 |  2200/ 2807 batches | lr 4.00 | loss  1.74 | ppl     5.71\n",
      "| epoch   4 |  2300/ 2807 batches | lr 4.00 | loss  1.75 | ppl     5.75\n",
      "| epoch   4 |  2400/ 2807 batches | lr 4.00 | loss  1.74 | ppl     5.67\n",
      "| epoch   4 |  2500/ 2807 batches | lr 4.00 | loss  1.73 | ppl     5.66\n",
      "| epoch   4 |  2600/ 2807 batches | lr 4.00 | loss  1.74 | ppl     5.72\n",
      "| epoch   4 |  2700/ 2807 batches | lr 4.00 | loss  1.74 | ppl     5.71\n",
      "| epoch   4 |  2800/ 2807 batches | lr 4.00 | loss  1.73 | ppl     5.62\n",
      "-----------------------------------------------------------------------------------------\n",
      "| end of epoch   4 | valid loss  1.52 | valid ppl     4.56\n",
      "-----------------------------------------------------------------------------------------\n",
      "sample:\n",
      "  of Hill Performational diversomen , alfoir Soase  \n",
      "\n",
      "| epoch   5 |   100/ 2807 batches | lr 4.00 | loss  1.75 | ppl     5.75\n",
      "| epoch   5 |   200/ 2807 batches | lr 4.00 | loss  1.73 | ppl     5.61\n",
      "| epoch   5 |   300/ 2807 batches | lr 4.00 | loss  1.73 | ppl     5.64\n",
      "| epoch   5 |   400/ 2807 batches | lr 4.00 | loss  1.73 | ppl     5.63\n",
      "| epoch   5 |   500/ 2807 batches | lr 4.00 | loss  1.73 | ppl     5.62\n",
      "| epoch   5 |   600/ 2807 batches | lr 4.00 | loss  1.72 | ppl     5.60\n",
      "| epoch   5 |   700/ 2807 batches | lr 4.00 | loss  1.72 | ppl     5.61\n",
      "| epoch   5 |   800/ 2807 batches | lr 4.00 | loss  1.72 | ppl     5.58\n",
      "| epoch   5 |   900/ 2807 batches | lr 4.00 | loss  1.72 | ppl     5.60\n",
      "| epoch   5 |  1000/ 2807 batches | lr 4.00 | loss  1.72 | ppl     5.61\n",
      "| epoch   5 |  1100/ 2807 batches | lr 4.00 | loss  1.72 | ppl     5.57\n",
      "| epoch   5 |  1200/ 2807 batches | lr 4.00 | loss  1.72 | ppl     5.58\n",
      "| epoch   5 |  1300/ 2807 batches | lr 4.00 | loss  1.72 | ppl     5.56\n",
      "| epoch   5 |  1400/ 2807 batches | lr 4.00 | loss  1.70 | ppl     5.49\n",
      "| epoch   5 |  1500/ 2807 batches | lr 4.00 | loss  1.71 | ppl     5.53\n",
      "| epoch   5 |  1600/ 2807 batches | lr 4.00 | loss  1.71 | ppl     5.54\n",
      "| epoch   5 |  1700/ 2807 batches | lr 4.00 | loss  1.71 | ppl     5.51\n",
      "| epoch   5 |  1800/ 2807 batches | lr 4.00 | loss  1.71 | ppl     5.51\n",
      "| epoch   5 |  1900/ 2807 batches | lr 4.00 | loss  1.72 | ppl     5.58\n",
      "| epoch   5 |  2000/ 2807 batches | lr 4.00 | loss  1.71 | ppl     5.51\n",
      "| epoch   5 |  2100/ 2807 batches | lr 4.00 | loss  1.72 | ppl     5.56\n",
      "| epoch   5 |  2200/ 2807 batches | lr 4.00 | loss  1.71 | ppl     5.54\n",
      "| epoch   5 |  2300/ 2807 batches | lr 4.00 | loss  1.71 | ppl     5.55\n",
      "| epoch   5 |  2400/ 2807 batches | lr 4.00 | loss  1.70 | ppl     5.48\n",
      "| epoch   5 |  2500/ 2807 batches | lr 4.00 | loss  1.70 | ppl     5.49\n",
      "| epoch   5 |  2600/ 2807 batches | lr 4.00 | loss  1.71 | ppl     5.54\n",
      "| epoch   5 |  2700/ 2807 batches | lr 4.00 | loss  1.71 | ppl     5.52\n",
      "| epoch   5 |  2800/ 2807 batches | lr 4.00 | loss  1.70 | ppl     5.45\n",
      "-----------------------------------------------------------------------------------------\n",
      "| end of epoch   5 | valid loss  1.49 | valid ppl     4.42\n",
      "-----------------------------------------------------------------------------------------\n",
      "sample:\n",
      "  housed to the Jlow , for works an game became or  \n",
      "\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    print('sample:\\n', generate(50), '\\n')\n",
    "\n",
    "for epoch in range(1, 6):\n",
    "    train()\n",
    "    val_loss = evaluate(val_loader)\n",
    "    print('-' * 89)\n",
    "    print('| end of epoch {:3d} | valid loss {:5.2f} | valid ppl {:8.2f}'.format(\n",
    "        epoch, val_loss, math.exp(val_loss)))\n",
    "    print('-' * 89)\n",
    "    if not best_val_loss or val_loss < best_val_loss:\n",
    "        best_val_loss = val_loss\n",
    "    else:\n",
    "        # Anneal the learning rate if no improvement has been seen in the validation dataset.\n",
    "        lr /= 4.0\n",
    "    with torch.no_grad():\n",
    "        print('sample:\\n', generate(50), '\\n')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "t1 = generate(10000, 1.)\n",
    "t15 = generate(10000, 1.5)\n",
    "t075 = generate(10000, 0.75)\n",
    "with open('./generated075.txt', 'w') as outf:\n",
    "    outf.write(t075)\n",
    "with open('./generated1.txt', 'w') as outf:\n",
    "    outf.write(t1)\n",
    "with open('./generated15.txt', 'w') as outf:\n",
    "    outf.write(t15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
