Skip to content

Commit fbcbd6c

Browse files
authored
Merge pull request wso2#1113 from Arshardh/gateway-proxy
Add gateway changes for proxy provider auth
2 parents 083d444 + 34efce6 commit fbcbd6c

13 files changed

Lines changed: 244 additions & 54 deletions

gateway/gateway-controller/api/openapi.yaml

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3337,12 +3337,37 @@ components:
33373337
properties:
33383338
type:
33393339
type: string
3340-
enum: [ api-key, bearer ]
3340+
enum: [ api-key ]
33413341
header:
33423342
type: string
33433343
value:
33443344
type: string
33453345

3346+
LLMUpstreamAuth:
3347+
type: object
3348+
required:
3349+
- type
3350+
properties:
3351+
type:
3352+
type: string
3353+
enum: [ api-key ]
3354+
header:
3355+
type: string
3356+
value:
3357+
type: string
3358+
3359+
LLMProxyProvider:
3360+
type: object
3361+
required:
3362+
- id
3363+
properties:
3364+
id:
3365+
type: string
3366+
description: Unique id of a deployed llm provider
3367+
example: wso2-openai-provider
3368+
auth:
3369+
$ref: '#/components/schemas/LLMUpstreamAuth'
3370+
33463371
LLMAccessControl:
33473372
type: object
33483373
required:
@@ -3484,9 +3509,7 @@ components:
34843509
maxLength: 253
34853510
example: "api.openai"
34863511
provider:
3487-
type: string
3488-
description: Unique id of a deployed llm provider
3489-
example: wso2-openai-provider
3512+
$ref: '#/components/schemas/LLMProxyProvider'
34903513
policies:
34913514
type: array
34923515
description: List of policies applied only to this operation (overrides or adds to API-level policies)

gateway/gateway-controller/pkg/api/generated/generated.go

Lines changed: 26 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

