{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "import numpy as np\n",
    "import scipy.stats as stats\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# параметры распределений\n",
    "mu0, sigma0 = -2., 1.\n",
    "mu1, sigma1 = 3., 2."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample(d0, d1, n=32):\n",
    "    x0 = d0.sample((n,))\n",
    "    x1 = d1.sample((n,))\n",
    "    y0 = torch.zeros((n, 1))\n",
    "    y1 = torch.ones((n, 1))\n",
    "    return torch.cat([x0, x1], 0), torch.cat([y0, y1], 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "d0 = torch.distributions.Normal(torch.tensor([mu0]), torch.tensor([sigma0]))\n",
    "d1 = torch.distributions.Normal(torch.tensor([mu1]), torch.tensor([sigma1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "layer = nn.Linear(1, 1)\n",
    "print([p.data[0] for p in layer.parameters()])\n",
    "layer_opt = optim.SGD(lr=1e-3, params=list(layer.parameters()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_freq = 500\n",
    "for i in range(10000):\n",
    "    if i%log_freq == 0:\n",
    "        with torch.no_grad():\n",
    "            x, y = sample(d0, d1, 100000)\n",
    "            out = torch.sigmoid(layer(x))\n",
    "            loss = F.binary_cross_entropy(out, y)\n",
    "        print('Ошибка после %d итераций: %f' %(i/log_freq, loss))\n",
    "    layer_opt.zero_grad()\n",
    "    x, y = sample(d0, d1, 1024)\n",
    "    out = torch.sigmoid(layer(x))\n",
    "    loss = F.binary_cross_entropy(out, y)\n",
    "    loss.backward()\n",
    "    layer_opt.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_scale = np.linspace(-10, 10, 5000)\n",
    "d0_pdf = stats.norm.pdf(x_scale, mu0, sigma0) \n",
    "d1_pdf = stats.norm.pdf(x_scale, mu1, sigma1)\n",
    "x_tensor = torch.tensor(x_scale.reshape(-1, 1), dtype=torch.float)\n",
    "with torch.no_grad():\n",
    "    dist = torch.sigmoid(layer(x_tensor)).numpy()\n",
    "ratio = d1_pdf / (d1_pdf + d0_pdf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(x_scale, d0_pdf*2, label='d0') # умножение на 2 для красоты графиков, на распределения не влияет\n",
    "plt.plot(x_scale, d1_pdf*2, label='d1')\n",
    "plt.plot(x_scale, dist.flatten(), label='pred')\n",
    "plt.plot(x_scale, ratio, label='ratio')\n",
    "plt.legend();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print([p.data[0] for p in layer.parameters()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.log(F.sigmoid(torch.tensor(-100.)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "F.logsigmoid(torch.tensor(-100.))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
