{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "# Ноутбук для решения задачи урока 5.1\n"
      ],
      "metadata": {
        "id": "zTBgKsVMMwV2"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Импортируем датасет\n",
        "\n",
        "import pandas as pd\n",
        "\n",
        "df = pd.read_csv(\"https://stepik.org/media/attachments/lesson/1028705/mulimodal_questions.csv\")\n",
        "df"
      ],
      "metadata": {
        "id": "RwhtSKrLNGQs"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Установим необходимую версию библиотеки\n",
        "\n",
        "!pip install bitsandbytes==0.40.0 -qq"
      ],
      "metadata": {
        "id": "4Ab_4iA6ZGCu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Скачаем zip-архив с картинками\n",
        "\n",
        "!wget https://stepik.org/media/attachments/lesson/1028705/images.zip\n",
        "!unzip images.zip"
      ],
      "metadata": {
        "id": "ANjMhA_WdStG"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from PIL import Image\n",
        "import torch\n",
        "from transformers import pipeline\n",
        "from transformers import BitsAndBytesConfig"
      ],
      "metadata": {
        "id": "zhd_OkoSYswj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "quantization_config = BitsAndBytesConfig(\n",
        "    load_in_4bit=True, # подгружаем сразу оптимальную версию\n",
        "    bnb_4bit_compute_dtype=torch.float16\n",
        ")"
      ],
      "metadata": {
        "id": "-SEJ24lnTAzJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# указываем тип задачи и модель\n",
        "model_id = \"llava-hf/llava-1.5-7b-hf\"\n",
        "\n",
        "pipe = pipeline(\"image-to-text\", model=model_id, model_kwargs={\"quantization_config\": quantization_config})"
      ],
      "metadata": {
        "id": "5CzWltM5TA1Y"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "images = [f\"images/im{i}.jpg\" for i in range(0, 10)]"
      ],
      "metadata": {
        "id": "gh-xqBpCbQay"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Получим предсказания для изображений\n",
        "\n",
        "ans = []\n",
        "for im_path, question in zip(images, df['question'].values):\n",
        "\n",
        "    im = Image.open(im_path)\n",
        "\n",
        "    prompt = f\"USER:<image>\\n{question}. answer with an int number\\nASSISTANT:\"\n",
        "    outputs = pipe(im, prompt=prompt, generate_kwargs={\"max_new_tokens\": 200})\n",
        "    outputs = int(outputs[0]['generated_text'].split('ASSISTANT: ')[1])\n",
        "    ans.append(outputs)\n",
        "    #break # уберите break, когда убедитесь, что код работает для одного изображения\n",
        "print(ans)"
      ],
      "metadata": {
        "id": "HaU7l0SbTA3r"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# запишем ответы в датафрейм\n",
        "\n",
        "df['answer'] = ans\n",
        "df.drop(columns=['image_name'], inplace=True)\n",
        "df.to_csv('answer.csv', index=False)\n",
        "df"
      ],
      "metadata": {
        "id": "-cAn0EufTBPx",
        "collapsed": true
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}