-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfine_tune2.py
90 lines (74 loc) · 3.65 KB
/
fine_tune2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from sentence_transformers import SentenceTransformer, InputExample
from torch.utils.data import DataLoader
from torch.optim import AdamW
import torch.nn as nn
# Load pre-trained MiniLM model
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
# Prepare your dataset with sentence and corresponding labels
train_examples = [
InputExample(texts=["Pick the object"], label=0), # pick = 0
InputExample(texts=["Pick the product"], label=0), # pick = 0
InputExample(texts=["Pick the cube"], label=0), # pick = 0
InputExample(texts=["Pick the thing"], label=0), # pick = 0
InputExample(texts=["Lift the thing"], label=0), # lift = 0
InputExample(texts=["Lift the object"], label=0), # lift = 0
InputExample(texts=["Lift the cube"], label=0), # lift = 0
InputExample(texts=["Lift it up"], label=0), # lift = 0
InputExample(texts=["Grab the cube"], label=0), # grab = 0
InputExample(texts=["Grab the object"], label=0), # grab = 0
InputExample(texts=["Grab the product"], label=0), # grab = 0
InputExample(texts=["Get the cube"], label=0), # get = 0
InputExample(texts=["Get it"], label=0), # get = 0
InputExample(texts=["Get the object"], label=0), # get = 0
InputExample(texts=["Get the product"], label=0), # get = 0
InputExample(texts=["Place the object on table"], label=1), # place = 1
InputExample(texts=["Place the product on table"], label=1), # place = 1
InputExample(texts=["Place the object"], label=1), # place = 1
InputExample(texts=["Place it down"], label=1), # place = 1
InputExample(texts=["Put down the object"], label=1), # put down = 1
InputExample(texts=["Put it down"], label=1), # put down = 1
InputExample(texts=["Leave it"], label=1), # leave = 1
InputExample(texts=["Leave it down"], label=1), # leave = 1
InputExample(texts=["Move to home position"], label=2), # move = 2
InputExample(texts=["Go home"], label=2), # go = 2
InputExample(texts=["Go to home position"], label=2), # go = 2
InputExample(texts=["Parking position"], label=2), # park = 2
]
# Custom collate function to process InputExample objects
def collate_fn(batch):
texts = [example.texts[0] for example in batch]
labels = torch.tensor([example.label for example in batch], dtype=torch.long)
return texts, labels
# Create a DataLoader with the custom collate function
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16, collate_fn=collate_fn)
# Freeze pre-trained layers (optional)
for param in model.parameters():
param.requires_grad = False
# Add a classification layer
classifier = nn.Linear(model.get_sentence_embedding_dimension(), 3) # Assuming 3 classes
# Define optimizer and loss function
optimizer = AdamW(classifier.parameters(), lr=2e-3)
loss_fn = nn.CrossEntropyLoss()
# Training loop
epochs = 500 # Increase number of epochs
for epoch in range(epochs):
model.train()
classifier.train()
total_loss = 0
for texts, labels in train_dataloader:
# Encode the sentences to get embeddings
embeddings = model.encode(texts, convert_to_tensor=True)
# Forward pass through classifier
outputs = classifier(embeddings)
# Calculate the loss
loss = loss_fn(outputs, labels)
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_dataloader)}")
# Save the fine-tuned classifier
torch.save(classifier.state_dict(), "fine-tuned-classifier-robotic-commands.pt")
model.save("fine-tuned-sentence-transformer")