Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 68 additions & 2 deletions deeplabcut/pose_estimation_pytorch/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,25 @@
from scipy.stats import truncnorm


def build_transforms(augmentations: dict) -> A.BaseCompose:
transforms = []

def transforms_legacy(augmentations):

"""
Apply DeepLabCut original augmentation methods.

This function preserves DLC's default augmentation implementations
for backward compatibility.

Args:
augmentations: dictionary of augmentations from config file

Returns:

The list of augmentation transforms

"""

transforms = []
if resize_aug := augmentations.get("resize", False):
transforms += build_resize_transforms(resize_aug)

Expand Down Expand Up @@ -121,6 +137,7 @@ def build_transforms(augmentations: dict) -> A.BaseCompose:
transforms.append(ElasticTransform(sigma=5, p=0.5))
if augmentations.get("grayscale", False):
transforms.append(Grayscale(alpha=(0.5, 1.0)))

if noise := augmentations.get("gaussian_noise", False):
# TODO inherit custom gaussian transform to support per_channel = 0.5
if not isinstance(noise, (int, float)):
Expand All @@ -146,6 +163,54 @@ def build_transforms(augmentations: dict) -> A.BaseCompose:
if augmentations.get("scale_to_unit_range"):
transforms.append(ScaleToUnitRange())

return transforms



def transform_new(augmentations: list[dict[str, Any]]) -> list[A.BasicTransform]:

"""
Creates augmentation using Albumentation library with hybrid support.

Args:
augmentations: the augmentations from "transform" in list format

Raises:
Warnings:
Augmentation method must exist in augmentations.
Augmentation must exist in the Albumentation library.

Returns:
The list of augmentation transforms

"""

custom_augmentations =['auto_padding', 'gaussian_noise', 'motion_blur', 'normalize_images']

transforms =[]
for aug in augmentations:
method = aug.get('augmentation', "")
if method == "":
warnings.warn("No augmentation method specified. Skipping this transform.")
continue

args = {key:value for key, value in aug.items() if key != 'augmentation'}
if method.lower() in custom_augmentations:
add_aug = transforms_legacy({method.lower(): args})
transforms.extend(add_aug)
elif hasattr(A, method):
func = getattr(A,method)
transforms.append(func(**args))
else:
warnings.warn(f"Albumentations has no method named {method}. Skipping this transform.")

return transforms


def build_transforms(augmentations: dict) -> A.BaseCompose:
transforms = transforms_legacy(augmentations)
transforms = transforms + transform_new(augmentations.get("transform", []))

return A.Compose(
transforms,
keypoint_params=A.KeypointParams(
Expand All @@ -155,6 +220,7 @@ def build_transforms(augmentations: dict) -> A.BaseCompose:
)



def build_auto_padding(
min_height: int | None = None,
min_width: int | None = None,
Expand Down