Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions src/it/java/io/weaviate/integration/RbacITest.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,6 @@ public void test_roles_Lifecycle() throws IOException {
Permission.groups("my-group", GroupType.OIDC, GroupsPermission.Action.READ));
});

requireAtLeast(Weaviate.Version.V132, () -> {
permissions.add(
Permission.aliases("ThingsAlias", myCollection, AliasesPermission.Action.CREATE));
});
requireAtLeast(Weaviate.Version.V133, () -> {
permissions.add(
Permission.groups("my-group", GroupType.OIDC, GroupsPermission.Action.READ));
});

// Act: create role
client.roles.create(nsRole, permissions);

Expand Down
20 changes: 20 additions & 0 deletions src/main/java/io/weaviate/client6/v1/api/Authentication.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,26 @@ public static Authentication resourceOwnerPassword(String username, String passw
};
}

/**
* Authenticate using Resource Owner Password Credentials authorization grant.
*
* @param clientSecret Client secret.
* @param username Resource owner username.
* @param password Resource owner password.
* @param scopes Client scopes.
*
* @return Authentication provider.
* @throws WeaviateOAuthException if an error occurred at any point of the token
* exchange process.
*/
public static Authentication resourceOwnerPasswordCredentials(String clientSecret, String username, String password,
List<String> scopes) {
return transport -> {
OidcConfig oidc = OidcUtils.getConfig(transport).withScopes(scopes).withScopes("offline_access");
return TokenProvider.resourceOwnerPasswordCredentials(oidc, clientSecret, username, password);
};
}

/**
* Authenticate using Client Credentials authorization grant.
*
Expand Down
21 changes: 17 additions & 4 deletions src/main/java/io/weaviate/client6/v1/api/Config.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import io.weaviate.client6.v1.internal.BuildInfo;
import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.Proxy;
import io.weaviate.client6.v1.internal.Timeout;
import io.weaviate.client6.v1.internal.TokenProvider;
import io.weaviate.client6.v1.internal.TransportOptions;
Expand All @@ -24,7 +25,8 @@ public record Config(
Map<String, String> headers,
Authentication authentication,
TrustManagerFactory trustManagerFactory,
Timeout timeout) {
Timeout timeout,
Proxy proxy) {

public static Config of(Function<Custom, ObjectBuilder<Config>> fn) {
return fn.apply(new Custom()).build();
Expand All @@ -40,23 +42,24 @@ private Config(Builder<?> builder) {
builder.headers,
builder.authentication,
builder.trustManagerFactory,
builder.timeout);
builder.timeout,
builder.proxy);
}

RestTransportOptions restTransportOptions() {
return restTransportOptions(null);
}

RestTransportOptions restTransportOptions(TokenProvider tokenProvider) {
return new RestTransportOptions(scheme, httpHost, httpPort, headers, tokenProvider, trustManagerFactory, timeout);
return new RestTransportOptions(scheme, httpHost, httpPort, headers, tokenProvider, trustManagerFactory, timeout, proxy);
}

GrpcChannelOptions grpcTransportOptions() {
return grpcTransportOptions(null);
}

GrpcChannelOptions grpcTransportOptions(TokenProvider tokenProvider) {
return new GrpcChannelOptions(scheme, grpcHost, grpcPort, headers, tokenProvider, trustManagerFactory, timeout);
return new GrpcChannelOptions(scheme, grpcHost, grpcPort, headers, tokenProvider, trustManagerFactory, timeout, proxy);
}

private abstract static class Builder<SelfT extends Builder<SelfT>> implements ObjectBuilder<Config> {
Expand All @@ -70,6 +73,7 @@ private abstract static class Builder<SelfT extends Builder<SelfT>> implements O
protected TrustManagerFactory trustManagerFactory;
protected Timeout timeout = new Timeout();
protected Map<String, String> headers = new HashMap<>();
protected Proxy proxy;

/**
* Set URL scheme. Subclasses may increase the visibility of this method to
Expand Down Expand Up @@ -175,6 +179,15 @@ public SelfT timeout(int initSeconds, int querySeconds, int insertSeconds) {
return (SelfT) this;
}

/**
* Set proxy for all requests.
*/
@SuppressWarnings("unchecked")
public SelfT proxy(Proxy proxy) {
this.proxy = proxy;
return (SelfT) this;
}

