-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaidx.py
122 lines (110 loc) · 3.84 KB
/
aidx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import sys
from dotenv import load_dotenv
import os
from src.main.create_mimicllm import main as create_mimicllm_main
from src.main.tokenize_mimicllm import main as tokenize_mimicllm_main
from src.main.create_parquet_datasets import main as create_parquet_datasets_main
from src.main.fine_tune_model import main as fine_tune_model_main
load_dotenv("config/.env")
HOST_IP = os.environ["DATABASE_IP"]
DATABASE_USER = os.environ["DATABASE_USER"]
DATABASE_PASSWORD = os.environ["DATABASE_PASSWORD"]
DATABASE_PORT = os.environ["DATABASE_PORT"]
def main():
parser = argparse.ArgumentParser(description="AIDx Entry Point")
subparsers = parser.add_subparsers(title="subcommands", required=True)
# Define a subparser for each command
create_mimicllm_parser = subparsers.add_parser(
"create-mimicllm", help="Create the MIMIC-LLM database"
)
create_mimicllm_parser.add_argument(
"--rewrite-log-db",
action="store_true",
help="If set, will rewrite the log database",
)
create_mimicllm_parser.add_argument(
"--discharge-note-only",
action="store_true",
help="If set, will only process discharge notes",
)
create_mimicllm_parser.set_defaults(func=create_mimicllm_main)
tokenize_mimicllm_parser = subparsers.add_parser(
"tokenize-mimicllm", help="Tokenize the MIMIC-LLM database"
)
tokenize_mimicllm_parser.add_argument(
"--batch-size",
type=int,
default=128,
help="The batch size to use for tokenization",
)
tokenize_mimicllm_parser.add_argument(
"--rewrite-log-db",
action="store_true",
help="If set, will rewrite the log database",
)
tokenize_mimicllm_parser.set_defaults(func=tokenize_mimicllm_main)
create_parquet_datasets_parser = subparsers.add_parser(
"create-parquet-datasets", help="Create the Parquet train/test datasets"
)
create_parquet_datasets_parser.add_argument(
"--chunk-size",
type=int,
default=10000,
help="The chunk size to use for querying the database",
)
create_parquet_datasets_parser.add_argument(
"--test-size",
type=float,
default=0.2,
help="The test size to use for splitting the data",
)
create_parquet_datasets_parser.add_argument(
"--parquet-dir",
type=str,
default="data/parquet",
help="The directory to store the Parquet files",
)
create_parquet_datasets_parser.set_defaults(func=create_parquet_datasets_main)
fine_tune_model_parser = subparsers.add_parser(
"fine-tune-model", help="Fine tune the model"
)
fine_tune_model_parser.add_argument(
"--model-name",
type=str,
default="mistralai/Mixtral-8x7B-Instruct-v0.1",
help="The model name to use for fine tuning",
)
fine_tune_model_parser.add_argument(
"--parquet-dir",
type=str,
default="data/mimicllm",
help="The directory with your parquet files",
)
fine_tune_model_parser.add_argument(
'--stream-data',
action='store_true',
help='If set, will stream the data from the parquet files',
)
fine_tune_model_parser.add_argument(
"--clear-cache",
action="store_true",
help="If set, will clear the cache",
)
fine_tune_model_parser.add_argument(
"--max-batch-size",
type=int,
default=16,
help="The maximum batch size to use for fine tuning",
)
fine_tune_model_parser.add_argument(
"--num-epochs",
type=int,
default=2,
help="The number of epochs to use for fine tuning",
)
fine_tune_model_parser.set_defaults(func=fine_tune_model_main)
args = parser.parse_args(args=None if sys.argv[1:] else ["--help"])
args.func(args)
if __name__ == "__main__":
main()