-
Notifications
You must be signed in to change notification settings - Fork 326
Expand file tree
/
Copy pathmultidimensional_data.py
More file actions
39 lines (30 loc) · 1.26 KB
/
multidimensional_data.py
File metadata and controls
39 lines (30 loc) · 1.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
from modAL.batch import uncertainty_batch_sampling
from modAL.expected_error import expected_error_reduction
from modAL.models import ActiveLearner
from modAL.uncertainty import entropy_sampling, margin_sampling
from sklearn.base import BaseEstimator
class MockClassifier(BaseEstimator):
def __init__(self, n_classes=2):
self.n_classes = n_classes
def fit(self, X, y):
return self
def predict(self, X):
return np.random.randint(0, self.n_classes, shape=(len(X), 1))
def predict_proba(self, X):
return np.ones(shape=(len(X), self.n_classes))/self.n_classes
if __name__ == '__main__':
X_train = np.random.rand(10, 5, 5)
y_train = np.random.randint(0, 2, size=10)
X_pool = np.random.rand(10, 5, 5)
y_pool = np.random.randint(0, 2, size=10)
strategies = [margin_sampling, entropy_sampling, uncertainty_batch_sampling, expected_error_reduction]
for query_strategy in strategies:
print("testing %s..." % query_strategy.__name__)
# max margin sampling
learner = ActiveLearner(
estimator=MockClassifier(), query_strategy=query_strategy,
X_training=X_train, y_training=y_train
)
learner.query(X_pool)
learner.teach(X_pool, y_pool)