Skip to content

Commit

Permalink
Support custom text formats and recursive (#496)
Browse files Browse the repository at this point in the history
* Add custom text types and recursive

* Add custom text types and recursive

* Fix format

* Update qdrant, Add pdf to unstructured

* Use unstructed as the default text extractor if installed

* Add tests for unstructured

* Update tests env for unstructured

* Fix error if last message is a function call, issue #569

* Remove csv, md and tsv from UNSTRUCTURED_FORMATS

* Update docstring of docs_path

* Update test for get_files_from_dir

* Update docstring of custom_text_types

* Fix missing search_string in update_context

* Add custom_text_types to notebook example
  • Loading branch information
thinkall authored Nov 21, 2023
1 parent ef1c3d3 commit 07646d4
Show file tree
Hide file tree
Showing 7 changed files with 516 additions and 269 deletions.
4 changes: 0 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ jobs:
python -c "import autogen"
pip install -e. pytest mock
pip uninstall -y openai
- name: Install unstructured if not windows
if: matrix.os != 'windows-2019'
run: |
pip install "unstructured[all-docs]"
- name: Test with pytest
if: matrix.python-version != '3.10'
run: |
Expand Down
8 changes: 6 additions & 2 deletions .github/workflows/contrib-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@ jobs:
run: |
python -m pip install --upgrade pip wheel
pip install pytest
- name: Install qdrant_client when python-version is 3.10
if: matrix.python-version == '3.10' || matrix.python-version == '3.8'
- name: Install qdrant_client when python-version is 3.8 and 3.10
if: matrix.python-version == '3.8' || matrix.python-version == '3.10'
run: |
pip install qdrant_client[fastembed]
- name: Install unstructured when python-version is 3.9 and 3.11 and not windows
if: (matrix.python-version == '3.9' || matrix.python-version == '3.11') && matrix.os != 'windows-2019'
run: |
pip install unstructured[all-docs]
- name: Install packages and dependencies for RetrieveChat
run: |
pip install -e .[retrievechat]
Expand Down
44 changes: 31 additions & 13 deletions autogen/agentchat/contrib/qdrant_retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Callable, Dict, List, Optional

from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
from autogen.retrieve_utils import get_files_from_dir, split_files_to_chunks
from autogen.retrieve_utils import get_files_from_dir, split_files_to_chunks, TEXT_FORMATS
import logging

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -45,8 +45,8 @@ def __init__(
prompt will be different for different tasks. The default value is `default`, which supports both code and qa.
- client (Optional, qdrant_client.QdrantClient(":memory:")): A QdrantClient instance. If not provided, an in-memory instance will be assigned. Not recommended for production.
will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function.
- docs_path (Optional, str): the path to the docs directory. It can also be the path to a single file,
or the url to a single file. Default is None, which works only if the collection is already created.
- docs_path (Optional, Union[str, List[str]]): the path to the docs directory. It can also be the path to a single file,
the url to a single file or a list of directories, files and urls. Default is None, which works only if the collection is already created.
- collection_name (Optional, str): the name of the collection.
If key not provided, a default name `autogen-docs` will be used.
- model (Optional, str): the model to use for the retrieve chat.
Expand All @@ -66,11 +66,14 @@ def __init__(
- customized_answer_prefix (Optional, str): the customized answer prefix for the retrieve chat. Default is "".
If not "" and the customized_answer_prefix is not in the answer, `Update Context` will be triggered.
- update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True.
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
- custom_token_count_function (Optional, Callable): a custom function to count the number of tokens in a string.
The function should take a string as input and return three integers (token_count, tokens_per_message, tokens_per_name).
Default is None, tiktoken will be used and may not be accurate for non-OpenAI models.
- custom_text_split_function(Optional, Callable): a custom function to split a string into a list of strings.
- custom_text_split_function (Optional, Callable): a custom function to split a string into a list of strings.
Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`.
- custom_text_types (Optional, List[str]): a list of file types to be processed. Default is `autogen.retrieve_utils.TEXT_FORMATS`.
This only applies to files under the directories in `docs_path`. Explictly included files and urls will be chunked regardless of their types.
- recursive (Optional, bool): whether to search documents recursively in the docs_path. Default is True.
- parallel (Optional, int): How many parallel workers to use for embedding. Defaults to the number of CPU cores.
- on_disk (Optional, bool): Whether to store the collection on disk. Default is False.
- quantization_config: Quantization configuration. If None, quantization will be disabled.
Expand Down Expand Up @@ -111,6 +114,8 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
must_break_at_empty_line=self._must_break_at_empty_line,
embedding_model=self._embedding_model,
custom_text_split_function=self.custom_text_split_function,
custom_text_types=self._custom_text_types,
recursive=self._recursive,
parallel=self._parallel,
on_disk=self._on_disk,
quantization_config=self._quantization_config,
Expand Down Expand Up @@ -139,15 +144,17 @@ def create_qdrant_from_dir(
must_break_at_empty_line: bool = True,
embedding_model: str = "BAAI/bge-small-en-v1.5",
custom_text_split_function: Callable = None,
custom_text_types: List[str] = TEXT_FORMATS,
recursive: bool = True,
parallel: int = 0,
on_disk: bool = False,
quantization_config: Optional[models.QuantizationConfig] = None,
hnsw_config: Optional[models.HnswConfigDiff] = None,
payload_indexing: bool = False,
qdrant_client_options: Optional[Dict] = {},
):
"""Create a Qdrant collection from all the files in a given directory, the directory can also be a single file or a url to
a single file.
"""Create a Qdrant collection from all the files in a given directory, the directory can also be a single file or a
url to a single file.
Args:
dir_path (str): the path to the directory, file or url.
Expand All @@ -156,24 +163,35 @@ def create_qdrant_from_dir(
collection_name (Optional, str): the name of the collection. Default is "all-my-documents".
chunk_mode (Optional, str): the chunk mode. Default is "multi_lines".
must_break_at_empty_line (Optional, bool): Whether to break at empty line. Default is True.
embedding_model (Optional, str): the embedding model to use. Default is "BAAI/bge-small-en-v1.5". The list of all the available models can be at https://qdrant.github.io/fastembed/examples/Supported_Models/.
embedding_model (Optional, str): the embedding model to use. Default is "BAAI/bge-small-en-v1.5".
The list of all the available models can be at https://qdrant.github.io/fastembed/examples/Supported_Models/.
custom_text_split_function (Optional, Callable): a custom function to split a string into a list of strings.
Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`.
custom_text_types (Optional, List[str]): a list of file types to be processed. Default is TEXT_FORMATS.
recursive (Optional, bool): whether to search documents recursively in the dir_path. Default is True.
parallel (Optional, int): How many parallel workers to use for embedding. Defaults to the number of CPU cores
on_disk (Optional, bool): Whether to store the collection on disk. Default is False.
quantization_config: Quantization configuration. If None, quantization will be disabled. Ref: https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/create_collection
hnsw_config: HNSW configuration. If None, default configuration will be used. Ref: https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/create_collection
quantization_config: Quantization configuration. If None, quantization will be disabled.
Ref: https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/create_collection
hnsw_config: HNSW configuration. If None, default configuration will be used.
Ref: https://qdrant.github.io/qdrant/redoc/index.html#tag/collections/operation/create_collection
payload_indexing: Whether to create a payload index for the document field. Default is False.
qdrant_client_options: (Optional, dict): the options for instantiating the qdrant client. Reference: https://github.com/qdrant/qdrant-client/blob/master/qdrant_client/qdrant_client.py#L36-L58.
qdrant_client_options: (Optional, dict): the options for instantiating the qdrant client.
Ref: https://github.com/qdrant/qdrant-client/blob/master/qdrant_client/qdrant_client.py#L36-L58.
"""
if client is None:
client = QdrantClient(**qdrant_client_options)
client.set_model(embedding_model)

if custom_text_split_function is not None:
chunks = split_files_to_chunks(
get_files_from_dir(dir_path), custom_text_split_function=custom_text_split_function
get_files_from_dir(dir_path, custom_text_types, recursive),
custom_text_split_function=custom_text_split_function,
)
else:
chunks = split_files_to_chunks(get_files_from_dir(dir_path), max_tokens, chunk_mode, must_break_at_empty_line)
chunks = split_files_to_chunks(
get_files_from_dir(dir_path, custom_text_types, recursive), max_tokens, chunk_mode, must_break_at_empty_line
)
logger.info(f"Found {len(chunks)} chunks.")

# Check if collection by same name exists, if not, create it with custom options
Expand Down
29 changes: 22 additions & 7 deletions autogen/agentchat/contrib/retrieve_user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
raise ImportError("Please install dependencies first. `pip install pyautogen[retrievechat]`")
from autogen.agentchat.agent import Agent
from autogen.agentchat import UserProxyAgent
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db, TEXT_FORMATS
from autogen.token_count_utils import count_token
from autogen.code_utils import extract_code

Expand Down Expand Up @@ -97,8 +97,8 @@ def __init__(
prompt will be different for different tasks. The default value is `default`, which supports both code and qa.
- client (Optional, chromadb.Client): the chromadb client. If key not provided, a default client `chromadb.Client()`
will be used. If you want to use other vector db, extend this class and override the `retrieve_docs` function.
- docs_path (Optional, str): the path to the docs directory. It can also be the path to a single file,
or the url to a single file. Default is None, which works only if the collection is already created.
- docs_path (Optional, Union[str, List[str]]): the path to the docs directory. It can also be the path to a single file,
the url to a single file or a list of directories, files and urls. Default is None, which works only if the collection is already created.
- collection_name (Optional, str): the name of the collection.
If key not provided, a default name `autogen-docs` will be used.
- model (Optional, str): the model to use for the retrieve chat.
Expand All @@ -124,11 +124,14 @@ def __init__(
- update_context (Optional, bool): if False, will not apply `Update Context` for interactive retrieval. Default is True.
- get_or_create (Optional, bool): if True, will create/return a collection for the retrieve chat. This is the same as that used in chromadb.
Default is False. Will raise ValueError if the collection already exists and get_or_create is False. Will be set to True if docs_path is None.
- custom_token_count_function(Optional, Callable): a custom function to count the number of tokens in a string.
- custom_token_count_function (Optional, Callable): a custom function to count the number of tokens in a string.
The function should take (text:str, model:str) as input and return the token_count(int). the retrieve_config["model"] will be passed in the function.
Default is autogen.token_count_utils.count_token that uses tiktoken, which may not be accurate for non-OpenAI models.
- custom_text_split_function(Optional, Callable): a custom function to split a string into a list of strings.
- custom_text_split_function (Optional, Callable): a custom function to split a string into a list of strings.
Default is None, will use the default function in `autogen.retrieve_utils.split_text_to_chunks`.
- custom_text_types (Optional, List[str]): a list of file types to be processed. Default is `autogen.retrieve_utils.TEXT_FORMATS`.
This only applies to files under the directories in `docs_path`. Explictly included files and urls will be chunked regardless of their types.
- recursive (Optional, bool): whether to search documents recursively in the docs_path. Default is True.
**kwargs (dict): other kwargs in [UserProxyAgent](../user_proxy_agent#__init__).
Example of overriding retrieve_docs:
Expand Down Expand Up @@ -181,6 +184,8 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
self._get_or_create = self._retrieve_config.get("get_or_create", False) if self._docs_path is not None else True
self.custom_token_count_function = self._retrieve_config.get("custom_token_count_function", count_token)
self.custom_text_split_function = self._retrieve_config.get("custom_text_split_function", None)
self._custom_text_types = self._retrieve_config.get("custom_text_types", TEXT_FORMATS)
self._recursive = self._retrieve_config.get("recursive", True)
self._context_max_tokens = self._max_tokens * 0.8
self._collection = True if self._docs_path is None else False # whether the collection is created
self._ipython = get_ipython()
Expand All @@ -189,6 +194,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
self._intermediate_answers = set() # the intermediate answers
self._doc_contents = [] # the contents of the current used doc
self._doc_ids = [] # the ids of the current used doc
self._search_string = "" # the search string used in the current query
# update the termination message function
self._is_termination_msg = (
self._is_termination_msg_retrievechat if is_termination_msg is None else is_termination_msg
Expand Down Expand Up @@ -282,6 +288,8 @@ def _generate_message(self, doc_contents, task="default"):
def _check_update_context(self, message):
if isinstance(message, dict):
message = message.get("content", "")
elif not isinstance(message, str):
message = ""
update_context_case1 = "UPDATE CONTEXT" in message[-20:].upper() or "UPDATE CONTEXT" in message[:20].upper()
update_context_case2 = self.customized_answer_prefix and self.customized_answer_prefix not in message.upper()
return update_context_case1, update_context_case2
Expand Down Expand Up @@ -320,7 +328,9 @@ def _generate_retrieve_user_reply(
if not doc_contents:
for _tmp_retrieve_count in range(1, 5):
self._reset(intermediate=True)
self.retrieve_docs(self.problem, self.n_results * (2 * _tmp_retrieve_count + 1))
self.retrieve_docs(
self.problem, self.n_results * (2 * _tmp_retrieve_count + 1), self._search_string
)
doc_contents = self._get_context(self._results)
if doc_contents:
break
Expand All @@ -329,7 +339,9 @@ def _generate_retrieve_user_reply(
# docs in the retrieved doc results to the context.
for _tmp_retrieve_count in range(5):
self._reset(intermediate=True)
self.retrieve_docs(_intermediate_info[0], self.n_results * (2 * _tmp_retrieve_count + 1))
self.retrieve_docs(
_intermediate_info[0], self.n_results * (2 * _tmp_retrieve_count + 1), self._search_string
)
self._get_context(self._results)
doc_contents = "\n".join(self._doc_contents) # + "\n" + "\n".join(self._intermediate_answers)
if doc_contents:
Expand Down Expand Up @@ -371,6 +383,8 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
get_or_create=self._get_or_create,
embedding_function=self._embedding_function,
custom_text_split_function=self.custom_text_split_function,
custom_text_types=self._custom_text_types,
recursive=self._recursive,
)
self._collection = True
self._get_or_create = True
Expand All @@ -384,6 +398,7 @@ def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str =
embedding_model=self._embedding_model,
embedding_function=self._embedding_function,
)
self._search_string = search_string
self._results = results
print("doc_ids: ", results["ids"])

Expand Down
Loading

0 comments on commit 07646d4

Please sign in to comment.