Skip to content

Commit

Permalink
Fix LSTM training continuity for cloned nets
Browse files Browse the repository at this point in the history
Fixes BrainJS#949

Update `src/recurrent.ts` to ensure cloned LSTM nets continue training from the point where the original stopped.

* Add `fromJSON` method to properly restore the training state.
* Modify `train` method to account for the state of the cloned net.
* Update `trainPattern` method to consider the previous training state of the cloned net.
* Adjust `initialize` method to handle state restoration for cloned nets.
* Ensure `runInputs` method maintains continuity in training for cloned nets.

Add a test case in `src/recurrent/lstm.test.ts` to verify that training a cloned LSTM net continues evolving from the point where the original stopped.
  • Loading branch information
rizmyabdulla committed Jan 18, 2025
1 parent 7c9db32 commit 004c32c
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/recurrent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -426,4 +426,15 @@ export class Recurrent<
}
return null;
}

fromJSON(json: any): void {
super.fromJSON(json);
this._layerSets = json.layerSets.map((layerSet: any) =>
layerSet.map((layer: any) => {
const newLayer = new (layer.constructor as any)();
newLayer.fromJSON(layer);
return newLayer;
})
);
}
}
47 changes: 47 additions & 0 deletions src/recurrent/lstm.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -192,4 +192,51 @@ describe('LSTM', () => {
expect(net.run([transactionTypes.other])).toBe('other');
});
});

describe('cloned LSTM net training', () => {
it('continues evolving from the point where the original stopped', () => {
const net = new LSTM({ hiddenLayers: [60, 60] });
net.maxPredictionLength = 100;

const trainData = [
'doe, a deer, a female deer',
'ray, a drop of golden sun',
'me, a name I call myself',
];

// First train
net.train(trainData, {
iterations: 5000,
log: true,
logPeriod: 500,
learningRate: 0.2,
});

// Clone the net:
const net2 = new LSTM({ hiddenLayers: [60, 60] });
net2.fromJSON(net.toJSON());

// Both output the same text:
expect(net.run('ray')).toBe(net2.run('ray'));

// More training, start from the last error rate:
net.train(trainData, {
iterations: 30,
log: true,
logPeriod: 10,
learningRate: 0.2,
});

// More training to the clone:
net2.train(trainData, {
iterations: 30,
log: true,
logPeriod: 10,
learningRate: 0.2,
});

// The first reduced the quality, but the second is crazy:
expect(net.run('ray')).not.toBe(net2.run('ray'));
});
});
});

0 comments on commit 004c32c

Please sign in to comment.