gateway/gateway-controller/pkg/api/handlers/handlers.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1469,7 +1469,7 @@ func (s *APIServer) ListLLMProxies(c *gin.Context, params api.ListLLMProxiesPara
14691469
Id: stringPtr(proxy.Metadata.Name),
14701470
DisplayName: stringPtr(proxy.Spec.DisplayName),
14711471
Version: stringPtr(proxy.Spec.Version),
1472-
Provider: stringPtr(proxy.Spec.Provider),
1472+
Provider: stringPtr(proxy.Spec.Provider.Id),
14731473
Status: &status,
14741474
CreatedAt: timePtr(cfg.CreatedAt),
14751475
UpdatedAt: timePtr(cfg.UpdatedAt),

gateway/gateway-controller/pkg/api/handlers/list_operations_test.go

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,20 @@ func TestListLLMProvidersWithData(t *testing.T) {
257257
ID: "provider1",
258258
Kind: "LlmProvider",
259259
Status: "active",
260-
SourceConfiguration: map[string]interface{}{
261-
"metadata": map[string]interface{}{
262-
"name": "openai-provider",
260+
SourceConfiguration: api.LLMProviderConfiguration{
261+
ApiVersion: api.LLMProviderConfigurationApiVersionGatewayApiPlatformWso2Comv1alpha1,
262+
Kind: api.LlmProvider,
263+
Metadata: api.Metadata{
264+
Name: "openai-provider",
263265
},
264-
"spec": map[string]interface{}{
265-
"displayName": "OpenAI Provider",
266-
"version": "1.0.0",
267-
"template": "openai-template",
266+
Spec: api.LLMProviderConfigData{
267+
DisplayName: "OpenAI Provider",
268+
Version: "1.0.0",
269+
Template: "openai-template",
270+
Upstream: api.LLMProviderConfigData_Upstream{
271+
Url: stringPtr("https://example.com"),
272+
},
273+
AccessControl: api.LLMAccessControl{Mode: api.AllowAll},
268274
},
269275
},
270276
CreatedAt: now,
@@ -313,14 +319,18 @@ func TestListLLMProxiesWithData(t *testing.T) {
313319
ID: "proxy1",
314320
Kind: "LlmProxy",
315321
Status: "active",
316-
SourceConfiguration: map[string]interface{}{
317-
"metadata": map[string]interface{}{
318-
"name": "llm-proxy-1",
322+
SourceConfiguration: api.LLMProxyConfiguration{
323+
ApiVersion: api.LLMProxyConfigurationApiVersionGatewayApiPlatformWso2Comv1alpha1,
324+
Kind: api.LlmProxy,
325+
Metadata: api.Metadata{
326+
Name: "llm-proxy-1",
319327
},
320-
"spec": map[string]interface{}{
321-
"displayName": "LLM Proxy 1",
322-
"version": "1.0.0",
323-
"provider": "openai-provider",
328+
Spec: api.LLMProxyConfigData{
329+
DisplayName: "LLM Proxy 1",
330+
Version: "1.0.0",
331+
Provider: api.LLMProxyProvider{
332+
Id: "openai-provider",
333+
},
324334
},
325335
},
326336
CreatedAt: now,

gateway/gateway-controller/pkg/config/label_validation_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,9 @@ func TestLabelValidationForAllTypes(t *testing.T) {
123123
Spec: api.LLMProxyConfigData{
124124
DisplayName: "Test Proxy",
125125
Version: "v1.0",
126-
Provider: "test-provider",
126+
Provider: api.LLMProxyProvider{
127+
Id: "test-provider",
128+
},
127129
},
128130
}
129131

gateway/gateway-controller/pkg/config/llm_validator.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -477,16 +477,16 @@ func (v *LLMValidator) validateProxyData(spec *api.LLMProxyConfigData) []Validat
477477
}
478478

479479
// Validate provider id
480-
if spec.Provider == "" {
480+
if spec.Provider.Id == "" {
481481
errors = append(errors, ValidationError{
482-
Field: "spec.provider",
482+
Field: "spec.provider.id",
483483
Message: "Provider is required",
484484
})
485485
return errors
486-
} else if !v.metadataNameRegex.MatchString(spec.Provider) {
486+
} else if !v.metadataNameRegex.MatchString(spec.Provider.Id) {
487487
errors = append(errors, ValidationError{
488-
Field: "spec.provider",
489-
Message: "spec.provider must consist of lowercase alphanumeric characters, hyphens, or dots, and must start and end with an alphanumeric character",
488+
Field: "spec.provider.id",
489+
Message: "spec.provider.id must consist of lowercase alphanumeric characters, hyphens, or dots, and must start and end with an alphanumeric character",
490490
})
491491
}
492492

gateway/gateway-controller/pkg/config/llm_validator_additional_test.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,9 @@ func TestLLMValidator_ValidateLLMProxy_NilAndMetadata(t *testing.T) {
360360
Metadata: api.Metadata{Name: longName},
361361
Spec: api.LLMProxyConfigData{
362362
DisplayName: "Test",
363-
Provider: "test-provider",
363+
Provider: api.LLMProxyProvider{
364+
Id: "test-provider",
365+
},
364366
},
365367
}
366368

@@ -399,7 +401,9 @@ func TestLLMValidator_ValidateProxyData_NilAndVersion(t *testing.T) {
399401
spec := &api.LLMProxyConfigData{
400402
DisplayName: "Test Proxy",
401403
Version: "invalid-version",
402-
Provider: "test-provider",
404+
Provider: api.LLMProxyProvider{
405+
Id: "test-provider",
406+
},
403407
}
404408

405409
errors := validator.validateProxyData(spec)
@@ -424,13 +428,15 @@ func TestLLMValidator_ProxyProviderValidation(t *testing.T) {
424428
t.Run("Invalid provider name format", func(t *testing.T) {
425429
spec := &api.LLMProxyConfigData{
426430
DisplayName: "Test Proxy",
427-
Provider: "Invalid_Provider_Name",
431+
Provider: api.LLMProxyProvider{
432+
Id: "Invalid_Provider_Name",
433+
},
428434
}
429435

430436
errors := validator.validateProxyData(spec)
431437
found := false
432438
for _, err := range errors {
433-
if err.Field == "spec.provider" {
439+
if err.Field == "spec.provider.id" {
434440
found = true
435441
break
436442
}

gateway/gateway-controller/pkg/config/mcp_validator.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ func (v *MCPValidator) validateUpstream(fieldPrefix string, upstream *api.MCPPro
249249
})
250250
}
251251

252-
if auth.Type == api.MCPProxyConfigDataUpstreamAuthTypeBearer {
252+
if auth.Type == api.MCPProxyConfigDataUpstreamAuthType("bearer") {
253253
// For Bearer token, value should start with "Bearer or "bearer "
254254
if auth.Value != nil &&
255255
!strings.HasPrefix(*auth.Value, "Bearer ") && !strings.HasPrefix(*auth.Value, "bearer ") {

gateway/gateway-controller/pkg/config/mcp_validator_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ func TestMCPValidator_ValidateUpstreamAuth(t *testing.T) {
557557
{
558558
name: "Valid bearer auth",
559559
auth: &authConfig{
560-
Type: api.MCPProxyConfigDataUpstreamAuthTypeBearer,
560+
Type: api.MCPProxyConfigDataUpstreamAuthType("bearer"),
561561
Header: stringPtr("Authorization"),
562562
Value: stringPtr("Bearer token123"),
563563
},
@@ -566,7 +566,7 @@ func TestMCPValidator_ValidateUpstreamAuth(t *testing.T) {
566566
{
567567
name: "Bearer auth without Bearer prefix",
568568
auth: &authConfig{
569-
Type: api.MCPProxyConfigDataUpstreamAuthTypeBearer,
569+
Type: api.MCPProxyConfigDataUpstreamAuthType("bearer"),
570570
Header: stringPtr("Authorization"),
571571
Value: stringPtr("token123"),
572572
},

gateway/gateway-controller/pkg/utils/llm_transformer.go

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ func (t *LLMProviderTransformer) transformProxy(proxy *api.LLMProxyConfiguration
5454
output *api.APIConfiguration) (*api.APIConfiguration, error) {
5555

5656
// Step 1: Retrieve and validate provider reference
57-
provider := t.store.GetByKindAndHandle(string(api.LlmProvider), proxy.Spec.Provider)
57+
provider := t.store.GetByKindAndHandle(string(api.LlmProvider), proxy.Spec.Provider.Id)
5858
if provider == nil {
59-
return nil, fmt.Errorf("failed to retrieve provider by id '%s'", proxy.Spec.Provider)
59+
return nil, fmt.Errorf("failed to retrieve provider by id '%s'", proxy.Spec.Provider.Id)
6060
}
6161

6262
// Step 1.5: Get provider's template and extract template params
@@ -135,6 +135,43 @@ func (t *LLMProviderTransformer) transformProxy(proxy *api.LLMProxyConfiguration
135135
}
136136
}
137137

138+
// Step 3.5: Apply proxy-level provider auth for proxy->provider loopback upstream
139+
if proxy.Spec.Provider.Auth != nil {
140+
auth := proxy.Spec.Provider.Auth
141+
switch auth.Type {
142+
case api.LLMUpstreamAuthTypeApiKey:
143+
if auth.Value == nil || *auth.Value == "" {
144+
return nil, fmt.Errorf("provider.auth.value is required")
145+
}
146+
header := ""
147+
if auth.Header != nil {
148+
header = *auth.Header
149+
}
150+
params, err := GetUpstreamAuthApikeyPolicyParams(header, *auth.Value)
151+
if err != nil {
152+
return nil, fmt.Errorf("failed to build upstream auth params: %w", err)
153+
}
154+
policyVersion, err := t.resolvePolicyVersion(constants.UPSTREAM_AUTH_APIKEY_POLICY_NAME)
155+
if err != nil {
156+
return nil, err
157+
}
158+
mh := api.Policy{
159+
Name: constants.UPSTREAM_AUTH_APIKEY_POLICY_NAME,
160+
Version: policyVersion,
161+
Params: &params,
162+
}
163+
if spec.Policies == nil {
164+
spec.Policies = &[]api.Policy{mh}
165+
} else {
166+
existing := *spec.Policies
167+
existing = append(existing, mh)
168+
spec.Policies = &existing
169+
}
170+
default:
171+
return nil, fmt.Errorf("unsupported upstream auth type: %s", auth.Type)
172+
}
173+
}
174+
138175
// Step 4: Build operations (AllowAll mode without exceptions)
139176
// This follows the same pattern as transformProvider AllowAll mode but simplified
140177
var ops []api.Operation

0 commit comments

Comments
 (0)