/**
* Weaviate will use the URL in this header to call Weaviate Embeddings
* Service if an appropriate vectorizer is configured for collection.
Expand Down
9 changes: 6 additions & 3 deletions src/main/java/io/weaviate/client6/v1/api/WeaviateClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,13 @@ public class WeaviateClient implements AutoCloseable {
public final WeaviateClusterClient cluster;

public WeaviateClient(Config config) {
RestTransportOptions restOpt;
RestTransportOptions restOpt = config.restTransportOptions();
GrpcChannelOptions grpcOpt;
if (config.authentication() == null) {
restOpt = config.restTransportOptions();
grpcOpt = config.grpcTransportOptions();
} else {
TokenProvider tokenProvider;
try (final var noAuthRest = new DefaultRestTransport(config.restTransportOptions())) {
try (final var noAuthRest = new DefaultRestTransport(restOpt)) {
tokenProvider = config.authentication().getTokenProvider(noAuthRest);
} catch (Exception e) {
// Generally exceptions are caught in TokenProvider internals.
Expand Down Expand Up @@ -126,6 +125,10 @@ public WeaviateClient(Config config) {
this.config = config;
}

public Config getConfig() {
return config;
}

/**
* Create {@link WeaviateClientAsync} with identical configurations.
* It is a shorthand for:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public CollectionHandle<Map<String, Object>> use(
return use(CollectionDescriptor.ofMap(collectionName), fn);
}

private <PropertiesT> CollectionHandle<PropertiesT> use(CollectionDescriptor<PropertiesT> collection,
public <PropertiesT> CollectionHandle<PropertiesT> use(CollectionDescriptor<PropertiesT> collection,
Comment thread
bevzzz marked this conversation as resolved.
Function<CollectionHandleDefaults.Builder, ObjectBuilder<CollectionHandleDefaults>> fn) {
return new CollectionHandle<>(restTransport, grpcTransport, collection, CollectionHandleDefaults.of(fn));
}
Expand Down
15 changes: 15 additions & 0 deletions src/main/java/io/weaviate/client6/v1/internal/Proxy.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package io.weaviate.client6.v1.internal;

import javax.annotation.Nullable;

public record Proxy(
String scheme,
String host,
int port,
@Nullable String username,
@Nullable String password
) {
public Proxy(String host, int port) {
this("http", host, port, null, null);
}
}
18 changes: 18 additions & 0 deletions src/main/java/io/weaviate/client6/v1/internal/TokenProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,24 @@ public static TokenProvider resourceOwnerPassword(OidcConfig oidc, String userna
return background(reuse(null, exchange(oidc, passwordGrant), DEFAULT_EARLY_EXPIRY));
}

/**
* Create a TokenProvider that uses Resource Owner Password Credentials authorization grant.
*
* @param oidc OIDC config.
* @param clientSecret Client secret.
* @param username Resource owner username.
* @param password Resource owner password.
*
* @return Internal TokenProvider implementation.
* @throws WeaviateOAuthException if an error occurred at any point of the token
* exchange process.
*/
public static TokenProvider resourceOwnerPasswordCredentials(OidcConfig oidc, String clientSecret, String username,
String password) {
final var passwordGrant = NimbusTokenProvider.resouceOwnerPasswordCredentials(oidc, clientSecret, username, password);
return background(reuse(null, exchange(oidc, passwordGrant), DEFAULT_EARLY_EXPIRY));
}

