Skip to content

Commit ec71323

Browse files
committed
feat: updated serialization
1 parent 7fbea07 commit ec71323

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+128
-488
lines changed

src/cluster/KMeans.test.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import { KMeans } from './KMeans'
2-
import { fromObject } from '../index'
1+
import { fromObject, KMeans } from '../index'
32
// Next steps: Improve on kmeans cluster testing
43
describe('KMeans', () => {
54
const X = [

src/compose/ColumnTransformer.test.ts

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
import { ColumnTransformer } from './ColumnTransformer'
2-
import { MinMaxScaler } from '../preprocessing/MinMaxScaler'
3-
import { SimpleImputer } from '../impute/SimpleImputer'
1+
import {
2+
fromObject,
3+
SimpleImputer,
4+
MinMaxScaler,
5+
ColumnTransformer
6+
} from '../index'
47
import * as dfd from 'danfojs-node'
58

69
describe('ColumnTransformer', function () {
@@ -30,4 +33,26 @@ describe('ColumnTransformer', function () {
3033

3134
expect(result.arraySync()).toEqual(expected)
3235
})
36+
it('ColumnTransformer serialize/deserialize test', async function () {
37+
const X = [
38+
[2, 2], // [1, .5]
39+
[2, 3], // [1, .75]
40+
[0, NaN], // [0, 1]
41+
[2, 0] // [.5, 0]
42+
]
43+
let newDf = new dfd.DataFrame(X)
44+
45+
const transformer = new ColumnTransformer({
46+
transformers: [
47+
['minmax', new MinMaxScaler(), [0]],
48+
['simpleImpute', new SimpleImputer({ strategy: 'median' }), [1]]
49+
]
50+
})
51+
52+
transformer.fitTransform(newDf)
53+
let obj = await transformer.toObject()
54+
let myResult = await fromObject(obj)
55+
56+
expect(myResult.transformers.length).toEqual(2)
57+
})
3358
})

src/compose/ColumnTransformer.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ export class ColumnTransformer extends Serialize {
7676
transformers = [],
7777
remainder = 'drop'
7878
}: ColumnTransformerParams = {}) {
79+
super()
7980
this.transformers = transformers
8081
this.remainder = remainder
8182
}

src/dummy/DummyClassifier.test.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import { DummyClassifier } from './DummyClassifier'
2-
import { fromObject } from '../simpleSerializer'
1+
import { DummyClassifier, fromObject } from '../index'
32
describe('DummyClassifier', function () {
43
it('Use DummyClassifier on simple example (mostFrequent)', function () {
54
const clf = new DummyClassifier()

src/dummy/DummyRegressor.test.ts

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import { DummyRegressor } from './DummyRegressor'
2-
import { toObject, fromObject } from '../simpleSerializer'
1+
import { DummyRegressor, fromObject } from '../index'
2+
33
describe('DummyRegressor', function () {
44
it('Use DummyRegressor on simple example (mean)', function () {
55
const reg = new DummyRegressor()
@@ -68,12 +68,13 @@ describe('DummyRegressor', function () {
6868
name: 'DummyRegressor',
6969
EstimatorType: 'regressor',
7070
strategy: 'constant',
71-
constant: 10
71+
constant: 10,
72+
quantile: undefined
7273
}
7374

7475
reg.fit(X, y)
7576

76-
expect(saveResult).toEqual(await toObject(reg))
77+
expect(saveResult).toEqual(await reg.toObject())
7778
})
7879

7980
it('Should load serialized DummyRegressor', async function () {
@@ -92,7 +93,7 @@ describe('DummyRegressor', function () {
9293
]
9394

9495
reg.fit(X, y)
95-
const saveReg = await toObject(reg)
96+
const saveReg = await reg.toObject()
9697
const newReg = await fromObject(saveReg)
9798

9899
expect(newReg.predict(predictX).arraySync()).toEqual([10, 10, 10])

src/ensemble/VotingClassifier.test.ts

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
import { makeVotingClassifier, VotingClassifier } from './VotingClassifier'
2-
import { DummyClassifier } from '../dummy/DummyClassifier'
3-
4-
import { LogisticRegression } from '../linear_model/LogisticRegression'
1+
import {
2+
makeVotingClassifier,
3+
VotingClassifier,
4+
DummyClassifier,
5+
LogisticRegression,
6+
fromObject
7+
} from '../index'
58

69
describe('VotingClassifier', function () {
710
it('Use VotingClassifier on simple example (voting = hard)', async function () {
@@ -118,8 +121,8 @@ describe('VotingClassifier', function () {
118121

119122
await voter.fit(X, y)
120123

121-
const savedModel = (await voter.toJson()) as string
122-
const newModel = new VotingClassifier({}).fromJson(savedModel)
124+
const savedModel = await voter.toObject()
125+
const newModel = await fromObject(savedModel)
123126

124127
expect(newModel.predict(X).arraySync()).toEqual([1, 1, 1, 1, 1])
125128
}, 30000)

src/ensemble/VotingClassifier.ts

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ import { Scikit1D, Scikit2D } from '../types'
22
import { tf } from '../shared/globals'
33
import { ClassifierMixin } from '../mixins'
44
import { LabelEncoder } from '../preprocessing/LabelEncoder'
5-
import { fromJson, toJson } from './serializeEnsemble'
65

76
/*
87
Next steps:
@@ -154,15 +153,6 @@ export class VotingClassifier extends ClassifierMixin {
154153
): Promise<Array<tf.Tensor1D> | Array<tf.Tensor2D>> {
155154
return (await this.fit(X, y)).transform(X)
156155
}
157-
158-
public fromJson(model: string) {
159-
return fromJson(this, model)
160-
}
161-
162-
public async toJson(): Promise<string> {
163-
const classJson = JSON.parse(super.toJson() as string)
164-
return toJson(this, classJson)
165-
}
166156
}
167157

168158
export function makeVotingClassifier(...args: any[]) {

src/ensemble/VotingRegressor.test.ts

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
import { makeVotingRegressor, VotingRegressor } from './VotingRegressor'
2-
import { DummyRegressor } from '../dummy/DummyRegressor'
3-
import { LinearRegression } from '../linear_model/LinearRegression'
1+
import {
2+
makeVotingRegressor,
3+
VotingRegressor,
4+
fromObject,
5+
DummyRegressor,
6+
LinearRegression
7+
} from '../index'
48

59
describe('VotingRegressor', function () {
610
it('Use VotingRegressor on simple example ', async function () {
@@ -51,8 +55,8 @@ describe('VotingRegressor', function () {
5155

5256
await voter.fit(X, y)
5357

54-
const savedModel = (await voter.toJson()) as string
55-
const newModel = new VotingRegressor({}).fromJson(savedModel)
58+
const savedModel = await voter.toObject()
59+
const newModel = await fromObject(savedModel)
5660
expect(newModel.score(X, y)).toEqual(voter.score(X, y))
5761
}, 30000)
5862
})

src/ensemble/VotingRegressor.ts

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import { Scikit1D, Scikit2D } from '../types'
22
import { tf } from '../shared/globals'
33
import { RegressorMixin } from '../mixins'
4-
import { fromJson, toJson } from './serializeEnsemble'
54
/*
65
Next steps:
76
0. Write validation code to check Estimator inputs
@@ -95,15 +94,6 @@ export class VotingRegressor extends RegressorMixin {
9594
public async fitTransform(X: Scikit2D, y: Scikit1D) {
9695
return (await this.fit(X, y)).transform(X)
9796
}
98-
99-
public fromJson(model: string) {
100-
return fromJson(this, model) as this
101-
}
102-
103-
public async toJson(): Promise<string> {
104-
const classJson = JSON.parse(super.toJson() as string)
105-
return toJson(this, classJson)
106-
}
10797
}
10898

10999
/**

src/ensemble/serializeEnsemble.ts

Lines changed: 0 additions & 90 deletions
This file was deleted.

0 commit comments

Comments
 (0)