Skip to content

Commit bcf9e2e

Browse files
committed
Add GradientDispatch bridge for custom gradient adapter dispatch
1 parent d5d2ba3 commit bcf9e2e

2 files changed

Lines changed: 101 additions & 0 deletions

File tree

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package org.tensorflow.op;
2+
3+
import java.lang.reflect.Constructor;
4+
import java.util.List;
5+
import java.util.concurrent.ConcurrentHashMap;
6+
import java.util.concurrent.ConcurrentMap;
7+
import org.tensorflow.AbstractGradientAdapter;
8+
import org.tensorflow.Graph;
9+
import org.tensorflow.GraphOperation;
10+
import org.tensorflow.Operand;
11+
import org.tensorflow.Output;
12+
import org.tensorflow.internal.c_api.TFJ_Scope;
13+
14+
final class DispatchingGradientAdapter extends AbstractGradientAdapter {
15+
16+
private final ConcurrentMap<String, RawCustomGradient> raw = new ConcurrentHashMap<>();
17+
private final ConcurrentMap<String, TypedEntry<?>> typed = new ConcurrentHashMap<>();
18+
19+
static final class TypedEntry<T extends RawOpInputs<?>> {
20+
final CustomGradient<T> grad;
21+
final Class<T> inputClass;
22+
final Constructor<T> ctor;
23+
24+
TypedEntry(CustomGradient<T> grad, Class<T> inputClass) {
25+
this.grad = grad;
26+
this.inputClass = inputClass;
27+
try {
28+
this.ctor = inputClass.getConstructor(org.tensorflow.GraphOperation.class);
29+
} catch (NoSuchMethodException e) {
30+
throw new IllegalArgumentException(
31+
"Inputs class " + inputClass.getName() + " must have a public ctor(GraphOperation).", e);
32+
}
33+
}
34+
}
35+
36+
void putRaw(String opType, RawCustomGradient g) {
37+
raw.put(opType, g);
38+
}
39+
40+
<T extends RawOpInputs<?>> void putTyped(String opType, CustomGradient<T> g, Class<T> inputClass) {
41+
typed.put(opType, new TypedEntry<>(g, inputClass));
42+
}
43+
44+
@Override
45+
protected List<Operand<?>> apply(
46+
Graph graph, TFJ_Scope scope, GraphOperation operation, List<Output<?>> gradInputs) {
47+
48+
final String opType = operation.type();
49+
50+
RawCustomGradient rg = raw.get(opType);
51+
if (rg != null) {
52+
// NativeScope & Ops constructors are package-private => must be in org.tensorflow.op
53+
Scope nativeScope = new NativeScope(scope, graph, operation.name()).withSubScope(operation.name());
54+
return rg.call(new Ops(nativeScope), operation, gradInputs);
55+
}
56+
57+
@SuppressWarnings("rawtypes")
58+
TypedEntry te = typed.get(opType);
59+
if (te != null) {
60+
return applyTyped(graph, scope, operation, gradInputs, te);
61+
}
62+
63+
throw new IllegalStateException("No Java custom gradient registered for op type: " + opType);
64+
}
65+
66+
private <T extends RawOpInputs<?>> List<Operand<?>> applyTyped(
67+
Graph graph, TFJ_Scope scope, GraphOperation operation, List<Output<?>> gradInputs, TypedEntry<T> te) {
68+
try {
69+
T inputs = te.ctor.newInstance(operation);
70+
Scope nativeScope = new NativeScope(scope, graph, operation.name()).withSubScope(operation.name());
71+
return te.grad.call(new Ops(nativeScope), inputs, gradInputs);
72+
} catch (ReflectiveOperationException e) {
73+
throw new RuntimeException("Failed to instantiate inputs for " + te.inputClass.getName(), e);
74+
}
75+
}
76+
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package org.tensorflow.op;
2+
3+
import org.tensorflow.internal.c_api.TFJ_GradFuncAdapter;
4+
5+
/** Public bridge to a single native gradient adapter. */
6+
public final class GradientDispatch {
7+
8+
// package-private adapter that can access NativeScope/Ops constructors
9+
static final DispatchingGradientAdapter ADAPTER = new DispatchingGradientAdapter();
10+
11+
private GradientDispatch() {}
12+
13+
public static TFJ_GradFuncAdapter adapter() {
14+
return ADAPTER;
15+
}
16+
17+
public static void putRaw(String opType, RawCustomGradient gradient) {
18+
ADAPTER.putRaw(opType, gradient);
19+
}
20+
21+
public static <T extends RawOpInputs<?>> void putTyped(
22+
String opType, CustomGradient<T> gradient, Class<T> inputClass) {
23+
ADAPTER.putTyped(opType, gradient, inputClass);
24+
}
25+
}

0 commit comments

Comments
 (0)