Skip to content

Commit 7fa5c42

Browse files
committed
feat: added test case for custom callbacks. works great and somehow serializes.
1 parent 2ddcad9 commit 7fa5c42

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

src/linear_model/LinearRegression.test.ts

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,38 @@ describe('LinearRegression', function () {
1717
expect(roughlyEqual(lr.intercept as number, 0)).toBe(true)
1818
}, 30000)
1919

20+
it('Works on arrays (small example) with custom callbacks', async function () {
21+
let trainingHasStarted = false
22+
const onTrainBegin = async (logs: any) => {
23+
trainingHasStarted = true
24+
console.log('training begins')
25+
}
26+
const lr = new LinearRegression({
27+
modelFitOptions: { callbacks: [new tf.CustomCallback({ onTrainBegin })] }
28+
})
29+
await lr.fit([[1], [2]], [2, 4])
30+
expect(tensorEqual(lr.coef, tf.tensor1d([2]), 0.1)).toBe(true)
31+
expect(roughlyEqual(lr.intercept as number, 0)).toBe(true)
32+
expect(trainingHasStarted).toBe(true)
33+
}, 30000)
34+
35+
it('Works on arrays (small example) with custom callbacks', async function () {
36+
let trainingHasStarted = false
37+
const onTrainBegin = async (logs: any) => {
38+
trainingHasStarted = true
39+
console.log('training begins')
40+
}
41+
const lr = new LinearRegression({
42+
modelFitOptions: { callbacks: [new tf.CustomCallback({ onTrainBegin })] }
43+
})
44+
await lr.fit([[1], [2]], [2, 4])
45+
46+
const serialized = await lr.toJSON()
47+
const newModel = await fromJSON(serialized)
48+
expect(tensorEqual(newModel.coef, tf.tensor1d([2]), 0.1)).toBe(true)
49+
expect(roughlyEqual(newModel.intercept as number, 0)).toBe(true)
50+
}, 30000)
51+
2052
it('Works on small multi-output example (small example)', async function () {
2153
const lr = new LinearRegression()
2254
await lr.fit(

src/linear_model/LinearRegression.ts

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ export interface LinearRegressionParams {
4141
*/
4242
fitIntercept?: boolean
4343
modelFitOptions?: Partial<ModelFitArgs>
44-
4544
}
4645

4746
/*
@@ -53,7 +52,7 @@ Next steps:
5352
/** Linear Least Squares
5453
* @example
5554
* ```js
56-
* import {LinearRegression} from 'scikitjs'
55+
* import { LinearRegression } from 'scikitjs'
5756
*
5857
* let X = [
5958
* [1, 2],
@@ -63,13 +62,16 @@ Next steps:
6362
* [10, 20]
6463
* ]
6564
* let y = [3, 5, 8, 8, 30]
66-
* const lr = new LinearRegression({fitIntercept: false})
65+
* const lr = new LinearRegression({ fitIntercept: false })
6766
await lr.fit(X, y)
6867
lr.coef.print() // probably around [1, 1]
6968
* ```
7069
*/
7170
export class LinearRegression extends SGDRegressor {
72-
constructor({ fitIntercept = true, modelFitOptions }: LinearRegressionParams = {}) {
71+
constructor({
72+
fitIntercept = true,
73+
modelFitOptions
74+
}: LinearRegressionParams = {}) {
7375
let tf = getBackend()
7476
super({
7577
modelCompileArgs: {

0 commit comments

Comments
 (0)