Skip to content

Commit 980af6f

Browse files
committed
Extract dates
1 parent 7a36913 commit 980af6f

8 files changed

Lines changed: 98 additions & 36 deletions

File tree

js/src/graphs.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ type Graph2D = Graph & {
2323

2424
export type PointData = {
2525
label: string
26-
points: [number, number][]
26+
points: [(number| string), (number | string)][]
2727
}
2828

2929
type PointGraph = Graph2D & {
30-
x_ticks: number[]
30+
x_ticks: (number | string)[]
3131
x_tick_labels: string[]
32-
y_ticks: number[]
32+
y_ticks: (number | string)[]
3333
y_tick_labels: string[]
3434
elements: PointData[]
3535
}

js/tests/envVars.test.ts

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import { expect } from 'vitest'
22

3-
import { sandboxTest } from './setup'
3+
import { isDebug, sandboxTest } from './setup'
44
import { CodeInterpreter } from '../src'
55

66
// Skip this test if we are running in debug mode — the pwd and user in the testing docker container are not the same as in the actual sandbox.
7-
sandboxTest('env vars', async () => {
7+
sandboxTest.skipIf(isDebug)('env vars', async () => {
88
const sandbox = await CodeInterpreter.create({
99
envs: { TEST_ENV_VAR: 'supertest' },
1010
})
@@ -29,7 +29,7 @@ sandboxTest('env vars on sandbox override', async () => {
2929
envs: { FOO: 'bar', SBX: 'value' },
3030
})
3131
await sandbox.notebook.execCell(
32-
"import os; os.environ['FOO'] = 'bar'; os.environ['RUNTIME_ENV'] = 'value'"
32+
"import os; os.environ['FOO'] = 'bar'; os.environ['RUNTIME_ENV'] = 'runtime'"
3333
)
3434
const result = await sandbox.notebook.execCell(
3535
"import os; os.getenv('FOO')",
@@ -41,10 +41,12 @@ sandboxTest('env vars on sandbox override', async () => {
4141
const result2 = await sandbox.notebook.execCell(
4242
"import os; os.getenv('RUNTIME_ENV')"
4343
)
44-
expect(result2.results[0].text.trim()).toEqual('value')
44+
expect(result2.results[0].text.trim()).toEqual('runtime')
4545

46-
const result3 = await sandbox.notebook.execCell("import os; os.getenv('SBX')")
47-
expect(result3.results[0].text.trim()).toEqual('value')
46+
if (!isDebug) {
47+
const result3 = await sandbox.notebook.execCell("import os; os.getenv('SBX')")
48+
expect(result3.results[0].text.trim()).toEqual('value')
49+
}
4850

4951
const result4 = await sandbox.notebook.execCell("import os; os.getenv('FOO')")
5052
expect(result4.results[0].text.trim()).toEqual('bar')

js/tests/graphs/line.test.ts

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,20 @@ sandboxTest('line', async ({ sandbox }) => {
66
const code = `
77
import numpy as np
88
import matplotlib.pyplot as plt
9+
import datetime
910
1011
# Generate x values
11-
x = np.linspace(0, 2*np.pi, 100)
12+
dates = [datetime.date(2023, 9, 1) + datetime.timedelta(seconds=i) for i in range(100)]
1213
14+
x = np.linspace(0, 2*np.pi, 100)
1315
# Calculate y values
1416
y_sin = np.sin(x)
1517
y_cos = np.cos(x)
1618
1719
# Create the plot
1820
plt.figure(figsize=(10, 6))
19-
plt.plot(x, y_sin, label='sin(x)')
20-
plt.plot(x, y_cos, label='cos(x)')
21+
plt.plot(dates, y_sin, label='sin(x)')
22+
plt.plot(dates, y_cos, label='cos(x)')
2123
2224
# Add labels and title
2325
plt.xlabel("Time (s)")
@@ -39,9 +41,10 @@ plt.show()
3941
expect(graph.x_unit).toBe('s')
4042
expect(graph.y_unit).toBe('Hz')
4143

42-
expect(graph.x_ticks.every((tick: number) => typeof tick === 'number')).toBe(
44+
expect(graph.x_ticks.every((tick: number) => typeof tick === 'string')).toBe(
4345
true
4446
)
47+
expect(new Date(graph.x_ticks[0])).toBeInstanceOf(Date)
4548
expect(graph.y_ticks.every((tick: number) => typeof tick === 'number')).toBe(
4649
true
4750
)
@@ -63,16 +66,17 @@ plt.show()
6366
expect(
6467
firstLine.points.every(
6568
(point: [number, number]) =>
66-
typeof point[0] === 'number' && typeof point[1] === 'number'
69+
typeof point[0] === "string" && typeof point[1] === 'number'
6770
)
6871
).toBe(true)
72+
expect(new Date(firstLine.points[0][0])).toEqual(new Date('2023-09-01T00:00:00.000Z'))
6973

7074
expect(secondLine.label).toBe('cos(x)')
7175
expect(secondLine.points.length).toBe(100)
7276
expect(
7377
secondLine.points.every(
7478
(point: [number, number]) =>
75-
typeof point[0] === 'number' && typeof point[1] === 'number'
79+
typeof point[0] === 'string' && typeof point[1] === 'number'
7680
)
7781
).toBe(true)
7882
})

python/e2b_code_interpreter/graphs.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,31 +40,36 @@ def __init__(self, **kwargs):
4040

4141
class PointData:
4242
label: str
43-
points: List[Tuple[float, float]]
43+
points: List[Tuple[Union[str, int, float], Union[str, int, float]]]
4444

4545
def __init__(self, **kwargs):
4646
self.label = kwargs["label"]
47-
self.points = [(float(x), float(y)) for x, y in kwargs["points"]]
47+
self.points = [(x, y) for x, y in kwargs["points"]]
4848

4949

5050
class PointGraph(Graph2D):
51-
x_ticks: List[float]
51+
x_ticks: List[Union[str, int, float]]
5252
x_tick_labels: List[str]
53-
y_ticks: List[float]
53+
x_unit: Optional[str]
54+
55+
y_ticks: List[Union[str, int, float]]
5456
y_tick_labels: List[str]
57+
y_unit: Optional[str]
5558

5659
elements: List[PointData]
5760

5861
def __init__(self, **kwargs):
5962
super().__init__(**kwargs)
6063
self.x_label = kwargs["x_label"]
61-
self.y_label = kwargs["y_label"]
6264
self.x_unit = kwargs["x_unit"]
63-
self.y_unit = kwargs["y_unit"]
6465
self.x_ticks = kwargs["x_ticks"]
6566
self.x_tick_labels = kwargs["x_tick_labels"]
67+
68+
self.y_label = kwargs["y_label"]
69+
self.y_unit = kwargs["y_unit"]
6670
self.y_ticks = kwargs["y_ticks"]
6771
self.y_tick_labels = kwargs["y_tick_labels"]
72+
6873
self.elements = [PointData(**d) for d in kwargs["elements"]]
6974

7075

python/e2b_code_interpreter/models.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import logging
23

34
from e2b import NotFoundException, TimeoutException, SandboxException
45
from dataclasses import dataclass, field
@@ -24,6 +25,8 @@
2425
Callable[[T], Awaitable[Any]],
2526
]
2627

28+
logger = logging.getLogger(__name__)
29+
2730

2831
@dataclass
2932
class OutputMessage:
@@ -133,8 +136,10 @@ def __init__(
133136
if graph:
134137
try:
135138
self.graph = deserialize_graph(graph)
136-
except Exception:
137-
pass
139+
except Exception as e:
140+
logger.error(
141+
f"Error deserializing graph, check if you are using the latest version of the library: {e}"
142+
)
138143
self.is_main_result = is_main_result
139144
self.extra = extra
140145

python/tests/async/test_async_statefulness.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,7 @@
44
async def test_stateful(async_sandbox: AsyncCodeInterpreter):
55
await async_sandbox.notebook.exec_cell("async_test_stateful = 1")
66

7-
result = await async_sandbox.notebook.exec_cell("async_test_stateful+=1; async_test_stateful")
7+
result = await async_sandbox.notebook.exec_cell(
8+
"async_test_stateful+=1; async_test_stateful"
9+
)
810
assert result.text == "2"

python/tests/graphs/test_line.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,25 @@
1+
import datetime
2+
13
from e2b_code_interpreter.code_interpreter_async import AsyncCodeInterpreter
24
from e2b_code_interpreter.graphs import LineGraph
35

46
code = """
57
import numpy as np
68
import matplotlib.pyplot as plt
9+
import datetime
710
811
# Generate x values
9-
x = np.linspace(0, 2*np.pi, 100)
12+
dates = [datetime.date(2023, 9, 1) + datetime.timedelta(seconds=i) for i in range(100)]
1013
14+
x = np.linspace(0, 2*np.pi, 100)
1115
# Calculate y values
1216
y_sin = np.sin(x)
1317
y_cos = np.cos(x)
1418
1519
# Create the plot
1620
plt.figure(figsize=(10, 6))
17-
plt.plot(x, y_sin, label='sin(x)')
18-
plt.plot(x, y_cos, label='cos(x)')
21+
plt.plot(dates, y_sin, label='sin(x)')
22+
plt.plot(dates, y_cos, label='cos(x)')
1923
2024
# Add labels and title
2125
plt.xlabel("Time (s)")
@@ -42,7 +46,9 @@ async def test_line_graph(async_sandbox: AsyncCodeInterpreter):
4246
assert graph.x_unit == "s"
4347
assert graph.y_unit == "Hz"
4448

45-
assert all(isinstance(x, float) for x in graph.x_ticks)
49+
assert all(isinstance(x, str) for x in graph.x_ticks)
50+
parsed_date = datetime.datetime.fromisoformat(graph.x_ticks[0])
51+
assert isinstance(parsed_date, datetime.datetime)
4652
assert all(isinstance(y, float) for y in graph.y_ticks)
4753

4854
assert all(isinstance(x, str) for x in graph.y_tick_labels)
@@ -56,9 +62,12 @@ async def test_line_graph(async_sandbox: AsyncCodeInterpreter):
5662
assert len(first_line.points) == 100
5763
assert all(isinstance(point, tuple) for point in first_line.points)
5864
assert all(
59-
isinstance(x, float) and isinstance(y, float) for x, y in first_line.points
65+
isinstance(x, str) and isinstance(y, float) for x, y in first_line.points
6066
)
6167

68+
parsed_date = datetime.datetime.fromisoformat(first_line.points[0][0])
69+
assert isinstance(parsed_date, datetime.datetime)
70+
6271
second_line = lines[1]
6372
assert second_line.label == "cos(x)"
6473
assert len(second_line.points) == 100

template/startup_scripts/0002_data.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,21 @@
1+
from datetime import date
12
import enum
23
import re
3-
from typing import Optional, List, Tuple, Literal, Any
4+
from typing import Optional, List, Tuple, Literal, Any, Union
45

6+
import matplotlib
57
import pandas
68
from matplotlib.axes import Axes
79
from matplotlib.collections import PathCollection
810
from matplotlib.lines import Line2D
911
from matplotlib.patches import Rectangle, Wedge, PathPatch
1012
from matplotlib.pyplot import Figure
13+
from matplotlib.dates import _SwitchableDateConverter
1114
import IPython
1215

1316
from IPython.core.formatters import BaseFormatter
1417
from matplotlib.text import Text
15-
from pydantic import BaseModel, Field
18+
from pydantic import BaseModel, Field, field_validator
1619
from traitlets.traitlets import Unicode, ObjectName
1720

1821

@@ -87,13 +90,29 @@ def _change_orientation(self):
8790

8891
class PointData(BaseModel):
8992
label: str
90-
points: List[Tuple[float, float]]
93+
points: List[Tuple[Union[str, int, float], Union[str, int, float]]]
94+
95+
@field_validator("points", mode="before")
96+
@classmethod
97+
def transform_points(
98+
cls, value
99+
) -> List[Tuple[Union[float, str], Union[float, str]]]:
100+
parsed_value = []
101+
for x, y in value:
102+
if isinstance(x, date):
103+
x = x.isoformat()
104+
105+
if isinstance(y, date):
106+
y = y.isoformat()
107+
108+
parsed_value.append((x, y))
109+
return parsed_value
91110

92111

93112
class PointGraph(Graph2D):
94-
x_ticks: List[float] = Field(default_factory=list)
113+
x_ticks: List[Union[str, int, float]] = Field(default_factory=list)
95114
x_tick_labels: List[str] = Field(default_factory=list)
96-
y_ticks: List[float] = Field(default_factory=list)
115+
y_ticks: List[Union[str, int, float]] = Field(default_factory=list)
97116
y_tick_labels: List[str] = Field(default_factory=list)
98117

99118
elements: List[PointData] = Field(default_factory=list)
@@ -103,11 +122,27 @@ def _extract_info(self, ax: Axes) -> None:
103122
Function to extract information for PointGraph
104123
"""
105124
super()._extract_info(ax)
106-
self.x_ticks = [float(tick) for tick in ax.get_xticks()]
125+
107126
self.x_tick_labels = [label.get_text() for label in ax.get_xticklabels()]
127+
self.x_ticks = self._extract_ticks_info(ax.xaxis.converter, ax.get_xticks())
108128

109-
self.y_ticks = [float(tick) for tick in ax.get_yticks()]
110129
self.y_tick_labels = [label.get_text() for label in ax.get_yticklabels()]
130+
self.y_ticks = self._extract_ticks_info(ax.yaxis.converter, ax.get_yticks())
131+
132+
@staticmethod
133+
def _extract_ticks_info(converter: Any, ticks: list) -> list:
134+
example_tick = ticks[0]
135+
136+
if isinstance(converter, _SwitchableDateConverter):
137+
return [matplotlib.dates.num2date(tick).isoformat() for tick in ticks]
138+
else:
139+
ticks_type = type(example_tick).__name__
140+
if ticks_type == "float":
141+
return [float(tick) for tick in ticks]
142+
elif ticks_type == "int":
143+
return [int(tick) for tick in ticks]
144+
else:
145+
return ticks
111146

112147

113148
class LineGraph(PointGraph):

0 commit comments

Comments
 (0)