-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
50 lines (38 loc) · 1.18 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from pathlib import Path
from fastai.text.all import *
import learner_ext
from data_ext import SequentialTextBlock
from transforms_ext import *
def get_annotations(o):
with open(path / "../labels" / o.name) as f:
labels = f.read()
return labels.split()
if __name__ == "__main__":
path = Path("./dialogue-corpus/texts")
files = get_text_files(path)
print(f"Number of text files: {len(files)}")
dls = DataBlock(
blocks=[
SequentialTextBlock.from_folder(
path, tok=BaseTokenizer(split_char=None), rules=[]
),
SequentialCategoryBlock,
],
get_items=get_text_files,
get_y=lambda o: get_annotations(o),
splitter=RandomSplitter(),
).dataloaders(path, path=path, seq_len=256, bs=64)
xb, yb = dls.one_batch()
print(xb.shape, yb.shape)
learn = learner_ext.sequential_model_learner(
dls, AWD_LSTM, drop_mult=0.3, metrics=accuracy
)
learn.fit_one_cycle(1, 3e-3)
learn.unfreeze()
learn.fit_one_cycle(
10,
3e-3,
cbs=SaveModelCallback(
monitor="accuracy", every_epoch=True, fname="dialogue_model"
),
)