Skip to content

Commit 08777d9

Browse files
MatteB03ndem0
authored andcommitted
Update tutorial to work with domain folder
1 parent 0194fab commit 08777d9

2 files changed

Lines changed: 100 additions & 55 deletions

File tree

tutorials/tutorial6/tutorial.ipynb

Lines changed: 68 additions & 43 deletions
Large diffs are not rendered by default.

tutorials/tutorial6/tutorial.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python
22
# coding: utf-8
33

4-
# # Tutorial: Building custom geometries with PINA `Location` class
4+
# # Tutorial: Building custom geometries with PINA `DomainInterface` class
55
#
66
# [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mathLab/PINA/blob/master/tutorials/tutorial6/tutorial.ipynb)
77
#
@@ -27,7 +27,7 @@
2727

2828
import matplotlib.pyplot as plt
2929
plt.style.use('tableau-colorblind10')
30-
from pina.geometry import EllipsoidDomain, Difference, CartesianDomain, Union, SimplexDomain
30+
from pina.domain import EllipsoidDomain, Difference, CartesianDomain, Union, SimplexDomain, DomainInterface
3131
from pina.label_tensor import LabelTensor
3232

3333
def plot_scatter(ax, pts, title):
@@ -164,7 +164,7 @@ def plot_scatter(ax, pts, title):
164164
plot_scatter(ax, c_e_nb_d_points, 'Difference')
165165

166166

167-
# ## Create Custom Location
167+
# ## Create Custom DomainInterface
168168

169169
# We will take a look on how to create our own geometry. The one we will try to make is a heart defined by the function $$(x^2+y^2-1)^3-x^2y^3 \le 0$$
170170

@@ -174,30 +174,35 @@ def plot_scatter(ax, pts, title):
174174

175175

176176
import torch
177-
from pina import Location
178177
from pina import LabelTensor
179178
import random
180179

181180

182-
# Next, we will create the `Heart(Location)` class and initialize it.
181+
# Next, we will create the `Heart(DomainInterface)` class and initialize it.
183182

184183
# In[12]:
185184

186185

187-
class Heart(Location):
186+
class Heart(DomainInterface):
188187
"""Implementation of the Heart Domain."""
189188

190189
def __init__(self, sample_border=False):
191190
super().__init__()
192191

193192

194193

195-
# Because the `Location` class we are inheriting from requires both a `sample` method and `is_inside` method, we will create them and just add in "pass" for the moment.
194+
# In[ ]:
195+
196+
197+
198+
199+
200+
# Because the `DomainInterface` class we are inheriting from requires both a `sample` method and `is_inside` method, we will create them and just add in "pass" for the moment. We also observe that the methods `sample_modes` and `variables` of the `DomainInterface` class are initialized as `abstractmethod`, so we need to redefine them both in the subclass `Heart` .
196201

197202
# In[13]:
198203

199204

200-
class Heart(Location):
205+
class Heart(DomainInterface):
201206
"""Implementation of the Heart Domain."""
202207

203208
def __init__(self, sample_border=False):
@@ -208,15 +213,22 @@ def is_inside(self):
208213

209214
def sample(self):
210215
pass
216+
217+
@property
218+
def sample_modes(self):
219+
pass
211220

221+
@property
222+
def variables(self):
223+
pass
212224

213-
# Now we have the skeleton for our `Heart` class. The `sample` method is where most of the work is done so let's fill it out.
214225

215-
# In[14]:
226+
# Now we have the skeleton for our `Heart` class. Also the `sample` method is where most of the work is done so let's fill it out.
216227

228+
# In[14]:
217229

218230

219-
class Heart(Location):
231+
class Heart(DomainInterface):
220232
"""Implementation of the Heart Domain."""
221233

222234
def __init__(self, sample_border=False):
@@ -225,7 +237,7 @@ def __init__(self, sample_border=False):
225237
def is_inside(self):
226238
pass
227239

228-
def sample(self, n, mode='random', variables='all'):
240+
def sample(self, n):
229241
sampled_points = []
230242

231243
while len(sampled_points) < n:
@@ -235,6 +247,14 @@ def sample(self, n, mode='random', variables='all'):
235247
sampled_points.append([x.item(), y.item()])
236248

237249
return LabelTensor(torch.tensor(sampled_points), labels=['x','y'])
250+
251+
@property
252+
def sample_modes(self):
253+
pass
254+
255+
@property
256+
def variables(self):
257+
pass
238258

239259

240260
# To create the Heart geometry we simply run:

0 commit comments

Comments
 (0)