@@ -17,6 +17,38 @@ describe('LinearRegression', function () {
1717 expect ( roughlyEqual ( lr . intercept as number , 0 ) ) . toBe ( true )
1818 } , 30000 )
1919
20+ it ( 'Works on arrays (small example) with custom callbacks' , async function ( ) {
21+ let trainingHasStarted = false
22+ const onTrainBegin = async ( logs : any ) => {
23+ trainingHasStarted = true
24+ console . log ( 'training begins' )
25+ }
26+ const lr = new LinearRegression ( {
27+ modelFitOptions : { callbacks : [ new tf . CustomCallback ( { onTrainBegin } ) ] }
28+ } )
29+ await lr . fit ( [ [ 1 ] , [ 2 ] ] , [ 2 , 4 ] )
30+ expect ( tensorEqual ( lr . coef , tf . tensor1d ( [ 2 ] ) , 0.1 ) ) . toBe ( true )
31+ expect ( roughlyEqual ( lr . intercept as number , 0 ) ) . toBe ( true )
32+ expect ( trainingHasStarted ) . toBe ( true )
33+ } , 30000 )
34+
35+ it ( 'Works on arrays (small example) with custom callbacks' , async function ( ) {
36+ let trainingHasStarted = false
37+ const onTrainBegin = async ( logs : any ) => {
38+ trainingHasStarted = true
39+ console . log ( 'training begins' )
40+ }
41+ const lr = new LinearRegression ( {
42+ modelFitOptions : { callbacks : [ new tf . CustomCallback ( { onTrainBegin } ) ] }
43+ } )
44+ await lr . fit ( [ [ 1 ] , [ 2 ] ] , [ 2 , 4 ] )
45+
46+ const serialized = await lr . toJSON ( )
47+ const newModel = await fromJSON ( serialized )
48+ expect ( tensorEqual ( newModel . coef , tf . tensor1d ( [ 2 ] ) , 0.1 ) ) . toBe ( true )
49+ expect ( roughlyEqual ( newModel . intercept as number , 0 ) ) . toBe ( true )
50+ } , 30000 )
51+
2052 it ( 'Works on small multi-output example (small example)' , async function ( ) {
2153 const lr = new LinearRegression ( )
2254 await lr . fit (
0 commit comments