/**
* Create a TokenProvider that uses Client Credentials authorization grant.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@ public abstract class TransportOptions<H> {
protected final H headers;
protected final TrustManagerFactory trustManagerFactory;
protected final Timeout timeout;
protected final Proxy proxy;

protected TransportOptions(String scheme, String host, int port, H headers, TokenProvider tokenProvider,
TrustManagerFactory tmf, Timeout timeout) {
TrustManagerFactory tmf, Timeout timeout, Proxy proxy) {
this.scheme = scheme;
this.host = host;
this.port = port;
this.tokenProvider = tokenProvider;
this.headers = headers;
this.timeout = timeout;
this.trustManagerFactory = tmf;
this.proxy = proxy;
}

public boolean isSecure() {
Expand Down Expand Up @@ -58,6 +60,11 @@ public TrustManagerFactory trustManagerFactory() {
return this.trustManagerFactory;
}

@Nullable
public Proxy proxy() {
return this.proxy;
}

/**
* isWeaviateDomain returns true if the host matches weaviate.io,
* semi.technology, or weaviate.cloud domain.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;

import io.grpc.HttpConnectProxiedSocketAddress;
import io.grpc.ManagedChannel;
import io.grpc.StatusRuntimeException;
import io.grpc.netty.shaded.io.grpc.netty.GrpcSslContexts;
Expand All @@ -22,12 +22,19 @@
import io.grpc.stub.MetadataUtils;
import io.grpc.stub.StreamObserver;
import io.weaviate.client6.v1.api.WeaviateApiException;
import io.weaviate.client6.v1.internal.Proxy;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateBlockingStub;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateGrpc.WeaviateFutureStub;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamReply;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBatch.BatchStreamRequest;

import javax.net.ssl.SSLException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;

public final class DefaultGrpcTransport implements GrpcTransport {
/**
* ListenableFuture callbacks are executed
Expand Down Expand Up @@ -92,7 +99,7 @@ public <RequestT, RequestM, ReplyM, ResponseT> CompletableFuture<ResponseT> perf
var method = rpc.methodAsync();
var stub = applyTimeout(futureStub, rpc);
var reply = method.apply(stub, message);
return toCompletableFuture(reply).thenApply(r -> rpc.unmarshal(r));
return toCompletableFuture(reply).thenApply(rpc::unmarshal);
}

/**
Expand Down Expand Up @@ -146,6 +153,27 @@ private static ManagedChannel buildChannel(GrpcChannelOptions transportOptions)
channel.sslContext(sslCtx);
}

if (transportOptions.proxy() != null) {
Proxy proxy = transportOptions.proxy();
if ("http".equals(proxy.scheme()) || "https".equals(proxy.scheme())) {
final SocketAddress proxyAddress = new InetSocketAddress(proxy.host(), proxy.port());
channel.proxyDetector(targetAddress -> {
if (targetAddress instanceof InetSocketAddress) {
HttpConnectProxiedSocketAddress.Builder builder = HttpConnectProxiedSocketAddress.newBuilder()
.setProxyAddress(proxyAddress)
.setTargetAddress((InetSocketAddress) targetAddress);

if (proxy.username() != null && proxy.password() != null) {
builder.setUsername(proxy.username());
builder.setPassword(proxy.password());
}
return builder.build();
}
return null;
});
}
}

channel.intercept(MetadataUtils.newAttachHeadersInterceptor(transportOptions.headers()));
return channel.build();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import javax.net.ssl.TrustManagerFactory;

import io.grpc.Metadata;
import io.weaviate.client6.v1.internal.Proxy;
import io.weaviate.client6.v1.internal.Timeout;
import io.weaviate.client6.v1.internal.TokenProvider;
import io.weaviate.client6.v1.internal.TransportOptions;
Expand All @@ -14,20 +15,19 @@ public class GrpcChannelOptions extends TransportOptions<Metadata> {
private final OptionalInt maxMessageSize;

public GrpcChannelOptions(String scheme, String host, int port, Map<String, String> headers,
TokenProvider tokenProvider, TrustManagerFactory tmf, Timeout timeout) {
this(scheme, host, port, buildMetadata(headers), tokenProvider, tmf, null, timeout);
TokenProvider tokenProvider, TrustManagerFactory tmf, Timeout timeout, Proxy proxy) {
this(scheme, host, port, buildMetadata(headers), tokenProvider, tmf, OptionalInt.empty(), timeout, proxy);
}

private GrpcChannelOptions(String scheme, String host, int port, Metadata headers,
TokenProvider tokenProvider, TrustManagerFactory tmf, OptionalInt maxMessageSize, Timeout timeout) {
super(scheme, host, port, headers, tokenProvider, tmf, timeout);
TokenProvider tokenProvider, TrustManagerFactory tmf, OptionalInt maxMessageSize, Timeout timeout, Proxy proxy) {
super(scheme, host, port, headers, tokenProvider, tmf, timeout, proxy);
this.maxMessageSize = maxMessageSize;
}

public GrpcChannelOptions withMaxMessageSize(int maxMessageSize) {
return new GrpcChannelOptions(scheme, host, port, headers, tokenProvider, trustManagerFactory,
OptionalInt.of(maxMessageSize),
timeout);
OptionalInt.of(maxMessageSize), timeout, proxy);
}

public OptionalInt maxMessageSize() {
Expand Down
20 changes: 16 additions & 4 deletions src/main/java/io/weaviate/client6/v1/internal/oidc/OidcConfig.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.weaviate.client6.v1.internal.oidc;

import io.weaviate.client6.v1.internal.Proxy;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
Expand All @@ -11,16 +13,26 @@
public record OidcConfig(
String clientId,
String providerMetadata,
Set<String> scopes) {
Set<String> scopes,
Proxy proxy) {

public OidcConfig(String clientId, String providerMetadata, Set<String> scopes) {
public OidcConfig(String clientId, String providerMetadata, Set<String> scopes, Proxy proxy) {
this.clientId = clientId;
this.providerMetadata = providerMetadata;
this.scopes = scopes != null ? Set.copyOf(scopes) : Collections.emptySet();
this.proxy = proxy;
}

public OidcConfig(String clientId, String providerMetadata, Set<String> scopes) {
this(clientId, providerMetadata, scopes, null);
}

public OidcConfig(String clientId, String providerMetadata, List<String> scopes) {
this(clientId, providerMetadata, scopes == null ? null : new HashSet<>(scopes));
this(clientId, providerMetadata, scopes == null ? null : new HashSet<>(scopes), null);
}

public OidcConfig(String clientId, String providerMetadata, List<String> scopes, Proxy proxy) {
this(clientId, providerMetadata, scopes == null ? null : new HashSet<>(scopes), proxy);
}

/** Create a new OIDC config with extended scopes. */
Expand All @@ -31,6 +43,6 @@ public OidcConfig withScopes(String... scopes) {
/** Create a new OIDC config with extended scopes. */
public OidcConfig withScopes(List<String> scopes) {
var newScopes = Stream.concat(this.scopes.stream(), scopes.stream()).collect(Collectors.toSet());
return new OidcConfig(clientId, providerMetadata, newScopes);
return new OidcConfig(clientId, providerMetadata, newScopes, proxy);
}
}
Loading
Loading