{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "from utils import mnist, plot_graphs, plot_mnist\n",
    "import numpy as np\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, valid_loader, test_loader = mnist(valid=10000, transform=transforms.ToTensor())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ConvLayer(nn.Module):\n",
    "    def __init__(self, size, padding=1, pool_layer=nn.MaxPool2d(2, stride=2),\n",
    "                 bn=False, dropout=False, activation_fn=nn.ReLU()):\n",
    "        super(ConvLayer, self).__init__()\n",
    "        layers = []\n",
    "        layers.append(nn.Conv2d(size[0], size[1], size[2], padding=padding))\n",
    "        if pool_layer is not None:\n",
    "            layers.append(pool_layer)\n",
    "        if bn:\n",
    "            layers.append(nn.BatchNorm2d(size[1]))\n",
    "        if dropout:\n",
    "            layers.append(nn.Dropout2d())\n",
    "        layers.append(activation_fn)\n",
    "        \n",
    "        self.model = nn.Sequential(*layers)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return self.model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FullyConnected(nn.Module):\n",
    "    def __init__(self, sizes, dropout=False, activation_fn=nn.Tanh):\n",
    "        super(FullyConnected, self).__init__()\n",
    "        layers = []\n",
    "        \n",
    "        for i in range(len(sizes) - 2):\n",
    "            layers.append(nn.Linear(sizes[i], sizes[i+1]))\n",
    "            if dropout:\n",
    "                layers.append(nn.Dropout())\n",
    "            layers.append(activation_fn())\n",
    "        else: # нам не нужен дропаут и фнкция активации в последнем слое\n",
    "            layers.append(nn.Linear(sizes[-2], sizes[-1]))\n",
    "        \n",
    "        self.model = nn.Sequential(*layers)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        return self.model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self, batchnorm=False, dropout=False, lr=1e-4, l2=0.):\n",
    "        super(Net, self).__init__()\n",
    "        \n",
    "        \n",
    "        self._conv1 = ConvLayer([1, 16, 3], bn=batchnorm)\n",
    "        self._conv2 = ConvLayer([16, 32, 3], bn=batchnorm, activation_fn=nn.Sigmoid())\n",
    "        \n",
    "        self.fc = FullyConnected([32*7*7, 10], dropout=dropout)\n",
    "        \n",
    "        self._loss = None\n",
    "        self.optim = optim.Adam(self.parameters(), lr=lr, weight_decay=l2)\n",
    "    \n",
    "    def conv(self, x):\n",
    "        x = self._conv1(x)\n",
    "        x = self._conv2(x)\n",
    "        return x\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.conv(x)\n",
    "        x = x.view(-1, 32*7*7)\n",
    "        x = self.fc(x)\n",
    "        return x\n",
    "    \n",
    "    def loss(self, output, target, **kwargs):\n",
    "        self._loss = F.cross_entropy(output, target, **kwargs)\n",
    "        self._correct = output.data.max(1, keepdim=True)[1]\n",
    "        self._correct = self._correct.eq(target.data.view_as(self._correct)).to(torch.float).cpu().mean()\n",
    "        return self._loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = {'bn': Net(True), 'drop': Net(False, True), 'plain': Net()}\n",
    "train_log = {k: [] for k in models}\n",
    "test_log = {k: [] for k in models}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch, models, log=None):\n",
    "    train_size = len(train_loader.sampler)\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        for model in models.values():\n",
    "            model.optim.zero_grad()\n",
    "            output = model(data)\n",
    "            loss = model.loss(output, target)\n",
    "            loss.backward()\n",
    "            model.optim.step()\n",
    "            \n",
    "        if batch_idx % 200 == 0:\n",
    "            line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\\tLosses '.format(\n",
    "                epoch, batch_idx * len(data), train_size, 100. * batch_idx / len(train_loader))\n",
    "            losses = ' '.join(['{}: {:.6f}'.format(k, m._loss.item()) for k, m in models.items()])\n",
    "            print(line + losses)\n",
    "            \n",
    "    else:\n",
    "        batch_idx += 1\n",
    "        line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\\tLosses '.format(\n",
    "            epoch, batch_idx * len(data), train_size, 100. * batch_idx / len(train_loader))\n",
    "        losses = ' '.join(['{}: {:.6f}'.format(k, m._loss.item()) for k, m in models.items()])\n",
    "        if log is not None:\n",
    "            for k in models:\n",
    "                log[k].append((models[k]._loss, models[k]._correct))\n",
    "        print(line + losses)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(models, loader, log=None):\n",
    "    test_size = len(loader.sampler)\n",
    "    avg_lambda = lambda l: 'Loss: {:.4f}'.format(l)\n",
    "    acc_lambda = lambda c, p: 'Accuracy: {}/{} ({:.0f}%)'.format(c, test_size, p)\n",
    "    line = lambda i, l, c, p: '{}: '.format(i) + avg_lambda(l) + '\\t' + acc_lambda(c, p)\n",
    "\n",
    "    test_loss = {k: 0. for k in models}\n",
    "    correct = {k: 0. for k in models}\n",
    "    with torch.no_grad():\n",
    "        for data, target in loader:\n",
    "            output = {k: m(data) for k, m in models.items()}\n",
    "            for k, m in models.items():\n",
    "                test_loss[k] += m.loss(output[k], target, reduction='sum').item() # sum up batch loss\n",
    "                pred = output[k].data.max(1, keepdim=True)[1] # get the index of the max log-probability\n",
    "                correct[k] += pred.eq(target.data.view_as(pred)).cpu().sum().item()\n",
    "    \n",
    "    for k in models:\n",
    "        test_loss[k] /= test_size\n",
    "    correct_pct = {k: c / test_size for k, c in correct.items()}\n",
    "    lines = '\\n'.join([line(k, test_loss[k], correct[k], 100*correct_pct[k]) for k in models]) + '\\n'\n",
    "    report = 'Test set:\\n' + lines\n",
    "    if log is not None:\n",
    "        for k in models:\n",
    "            log[k].append((test_loss[k], correct_pct[k]))\n",
    "    print(report)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train Epoch: 1 [0/50000 (0%)]\tLosses bn: 2.345084 drop: 2.314619 plain: 2.376831\n",
      "Train Epoch: 1 [10000/50000 (20%)]\tLosses bn: 1.897495 drop: 2.278836 plain: 2.303455\n",
      "Train Epoch: 1 [20000/50000 (40%)]\tLosses bn: 1.538954 drop: 2.127944 plain: 2.251878\n",
      "Train Epoch: 1 [30000/50000 (60%)]\tLosses bn: 1.134661 drop: 1.686998 plain: 2.077407\n",
      "Train Epoch: 1 [40000/50000 (80%)]\tLosses bn: 0.925215 drop: 1.266169 plain: 1.634986\n",
      "Train Epoch: 1 [50000/50000 (100%)]\tLosses bn: 0.796221 drop: 0.981006 plain: 1.247663\n",
      "Test set:\n",
      "bn: Loss: 0.7953\tAccuracy: 8999.0/10000 (90%)\n",
      "drop: Loss: 0.9641\tAccuracy: 8176.0/10000 (82%)\n",
      "plain: Loss: 1.2313\tAccuracy: 7646.0/10000 (76%)\n",
      "\n",
      "Train Epoch: 2 [0/50000 (0%)]\tLosses bn: 0.910296 drop: 1.106905 plain: 1.297997\n",
      "Train Epoch: 2 [10000/50000 (20%)]\tLosses bn: 0.612191 drop: 0.681521 plain: 0.881276\n",
      "Train Epoch: 2 [20000/50000 (40%)]\tLosses bn: 0.555682 drop: 0.652948 plain: 0.722826\n",
      "Train Epoch: 2 [30000/50000 (60%)]\tLosses bn: 0.463708 drop: 0.514407 plain: 0.604591\n",
      "Train Epoch: 2 [40000/50000 (80%)]\tLosses bn: 0.450578 drop: 0.489749 plain: 0.571646\n",
      "Train Epoch: 2 [50000/50000 (100%)]\tLosses bn: 0.277848 drop: 0.337389 plain: 0.384886\n",
      "Test set:\n",
      "bn: Loss: 0.3730\tAccuracy: 9374.0/10000 (94%)\n",
      "drop: Loss: 0.4383\tAccuracy: 8923.0/10000 (89%)\n",
      "plain: Loss: 0.5157\tAccuracy: 8714.0/10000 (87%)\n",
      "\n",
      "Train Epoch: 3 [0/50000 (0%)]\tLosses bn: 0.493524 drop: 0.637629 plain: 0.764546\n",
      "Train Epoch: 3 [10000/50000 (20%)]\tLosses bn: 0.368176 drop: 0.433181 plain: 0.476079\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(1, 41):\n",
    "    for model in models.values():\n",
    "        model.train()\n",
    "    train(epoch, models, train_log)\n",
    "    for model in models.values():\n",
    "        model.eval()\n",
    "    test(models, valid_loader, test_log)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_graphs(test_log, 'loss')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_graphs(test_log, 'accuracy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
