-
Notifications
You must be signed in to change notification settings - Fork 112
78 lines (69 loc) · 3.64 KB
/
train.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
name: Train a model
# This workflow is triggered manually from the GitHub website.
# To run it, click the "Actions" tab on the repo page.
on:
workflow_dispatch:
inputs:
model-name:
required: true
description: The name of the Replicate model to publish, in the format `your-replicate-username/desired-model-name`. If the model doesn't already exist, it will be created automatically.
prompt-identifier:
required: true
description: A short string representing your custom trained style or concept. This should be an uncommon string of letters like `zxz`, rather than a common word or name like `sarah` of `dog`. You'll use this string in prompts when running the model like "a pencil sketch of zxz sitting in a meadow".
default: zxz
max-train-steps:
required: true
description: Total number of training steps to perform. Higher numbers produce more accurate results, but take longer to run. Set this to something low like 100 to test that your workflow is working, then run again at a higher number like 2000 to train more accurately.
type: number
default: 2000
# Find your Replicate API token at https://replicate.com/account
env:
replicate-api-token: ${{ secrets.REPLICATE_API_TOKEN }}
jobs:
train:
runs-on: ubuntu-latest
steps:
- name: Check secrets
if: ${{ env.replicate-api-token == '' }}
run: |
echo "🙈 Uh oh! Missing repository secret: REPLICATE_API_TOKEN"
echo "Go to https://replicate.com/account to copy your API token,"
echo "then visit https://github.com/${{ github.repository }}/settings/secrets/actions/new to set it."
echo
echo
echo
exit 1
- name: Checkout code
uses: actions/checkout@v3
- name: Zip training data
run: |
zip -r data.zip data
- name: Upload training data
id: upload-training-data
run: |
RESPONSE=$(curl -s -X POST -H "Authorization: Token ${{ secrets.REPLICATE_API_TOKEN }}" https://dreambooth-api-experimental.replicate.com/v1/upload/data.zip)
curl -X PUT -H "Content-Type: application/zip" --upload-file data.zip "$(jq -r ".upload_url" <<< "$RESPONSE")"
echo "INSTANCE_DATA_URL=$(jq -r '.serving_url' <<< $RESPONSE)"
echo "INSTANCE_DATA_URL=$(jq -r '.serving_url' <<< $RESPONSE)" >> $GITHUB_OUTPUT
- name: Start training
run: |
curl -s -X POST \
-H "Authorization: Token ${{ secrets.REPLICATE_API_TOKEN }}" \
-H "Content-Type: application/json" \
-d '{
"input": {
"instance_prompt": "a photo of a ${{ inputs.prompt-identifier }} person",
"class_prompt": "a photo of a person",
"instance_data": "${{ steps.upload-training-data.outputs.INSTANCE_DATA_URL }}",
"max_train_steps": ${{ inputs.max-train-steps }}
},
"model": "${{ inputs.model-name }}",
"trainer_version": "cd3f925f7ab21afaef7d45224790eedbb837eeac40d22e8fefe015489ab644aa"
}' \
https://dreambooth-api-experimental.replicate.com/v1/trainings
- name: Link to model
run: |
echo "🚂 Your model is now training!"
echo "This takes about 10-30 minutes, depending on how many training steps you're running."
echo "To see your model, visit https://replicate.com/${{ inputs.model-name }} and refresh until you see the prediction form."
echo "For more info on how to run your new model, see https://replicate.com/blog/dreambooth-api"