|
| 1 | +import base64 |
| 2 | +import os |
| 3 | +from typing import Iterable, List |
| 4 | + |
| 5 | +from cryptography.hazmat.primitives.ciphers.aead import AESGCM |
| 6 | +from temporalio.api.common.v1 import Payload |
| 7 | +from temporalio.converter import PayloadCodec |
| 8 | + |
| 9 | +default_key = base64.b64decode(b"MkUb3RVdHQuOTedqETZW7ra2GkZqpBRmYWRACUospMc=") |
| 10 | +default_key_id = "my-key" |
| 11 | + |
| 12 | + |
| 13 | +class EncryptionCodec(PayloadCodec): |
| 14 | + def __init__(self, key_id: str = default_key_id, key: bytes = default_key) -> None: |
| 15 | + super().__init__() |
| 16 | + self.key_id = key_id |
| 17 | + # We are using direct AESGCM to be compatible with samples from |
| 18 | + # TypeScript and Go. Pure Python samples may prefer the higher-level, |
| 19 | + # safer APIs. |
| 20 | + self.encryptor = AESGCM(key) |
| 21 | + |
| 22 | + async def encode(self, payloads: Iterable[Payload]) -> List[Payload]: |
| 23 | + # We blindly encode all payloads with the key and set the metadata |
| 24 | + # saying which key we used |
| 25 | + return [ |
| 26 | + Payload( |
| 27 | + metadata={ |
| 28 | + "encoding": b"binary/encrypted", |
| 29 | + "encryption-key-id": self.key_id.encode(), |
| 30 | + }, |
| 31 | + data=self.encrypt(p.SerializeToString()), |
| 32 | + ) |
| 33 | + for p in payloads |
| 34 | + ] |
| 35 | + |
| 36 | + async def decode(self, payloads: Iterable[Payload]) -> List[Payload]: |
| 37 | + ret: List[Payload] = [] |
| 38 | + for p in payloads: |
| 39 | + # Ignore ones w/out our expected encoding |
| 40 | + if p.metadata.get("encoding", b"").decode() != "binary/encrypted": |
| 41 | + ret.append(p) |
| 42 | + continue |
| 43 | + # Confirm our key ID is the same |
| 44 | + key_id = p.metadata.get("encryption-key-id", b"").decode() |
| 45 | + if key_id != self.key_id: |
| 46 | + raise ValueError(f"Unrecognized key ID {key_id}") |
| 47 | + # Decrypt and append |
| 48 | + ret.append(Payload.FromString(self.decrypt(p.data))) |
| 49 | + return ret |
| 50 | + |
| 51 | + def encrypt(self, data: bytes) -> bytes: |
| 52 | + nonce = os.urandom(12) |
| 53 | + return nonce + self.encryptor.encrypt(nonce, data, None) |
| 54 | + |
| 55 | + def decrypt(self, data: bytes) -> bytes: |
| 56 | + return self.encryptor.decrypt(data[:12], data[12:], None) |
0 commit comments