Skip to content

Commit 153df0d

Browse files
authored
Fix cartopy tri default transform for Triangulation inputs (#595)
1 parent 9e4bccf commit 153df0d

2 files changed

Lines changed: 48 additions & 18 deletions

File tree

ultraplot/internals/inputs.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@
1616
from cartopy.crs import PlateCarree
1717
except ModuleNotFoundError:
1818
PlateCarree = object
19+
try:
20+
from matplotlib.tri import Triangulation
21+
except ModuleNotFoundError:
22+
Triangulation = object
1923

2024

2125
# Constants
@@ -300,8 +304,16 @@ def triangulation_wrapper(self, *args, **kwargs):
300304
# Manually set the name to the original function's name
301305
triangulation_wrapper.__name__ = func.__name__
302306

307+
def _tri_cartopy_default(args, kwargs):
308+
# If the first parsed argument is already a Triangulation then it may
309+
# be in projected coordinates, so skip implicit PlateCarree defaults.
310+
return not (args and isinstance(args[0], Triangulation))
311+
303312
final_wrapper = _preprocess_or_redirect(
304-
*keys, keywords=keywords, allow_extra=allow_extra
313+
*keys,
314+
keywords=keywords,
315+
allow_extra=allow_extra,
316+
cartopy_default_transform=_tri_cartopy_default,
305317
)(triangulation_wrapper)
306318

307319
# Finally make sure all other metadata is correct
@@ -311,7 +323,9 @@ def triangulation_wrapper(self, *args, **kwargs):
311323
return _decorator
312324

313325

314-
def _preprocess_or_redirect(*keys, keywords=None, allow_extra=True):
326+
def _preprocess_or_redirect(
327+
*keys, keywords=None, allow_extra=True, cartopy_default_transform=True
328+
):
315329
"""
316330
Redirect internal plotting calls to native matplotlib methods. Also convert
317331
keyword args to positional and pass arguments through 'data' dictionary.
@@ -335,18 +349,6 @@ def _preprocess_or_redirect(self, *args, **kwargs):
335349
func_native = getattr(super(PlotAxes, self), name)
336350
return func_native(*args, **kwargs)
337351
else:
338-
# Impose default coordinate system
339-
from ..constructor import Proj
340-
341-
if self._name == "basemap" and name in BASEMAP_FUNCS:
342-
if kwargs.get("latlon", None) is None:
343-
kwargs["latlon"] = True
344-
if self._name == "cartopy" and name in CARTOPY_FUNCS:
345-
if kwargs.get("transform", None) is None:
346-
kwargs["transform"] = PlateCarree()
347-
else:
348-
kwargs["transform"] = Proj(kwargs["transform"])
349-
350352
# Process data args
351353
# NOTE: Raises error if there are more args than keys
352354
args, kwargs = _kwargs_to_args(
@@ -358,6 +360,25 @@ def _preprocess_or_redirect(self, *args, **kwargs):
358360
for key in set(keywords) & set(kwargs):
359361
kwargs[key] = _from_data(data, kwargs[key])
360362

363+
# Impose default coordinate system using parsed inputs. This keeps
364+
# behavior consistent across positional/keyword/data pathways.
365+
from ..constructor import Proj
366+
367+
if self._name == "basemap" and name in BASEMAP_FUNCS:
368+
if kwargs.get("latlon", None) is None:
369+
kwargs["latlon"] = True
370+
if self._name == "cartopy" and name in CARTOPY_FUNCS:
371+
if kwargs.get("transform", None) is None:
372+
use_default_transform = cartopy_default_transform
373+
if callable(use_default_transform):
374+
use_default_transform = bool(
375+
use_default_transform(args, kwargs)
376+
)
377+
if use_default_transform:
378+
kwargs["transform"] = PlateCarree()
379+
else:
380+
kwargs["transform"] = Proj(kwargs["transform"])
381+
361382
# Auto-setup matplotlib with the input unit registry
362383
_load_objects()
363384
for arg in args:

ultraplot/tests/test_geographic.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -931,10 +931,10 @@ def test_rasterize_feature():
931931

932932
def test_check_tricontourf():
933933
"""
934-
Ensure that tricontour functions are getting
935-
the transform for GeoAxes.
934+
Ensure transform defaults are applied only when appropriate for tri-plots.
936935
"""
937936
import cartopy.crs as ccrs
937+
from matplotlib.tri import Triangulation
938938

939939
lon0 = 90
940940
lon = np.linspace(-180, 180, 10)
@@ -947,6 +947,7 @@ def test_check_tricontourf():
947947
data[mask_box] = 1.5
948948

949949
lon, lat, data = map(np.ravel, (lon2d, lat2d, data))
950+
triangulation = Triangulation(lon, lat)
950951

951952
fig, ax = uplt.subplots(proj="cyl", proj_kw={"lon0": lon0})
952953
original_func = ax[0]._call_native
@@ -956,10 +957,18 @@ def test_check_tricontourf():
956957
autospec=True,
957958
side_effect=original_func,
958959
) as mocked:
959-
for func in "tricontour tricontourf".split():
960-
getattr(ax[0], func)(lon, lat, data)
960+
ax[0].tricontourf(lon, lat, data)
961961
assert "transform" in mocked.call_args.kwargs
962962
assert isinstance(mocked.call_args.kwargs["transform"], ccrs.PlateCarree)
963+
964+
with mock.patch.object(
965+
ax[0],
966+
"_call_native",
967+
autospec=True,
968+
side_effect=original_func,
969+
) as mocked:
970+
ax[0].tricontourf(triangulation, data)
971+
assert "transform" not in mocked.call_args.kwargs
963972
uplt.close(fig)
964973

965974

0 commit comments

Comments
 (0)