c.f. numpy/numpy#31618 and scikit-image/scikit-image#8212 where this was originally found.
I used Claude to help debug this, I hope that's alright.
OpenBLAS uses templating to generate several variants of gemm_driver. If more than one of these variants are called concurrently this can corrupt internal state in OpenBLAS and the design intention that only one GEMM happens at a time.
OpenBLAS serializes concurrent multithreaded Level-3 calls with a mutex declared inside gemm_driver() (driver/level3/level3_thread.c):
static int gemm_driver(blas_arg_t *args, ...) {
...
static pthread_mutex_t level3_lock = PTHREAD_MUTEX_INITIALIZER; // ~line 590
...
pthread_mutex_lock(&level3_lock); // wraps job setup + exec_blas ~line 659
But gemm_driver is template-compiled once per transpose combination — GEMM_LOCAL is #defined to GEMM_NN, GEMM_NT, GEMM_TN, GEMM_TT, … (lines 58–76), producing four separate functions dgemm_thread_{nn,nt,tn,tt}. A function-local static gets one instance per compiled function, so there are four independent level3_lock mutexes.
Consider the following C reproducer:
#define _GNU_SOURCE
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <pthread.h>
#include <dlfcn.h>
typedef long bint;
typedef void (*dgemm_t)(int,int,int,bint,bint,bint,double,const double*,bint,
const double*,bint,double,double*,bint);
typedef void (*setnt_t)(int);
static dgemm_t dgemm;
enum { CblasRowMajor=101, CblasNoTrans=111, CblasTrans=112 };
#define M 262144 /* large enough that each GEMM goes multi-threaded */
static double *Att; /* 4 x M, feeds the TransA=Trans GEMM (lda=M) */
static double *Ann; /* M x 4, feeds the TransA=NoTrans GEMM (lda=4) */
static double Bmat[8], *ref;
static int NTHREADS = 8, NITER = 200, SAME_VARIANT = 0;
static long mismatches = 0;
static pthread_mutex_t mlock = PTHREAD_MUTEX_INITIALIZER;
/* GEMM #1: TransA=Trans, K=2 -> dgemm_thread_{tn/nt} variant */
/* GEMM #2: either Trans (same variant) or NoTrans (different variant), K=4 */
static void compute(double *o1, double *o2) {
dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, M, 2, 2,
1.0, Att, M, Bmat, 2, 0.0, o1, 2);
if (SAME_VARIANT)
dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, M, 2, 4,
1.0, Att, M, Bmat, 2, 0.0, o2, 2); /* same Trans variant */
else
dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, 2, 4,
1.0, Ann, 4, Bmat, 2, 0.0, o2, 2); /* different (NoTrans) variant */
}
static void *worker(void *unused) {
(void)unused;
double *o1 = malloc((size_t)M*2*sizeof(double));
double *o2 = malloc((size_t)M*2*sizeof(double));
long local = 0;
for (int i = 0; i < NITER; i++) {
compute(o1, o2);
if (memcmp(o1, ref, (size_t)M*2*sizeof(double)) != 0) local++;
}
free(o1); free(o2);
pthread_mutex_lock(&mlock); mismatches += local; pthread_mutex_unlock(&mlock);
return NULL;
}
int main(int argc, char **argv) {
if (argc > 1) NTHREADS = atoi(argv[1]);
if (argc > 2) NITER = atoi(argv[2]);
if (argc > 3) SAME_VARIANT = atoi(argv[3]);
void *h = dlopen(getenv("OB_LIB"), RTLD_NOW | RTLD_GLOBAL);
if (!h) { fprintf(stderr, "set OB_LIB to the OpenBLAS .so (%s)\n", dlerror()); return 2; }
/* try plain and scipy-suffixed symbol names */
dgemm = (dgemm_t)dlsym(h, "cblas_dgemm");
if (!dgemm) dgemm = (dgemm_t)dlsym(h, "scipy_cblas_dgemm64_");
setnt_t setnt = (setnt_t)dlsym(h, "openblas_set_num_threads");
if (!setnt) setnt = (setnt_t)dlsym(h, "scipy_openblas_set_num_threads64_");
if (!dgemm) { fprintf(stderr, "no cblas_dgemm symbol\n"); return 2; }
if (setnt) setnt(4);
Att = malloc((size_t)4*M*sizeof(double));
Ann = malloc((size_t)4*M*sizeof(double));
for (size_t i = 0; i < (size_t)4*M; i++) { Att[i] = (double)(i % 512); Ann[i] = (double)((i*7) % 512); }
for (int i = 0; i < 8; i++) Bmat[i] = 0.1*(i+1);
ref = malloc((size_t)M*2*sizeof(double));
{ double *tmp = malloc((size_t)M*2*sizeof(double)); compute(ref, tmp); free(tmp); }
pthread_t th[256];
for (int i = 0; i < NTHREADS; i++) pthread_create(&th[i], NULL, worker, NULL);
for (int i = 0; i < NTHREADS; i++) pthread_join(th[i], NULL);
printf("%-26s %d threads x %d iters: %ld mismatches (%s)\n",
SAME_VARIANT ? "same transpose variant:" : "different variants:",
NTHREADS, NITER, mismatches, mismatches ? "RACE" : "clean");
return mismatches ? 1 : 0;
}
You can compile it and run it like so:
cc -O2 openblas_gemm_variant_race.c -lpthread -ldl -o race
goldbaum at Nathans-MBP in ~/Downloads
○ OB_LIB=/opt/homebrew/opt/openblas/lib/libopenblas.dylib ./race 8 200 0
different variants: 8 threads x 200 iters: 66 mismatches (RACE)
goldbaum at Nathans-MBP in ~/Downloads
○ OB_LIB=/opt/homebrew/opt/openblas/lib/libopenblas.dylib ./race 8 200 1
same transpose variant: 8 threads x 200 iters: 0 mismatches (clean)
c.f. numpy/numpy#31618 and scikit-image/scikit-image#8212 where this was originally found.
I used Claude to help debug this, I hope that's alright.
OpenBLAS uses templating to generate several variants of
gemm_driver. If more than one of these variants are called concurrently this can corrupt internal state in OpenBLAS and the design intention that only one GEMM happens at a time.OpenBLAS serializes concurrent multithreaded Level-3 calls with a mutex declared inside
gemm_driver()(driver/level3/level3_thread.c):But
gemm_driveris template-compiled once per transpose combination —GEMM_LOCALis #defined toGEMM_NN,GEMM_NT,GEMM_TN,GEMM_TT, … (lines 58–76), producing four separate functionsdgemm_thread_{nn,nt,tn,tt}. A function-local static gets one instance per compiled function, so there are four independentlevel3_lockmutexes.Consider the following C reproducer:
You can compile it and run it like so: