{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "from utils import mnist, plot_graphs\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, valid_loader, test_loader = mnist(valid=10000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self, batchnorm=False, dropout=False, lr=1e-4, l2=0.):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(28*28, 128)\n",
    "        self.fc2 = nn.Linear(128, 128)\n",
    "        self.fc3 = nn.Linear(128, 10)\n",
    "        if batchnorm:\n",
    "            self.bn = nn.BatchNorm1d(128)\n",
    "        self.batchnorm = batchnorm\n",
    "        \n",
    "        self.dropout = dropout\n",
    "        self.optim = optim.Adam(self.parameters(), lr=lr, weight_decay=l2)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, 28*28)\n",
    "        x = torch.tanh(self.fc1(x))\n",
    "        if self.batchnorm:\n",
    "            x = self.bn(x)\n",
    "        x = torch.tanh(self.fc2(x))\n",
    "        if self.dropout:\n",
    "            x = F.dropout(x, 0.5)\n",
    "        x = self.fc3(x)\n",
    "        x = torch.log_softmax(x, dim=1)\n",
    "        return x\n",
    "    \n",
    "    def loss(self, output, target, **kwargs):\n",
    "        self._loss = F.nll_loss(output, target, **kwargs)\n",
    "        return self._loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch, models):\n",
    "    train_size = len(train_loader.sampler)\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        for k, model in models.items():\n",
    "            model.optim.zero_grad()\n",
    "            output = model(data)\n",
    "            loss = model.loss(output, target)\n",
    "            loss.backward()\n",
    "            model.optim.step()\n",
    "            \n",
    "        if batch_idx % 200 == 0:\n",
    "            line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\\tLosses '.format(\n",
    "                epoch, batch_idx * len(data), train_size, 100. * batch_idx / len(train_loader))\n",
    "            losses = ' '.join(['{}: {:.6f}'.format(k, m._loss.item()) for k, m in models.items()])\n",
    "            print(line + losses)\n",
    "            \n",
    "    else:\n",
    "        batch_idx += 1\n",
    "        line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\\tLosses '.format(\n",
    "            epoch, batch_idx * len(data), train_size, 100. * batch_idx / len(train_loader))\n",
    "        losses = ' '.join(['{}: {:.6f}'.format(k, m._loss.item()) for k, m in models.items()])\n",
    "        print(line + losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = {'default': Net(False, False), 'bn': Net(True, False), 'drop': Net(False, True), 'both': Net(True, True)}\n",
    "train_log = {k: [] for k in models}\n",
    "test_log = {k: [] for k in models}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(models, loader, log=None):\n",
    "    test_size = len(loader.sampler)\n",
    "    avg_lambda = lambda l: 'Loss: {:.4f}'.format(l)\n",
    "    acc_lambda = lambda c, p: 'Accuracy: {}/{} ({:.0f}%)'.format(c, test_size, p)\n",
    "    line = lambda i, l, c, p: '{}: '.format(i) + avg_lambda(l) + '\\t' + acc_lambda(c, p)\n",
    "\n",
    "    test_loss = {k: 0. for k in models}\n",
    "    correct = {k: 0. for k in models}\n",
    "    with torch.no_grad():\n",
    "        for data, target in loader:\n",
    "            # output = {k: m(data) for m in models}\n",
    "            for k, m in models.items():\n",
    "                output = m(data)\n",
    "                test_loss[k] += m.loss(output, target, reduction='sum').item() # sum up batch loss\n",
    "                pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n",
    "                correct[k] += pred.eq(target.data.view_as(pred)).cpu().sum().item()\n",
    "    \n",
    "    for k in models:\n",
    "        test_loss[k] /= test_size\n",
    "    correct_pct = {k: 100. * correct[k] / test_size for k in correct}\n",
    "    lines = '\\n'.join([line(k, test_loss[k], correct[k], correct_pct[k]) for k in models]) + '\\n'\n",
    "    report = 'Test set:\\n' + lines\n",
    "    if log is not None:\n",
    "        for k in models:\n",
    "            log[k].append((test_loss[k], correct_pct[k]))\n",
    "    print(report)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "for epoch in range(1, 101):\n",
    "    for model in models.values():\n",
    "        model.train()\n",
    "    train(epoch, models)\n",
    "    for model in models.values():\n",
    "        model.eval()\n",
    "    test(models, valid_loader, test_log)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_graphs(test_log, 'loss')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_graphs(test_log, 'accuracy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_log"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
