-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Could you share the script to train the model? #44
Comments
I am also interested in the training script for the model mentioned by @xianruizhong. Could you kindly share it, if available? Your assistance would be greatly appreciated. Thank you! |
For future runs: Something simple like this would work :) from transformers import AlbertForMaskedLM, AlbertConfig, Trainer, TrainingArguments
from datasets import load_dataset
from transformers import DataCollatorForLanguageModeling
from rxnmapper.tokenization_smiles import SmilesTokenizer
def tokenize_function(examples):
tokenized = tokenizer(examples["rxn"])
return tokenized
# Tokenizer setup
vocab_path = "PATH/TO/vocab.txt"
tokenizer = SmilesTokenizer(vocab_path)
# Dataset setup
dataset_path = (
"PATH/TO/data.csv"
)
dataset = load_dataset(
"csv",
data_files=dataset_path,
)
# Tokenize the dataset
dataset = dataset.map(
tokenize_function,
batched=True,
num_proc=24,
)
# Model setup
alberta_config_path = "PATH/TO/config.json"
model_config = AlbertConfig.from_pretrained(alberta_config_path)
model = AlbertForMaskedLM(model_config)
# Data collator for MLM, using a mask probability of 15%
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)
# Training arguments setup
training_args = TrainingArguments(
output_dir="models/alberta",
learning_rate=2e-4,
# num_train_epochs=3,
max_steps=1_500_000,
weight_decay=0.001,
logging_steps=100,
eval_strategy="no",
save_strategy="steps",
save_steps=1_000,
save_total_limit=2,
save_only_model=True,
per_device_train_batch_size=16,
gradient_accumulation_steps=4,
logging_dir="logs",
logging_first_step=True,
report_to="wandb",
run_name="alberta_test_run",
)
# Trainer setup
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=dataset["train"],
)
# Train the model
trainer.train() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
As title. Thanks!
The text was updated successfully, but these errors were encountered: