{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "from utils import mnist, plot_graphs, plot_mnist, to_onehot\n",
    "import numpy as np\n",
    "import os \n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "root_folder = 'FC_CAAE_results'\n",
    "fixed_folder = root_folder + '/Fixed_results/'\n",
    "recon_folder = root_folder + '/Recon_results/'\n",
    "\n",
    "if os.path.isdir(root_folder):\n",
    "    !rm -r $root_folder\n",
    "os.mkdir(root_folder)\n",
    "os.mkdir(fixed_folder)\n",
    "os.mkdir(recon_folder)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist_tanh = transforms.Compose([\n",
    "                transforms.ToTensor(),\n",
    "                transforms.Normalize((0.5,), (0.5,)),\n",
    "                lambda x: x.to(device)\n",
    "           ])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)\n",
    "\n",
    "lr = 0.0001\n",
    "prior_size = 10\n",
    "train_epoch = 1000\n",
    "batch_size = 250\n",
    "train_loader, valid_loader, test_loader = mnist(batch_size=batch_size, valid=10000, transform=mnist_tanh)\n",
    "fixed_z = torch.randn((10, prior_size)).repeat((1,10)).view(-1, prior_size).to(device)\n",
    "fixed_z_label = to_onehot(torch.tensor(list(range(10))).repeat((10)), 10).to(device)\n",
    "fixed_data, fixed_label = next(iter(test_loader))\n",
    "fixed_data = fixed_data[:100].to(device)\n",
    "fixed_label = to_onehot(fixed_label[:100], 10).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data, label = next(iter(train_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FullyConnected(nn.Module):\n",
    "    def __init__(self, sizes, dropout=False, activation_fn=nn.Tanh(), flatten=False, \n",
    "                 last_fn=None, first_fn=None, device='cpu'):\n",
    "        super(FullyConnected, self).__init__()\n",
    "        layers = []\n",
    "        self.flatten = flatten\n",
    "        if first_fn is not None:\n",
    "            layers.append(first_fn)\n",
    "        for i in range(len(sizes) - 2):\n",
    "            layers.append(nn.Linear(sizes[i], sizes[i+1]))\n",
    "            if dropout:\n",
    "                layers.append(nn.Dropout(dropout))\n",
    "            layers.append(activation_fn) # нам не нужен дропаут и фнкция активации в последнем слое\n",
    "        else: \n",
    "            layers.append(nn.Linear(sizes[-2], sizes[-1]))\n",
    "        if last_fn is not None:\n",
    "            layers.append(last_fn)\n",
    "        self.model = nn.Sequential(*layers)\n",
    "        self.to(device)\n",
    "        \n",
    "    def forward(self, x, y=None):\n",
    "        if self.flatten:\n",
    "            x = x.view(x.shape[0], -1)\n",
    "        if y is not None:\n",
    "            x = torch.cat([x, y], dim=1)\n",
    "        return self.model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Enc = FullyConnected([28*28, 1024, 1024, prior_size], activation_fn=nn.LeakyReLU(0.2), flatten=True, device=device)\n",
    "Dec = FullyConnected([prior_size+10, 1024, 1024, 28*28], activation_fn=nn.LeakyReLU(0.2), last_fn=nn.Tanh(), device=device)\n",
    "Disc = FullyConnected([prior_size+10, 1024, 1024, 1], dropout=0.3, activation_fn=nn.LeakyReLU(0.2), device=device)\n",
    "\n",
    "Enc_optimizer = optim.Adam(Enc.parameters(), lr=lr)\n",
    "Dec_optimizer = optim.Adam(Dec.parameters(), lr=lr)\n",
    "Disc_optimizer = optim.Adam(Disc.parameters(), lr=lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_log = {'E': [],'AE': [], 'D': []}\n",
    "test_log = {'E': [],'AE': [], 'D': []}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_zeros = torch.zeros((batch_size, 1)).to(device)\n",
    "batch_ones = torch.ones((batch_size, 1)).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch, Enc, Dec, Disc, log=None):\n",
    "    train_size = len(train_loader.sampler)\n",
    "    for batch_idx, (data, label) in enumerate(train_loader):\n",
    "        label = to_onehot(label, 10, device)\n",
    "        # train D\n",
    "        Enc.zero_grad()\n",
    "        Disc.zero_grad()\n",
    "        \n",
    "        z = torch.randn((batch_size, prior_size)).to(device)\n",
    "        z_label = to_onehot(np.random.randint(0, 10, (batch_size)), 10, device)\n",
    "\n",
    "        fake_pred = Disc(Enc(data), label)\n",
    "        true_pred = Disc(z, z_label)\n",
    "\n",
    "        \n",
    "        fake_loss = F.binary_cross_entropy_with_logits(fake_pred, batch_zeros)\n",
    "        true_loss = F.binary_cross_entropy_with_logits(true_pred, batch_ones)\n",
    "        \n",
    "        Disc_loss = 0.5*(fake_loss + true_loss)\n",
    "        \n",
    "        Disc_loss.backward()\n",
    "        Disc_optimizer.step()\n",
    "        \n",
    "        # train AE\n",
    "        Enc.zero_grad()\n",
    "        Dec.zero_grad()\n",
    "        Disc.zero_grad()\n",
    "        \n",
    "        z = torch.randn((batch_size, prior_size))\n",
    "        z_label = to_onehot(np.random.randint(0, 10, (batch_size)), 10)\n",
    "        \n",
    "        latent = Enc(data)\n",
    "        reconstructed = Dec(latent, label).view(-1, 1, 28, 28)\n",
    "        fake_pred = Disc(latent, label)\n",
    "        \n",
    "        Enc_loss = F.binary_cross_entropy_with_logits(fake_pred, batch_ones)\n",
    "        AE_loss = F.mse_loss(reconstructed, data)\n",
    "        G_loss = AE_loss + Enc_loss\n",
    "        \n",
    "        G_loss.backward()\n",
    "        Dec_optimizer.step()\n",
    "        Enc_optimizer.step()\n",
    "            \n",
    "        if batch_idx % 100 == 0:\n",
    "            line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\\tLosses '.format(\n",
    "                epoch, batch_idx * len(data), train_size, 100. * batch_idx / len(train_loader))\n",
    "            losses = 'E: {:.4f}, AE: {:.4f}, D: {:.4f}'.format(Enc_loss.item(), AE_loss.item(), Disc_loss.item())\n",
    "            print(line + losses)\n",
    "            \n",
    "    else:\n",
    "        batch_idx += 1\n",
    "        line = 'Train Epoch: {} [{}/{} ({:.0f}%)]\\tLosses '.format(\n",
    "            epoch, batch_idx * len(data), train_size, 100. * batch_idx / len(train_loader))\n",
    "        losses = 'E: {:.4f}, AE: {:.4f}, D: {:.4f}'.format(Enc_loss.item(), AE_loss.item(), Disc_loss.item())\n",
    "        print(line + losses)\n",
    "        log['E'].append(Enc_loss.item())\n",
    "        log['AE'].append(AE_loss.item())\n",
    "        log['D'].append(Disc_loss.item())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(Enc, Dec, Disc, loader, epoch, log=None):\n",
    "    test_size = len(loader)\n",
    "    E_loss = 0.\n",
    "    AE_loss = 0.\n",
    "    D_loss = 0.\n",
    "    test_loss = {'E': 0., 'AE': 0., 'D': 0.}\n",
    "    with torch.no_grad():\n",
    "        for data, label in loader:\n",
    "            label = to_onehot(label, 10, device)\n",
    "            z = torch.randn((batch_size, prior_size)).to(device)\n",
    "            z_label = to_onehot(np.random.randint(0, 10, (batch_size)), 10, device)\n",
    "            latent = Enc(data)\n",
    "            reconstructed = Dec(latent, label).view(-1, 1, 28, 28)\n",
    "            fake_pred = Disc(latent, label)\n",
    "            true_pred = Disc(z, z_label)\n",
    "        \n",
    "            fake_loss = F.binary_cross_entropy_with_logits(fake_pred, batch_zeros).item()\n",
    "            true_loss = F.binary_cross_entropy_with_logits(true_pred, batch_ones).item()\n",
    "            \n",
    "            D_loss += 0.5*(fake_loss + true_loss)\n",
    "            E_loss += F.binary_cross_entropy_with_logits(fake_pred, batch_ones).item()\n",
    "            AE_loss += F.mse_loss(reconstructed, data)\n",
    "            \n",
    "        E_loss /= test_size\n",
    "        D_loss /= test_size\n",
    "        AE_loss /= test_size\n",
    "\n",
    "        fixed_gen = Dec(fixed_z, fixed_z_label).cpu().data.numpy().reshape(100, 1, 28, 28)\n",
    "        plot_mnist(fixed_gen, (10, 10), True, fixed_folder + '%03d.png' % epoch)\n",
    "        fixed_reconstruction = Dec(Enc(fixed_data), fixed_label).cpu().data.numpy().reshape(100, 1, 28, 28)\n",
    "        plot_mnist(fixed_reconstruction, (10, 10), True, recon_folder + '%03d.png' % epoch)\n",
    "        \n",
    "    report = 'Test losses. E: {:.4f}, AE: {:.4f}, D: {:.4f}'.format(E_loss, AE_loss, D_loss)\n",
    "    print(report)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(1, 1001):\n",
    "    Enc.train()\n",
    "    Dec.train()\n",
    "    Disc.train()\n",
    "    train(epoch, Enc, Dec, Disc, train_log)\n",
    "    Enc.eval()\n",
    "    Dec.eval()\n",
    "    Disc.eval()\n",
    "    test(Enc, Dec, Disc, valid_loader, epoch, test_log)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
