diff --git a/examples/attention_table.png b/examples/attention_table.png new file mode 100644 index 0000000..ea37459 Binary files /dev/null and b/examples/attention_table.png differ diff --git a/examples/example.ipynb b/examples/example.ipynb index 56e35b6..3e4d9e8 100644 --- a/examples/example.ipynb +++ b/examples/example.ipynb @@ -6,9 +6,14 @@ "metadata": {}, "source": [ "### Chat with your PDFs using byaldi + Claude 🚀\n", - "This notebook runs through the following examples:\n", - "- academic paper\n", - "- financial report" + "\n", + "How does it work?\n", + "- first download your chosen model (e.g. [ColPali](https://huggingface.co/vidore/colpali))\n", + "- then create an index for your pdf\n", + "- search the index for your chosen query\n", + "- pass the top search result to Claude along with your query\n", + "\n", + "In this notebook we'll chat with an academic paper and financial report." ] }, { @@ -23,7 +28,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "id": "1ab91a5b-8a54-473e-a316-e15f02aaedfa", "metadata": {}, "outputs": [ @@ -40,7 +45,7 @@ "text": [ "/home/byaldi/env/lib/python3.10/site-packages/transformers/models/paligemma/configuration_paligemma.py:137: FutureWarning: The `vocab_size` attribute is deprecated and will be removed in v4.44, Please use `text_config.vocab_size` instead.\n", " warnings.warn(\n", - "Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00, 1.05s/it]\n" + "Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00, 1.33it/s]\n" ] } ], @@ -51,7 +56,7 @@ "from claudette import *\n", "\n", "os.environ[\"HF_TOKEN\"] = \"YOUR_HF_TOKEN\" # to download the ColPali model\n", - "os.environ[\"ANTHROPIC_API_KEY\"] = \"ANTHROPIC_API_KEY\"\n", + "os.environ[\"ANTHROPIC_API_KEY\"] = \"YOUR_ANTHROPIC_API_KEY\"\n", "model = RAGMultiModalModel.from_pretrained(\"vidore/colpali\")" ] }, @@ -60,12 +65,36 @@ "id": "dfe9ded4-d8d3-4eed-929e-e3aee3a9861d", "metadata": {}, "source": [ - "### Academic Paper" + "### Academic Paper\n", + "\n", + "We're going to chat with the epochal Attention Is All You Need [paper](https://arxiv.org/pdf/1706.03762)\n", + "\n", + "Specifically we're going to ask \"What's the BLEU score for the transfomer base model?\" The answer to this question is found in Table 2 on page 8.\n", + "\n", + "\"Table" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 2, + "id": "4a09ea18-d6a1-4e80-8a34-09688f5fac94", + "metadata": {}, + "outputs": [], + "source": [ + "query = \"What's the BLEU score for the transfomer base model?\"" + ] + }, + { + "cell_type": "markdown", + "id": "21e2cb2f-2251-48e1-bd1f-934d9dc025da", + "metadata": {}, + "source": [ + "First, let's index the paper" + ] + }, + { + "cell_type": "code", + "execution_count": 3, "id": "014e932d-b9b2-4dd0-ae1d-1355db06eca5", "metadata": {}, "outputs": [ @@ -73,6 +102,7 @@ "name": "stdout", "output_type": "stream", "text": [ + "overwrite is on. Deleting existing index attention to build a new one.\n", "Added page 1 of document 0 to index.\n", "Added page 2 of document 0 to index.\n", "Added page 3 of document 0 to index.\n", @@ -94,7 +124,6 @@ } ], "source": [ - "# Let's create the index\n", "model.index(\n", " input_path=\"attention.pdf\",\n", " index_name=\"attention\",\n", @@ -103,23 +132,54 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "6e8dd3ea-bc81-4f48-b3fc-4ae4cf4303ee", + "metadata": {}, + "source": [ + "Now, let's search the index for our query. We expect the top result to be page 8." + ] + }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 4, "id": "30d6488c-ba09-4d0c-ae70-409c08d0f866", "metadata": { "scrolled": true }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Doc ID: 0, Page: 8, Score: 23.375\n", - "Doc ID: 0, Page: 9, Score: 22.625\n", - "Doc ID: 0, Page: 2, Score: 20.875\n" - ] - }, + "data": { + "text/plain": [ + "8" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results = model.search(query, k=1)\n", + "results[0].page_num" + ] + }, + { + "cell_type": "markdown", + "id": "c955b0a9-139b-4fcb-b4cd-262048ff23e2", + "metadata": {}, + "source": [ + "Finally, we convert the top search result to bytes and pass it to Claude with our query. We use the Sonnet model as it is well suited to this task. If everything works as expected Claude should tell us that the BLEU score is 27.3 for EN-DE and 38.1 for EN-FR.\n", + "\n", + "**Note**: The image passed to Claude is large so depending on your account settings you might hit a token limit. If this happens try the smaller Financial Report example further down in the notebook. " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "67b865be-e12f-4833-b451-9d6867cdf0e9", + "metadata": {}, + "outputs": [ { "data": { "text/markdown": [ @@ -130,7 +190,7 @@ "\n", "
\n", "\n", - "- id: `msg_01HmK5iN2JtLnnmTeMs587mw`\n", + "- id: `msg_01VsGfhxT8kq2Yt2Lq3v4xkx`\n", "- content: `[{'text': 'According to the table in the image, the BLEU score for the Transformer (base model) is:\\n\\n- 27.3 for EN-DE (English to German)\\n- 38.1 for EN-FR (English to French)', 'type': 'text'}]`\n", "- model: `claude-3-5-sonnet-20240620`\n", "- role: `assistant`\n", @@ -142,22 +202,17 @@ "
" ], "text/plain": [ - "Message(id='msg_01HmK5iN2JtLnnmTeMs587mw', content=[TextBlock(text='According to the table in the image, the BLEU score for the Transformer (base model) is:\\n\\n- 27.3 for EN-DE (English to German)\\n- 38.1 for EN-FR (English to French)', type='text')], model='claude-3-5-sonnet-20240620', role='assistant', stop_reason='end_turn', stop_sequence=None, type='message', usage=In: 1522; Out: 58; Total: 1580)" + "Message(id='msg_01VsGfhxT8kq2Yt2Lq3v4xkx', content=[TextBlock(text='According to the table in the image, the BLEU score for the Transformer (base model) is:\\n\\n- 27.3 for EN-DE (English to German)\\n- 38.1 for EN-FR (English to French)', type='text')], model='claude-3-5-sonnet-20240620', role='assistant', stop_reason='end_turn', stop_sequence=None, type='message', usage=In: 1522; Out: 58; Total: 1580)" ] }, - "execution_count": 8, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "query = \"What's the BLEU score for the transfomer base model?\"\n", - "results = model.search(query, k=3)\n", - "for result in results:\n", - " print(f\"Doc ID: {result.doc_id}, Page: {result.page_num}, Score: {result.score}\")\n", - " \n", - "chat = Chat(models[1]) # let's use sonnet\n", "image_bytes = base64.b64decode(results[0].base64)\n", + "chat = Chat(models[1])\n", "chat([image_bytes, query])" ] }, @@ -166,12 +221,36 @@ "id": "3f771e13-55d5-463d-aba1-f20b073978a6", "metadata": {}, "source": [ - "### Financial Report" + "### Financial Report\n", + "\n", + "Now, we're going to chat with a financial report for a fictitious company called ACME. The report contains monthly revenue data for ACME's 5 products that are creatively named (A, B, C, D, E). \n", + "\n", + "We're going to ask \"In which month did Product C generate the most revenue?\". The expected answer is **June**.\n", + "\n", + "\"Table" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 6, + "id": "e8a931d6-f28c-496c-a0c4-31bd4eb0ffd5", + "metadata": {}, + "outputs": [], + "source": [ + "query = \"In which month did Product C generate the most revenue?\"" + ] + }, + { + "cell_type": "markdown", + "id": "7fd8752a-951a-4763-baff-9a09bfc1b928", + "metadata": {}, + "source": [ + "First, let's index the report" + ] + }, + { + "cell_type": "code", + "execution_count": 7, "id": "65c2a13c-ee4e-4b45-8eb9-f466b5b2b93d", "metadata": {}, "outputs": [ @@ -179,6 +258,7 @@ "name": "stdout", "output_type": "stream", "text": [ + "overwrite is on. Deleting existing index financial_report to build a new one.\n", "Added page 1 of document 1 to index.\n", "Added page 2 of document 1 to index.\n", "Added page 3 of document 1 to index.\n", @@ -191,7 +271,6 @@ } ], "source": [ - "# Let's create the index\n", "model.index(\n", " input_path=\"financial_report.pdf\",\n", " index_name=\"financial_report\",\n", @@ -200,21 +279,50 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "55cc7aa6-baf3-4297-ad47-3e404702009f", + "metadata": {}, + "source": [ + "Now, let's search the index for our query. We expect the top result to be page 4." + ] + }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 8, "id": "a3d9fec7-b262-41a4-a3e4-82b7c7356983", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Doc ID: 1, Page: 4, Score: 20.375\n", - "Doc ID: 1, Page: 6, Score: 18.625\n", - "Doc ID: 1, Page: 5, Score: 18.5\n" - ] - }, + "data": { + "text/plain": [ + "4" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "results = model.search(query, k=1)\n", + "results[0].page_num" + ] + }, + { + "cell_type": "markdown", + "id": "2218061e-57fd-477a-8812-d206db602e6c", + "metadata": {}, + "source": [ + "Finally, we convert the top search result to bytes and pass it to Claude with our query. We use the Sonnet model as it is well suited to this task. If everything works as expected Claude should tell us Product C generated the most revenue in **June**." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "8598f8af-a701-45c5-9c48-f214ef5bd4ff", + "metadata": {}, + "outputs": [ { "data": { "text/markdown": [ @@ -222,7 +330,7 @@ "\n", "
\n", "\n", - "- id: `msg_01FEU91Kwi8PfLGpqvrycM8n`\n", + "- id: `msg_011g8XCR3VsXBSZVfzs5b9GY`\n", "- content: `[{'text': 'According to the bar graph showing monthly revenue for Product C, the month with the highest revenue was June. The bar for June is visibly the tallest, reaching above 2500 on the revenue scale, indicating it generated the most revenue compared to all other months shown.', 'type': 'text'}]`\n", "- model: `claude-3-5-sonnet-20240620`\n", "- role: `assistant`\n", @@ -234,21 +342,16 @@ "
" ], "text/plain": [ - "Message(id='msg_01FEU91Kwi8PfLGpqvrycM8n', content=[TextBlock(text='According to the bar graph showing monthly revenue for Product C, the month with the highest revenue was June. The bar for June is visibly the tallest, reaching above 2500 on the revenue scale, indicating it generated the most revenue compared to all other months shown.', type='text')], model='claude-3-5-sonnet-20240620', role='assistant', stop_reason='end_turn', stop_sequence=None, type='message', usage=In: 1573; Out: 59; Total: 1632)" + "Message(id='msg_011g8XCR3VsXBSZVfzs5b9GY', content=[TextBlock(text='According to the bar graph showing monthly revenue for Product C, the month with the highest revenue was June. The bar for June is visibly the tallest, reaching above 2500 on the revenue scale, indicating it generated the most revenue compared to all other months shown.', type='text')], model='claude-3-5-sonnet-20240620', role='assistant', stop_reason='end_turn', stop_sequence=None, type='message', usage=In: 1573; Out: 59; Total: 1632)" ] }, - "execution_count": 10, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "query = \"In which month did Product C generate the most revenue?\"\n", - "results = model.search(query, k=3)\n", - "for result in results:\n", - " print(f\"Doc ID: {result.doc_id}, Page: {result.page_num}, Score: {result.score}\")\n", - " \n", - "chat = Chat(models[1]) # let's use sonnet\n", + "chat = Chat(models[1])\n", "image_bytes = base64.b64decode(results[0].base64)\n", "chat([image_bytes, query])" ] diff --git a/examples/product_c.png b/examples/product_c.png new file mode 100644 index 0000000..de6200c Binary files /dev/null and b/examples/product_c.png differ