diff --git a/.gitignore b/.gitignore index 6339aeb6..50d213a6 100644 --- a/.gitignore +++ b/.gitignore @@ -139,3 +139,6 @@ doc/dist .idea .vscode/* !.vscode/settings.json + +# local configuration files +*.local.* diff --git a/GEMINI.md b/GEMINI.md new file mode 100644 index 00000000..d5d23ae9 --- /dev/null +++ b/GEMINI.md @@ -0,0 +1,99 @@ +# GEMINI.md + +This file provides guidance to coding agents when working with code in this repository. + +## Project Overview + +Firebase Functions Python SDK - A Python SDK for defining Firebase Functions that respond to Firebase and Google Cloud events using decorators or HTTP requests. + +## Development Commands + +All commands use the Makefile: + +- `make install` - Install dependencies with uv +- `make lint` - Run ruff linter +- `make format` - Format code with ruff +- `make format-check` - Check formatting without modifying +- `make typecheck` - Run type checking with mypy +- `make test` - Run tests +- `make test-cov` - Run tests with coverage report +- `make docs` - Generate documentation + +### Testing Individual Functions + +```bash +# Run specific test file +uv run pytest tests/test_https_fn.py + +# Run specific test +uv run pytest tests/test_https_fn.py::test_on_request_no_args + +# Run with verbose output +uv run pytest -vv tests/test_firestore_fn.py +``` + +## Architecture + +### Core Design Pattern + +The SDK uses a decorator-based API where functions are defined using specific decorators for different trigger types: + +```python +@https_fn.on_request() +@firestore_fn.on_document_created(document="posts/{postId}") +@pubsub_fn.on_message_published(topic="my-topic") +``` + +### Key Components + +1. **Function Modules** (`src/firebase_functions/`): + - Each trigger type has its own module (e.g., `https_fn.py`, `firestore_fn.py`) + - All decorators follow a similar pattern: validate options → create endpoint metadata → wrap function + +2. **Core Infrastructure** (`src/firebase_functions/core.py`): + - `CloudEvent` and `Event` classes for event data + - Endpoint metadata stored in `__firebase_endpoint__` attribute + - Type definitions for various event types + +3. **Private Modules** (`src/firebase_functions/private/`): + - `manifest.py` - Generates deployment manifests + - `serving.py` - Handles function serving and registration + - `util.py` - Common utilities + +4. **Options Pattern**: + - Each function type has an options dataclass (e.g., `HttpsOptions`, `FirestoreOptions`) + - Options control deployment settings like memory, timeout, regions, etc. + +### Function Registration Flow + +1. Decorator validates options and creates `ManifestEndpoint` +2. Endpoint metadata attached to function via `__firebase_endpoint__` +3. During deployment, `manifest.py` collects all decorated functions +4. Functions registered with `functions-framework` for runtime handling + +## Testing Strategy + +- Each function type has comprehensive tests in `/tests/` +- Tests use `unittest.mock` for mocking Firebase services +- Don't overuse mocks +- Test both successful cases and error conditions +- Verify endpoint metadata is correctly set + +## Code Quality Requirements + +Before committing changes: + +1. Run `make lint` and fix any issues +2. Run `make typecheck` and fix type errors +3. Run `make format` to ensure consistent formatting +4. Run `make test` to ensure all tests pass +5. Add tests for new functionality + +## Important Notes + +- Python 3.10+ required +- Uses `uv` package manager (not pip/poetry) +- Built on top of Google's `functions-framework` +- All public APIs should maintain backward compatibility +- Follow existing patterns when adding new trigger types + diff --git a/Makefile b/Makefile new file mode 100644 index 00000000..e6382d6a --- /dev/null +++ b/Makefile @@ -0,0 +1,40 @@ +.PHONY: help install lint format format-check typecheck test test-cov docs clean + +help: ## Show this help + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-15s\033[0m %s\n", $$1, $$2}' + +install: ## Install dependencies with uv + uv sync + +lint: ## Run ruff linter + uv run ruff check . + +format: ## Format code with ruff + uv run ruff format . + +format-check: ## Check code formatting without modifying files + uv run ruff format --check . + +fix: ## Fix linting issues and format code + uv run ruff check . --fix + uv run ruff format . + +typecheck: ## Run type checking with mypy + uv run mypy . + +test: ## Run tests + uv run pytest + +test-cov: ## Run tests with coverage report + uv run pytest --cov=src --cov-report term --cov-report html --cov-report xml -vv + +docs: ## Generate documentation + mkdir -p ./docs/build + uv run ./docs/generate.sh --out=./docs/build/ --pypath=src/ + +clean: ## Clean build artifacts + rm -rf build dist *.egg-info + find . -type d -name __pycache__ -exec rm -rf {} + + find . -type f -name "*.pyc" -delete + rm -rf htmlcov .coverage coverage.xml .pytest_cache + rm -rf docs/build \ No newline at end of file diff --git a/async.md b/async.md new file mode 100644 index 00000000..fb16fab4 --- /dev/null +++ b/async.md @@ -0,0 +1,253 @@ +# Async Support for Firebase Functions Python + +## Overview + +This document outlines the design and implementation plan for adding async function support to firebase-functions-python. The goal is to leverage the new async capabilities in functions-framework while maintaining full backward compatibility with existing sync functions. + +## Background + +Functions-framework recently added async support via the `--asgi` flag, allowing async functions to be defined like: + +```python +import functions_framework.aio + +@functions_framework.aio.http +async def hello_async(request): # Starlette.Request + await asyncio.sleep(1) + return "Hello, async world!" +``` + +## Design Goals + +1. **No code duplication** - Reuse existing decorators and logic +2. **Backward compatibility** - All existing sync functions must continue to work +3. **Unified API** - Users shouldn't need different decorators for sync vs async +4. **Type safety** - Proper typing for both sync and async cases +5. **Flexibility** - The aio namespace accepts both sync and async functions +6. **Universal support** - Async should work for ALL function types, not just HTTP + +## Function Types to Support + +Firebase Functions Python supports multiple trigger types that all need async support: + +### 1. HTTP Functions +- `@https_fn.on_request()` - Raw HTTP requests +- `@https_fn.on_call()` - Callable functions with auth/validation + +### 2. Firestore Functions +- `@firestore_fn.on_document_created()` +- `@firestore_fn.on_document_updated()` +- `@firestore_fn.on_document_deleted()` +- `@firestore_fn.on_document_written()` + +### 3. Realtime Database Functions +- `@db_fn.on_value_created()` +- `@db_fn.on_value_updated()` +- `@db_fn.on_value_deleted()` +- `@db_fn.on_value_written()` + +### 4. Cloud Storage Functions +- `@storage_fn.on_object_archived()` +- `@storage_fn.on_object_deleted()` +- `@storage_fn.on_object_finalized()` +- `@storage_fn.on_object_metadata_updated()` + +### 5. Pub/Sub Functions +- `@pubsub_fn.on_message_published()` + +### 6. Scheduler Functions +- `@scheduler_fn.on_schedule()` + +### 7. Task Queue Functions +- `@tasks_fn.on_task_dispatched()` + +### 8. EventArc Functions +- `@eventarc_fn.on_custom_event_published()` + +### 9. Remote Config Functions +- `@remote_config_fn.on_config_updated()` + +### 10. Test Lab Functions +- `@test_lab_fn.on_test_matrix_completed()` + +### 11. Alerts Functions +- Various alert triggers for billing, crashlytics, performance, etc. + +### 12. Identity Functions +- `@identity_fn.before_user_created()` +- `@identity_fn.before_user_signed_in()` + +## Implementation Strategy + +### Phase 1: Core Infrastructure + +#### 1.1 Async Detection Mechanism +- Add utility function to detect if a function is async using `inspect.iscoroutinefunction()` +- This detection should happen at decoration time + +#### 1.2 Metadata Storage +- Extend the `__firebase_endpoint__` attribute to include runtime mode information +- Add a field to `ManifestEndpoint` to indicate async functions: + ```python + @dataclasses.dataclass(frozen=True) + class ManifestEndpoint: + # ... existing fields ... + runtime_mode: Literal["sync", "async"] | None = "sync" + ``` + +#### 1.3 Type System Updates +- Create type unions to handle both sync and async cases +- For HTTP functions: + - Sync: `flask.Request` and `flask.Response` + - Async: `starlette.requests.Request` and response types +- For event functions: + - Both sync and async will receive the same event objects + - The difference is whether the handler is async + +### Phase 2: Decorator Updates + +#### 2.1 Namespace-based Approach +Instead of modifying existing decorators, we created a new `aio` namespace: +- `firebase_functions.aio.https_fn` for async HTTP functions +- The aio decorators accept both sync and async functions +- ASGI runtime handles sync functions by running them in a thread pool + +#### 2.2 Shared Implementation +To avoid code duplication, we extracted shared business logic: +- `_validate_on_call_request_headers()` - Validates headers and method +- `_process_on_call_request_body()` - Processes request body after reading +- `_format_on_call_response()` - Formats successful responses +- `_format_on_call_error()` - Formats error responses +- `_add_cors_headers_to_response()` - Adds CORS headers to any response + +The decorators use a shared implementation with an `asgi` parameter to differentiate between sync and async modes. + +#### 2.2 HTTP Functions Special Handling +HTTP functions need special care because the request type changes: +- Sync: `flask.Request` +- Async: `starlette.requests.Request` + +We'll need to handle this in the type system and potentially in request processing. + +### Phase 3: Manifest and Deployment + +#### 3.1 Manifest Generation +- Update `serving.py` to include runtime mode in the manifest +- The functions.yaml should indicate which functions need async runtime + +#### 3.2 Firebase CLI Integration +- The CLI needs to read the runtime mode from the manifest +- When deploying async functions, it should: + - Set appropriate environment variables + - Pass the `--asgi` flag to functions-framework + - Potentially use different container configurations + +### Phase 4: Testing and Validation + +#### 4.1 Test Coverage +- Add async versions of existing tests +- Test mixed deployments (both sync and async functions) +- Verify proper error handling in async contexts +- Test timeout behavior for async functions + +#### 4.2 Example Updates +- Update examples to show async usage +- Create migration guide for converting sync to async + +## Example Usage + +### HTTP Functions +```python +from firebase_functions import https_fn +from firebase_functions.aio import https_fn as async_https_fn + +# Sync (existing) +@https_fn.on_request() +def sync_http(request: Request) -> Response: + return Response("Hello sync") + +# Async (new) +@async_https_fn.on_request() +async def async_http(request) -> Response: # Will be Starlette Request + result = await some_async_api_call() + return Response(f"Hello async: {result}") + +# Sync function in aio namespace (also supported) +@async_https_fn.on_request() +def sync_in_async_http(request) -> Response: + # This sync function will run in ASGI's thread pool + return Response("Hello from sync in async") +``` + +### Firestore Functions +```python +# Sync (existing) +@firestore_fn.on_document_created(document="users/{userId}") +def sync_user_created(event: Event[DocumentSnapshot]) -> None: + print(f"User created: {event.data.id}") + +# Async (new) +@firestore_fn.on_document_created(document="users/{userId}") +async def async_user_created(event: Event[DocumentSnapshot]) -> None: + await send_welcome_email(event.data.get("email")) + await update_analytics(event.data.id) +``` + +### Pub/Sub Functions +```python +# Async (new) +@pubsub_fn.on_message_published(topic="process-queue") +async def async_process_message(event: CloudEvent[MessagePublishedData]) -> None: + message = event.data.message + await process_job(message.data) +``` + +## Benefits + +1. **Performance**: Async functions can handle I/O-bound operations more efficiently +2. **Scalability**: Better resource utilization for functions that make external API calls +3. **Modern Python**: Aligns with Python's async/await ecosystem +4. **Flexibility**: Users can choose sync or async based on their needs + +## Considerations + +1. **Cold Start**: Need to verify async functions don't increase cold start times +2. **Memory Usage**: Monitor if async runtime uses more memory +3. **Debugging**: Ensure stack traces and error messages are clear for async functions +4. **Timeouts**: Verify timeout behavior works correctly with async functions + +## Migration Path + +1. Start with HTTP functions as proof of concept +2. Extend to event-triggered functions +3. Update documentation and examples +4. Release as minor version update (backward compatible) + +## Implementation Status + +### Completed (Phase 1) +- ✅ HTTP functions (on_request and on_call) with async support +- ✅ Shared business logic to avoid code duplication +- ✅ CORS handling for async functions +- ✅ Type safety with overloads +- ✅ Support for both sync and async functions in aio namespace +- ✅ Comprehensive tests + +### Remaining Work +- Event-triggered functions (Firestore, Database, Storage, etc.) +- Documentation and examples +- Integration with Firebase CLI for deployment + +## Open Questions + +1. Should we support both Flask and Starlette response types for async HTTP functions? +2. How should we handle async context managers and cleanup? +3. Should we provide async versions of Firebase Admin SDK operations? + +## Next Steps + +1. Prototype async support for HTTP functions +2. Test with functions-framework in ASGI mode +3. Design type system for handling both sync and async +4. Update manifest generation +5. Coordinate with Firebase CLI team for deployment support \ No newline at end of file diff --git a/examples/async_example.py b/examples/async_example.py new file mode 100644 index 00000000..01952fd6 --- /dev/null +++ b/examples/async_example.py @@ -0,0 +1,62 @@ +"""Example showing async HTTP functions with firebase-functions-python.""" + +import asyncio + +from flask import Request, Response + +from firebase_functions import https_fn +from firebase_functions.aio import https_fn as async_https_fn + + +# Traditional synchronous function +@https_fn.on_request() +def sync_hello(request: Request) -> Response: + """A traditional synchronous HTTP function.""" + name = request.args.get("name", "World") + return Response(f"Hello {name}! (sync)") + + +# New async function using aio namespace +@async_https_fn.on_request() +async def async_hello(request) -> dict: + """An async HTTP function that can use await.""" + # Simulate async operation (e.g., database query, API call) + await asyncio.sleep(0.1) + + # In async functions, request is a Starlette Request object + name = request.query_params.get("name", "World") + + # Can return dict which will be JSON serialized + return {"message": f"Hello {name}! (async)", "type": "async"} + + +# Async callable function +@async_https_fn.on_call() +async def async_callable(request: async_https_fn.CallableRequest) -> dict: + """An async callable function.""" + # Access the data sent by the client + name = request.data.get("name", "World") + + # Simulate async work + await asyncio.sleep(0.1) + + # Access auth information if available + user_id = request.auth.uid if request.auth else "anonymous" + + return { + "message": f"Hello {name}!", + "user": user_id, + "timestamp": asyncio.get_event_loop().time(), + } + + +# Example of mixing sync and async in the same file +@https_fn.on_request() +def list_functions(request: Request) -> Response: + """List all functions in this module.""" + functions = [ + {"name": "sync_hello", "type": "sync", "url": "/sync_hello"}, + {"name": "async_hello", "type": "async", "url": "/async_hello"}, + {"name": "async_callable", "type": "async_callable", "url": "/async_callable"}, + ] + return Response(str(functions), content_type="application/json") diff --git a/src/firebase_functions/aio/__init__.py b/src/firebase_functions/aio/__init__.py new file mode 100644 index 00000000..704c4a58 --- /dev/null +++ b/src/firebase_functions/aio/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2025 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Module for async Firebase Functions. +""" + +from firebase_functions.aio import https_fn + +__all__ = ["https_fn"] diff --git a/src/firebase_functions/aio/https_fn.py b/src/firebase_functions/aio/https_fn.py new file mode 100644 index 00000000..c2961806 --- /dev/null +++ b/src/firebase_functions/aio/https_fn.py @@ -0,0 +1,119 @@ +# Copyright 2025 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Module for async functions that listen to HTTPS endpoints. +These can be raw web requests and Callable RPCs. +""" + +import inspect as _inspect +import typing as _typing + +# Import the shared implementation and types from the sync module +from firebase_functions.https_fn import ( + CallableRequest, + FunctionsErrorCode, + HttpsError, + _create_on_call_decorator, + _create_on_request_decorator, +) +from firebase_functions.options import HttpsOptions +from firebase_functions.private import util as _util + +# Type stubs for Starlette types (to avoid hard dependency) +# In production, these would be actual imports from starlette +if _typing.TYPE_CHECKING: + from starlette.requests import Request as StarletteRequest + from starlette.responses import Response as StarletteResponse +else: + # Runtime placeholder types + StarletteRequest = _typing.Any + StarletteResponse = _typing.Any + +# Type aliases for async handlers +_AsyncC1 = _typing.Callable[[StarletteRequest], _typing.Awaitable[StarletteResponse]] +_AsyncC2 = _typing.Callable[[CallableRequest[_typing.Any]], _typing.Awaitable[_typing.Any]] + + +@_util.copy_func_kwargs(HttpsOptions) +def on_request(**kwargs) -> _typing.Callable[[_AsyncC1], _AsyncC1]: + """ + Handler which handles async HTTPS requests. + Requires an async function that takes a Starlette ``Request`` and returns a ``Response``. + + Example: + + .. code-block:: python + + from firebase_functions.aio import https_fn + + @https_fn.on_request() + async def example(request) -> Response: + await some_async_operation() + return Response("Hello async world!") + + :param \\*\\*kwargs: Https options. + :type \\*\\*kwargs: as :exc:`firebase_functions.options.HttpsOptions` + :rtype: :exc:`typing.Callable` \\[ \\[ :exc:`starlette.requests.Request` \\], + :exc:`typing.Awaitable` \\[ :exc:`starlette.responses.Response` \\] \\] + An async function that takes a Starlette Request and returns a Response. + """ + options = HttpsOptions(**kwargs) + + def on_request_inner_decorator(func: _AsyncC1) -> _AsyncC1: + # Allow both sync and async functions - ASGI can handle both + return _typing.cast(_AsyncC1, _create_on_request_decorator(func, options, asgi=True)) + + return on_request_inner_decorator + + +@_util.copy_func_kwargs(HttpsOptions) +def on_call(**kwargs) -> _typing.Callable[[_AsyncC2], _AsyncC2]: + """ + Declares an async callable method for clients to call using a Firebase SDK. + Requires an async function that takes a ``CallableRequest``. + + Example: + + .. code-block:: python + + from firebase_functions.aio import https_fn + + @https_fn.on_call() + async def example(request: CallableRequest) -> Any: + await some_async_operation() + return {"message": "Hello async world!"} + + :param \\*\\*kwargs: Https options. + :type \\*\\*kwargs: as :exc:`firebase_functions.options.HttpsOptions` + :rtype: :exc:`typing.Callable` + \\[ \\[ :exc:`firebase_functions.https.CallableRequest` \\[ + :exc:`object` \\] \\], :exc:`typing.Awaitable` \\[ :exc:`object` \\] \\] + An async function that takes a ``CallableRequest`` and returns an object. + """ + options = HttpsOptions(**kwargs) + + def on_call_inner_decorator(func: _AsyncC2) -> _AsyncC2: + # Allow both sync and async functions - ASGI can handle both + return _typing.cast(_AsyncC2, _create_on_call_decorator(func, options, asgi=True)) + + return on_call_inner_decorator + + +# Re-export common types and exceptions so users don't need to import from both modules +__all__ = [ + "on_request", + "on_call", + "HttpsError", + "FunctionsErrorCode", + "CallableRequest", +] diff --git a/src/firebase_functions/https_fn.py b/src/firebase_functions/https_fn.py index 7e692de0..05f1ff02 100644 --- a/src/firebase_functions/https_fn.py +++ b/src/firebase_functions/https_fn.py @@ -19,6 +19,7 @@ import dataclasses as _dataclasses import enum as _enum import functools as _functools +import inspect as _inspect import json as _json import typing as _typing @@ -31,7 +32,7 @@ import firebase_functions.core as _core import firebase_functions.private.util as _util -from firebase_functions.options import _GLOBAL_OPTIONS, HttpsOptions +from firebase_functions.options import _GLOBAL_OPTIONS, CorsOptions, HttpsOptions class FunctionsErrorCode(str, _enum.Enum): @@ -351,62 +352,255 @@ class CallableRequest(_typing.Generic[_core.T]): _C2 = _typing.Callable[[CallableRequest[_typing.Any]], _typing.Any] +def _validate_on_call_request_headers(method: str, headers: dict) -> None: + """Validate method and headers for on_call requests.""" + if method != "POST": + _logging.warning("Request has invalid method. %s", method) + raise HttpsError(FunctionsErrorCode.INVALID_ARGUMENT, "Bad Request") + + # Try both lowercase and capitalized versions for compatibility + content_type = headers.get("content-type", "") or headers.get("Content-Type", "") + if not content_type.startswith("application/json"): + _logging.warning("Request has invalid content type. %s", content_type) + raise HttpsError(FunctionsErrorCode.INVALID_ARGUMENT, "Bad Request") + + +def _process_on_call_request_body( + raw_request: _typing.Any, + body_bytes: bytes, + headers: dict, + method: str, + enforce_app_check: bool, +) -> CallableRequest: + """Process on_call request after body is read. Shared between sync/async.""" + # Validate headers/method + _validate_on_call_request_headers(method, headers) + + # Parse and validate JSON + try: + json_data = _json.loads(body_bytes) + except _json.JSONDecodeError: + _logging.error("Request body is not valid JSON") + raise HttpsError(FunctionsErrorCode.INVALID_ARGUMENT, "Bad Request") from None + + if "data" not in json_data: + _logging.warning("Request body is missing data.") + raise HttpsError(FunctionsErrorCode.INVALID_ARGUMENT, "Bad Request") + + # Create mock request for token checking + class HeadersAdapter: + def __init__(self, headers): + self.headers = headers + + mock_request = HeadersAdapter(headers) + token_status = _util.on_call_check_tokens(mock_request) # type: ignore + + # Validate tokens + if token_status.auth == _util.OnCallTokenState.INVALID: + raise HttpsError(FunctionsErrorCode.UNAUTHENTICATED, "Unauthenticated") + + if enforce_app_check and token_status.app in ( + _util.OnCallTokenState.MISSING, + _util.OnCallTokenState.INVALID, + ): + raise HttpsError(FunctionsErrorCode.UNAUTHENTICATED, "Unauthenticated") + + # Build context + context = CallableRequest(raw_request=raw_request, data=json_data["data"]) + + # Add app check data + if token_status.app == _util.OnCallTokenState.VALID and token_status.app_token is not None: + context = _dataclasses.replace( + context, + app=AppCheckData(token_status.app_token["sub"], token_status.app_token), + ) + + # Add auth data + if token_status.auth_token is not None: + context = _dataclasses.replace( + context, + auth=AuthData( + token_status.auth_token["uid"] if "uid" in token_status.auth_token else None, + token_status.auth_token, + ), + ) + + # Add instance ID (try both cases) + instance_id = headers.get("firebase-instance-id-token") or headers.get("Firebase-Instance-ID-Token") + if instance_id is not None: + context = _dataclasses.replace(context, instance_id_token=instance_id) + + return context + + +def _format_on_call_response(result: _typing.Any) -> dict: + """Format the result for on_call response.""" + return {"result": result} + + +def _format_on_call_error(err: Exception) -> tuple[dict, int]: + """Format error for on_call response.""" + if not isinstance(err, HttpsError): + _logging.error("Unhandled error: %s", err) + err = HttpsError(FunctionsErrorCode.INTERNAL, "INTERNAL") + + return {"error": err._as_dict()}, err._http_error_code.status + + +def _add_cors_headers_to_response( + response, # Can be Flask Response or Starlette Response + cors_options: CorsOptions | None, + allowed_methods: list[str] | None = None, +) -> None: + """Add CORS headers to any response object with headers dict.""" + if not cors_options: + return + + origins = cors_options.cors_origins or "*" + if isinstance(origins, list): + origins = ", ".join(origins) + + methods = allowed_methods or cors_options.cors_methods or ["*"] + if isinstance(methods, list): + methods = ", ".join(methods) + + response.headers["Access-Control-Allow-Origin"] = origins + response.headers["Access-Control-Allow-Methods"] = methods + response.headers["Access-Control-Allow-Headers"] = ( + "Content-Type, Authorization, Firebase-Instance-ID-Token, " + "Firebase-AppCheck, X-Firebase-AppCheck" + ) + + if origins != "*": + response.headers["Vary"] = "Origin" + response.headers["Access-Control-Allow-Credentials"] = "true" + + def _on_call_handler(func: _C2, request: Request, enforce_app_check: bool) -> Response: + """Sync on_call handler using shared logic.""" try: - if not _util.valid_on_call_request(request): - _logging.error("Invalid request, unable to process.") - raise HttpsError(FunctionsErrorCode.INVALID_ARGUMENT, "Bad Request") - context: CallableRequest = CallableRequest( + # Use shared processing + context = _process_on_call_request_body( raw_request=request, - data=_json.loads(request.data)["data"], + body_bytes=request.data, + headers=dict(request.headers), + method=request.method, + enforce_app_check=enforce_app_check, ) - token_status = _util.on_call_check_tokens(request) + # Call function + result = _core._with_init(func)(context) - if token_status.auth == _util.OnCallTokenState.INVALID: - raise HttpsError(FunctionsErrorCode.UNAUTHENTICATED, "Unauthenticated") + # Format response + return _jsonify(_format_on_call_response(result)) - if enforce_app_check and token_status.app in ( - _util.OnCallTokenState.MISSING, - _util.OnCallTokenState.INVALID, - ): - raise HttpsError(FunctionsErrorCode.UNAUTHENTICATED, "Unauthenticated") - if token_status.app == _util.OnCallTokenState.VALID and token_status.app_token is not None: - context = _dataclasses.replace( - context, - app=AppCheckData(token_status.app_token["sub"], token_status.app_token), - ) + # Disable broad exceptions lint since we want to handle all exceptions here + # and wrap as an HttpsError. + # pylint: disable=broad-except + except Exception as err: + error_dict, status = _format_on_call_error(err) + return _make_response(_jsonify(error_dict), status) - if token_status.auth_token is not None: - context = _dataclasses.replace( - context, - auth=AuthData( - token_status.auth_token["uid"] if "uid" in token_status.auth_token else None, - token_status.auth_token, - ), - ) - instance_id = request.headers.get("Firebase-Instance-ID-Token") - if instance_id is not None: - # Validating the token requires an http request, so we don't do it. - # If the user wants to use it for something, it will be validated then. - # Currently, the only real use case for this token is for sending - # pushes with FCM. In that case, the FCM APIs will validate the token. - context = _dataclasses.replace( - context, - instance_id_token=request.headers.get("Firebase-Instance-ID-Token"), - ) - result = _core._with_init(func)(context) - return _jsonify(result=result) +async def _on_call_handler_async(func: _C2, request, enforce_app_check: bool): + """Async on_call handler using shared logic.""" + # Import here to avoid runtime dependency when not using async + from starlette.responses import JSONResponse + + try: + # Read body (only async-specific part) + body_bytes = await request.body() + + # Use shared processing + context = _process_on_call_request_body( + raw_request=request, + body_bytes=body_bytes, + headers=dict(request.headers), + method=request.method, + enforce_app_check=enforce_app_check, + ) + + # Call function (check if it's async or sync) + if _inspect.iscoroutinefunction(func): + result = await _core._with_init(func)(context) + else: + result = _core._with_init(func)(context) + + # Format response + return JSONResponse(_format_on_call_response(result)) + # Disable broad exceptions lint since we want to handle all exceptions here # and wrap as an HttpsError. # pylint: disable=broad-except except Exception as err: - if not isinstance(err, HttpsError): - _logging.error("Unhandled error: %s", err) - err = HttpsError(FunctionsErrorCode.INTERNAL, "INTERNAL") - status = err._http_error_code.status - return _make_response(_jsonify(error=err._as_dict()), status) + error_dict, status = _format_on_call_error(err) + return JSONResponse(content=error_dict, status_code=status) + + +@_typing.overload +def _create_on_request_decorator(func: _C1, options: HttpsOptions, asgi: _typing.Literal[False] = False) -> _C1: + ... + +@_typing.overload +def _create_on_request_decorator(func: _typing.Callable[..., _typing.Awaitable[_typing.Any]], options: HttpsOptions, asgi: _typing.Literal[True]) -> _typing.Callable[..., _typing.Awaitable[_typing.Any]]: + ... + +def _create_on_request_decorator(func: _typing.Union[_C1, _typing.Callable[..., _typing.Awaitable[_typing.Any]]], options: HttpsOptions, asgi: bool = False) -> _typing.Union[_C1, _typing.Callable[..., _typing.Awaitable[_typing.Any]]]: + """ + Internal helper to create the on_request decorator wrapper. + This shared implementation is used by both sync and async versions. + """ + if asgi: + # For async functions, we need an async wrapper + @_functools.wraps(func) + async def async_wrapper(request): # Will receive Starlette Request + # Import here to avoid runtime dependency when not using async + from starlette.responses import JSONResponse + from starlette.responses import Response as StarletteResponse + + # Handle OPTIONS preflight + if request.method == "OPTIONS" and options.cors: + response = StarletteResponse(status_code=200) + _add_cors_headers_to_response(response, options.cors) + return response + + # Call the function (check if it's async or sync) + if _inspect.iscoroutinefunction(func): + result = await _core._with_init(func)(request) + else: + result = _core._with_init(func)(request) + + # Convert to response + if isinstance(result, dict): + response = JSONResponse(result) + elif hasattr(result, "headers"): # Already a response + response = result + else: + response = StarletteResponse(content=str(result)) + + # Add CORS headers + _add_cors_headers_to_response(response, options.cors) + return response + + wrapper = async_wrapper + else: + # For sync functions, use the existing logic + @_functools.wraps(func) + def sync_wrapper(request: Request) -> Response: + if options.cors is not None: + return _cross_origin( + methods=options.cors.cors_methods, + origins=options.cors.cors_origins, + )(func)(request) + return _core._with_init(func)(request) + + wrapper = sync_wrapper + + _util.set_func_endpoint_attr( + wrapper, + options._endpoint(func_name=func.__name__, asgi=asgi), + ) + return _typing.cast(_C1, wrapper) @_util.copy_func_kwargs(HttpsOptions) @@ -432,22 +626,79 @@ def example(request: Request) -> Response: options = HttpsOptions(**kwargs) def on_request_inner_decorator(func: _C1): - @_functools.wraps(func) - def on_request_wrapped(request: Request) -> Response: - if options.cors is not None: - return _cross_origin( - methods=options.cors.cors_methods, - origins=options.cors.cors_origins, - )(func)(request) - return _core._with_init(func)(request) + return _create_on_request_decorator(func, options, asgi=False) + + return on_request_inner_decorator + + +@_typing.overload +def _create_on_call_decorator(func: _C2, options: HttpsOptions, asgi: _typing.Literal[False] = False) -> _C2: + ... + +@_typing.overload +def _create_on_call_decorator(func: _typing.Callable[..., _typing.Awaitable[_typing.Any]], options: HttpsOptions, asgi: _typing.Literal[True]) -> _typing.Callable[..., _typing.Awaitable[_typing.Any]]: + ... + +def _create_on_call_decorator(func: _typing.Union[_C2, _typing.Callable[..., _typing.Awaitable[_typing.Any]]], options: HttpsOptions, asgi: bool = False) -> _typing.Union[_C2, _typing.Callable[..., _typing.Awaitable[_typing.Any]]]: + """ + Internal helper to create the on_call decorator wrapper. + This shared implementation is used by both sync and async versions. + """ + origins: _typing.Any = "*" + if options.cors is not None and options.cors.cors_origins is not None: + origins = options.cors.cors_origins + + # Default to False. + enforce_app_check = False + # If the global option is set, use that. + if options.enforce_app_check is None and _GLOBAL_OPTIONS.enforce_app_check is not None: + enforce_app_check = _GLOBAL_OPTIONS.enforce_app_check + # If the global option is not set, use the local option. + elif options.enforce_app_check is not None: + enforce_app_check = options.enforce_app_check - _util.set_func_endpoint_attr( - on_request_wrapped, - options._endpoint(func_name=func.__name__), + if asgi: + # For async callable functions + @_functools.wraps(func) + async def async_wrapper(request): # Will receive Starlette Request + # Import here to avoid runtime dependency when not using async + from starlette.responses import Response as StarletteResponse + + # Handle OPTIONS preflight + if request.method == "OPTIONS" and options.cors: + response = StarletteResponse(status_code=200) + _add_cors_headers_to_response(response, options.cors, ["POST"]) + return response + + # Use async handler + response = await _on_call_handler_async(func, request, enforce_app_check) + + # Add CORS headers + _add_cors_headers_to_response(response, options.cors, ["POST"]) + return response + + wrapper = async_wrapper + else: + # For sync callable functions + @_cross_origin( + methods="POST", + origins=origins, ) - return on_request_wrapped + @_functools.wraps(func) + def sync_wrapper(request: Request): + return _on_call_handler( + func, + request, + enforce_app_check, + ) - return on_request_inner_decorator + wrapper = sync_wrapper + + _util.set_func_endpoint_attr( + wrapper, + options._endpoint(func_name=func.__name__, callable=True, asgi=asgi), + ) + return _typing.cast(_C2, wrapper) @_util.copy_func_kwargs(HttpsOptions) @@ -474,35 +725,6 @@ def example(request: CallableRequest) -> Any: options = HttpsOptions(**kwargs) def on_call_inner_decorator(func: _C2): - origins: _typing.Any = "*" - if options.cors is not None and options.cors.cors_origins is not None: - origins = options.cors.cors_origins - - # Default to False. - enforce_app_check = False - # If the global option is set, use that. - if options.enforce_app_check is None and _GLOBAL_OPTIONS.enforce_app_check is not None: - enforce_app_check = _GLOBAL_OPTIONS.enforce_app_check - # If the global option is not set, use the local option. - elif options.enforce_app_check is not None: - enforce_app_check = options.enforce_app_check - - @_cross_origin( - methods="POST", - origins=origins, - ) - @_functools.wraps(func) - def on_call_wrapped(request: Request): - return _on_call_handler( - func, - request, - enforce_app_check, - ) - - _util.set_func_endpoint_attr( - on_call_wrapped, - options._endpoint(func_name=func.__name__, callable=True), - ) - return on_call_wrapped + return _create_on_call_decorator(func, options, asgi=False) return on_call_inner_decorator diff --git a/src/firebase_functions/options.py b/src/firebase_functions/options.py index badf87e5..c346bc03 100644 --- a/src/firebase_functions/options.py +++ b/src/firebase_functions/options.py @@ -1149,6 +1149,9 @@ def _endpoint( https_trigger["invoker"] = invoker kwargs_merged["httpsTrigger"] = https_trigger + if "asgi" in kwargs and kwargs["asgi"] is True: + kwargs_merged["asgi"] = True + return _manifest.ManifestEndpoint(**_typing.cast(dict, kwargs_merged)) diff --git a/src/firebase_functions/private/manifest.py b/src/firebase_functions/private/manifest.py index 7672a9f5..7f74e0f1 100644 --- a/src/firebase_functions/private/manifest.py +++ b/src/firebase_functions/private/manifest.py @@ -169,6 +169,7 @@ class ManifestEndpoint: scheduleTrigger: ScheduleTrigger | None = None blockingTrigger: BlockingTrigger | None = None taskQueueTrigger: TaskQueueTrigger | None = None + asgi: bool | None = None class ManifestRequiredApi(_typing.TypedDict): diff --git a/tests/test_async_https_fn.py b/tests/test_async_https_fn.py new file mode 100644 index 00000000..0ebc7384 --- /dev/null +++ b/tests/test_async_https_fn.py @@ -0,0 +1,328 @@ +""" +Tests for the async https module. +""" + +import asyncio +import json +import sys +import unittest +from unittest.mock import AsyncMock, Mock, patch + +from firebase_functions import core +from firebase_functions.aio import https_fn +from firebase_functions.https_fn import CallableRequest, FunctionsErrorCode, HttpsError +from firebase_functions.options import CorsOptions + + +# Mock Starlette for tests +class MockStarletteResponse: + def __init__(self, content=None, status_code=200): + self.content = content + self.status_code = status_code + self.headers = {} + +class MockJSONResponse(MockStarletteResponse): + def __init__(self, content, status_code=200): + super().__init__(json.dumps(content), status_code) + self.headers["content-type"] = "application/json" + + +class TestAsyncHttps(unittest.TestCase): + """ + Tests for the async http module. + """ + + def test_on_request_decorator_accepts_sync_function(self): + """Test that on_request accepts sync functions in ASGI mode.""" + # Should not raise for sync function (ASGI can handle both) + @https_fn.on_request() + def sync_func(request): + return {"message": "sync"} + + # Check that the function is decorated properly + self.assertTrue(hasattr(sync_func, "__firebase_endpoint__")) + endpoint = sync_func.__firebase_endpoint__ + self.assertEqual(endpoint.asgi, True) + self.assertEqual(endpoint.entryPoint, "sync_func") + + def test_on_call_decorator_accepts_sync_function(self): + """Test that on_call accepts sync functions in ASGI mode.""" + # Should not raise for sync function (ASGI can handle both) + @https_fn.on_call() + def sync_func(request): + return {"message": "sync"} + + # Check that the function is decorated properly + self.assertTrue(hasattr(sync_func, "__firebase_endpoint__")) + endpoint = sync_func.__firebase_endpoint__ + self.assertEqual(endpoint.asgi, True) + self.assertEqual(endpoint.entryPoint, "sync_func") + self.assertIsNotNone(endpoint.callableTrigger) + + def test_on_request_decorator_accepts_async_function(self): + """Test that on_request accepts async functions.""" + + # Should not raise for async function + @https_fn.on_request() + async def async_func(request): + return {"message": "async"} + + # Check that the function is decorated properly + self.assertTrue(hasattr(async_func, "__firebase_endpoint__")) + endpoint = async_func.__firebase_endpoint__ + self.assertEqual(endpoint.asgi, True) + self.assertEqual(endpoint.entryPoint, "async_func") + + def test_on_call_decorator_accepts_async_function(self): + """Test that on_call accepts async functions.""" + + # Should not raise for async function + @https_fn.on_call() + async def async_callable(request): + return {"message": "async"} + + # Check that the function is decorated properly + self.assertTrue(hasattr(async_callable, "__firebase_endpoint__")) + endpoint = async_callable.__firebase_endpoint__ + self.assertEqual(endpoint.asgi, True) + self.assertEqual(endpoint.entryPoint, "async_callable") + self.assertIsNotNone(endpoint.callableTrigger) + + def test_async_on_request_calls_init_function(self): + """Test that async on_request calls the init function.""" + hello = None + + @core.init + def init(): + nonlocal hello + hello = "world" + + func = AsyncMock(__name__="example_func") + func.return_value = {"result": "test"} + + with patch.dict(sys.modules, { + 'starlette': Mock(), + 'starlette.responses': Mock(Response=MockStarletteResponse, JSONResponse=MockJSONResponse) + }): + @https_fn.on_request() + async def decorated_func(request): + return await func(request) + + # Create a mock request + mock_request = Mock() + + # Run the async function + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete(decorated_func(mock_request)) + finally: + loop.close() + + self.assertEqual(hello, "world") + func.assert_called_once_with(mock_request) + + def test_on_request_with_options(self): + """Test that on_request passes options correctly.""" + + @https_fn.on_request( + region="us-central1", + memory=512, + timeout_sec=60, + ) + async def async_func(request): + return {"message": "async"} + + endpoint = async_func.__firebase_endpoint__ + self.assertEqual(endpoint.asgi, True) + self.assertEqual(endpoint.region, ["us-central1"]) + self.assertEqual(endpoint.availableMemoryMb, 512) + self.assertEqual(endpoint.timeoutSeconds, 60) + + def test_on_call_with_options(self): + """Test that on_call passes options correctly.""" + + @https_fn.on_call( + region="europe-west1", + enforce_app_check=True, + ) + async def async_callable(request): + return {"message": "async"} + + endpoint = async_callable.__firebase_endpoint__ + self.assertEqual(endpoint.asgi, True) + self.assertEqual(endpoint.region, ["europe-west1"]) + # Note: enforce_app_check is not stored in the endpoint directly + + def test_async_on_call_handler(self): + """Test that async on_call handler works correctly.""" + + # Patch starlette imports + with patch.dict(sys.modules, { + 'starlette': Mock(), + 'starlette.responses': Mock(Response=MockStarletteResponse, JSONResponse=MockJSONResponse) + }): + @https_fn.on_call() + async def async_callable(request: CallableRequest): + await asyncio.sleep(0.01) # Simulate async work + return {"message": "Hello " + request.data.get("name", "World")} + + # Create a mock Starlette request with proper structure + mock_request = AsyncMock() + mock_request.method = "POST" + mock_request.headers = { + "content-type": "application/json", + } + mock_request.body = AsyncMock(return_value=b'{"data": {"name": "Alice"}}') + + # Run the async function + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + response = loop.run_until_complete(async_callable(mock_request)) + # Response should be a JSONResponse with proper structure + self.assertEqual(response.status_code, 200) + # Note: In real test we'd need to check response body content + finally: + loop.close() + + def test_async_on_request_cors_preflight(self): + """Test that async on_request handles CORS preflight correctly.""" + + with patch.dict(sys.modules, { + 'starlette': Mock(), + 'starlette.responses': Mock(Response=MockStarletteResponse, JSONResponse=MockJSONResponse) + }): + @https_fn.on_request(cors=CorsOptions(cors_origins=["https://example.com"], cors_methods=["GET", "POST"])) + async def async_func(request): + return {"message": "Hello"} + + # Create a mock OPTIONS request + mock_request = Mock() + mock_request.method = "OPTIONS" + + # Run the async function + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + response = loop.run_until_complete(async_func(mock_request)) + # Response should have CORS headers + self.assertEqual(response.status_code, 200) + self.assertIn("Access-Control-Allow-Origin", response.headers) + self.assertIn("Access-Control-Allow-Methods", response.headers) + finally: + loop.close() + + def test_async_on_call_cors_headers(self): + """Test that async on_call adds CORS headers correctly.""" + + with patch.dict(sys.modules, { + 'starlette': Mock(), + 'starlette.responses': Mock(Response=MockStarletteResponse, JSONResponse=MockJSONResponse) + }): + @https_fn.on_call(cors=CorsOptions(cors_origins="*")) + async def async_callable(request: CallableRequest): + return {"result": "success"} + + # Create a mock OPTIONS request for preflight + mock_request = AsyncMock() + mock_request.method = "OPTIONS" + + # Run the async function + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + response = loop.run_until_complete(async_callable(mock_request)) + # Response should have CORS headers + self.assertEqual(response.status_code, 200) + self.assertIn("Access-Control-Allow-Origin", response.headers) + self.assertEqual(response.headers["Access-Control-Allow-Origin"], "*") + finally: + loop.close() + + def test_async_on_call_error_handling(self): + """Test that async on_call handles HttpsError correctly.""" + + with patch.dict(sys.modules, { + 'starlette': Mock(), + 'starlette.responses': Mock(Response=MockStarletteResponse, JSONResponse=MockJSONResponse) + }): + @https_fn.on_call() + async def async_callable(request: CallableRequest): + raise HttpsError(FunctionsErrorCode.INVALID_ARGUMENT, "Bad input", {"field": "name"}) + + # Create a mock request + mock_request = AsyncMock() + mock_request.method = "POST" + mock_request.headers = {"content-type": "application/json"} + mock_request.body = AsyncMock(return_value=b'{"data": {}}') + + # Run the async function + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + response = loop.run_until_complete(async_callable(mock_request)) + # Response should be error with proper status + self.assertEqual(response.status_code, 400) # INVALID_ARGUMENT maps to 400 + finally: + loop.close() + + def test_re_exported_types(self): + """Test that common types are re-exported from aio.https_fn.""" + # Check that types are available + self.assertEqual(https_fn.HttpsError, HttpsError) + self.assertEqual(https_fn.FunctionsErrorCode, FunctionsErrorCode) + self.assertEqual(https_fn.CallableRequest, CallableRequest) + + def test_sync_on_request_in_aio_namespace(self): + """Test that sync functions work in aio namespace.""" + + with patch.dict(sys.modules, { + 'starlette': Mock(), + 'starlette.responses': Mock(Response=MockStarletteResponse, JSONResponse=MockJSONResponse) + }): + @https_fn.on_request() + def sync_func(request): + # This is a sync function in the aio namespace + return {"message": "Hello from sync"} + + # Create a mock request + mock_request = Mock() + + # Run the function (even though it's sync, the wrapper is async) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + response = loop.run_until_complete(sync_func(mock_request)) + self.assertEqual(response.status_code, 200) + finally: + loop.close() + + def test_multiple_async_functions_in_same_module(self): + """Test that multiple async functions can be defined in the same module.""" + + @https_fn.on_request() + async def func1(request): + return {"function": "1"} + + @https_fn.on_request() + async def func2(request): + return {"function": "2"} + + @https_fn.on_call() + async def func3(request): + return {"function": "3"} + + # Check that all functions have proper endpoints + self.assertEqual(func1.__firebase_endpoint__.entryPoint, "func1") + self.assertEqual(func2.__firebase_endpoint__.entryPoint, "func2") + self.assertEqual(func3.__firebase_endpoint__.entryPoint, "func3") + + # All should have asgi=True + self.assertTrue(func1.__firebase_endpoint__.asgi) + self.assertTrue(func2.__firebase_endpoint__.asgi) + self.assertTrue(func3.__firebase_endpoint__.asgi) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_https_fn.py b/tests/test_https_fn.py index 1748b367..0283918b 100644 --- a/tests/test_https_fn.py +++ b/tests/test_https_fn.py @@ -60,6 +60,7 @@ def init(): json={ "data": {"test": "value"}, }, + content_type="application/json", ).get_environ() request = Request(environ) decorated_func = https_fn.on_call()(func) diff --git a/tests/test_manifest.py b/tests/test_manifest.py index 681d90f5..8e21432c 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -114,6 +114,48 @@ def test_endpoint_to_dict(self): "Generated endpoint spec dict does not match expected dict." ) + def test_endpoint_with_asgi_field(self): + """Check that asgi field is included in manifest when set.""" + # Create an endpoint with asgi=True + async_endpoint = _manifest.ManifestEndpoint( + entryPoint="async_func", + platform="gcfv2", + region=["us-central1"], + asgi=True, + httpsTrigger={}, + ) + + # Convert to dict + # pylint: disable=protected-access + endpoint_dict = _manifest._dataclass_to_spec(async_endpoint) + + # Check that asgi field is present and True + assert "asgi" in endpoint_dict, "asgi field should be present in manifest" + assert endpoint_dict["asgi"] is True, "asgi field should be True" + + # Check other fields are preserved + assert endpoint_dict["entryPoint"] == "async_func" + assert endpoint_dict["platform"] == "gcfv2" + assert endpoint_dict["region"] == ["us-central1"] + assert endpoint_dict["httpsTrigger"] == {} + + def test_endpoint_without_asgi_field(self): + """Check that asgi field is not included when None.""" + # Create an endpoint without asgi (default None) + sync_endpoint = _manifest.ManifestEndpoint( + entryPoint="sync_func", + platform="gcfv2", + region=["us-central1"], + httpsTrigger={}, + ) + + # Convert to dict + # pylint: disable=protected-access + endpoint_dict = _manifest._dataclass_to_spec(sync_endpoint) + + # Check that asgi field is not present + assert "asgi" not in endpoint_dict, "asgi field should not be present when None" + def test_endpoint_expressions(self): """Check Expression values convert to CEL strings.""" max_param = _params.IntParam("MAX")