@@ -179,10 +179,19 @@ LibHandle& getActiveHandle() {
179179 return activeHandle;
180180}
181181
182+ LibHandle& getPreviousHandle () {
183+ thread_local LibHandle previousHandle =
184+ AFSymbolManager::getInstance ().getDefaultHandle ();
185+ return previousHandle;
186+ }
187+
182188AFSymbolManager::AFSymbolManager ()
183- : defaultHandle(nullptr )
189+ : bkndHandles{}
190+ , defaultHandle(nullptr )
184191 , numBackends(0 )
192+ , newCustomHandleIndex(NUM_BACKENDS)
185193 , backendsAvailable(0 )
194+ , defaultBackend(AF_BACKEND_DEFAULT)
186195 , logger(loggerFactory(" unified" )) {
187196 // In order of priority.
188197 static const af_backend order[] = {AF_BACKEND_CUDA, AF_BACKEND_OPENCL,
@@ -229,20 +238,87 @@ af_err setBackend(af::Backend bknd) {
229238 auto & instance = AFSymbolManager::getInstance ();
230239 if (bknd == AF_BACKEND_DEFAULT) {
231240 if (instance.getDefaultHandle ()) {
232- getActiveHandle () = instance.getDefaultHandle ();
233- getActiveBackend () = instance.getDefaultBackend ();
241+ getPreviousHandle () = getActiveHandle ();
242+ getActiveHandle () = instance.getDefaultHandle ();
243+ getActiveBackend () = instance.getDefaultBackend ();
234244 return AF_SUCCESS;
235245 } else {
236- UNIFIED_ERROR_LOAD_LIB ();
246+ UNIFIED_ERROR_LOAD_LIB (AF_ERR_NO_TGT_BKND_LIB );
237247 }
238248 }
239249 int idx = bknd >> 1U ; // Convert 1, 2, 4 -> 0, 1, 2
240250 if (instance.getHandle (idx)) {
241- getActiveHandle () = instance.getHandle (idx);
251+ getPreviousHandle () = getActiveHandle ();
252+ getActiveHandle () = instance.getHandle (idx);
253+ getActiveBackend () = bknd;
254+ return AF_SUCCESS;
255+ } else {
256+ UNIFIED_ERROR_LOAD_LIB (AF_ERR_NO_TGT_BKND_LIB);
257+ }
258+ }
259+
260+ af_err AFSymbolManager::addBackendLibrary (const char * lib_path) {
261+ if ((newCustomHandleIndex + 1 ) > MAX_BKND_HANDLES) {
262+ // No more space for an additional handle
263+ UNIFIED_ERROR_LOAD_LIB (AF_ERR_BKND_LIB_LIST_FULL);
264+ }
265+
266+ string show_flag = getEnvVar (" AF_SHOW_LOAD_PATH" );
267+ bool show_load_path = show_flag == " 1" ;
268+
269+ typedef af_err (*func)(int *);
270+ LibHandle handle = nullptr ;
271+ if ((handle = loadLibrary (lib_path))) {
272+ func count_func =
273+ (func)getFunctionPointer (handle, " af_get_device_count" );
274+ if (count_func) {
275+ int count = 0 ;
276+ count_func (&count);
277+ AF_TRACE (" Device Count: {}." , count);
278+ if (count == 0 ) {
279+ // No available device for this backend
280+ handle = nullptr ;
281+ UNIFIED_ERROR_LOAD_LIB (AF_ERR_BKND_NO_DEVICE);
282+ }
283+ } else {
284+ // Loaded library is invalid
285+ handle = nullptr ;
286+ UNIFIED_ERROR_LOAD_LIB (AF_ERR_BKND_LIB_INVALID);
287+ }
288+
289+ if (show_load_path) { printf (" Using %s\n " , lib_path); }
290+
291+ bkndHandles[newCustomHandleIndex] = handle;
292+ newCustomHandleIndex++;
293+
294+ return AF_SUCCESS;
295+ } else {
296+ // loadLibrary failed, maybe because path is invalid or another reason
297+ UNIFIED_ERROR_LOAD_LIB (AF_ERR_LOAD_LIB);
298+ }
299+ }
300+
301+ af_err AFSymbolManager::setBackendLibrary (int lib_idx) {
302+ typedef af_err (*func)(af_backend*);
303+ int actual_idx = lib_idx + NUM_BACKENDS;
304+
305+ if (actual_idx >= MAX_BKND_HANDLES) {
306+ // lib_idx more than the capacity of bkndHandles
307+ UNIFIED_ERROR_LOAD_LIB (AF_ERR_BKND_LIB_IDX_INVALID);
308+ }
309+
310+ if (bkndHandles[actual_idx]) {
311+ getPreviousHandle () = getActiveHandle ();
312+ getActiveHandle () = getHandle (actual_idx);
313+ af_backend bknd = (af_backend)0 ;
314+ func get_backend_func = (func)getFunctionPointer (
315+ getActiveHandle (), " af_get_active_backend" );
316+ if (get_backend_func) { get_backend_func (&bknd); }
242317 getActiveBackend () = bknd;
243318 return AF_SUCCESS;
244319 } else {
245- UNIFIED_ERROR_LOAD_LIB ();
320+ // lib_idx not pointing to a library yet
321+ UNIFIED_ERROR_LOAD_LIB (AF_ERR_NO_TGT_BKND_LIB);
246322 }
247323}
248324
0 commit comments