Skip to content

Commit

Permalink
adding the url filter and its pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
May15farhat committed Jul 2, 2024
1 parent 52eff0d commit 18b552a
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 1 deletion.
Empty file removed arabicweb/__init__.py
Empty file.
3 changes: 2 additions & 1 deletion arabicweb/pipelines/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .c4_filters import C4BadWordsFilter, C4ParagraphFilter, C4QualityFilter
from .gopher_quality_filter import GopherQualityFilter
from .fineweb_quality_filter import FineWebQualityFilter
from .fineweb_quality_filter import FineWebQualityFilter
from .url_filter import URLFilter
182 changes: 182 additions & 0 deletions arabicweb/pipelines/filters/url_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""
This file has been modified from its original version obtained from Dattrove.
Changes Made:
- Updated the nomalize function to be able to read and parse the list of the arabic banned words.
- Added the banned_words_path argument to include our custom arabic banned words.
- Th custom list we used is under the en_ar_banned_words name and contain both arabic and english banned words.
"""

import os
import re
import tarfile
from typing import Iterable

from huggingface_hub import cached_assets_path
from loguru import logger
import unicodedata
from datatrove.data import Document
from datatrove.io import safely_create_file
from datatrove.utils._import_utils import ASSETS_PATH

from datatrove.pipeline.writers.disk_base import DiskWriter
from datatrove.pipeline.filters.base_filter import BaseFilter


normalizer = re.compile(r"[^a-zA-Z0-9]+")


NUMBERS_PATTERN = re.compile(r"\d+")
WHITESPACE_PATTERN = re.compile(r"\s+")


def normalize(text, config=None):
if config is None:
config = {}

# Default configuration
default_config = {
"lowercase": True,
"norm_whitespace": True,
"remove_punctuation": True,
"norm_unicode_diacritics": True,
"norm_numbers": True,
}

# Merge provided config with default config
final_config = {**default_config, **config}

# Apply text normalization based on configuration
if final_config["lowercase"]:
text = text.lower()

if final_config["norm_whitespace"]:
text = WHITESPACE_PATTERN.sub(" ", text.strip())

if final_config["remove_punctuation"]:
punctuation = '!/—":%1〈&(、━\\【#%「」,】;+^]~"《„\';\'{|∶´[=-`*.(–?!:$~«〉,><》)?)。…@_."}►»'
text = "".join([char for char in text if char not in punctuation])

if final_config["norm_unicode_diacritics"]:
text = "".join(c for c in unicodedata.normalize("NFD", text) if unicodedata.category(c) != "Mn")

if final_config["norm_numbers"]:
text = NUMBERS_PATTERN.sub("0", text)

return text.strip()


def parse_list(line, do_normalize=True):
return {normalize(x) if do_normalize else x.strip() for x in line if x[0] != "#"}


def get_list(abs_path: str, file_name: str, extra: set, do_normalize: bool = True):
with open(os.path.join(abs_path, file_name)) as f:
return parse_list(f, do_normalize).union(extra)


class URLFilter(BaseFilter):
"""
Performs filtering based on samples urls.
Samples are removed if:
- their domain is present on `block_listed_domains`
- if their subdomain is present on `block_listed_domains`
- if the full url is present on `block_listed_url`
- if any word from `banned_words` is in the url
- if there are at least `soft_word_threshold` words from `soft_banned_words` in the url
- if any word from `banned_subwords` is a substring of the url
"""

name = "😈 Url-filter"
_requires_dependencies = ["tldextract", "fasteners", ("ahocorasick", "pyahocorasick")]

def __init__(
self,
soft_word_threshold: int = 2,
extra_domains: Iterable = None,
extra_urls: Iterable = None,
banned_words: Iterable = None,
banned_subwords: Iterable = None,
soft_banned_words: Iterable = None,
use_integrated_lists: bool = True,
banned_words_path: str | None = None,
exclusion_writer: DiskWriter = None,
):
import ahocorasick
from tldextract import TLDExtract

super().__init__(exclusion_writer)
self.soft_word_threshold = soft_word_threshold
self.block_listed_domains = parse_list(extra_domains, do_normalize=False) if extra_domains else set()
self.block_listed_url = parse_list(extra_urls, do_normalize=False) if extra_urls else set()
self.banned_words = parse_list(banned_words) if banned_words else set()
self.banned_subwords = parse_list(banned_subwords) if banned_subwords else set()
self.soft_banned_words = parse_list(soft_banned_words) if soft_banned_words else set()
self.use_integrated_lists = use_integrated_lists
self._downloaded = False
self.tldextractor = TLDExtract()
self.banned_words_path = banned_words_path

self.banned_subwords_automaton = ahocorasick.Automaton(ahocorasick.STORE_INTS)
for word in self.banned_subwords:
self.banned_subwords_automaton.add_word(word, len(self.banned_subwords_automaton))

if not self.use_integrated_lists:
self.banned_subwords_automaton.make_automaton()

def download_data(self):
if self._downloaded or not self.use_integrated_lists:
return
download_dir = cached_assets_path(library_name="datatrove", namespace="filters", subfolder="url_filter")
file_to_lock = os.path.join(download_dir, "url_filterblacklists.tar.gz")

def do_extract():
logger.info("💥 Extracting url filter blacklists...")
with tarfile.open(os.path.join(ASSETS_PATH, "url_filterblacklists.tar.gz"), "r:gz") as tar:
tar.extractall(download_dir)
logger.info("💥 Extracted url filter blacklists.")

safely_create_file(file_to_lock, do_extract)

self.block_listed_domains = get_list(
download_dir, "adult/domains", self.block_listed_domains, do_normalize=False
)
self.block_listed_url = get_list(download_dir, "adult/urls", self.block_listed_url, do_normalize=False)
self.banned_words = get_list(self.banned_words_path, "en_ar_banned_words.txt", self.banned_words)
self.banned_subwords = get_list(ASSETS_PATH, "banned_subwords.txt", self.banned_subwords)
self.soft_banned_words = get_list(ASSETS_PATH, "soft_banned_words.txt", self.soft_banned_words)
for word in self.banned_subwords:
self.banned_subwords_automaton.add_word(word, len(self.banned_subwords_automaton))
self.banned_subwords_automaton.make_automaton()
self._downloaded = True

def filter(self, document: Document) -> bool | tuple[bool, str]:
self.download_data()
url = document.metadata.get("source")

assert url, "Document does not have url in its metadata"
url_info = self.tldextractor(url)

if url_info.registered_domain in self.block_listed_domains:
return False, "domain"

if url_info.fqdn in self.block_listed_domains:
return False, "subdomain"

if url in self.block_listed_url:
return False, "url"

url_words = set(normalizer.split(url))
print("banned words are : ", self.banned_words)
if any(word in url_words for word in self.banned_words):
return False, "hard_blacklisted"

nb_soft_words = sum([word in url_words for word in self.soft_banned_words])
if nb_soft_words >= self.soft_word_threshold:
return False, "soft_blacklisted"

normalized_space = normalize(url)
if self.banned_subwords and next(self.banned_subwords_automaton.iter(normalized_space), False):
return False, "blacklisted_subword"

return True
49 changes: 49 additions & 0 deletions arabicweb/process/url_filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import sys
import os


sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from datatrove.executor.base import PipelineExecutor
from datatrove.executor.local import LocalPipelineExecutor
from datatrove.pipeline.readers import JsonlReader
from datatrove.pipeline.writers.jsonl import JsonlWriter
from pipelines.filters import URLFilter


PART = "url-filtered"
S3_FILTERED_PATH = f""
S3_LOGS_FOLDER = f"{S3_FILTERED_PATH}/logs/{PART}"
S3_DATA_PATH = f""


INPUT_READER = JsonlReader(
data_folder=S3_DATA_PATH,
recursive=True,
file_progress=True,
)


# Make sure to set the banned_words_path to the dir containing you customized list
def run_pipeline():
pipeline = [
INPUT_READER,
URLFilter(
banned_words_path="",
exclusion_writer=JsonlWriter(output_folder=f"{S3_FILTERED_PATH}/{PART}/removed-bad-URLs"),
),
JsonlWriter(output_folder=f"{S3_FILTERED_PATH}/{PART}/output"),
]

executor: PipelineExecutor = LocalPipelineExecutor(
pipeline=pipeline,
workers=50,
tasks=50,
logging_dir=f"{S3_LOGS_FOLDER}/url-filtered",
)

print(executor.run())


if __name__ == "__main__":
run_pipeline()

0 comments on commit 18b552a

Please sign in to comment.