Skip to content

gemm_driver templating breaks GEMM locking when different GEMMs happen concurrently #5836

@ngoldbaum

Description

@ngoldbaum

c.f. numpy/numpy#31618 and scikit-image/scikit-image#8212 where this was originally found.

I used Claude to help debug this, I hope that's alright.

OpenBLAS uses templating to generate several variants of gemm_driver. If more than one of these variants are called concurrently this can corrupt internal state in OpenBLAS and the design intention that only one GEMM happens at a time.

OpenBLAS serializes concurrent multithreaded Level-3 calls with a mutex declared inside gemm_driver() (driver/level3/level3_thread.c):

static int gemm_driver(blas_arg_t *args, ...) {
  ...
  static pthread_mutex_t level3_lock = PTHREAD_MUTEX_INITIALIZER;   // ~line 590
  ...
  pthread_mutex_lock(&level3_lock);     // wraps job setup + exec_blas  ~line 659

But gemm_driver is template-compiled once per transpose combination — GEMM_LOCAL is #defined to GEMM_NN, GEMM_NT, GEMM_TN, GEMM_TT, … (lines 58–76), producing four separate functions dgemm_thread_{nn,nt,tn,tt}. A function-local static gets one instance per compiled function, so there are four independent level3_lock mutexes.

Consider the following C reproducer:

#define _GNU_SOURCE
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <pthread.h>
#include <dlfcn.h>

typedef long bint;
typedef void (*dgemm_t)(int,int,int,bint,bint,bint,double,const double*,bint,
                        const double*,bint,double,double*,bint);
typedef void (*setnt_t)(int);
static dgemm_t dgemm;
enum { CblasRowMajor=101, CblasNoTrans=111, CblasTrans=112 };

#define M 262144          /* large enough that each GEMM goes multi-threaded */
static double *Att;       /* 4 x M, feeds the TransA=Trans GEMM (lda=M)      */
static double *Ann;       /* M x 4, feeds the TransA=NoTrans GEMM (lda=4)    */
static double Bmat[8], *ref;
static int NTHREADS = 8, NITER = 200, SAME_VARIANT = 0;
static long mismatches = 0;
static pthread_mutex_t mlock = PTHREAD_MUTEX_INITIALIZER;

/* GEMM #1: TransA=Trans, K=2  ->  dgemm_thread_{tn/nt} variant */
/* GEMM #2: either Trans (same variant) or NoTrans (different variant), K=4 */
static void compute(double *o1, double *o2) {
  dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, M, 2, 2,
        1.0, Att, M, Bmat, 2, 0.0, o1, 2);
  if (SAME_VARIANT)
    dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, M, 2, 4,
          1.0, Att, M, Bmat, 2, 0.0, o2, 2);   /* same Trans variant */
  else
    dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, M, 2, 4,
          1.0, Ann, 4, Bmat, 2, 0.0, o2, 2);   /* different (NoTrans) variant */
}

static void *worker(void *unused) {
  (void)unused;
  double *o1 = malloc((size_t)M*2*sizeof(double));
  double *o2 = malloc((size_t)M*2*sizeof(double));
  long local = 0;
  for (int i = 0; i < NITER; i++) {
    compute(o1, o2);
    if (memcmp(o1, ref, (size_t)M*2*sizeof(double)) != 0) local++;
  }
  free(o1); free(o2);
  pthread_mutex_lock(&mlock); mismatches += local; pthread_mutex_unlock(&mlock);
  return NULL;
}

int main(int argc, char **argv) {
  if (argc > 1) NTHREADS = atoi(argv[1]);
  if (argc > 2) NITER    = atoi(argv[2]);
  if (argc > 3) SAME_VARIANT = atoi(argv[3]);

  void *h = dlopen(getenv("OB_LIB"), RTLD_NOW | RTLD_GLOBAL);
  if (!h) { fprintf(stderr, "set OB_LIB to the OpenBLAS .so (%s)\n", dlerror()); return 2; }
  /* try plain and scipy-suffixed symbol names */
  dgemm = (dgemm_t)dlsym(h, "cblas_dgemm");
  if (!dgemm) dgemm = (dgemm_t)dlsym(h, "scipy_cblas_dgemm64_");
  setnt_t setnt = (setnt_t)dlsym(h, "openblas_set_num_threads");
  if (!setnt) setnt = (setnt_t)dlsym(h, "scipy_openblas_set_num_threads64_");
  if (!dgemm) { fprintf(stderr, "no cblas_dgemm symbol\n"); return 2; }
  if (setnt) setnt(4);

  Att = malloc((size_t)4*M*sizeof(double));
  Ann = malloc((size_t)4*M*sizeof(double));
  for (size_t i = 0; i < (size_t)4*M; i++) { Att[i] = (double)(i % 512); Ann[i] = (double)((i*7) % 512); }
  for (int i = 0; i < 8; i++) Bmat[i] = 0.1*(i+1);

  ref = malloc((size_t)M*2*sizeof(double));
  { double *tmp = malloc((size_t)M*2*sizeof(double)); compute(ref, tmp); free(tmp); }

  pthread_t th[256];
  for (int i = 0; i < NTHREADS; i++) pthread_create(&th[i], NULL, worker, NULL);
  for (int i = 0; i < NTHREADS; i++) pthread_join(th[i], NULL);

  printf("%-26s %d threads x %d iters: %ld mismatches (%s)\n",
         SAME_VARIANT ? "same transpose variant:" : "different variants:",
         NTHREADS, NITER, mismatches, mismatches ? "RACE" : "clean");
  return mismatches ? 1 : 0;
}

You can compile it and run it like so:

cc -O2 openblas_gemm_variant_race.c -lpthread -ldl -o race 
goldbaum at Nathans-MBP in ~/Downloads
○  OB_LIB=/opt/homebrew/opt/openblas/lib/libopenblas.dylib ./race 8 200 0
different variants:        8 threads x 200 iters: 66 mismatches (RACE)

goldbaum at Nathans-MBP in ~/Downloads
○  OB_LIB=/opt/homebrew/opt/openblas/lib/libopenblas.dylib ./race 8 200 1
same transpose variant:    8 threads x 200 iters: 0 mismatches (clean)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions