-
Notifications
You must be signed in to change notification settings - Fork 953
Expand file tree
/
Copy pathtest_config.py
More file actions
131 lines (114 loc) · 5.07 KB
/
test_config.py
File metadata and controls
131 lines (114 loc) · 5.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import unittest
from typing import get_args, List, Union
import torch
from executorch.devtools.bundled_program.config import DataContainer
from executorch.devtools.bundled_program.util.test_util import (
get_random_test_suites,
get_random_test_suites_with_eager_model,
SampleModel,
)
from executorch.extension.pytree import tree_flatten
class TestConfig(unittest.TestCase):
def assertTensorEqual(self, t1: torch.Tensor, t2: torch.Tensor) -> None:
self.assertTrue((t1 == t2).all())
def assertIOListEqual(
self,
tl1: List[Union[bool, float, int, torch.Tensor]],
tl2: List[Union[bool, float, int, torch.Tensor]],
) -> None:
self.assertEqual(len(tl1), len(tl2))
for t1, t2 in zip(tl1, tl2):
if isinstance(t1, torch.Tensor):
assert isinstance(t2, torch.Tensor)
self.assertTensorEqual(t1, t2)
else:
self.assertTrue(t1 == t2)
def test_create_test_suites(self) -> None:
n_sets_per_plan_test = 10
n_method_test_suites = 5
(
rand_method_names,
rand_inputs,
rand_expected_outpus,
method_test_suites,
) = get_random_test_suites(
n_model_inputs=2,
model_input_sizes=[[2, 2], [2, 2]],
n_model_outputs=1,
model_output_sizes=[[2, 2]],
dtype=torch.int32,
n_sets_per_plan_test=n_sets_per_plan_test,
n_method_test_suites=n_method_test_suites,
)
self.assertEqual(len(method_test_suites), n_method_test_suites)
# Compare to see if bundled execution plan test match expectations.
for method_test_suite_idx in range(n_method_test_suites):
self.assertEqual(
method_test_suites[method_test_suite_idx].method_name,
rand_method_names[method_test_suite_idx],
)
for testset_idx in range(n_sets_per_plan_test):
self.assertIOListEqual(
# pyre-ignore [6]: expected `List[Union[bool, float, int, Tensor]]` but got `Sequence[Union[bool, float, int, Tensor]]
rand_inputs[method_test_suite_idx][testset_idx],
method_test_suites[method_test_suite_idx]
.test_cases[testset_idx]
.inputs,
)
self.assertIOListEqual(
# pyre-ignore [6]: expected `List[Union[bool, float, int, Tensor]]` but got `Sequence[Union[bool, float, int, Tensor]]
rand_expected_outpus[method_test_suite_idx][testset_idx],
method_test_suites[method_test_suite_idx]
.test_cases[testset_idx]
.expected_outputs,
)
def test_create_test_suites_from_eager_model(self) -> None:
n_sets_per_plan_test = 10
eager_model = SampleModel()
method_names: List[str] = eager_model.method_names
rand_inputs, method_test_suites = get_random_test_suites_with_eager_model(
eager_model=eager_model,
method_names=method_names,
n_model_inputs=2,
model_input_sizes=[[2, 2], [2, 2]],
dtype=torch.int32,
n_sets_per_plan_test=n_sets_per_plan_test,
)
self.assertEqual(len(method_test_suites), len(method_names))
# Compare to see if bundled testcases match expectations.
for method_test_suite_idx in range(len(method_names)):
self.assertEqual(
method_test_suites[method_test_suite_idx].method_name,
method_names[method_test_suite_idx],
)
for testset_idx in range(n_sets_per_plan_test):
ri = rand_inputs[method_test_suite_idx][testset_idx]
self.assertIOListEqual(
# pyre-ignore [6]: expected `List[Union[bool, float, int, Tensor]]` but got `Sequence[Union[bool, float, int, Tensor]]
ri,
method_test_suites[method_test_suite_idx]
.test_cases[testset_idx]
.inputs,
)
model_outputs = getattr(
eager_model, method_names[method_test_suite_idx]
)(*ri)
if isinstance(model_outputs, get_args(DataContainer)):
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
flatten_eager_model_outputs = tree_flatten(model_outputs)
else:
flatten_eager_model_outputs = [
model_outputs,
]
self.assertIOListEqual(
flatten_eager_model_outputs,
method_test_suites[method_test_suite_idx]
.test_cases[testset_idx]
.expected_outputs,
)