Multiplying Matrices, Fast and Slow

I recently read a very interesting blog post
about exposing Intel SIMD intrinsics via a fork of the Scala compiler ( scala-virtualized
), which reports multiplicative improvements in throughput over HotSpot JIT compiled code. The academic paper
, which has been accepted at CGO 2018
, proposes a powerful alternative to the traditional JVM approach of pairing dumb programmers with a (hopefully) smart JIT compiler. Lightweight Modular Staging
( LMS
) allows the generation of an executable binary from a high level representation: handcrafted representations of vectorised algorithms, written in a dialect of Scala, can be compiled natively and later invoked with a single JNI call. This approach bypasses C2 without incurring excessive JNI costs. The freely available benchmarks
supplied can be easily run to reproduce the results in the paper, which is an achievement in itself, but some of the Java implementations used as baselines look less efficient than they could be. This post is about improving the efficiency of the Java matrix multiplication the LMS generated code is benchmarked against. Despite finding edge cases where autovectorisation fails, I find it is possible to get performance comparable to LMS with plain Java (and a JDK upgrade).

Two implementations
of Java matrix multiplication are provided in the NGen benchmarks: JMMM.baseline
– a naive but cache unfriendly matrix multiplication – and JMMM.blocked
which is supplied as an improvement. JMMM.blocked
is something of a local maximum because it does manual loop unrolling: this actually removes the trigger for autovectorisation analysis. I provide a simple and cache-efficient Java implementation (with the same asymptotic complexity, the improvement is just technical) and benchmark these implementations using JDK8 and the soon to be released JDK10 separately.

public void fast(float[] a, float[] b, float[] c, int n) {
   int in = 0;
   for (int i = 0; i < n; ++i) {
       int kn = 0;
       for (int k = 0; k < n; ++k) {
           float aik = a[in + k];
           for (int j = 0; j < n; ++j) {
               c[in + j] += aik * b[kn + j];
           }
           kn += n;
       }
       in += n;
    }
}

With JDK 1.8.0_131, the “fast” implementation is only 2x faster than the blocked algorithm; this is nowhere near fast enough to match LMS. In fact, LMS does a lot better than 5x blocked (6x-8x) on my Skylake laptop at 2.6GHz, and performs between 2x and 4x better than the improved implementation. Flops / Cycle
is calculated as size ^ 3 * 2 / CPU frequency Hz
.

====================================================
Benchmarking MMM.jMMM.fast (JVM implementation)
----------------------------------------------------
    Size (N) | Flops / Cycle
----------------------------------------------------
           8 | 0.4994459272
          32 | 1.0666533335
          64 | 0.9429120397
         128 | 0.9692385519
         192 | 0.9796619688
         256 | 1.0141446247
         320 | 0.9894415771
         384 | 1.0046245750
         448 | 1.0221353392
         512 | 0.9943527764
         576 | 0.9952093603
         640 | 0.9854689714
         704 | 0.9947153752
         768 | 1.0197765248
         832 | 1.0479691069
         896 | 1.0060121097
         960 | 0.9937347412
        1024 | 0.9056494897
====================================================

====================================================
Benchmarking MMM.nMMM.blocked (LMS generated)
----------------------------------------------------
    Size (N) | Flops / Cycle
----------------------------------------------------
           8 | 0.2500390686
          32 | 3.9999921875
          64 | 4.1626523901
         128 | 4.4618695374
         192 | 3.9598982956
         256 | 4.3737341517
         320 | 4.2412225389
         384 | 3.9640163416
         448 | 4.0957167537
         512 | 3.3801071278
         576 | 4.1869326167
         640 | 3.8225244883
         704 | 3.8648224140
         768 | 3.5240611589
         832 | 3.7941562681
         896 | 3.1735179981
         960 | 2.5856903789
        1024 | 1.7817152313
====================================================

====================================================
Benchmarking MMM.jMMM.blocked (JVM implementation)
----------------------------------------------------
    Size (N) | Flops / Cycle
----------------------------------------------------
           8 | 0.3333854248
          32 | 0.6336670915
          64 | 0.5733484649
         128 | 0.5987433798
         192 | 0.5819900921
         256 | 0.5473562109
         320 | 0.5623263520
         384 | 0.5583823292
         448 | 0.5657882256
         512 | 0.5430879470
         576 | 0.5269635678
         640 | 0.5595204791
         704 | 0.5297557807
         768 | 0.5493631388
         832 | 0.5471832673
         896 | 0.4769554752
         960 | 0.4985080443
        1024 | 0.4014589400
====================================================

JDK10 is about to be released so it’s worth looking at the effect of recent improvements to C2, including better use of AVX2 and support for vectorised FMA. Since LMS depends on scala-virtualized, which currently only supports Scala 2.11, the LMS implementation cannot be run with a more recent JDK so its performance running in JDK10 could only be extrapolated. Since its raison d’être
is to bypass
C2, it could be reasonably assumed it is insulated from JVM performance improvements (or regressions). Measurements of floating point operations per cycle provide a sensible comparison, in any case.

Moving away from ScalaMeter, I created a JMH benchmark to see how matrix multiplication behaves in JDK10.

@OutputTimeUnit(TimeUnit.SECONDS)
@State(Scope.Benchmark)
public class MMM {

  @Param({"8", "32", "64", "128", "192", "256", "320", "384", "448", "512" , "576", "640", "704", "768", "832", "896", "960", "1024"})
  int size;

  private float[] a;
  private float[] b;
  private float[] c;

  @Setup(Level.Trial)
  public void init() {
    a = DataUtil.createFloatArray(size * size);
    b = DataUtil.createFloatArray(size * size);
    c = new float[size * size];
  }

  @Benchmark
  public void fast(Blackhole bh) {
    fast(a, b, c, size);
    bh.consume(c);
  }

  @Benchmark
  public void baseline(Blackhole bh) {
    baseline(a, b, c, size);
    bh.consume(c);
  }

  @Benchmark
  public void blocked(Blackhole bh) {
    blocked(a, b, c, size);
    bh.consume(c);
  }

  //
  // Baseline implementation of a Matrix-Matrix-Multiplication
  //
  public void baseline (float[] a, float[] b, float[] c, int n){
    for (int i = 0; i < n; i += 1) {
      for (int j = 0; j < n; j += 1) {
        float sum = 0.0f;
        for (int k = 0; k < n; k += 1) {
          sum += a[i * n + k] * b[k * n + j];
        }
        c[i * n + j] = sum;
      }
    }
  }

  //
  // Blocked version of MMM, reference implementation available at:
  // http://csapp.cs.cmu.edu/2e/waside/waside-blocking.pdf
  //
  public void blocked(float[] a, float[] b, float[] c, int n) {
    int BLOCK_SIZE = 8;
    for (int kk = 0; kk < n; kk += BLOCK_SIZE) {
      for (int jj = 0; jj < n; jj += BLOCK_SIZE) {
        for (int i = 0; i < n; i++) {
          for (int j = jj; j < jj + BLOCK_SIZE; ++j) {
            float sum = c[i * n + j];
            for (int k = kk; k < kk + BLOCK_SIZE; ++k) {
              sum += a[i * n + k] * b[k * n + j];
            }
            c[i * n + j] = sum;
          }
        }
      }
    }
  }

  public void fast(float[] a, float[] b, float[] c, int n) {
    int in = 0;
    for (int i = 0; i < n; ++i) {
      int kn = 0;
      for (int k = 0; k < n; ++k) {
        float aik = a[in + k];
        for (int j = 0; j < n; ++j) {
          c[in + j] = Math.fma(aik,  b[kn + j], c[in + j]);
        }
        kn += n;
      }
      in += n;
    }
  }
}
BenchmarkModeThreadsSamplesScoreScore Error (99.9%)UnitParam: sizeRatio to blockedFlops/Cycle
baselinethrpt1101228544.8238793.17392ops/s81.0615983360.483857652
baselinethrpt11022973.034021012.043446ops/s321.3022669470.57906183
baselinethrpt1102943.088879221.57475ops/s641.3014147330.593471609
baselinethrpt110358.0101359.342801ops/s1281.2928896180.577539747
baselinethrpt110105.7583664.275503ops/s1921.2464151430.575804515
baselinethrpt11041.4655571.112753ops/s2561.4300039460.535135851
baselinethrpt11020.4790810.462547ops/s3201.1542678940.516198866
baselinethrpt11011.6866850.263476ops/s3841.1865353490.509027985
baselinethrpt1107.3441840.269656ops/s4481.1664211270.507965526
baselinethrpt1103.5451530.108086ops/s5120.817966570.366017216
baselinethrpt1103.7893840.130934ops/s5761.3271682940.557048123
baselinethrpt1101.9819570.040136ops/s6401.0209652710.399660104
baselinethrpt1101.766720.036386ops/s7041.1682724420.474179037
baselinethrpt1101.010260.049853ops/s7680.8455141120.352024966
baselinethrpt1101.1158140.03803ops/s8321.1487521710.494331667
baselinethrpt1100.7035610.110626ops/s8960.9384354360.389298235
baselinethrpt1100.6298960.052448ops/s9601.0817416510.428685898
baselinethrpt1100.4077720.019079ops/s10241.0253565610.336801424
blockedthrpt1101157259.55849097.48711ops/s810.455782226
blockedthrpt11017640.80251226.401298ops/s3210.444656782
blockedthrpt1102261.45348198.937035ops/s6410.456020355
blockedthrpt110276.90696122.851857ops/s12810.446704605
blockedthrpt11084.8500334.441454ops/s19210.461968485
blockedthrpt11028.9968137.585551ops/s25610.374219842
blockedthrpt11017.7420520.627629ops/s32010.447208892
blockedthrpt1109.849420.367603ops/s38410.429003641
blockedthrpt1106.296340.402846ops/s44810.435490676
blockedthrpt1104.3341050.384849ops/s51210.447472097
blockedthrpt1102.855240.199102ops/s57610.419726816
blockedthrpt1101.9412580.10915ops/s64010.391453182
blockedthrpt1101.512250.076621ops/s70410.40588053
blockedthrpt1101.1948470.063147ops/s76810.416344283
blockedthrpt1100.9713270.040421ops/s83210.430320551
blockedthrpt1100.7497170.042997ops/s89610.414837526
blockedthrpt1100.5822980.016725ops/s96010.39629231
blockedthrpt1100.3976880.043639ops/s102410.328472491
fastthrpt1101869676.34576416.50848ops/s81.6156067430.736364837
fastthrpt11048485.472161301.926828ops/s322.7484844961.222132271
fastthrpt1106431.341657153.905413ops/s642.8438973921.296875098
fastthrpt110840.60182145.998723ops/s1283.0356832421.356053685
fastthrpt110260.38699613.022418ops/s1923.0687907451.417684611
fastthrpt110107.8957086.584674ops/s2563.7209505751.392453537
fastthrpt11056.2453362.729061ops/s3203.1701708461.417728592
fastthrpt11032.9179962.196624ops/s3843.3421253231.433783932
fastthrpt11020.9601892.077684ops/s4483.3289480871.449725854
fastthrpt11014.0051860.7839ops/s5123.2313905641.445957112
fastthrpt1108.8275840.883654ops/s5763.0917134811.297675056
fastthrpt1107.4556070.442882ops/s6403.8406059371.503417416
fastthrpt1105.3228940.464362ops/s7043.5198505541.428638807
fastthrpt1104.3085220.153846ops/s7683.6059194191.501303934
fastthrpt1103.3752740.106715ops/s8323.4749100971.495325228
fastthrpt1102.3201520.367881ops/s8963.0947037351.28379924
fastthrpt1102.0574780.150198ops/s9603.5333763811.400249889
fastthrpt1101.662550.181116ops/s10244.1805385131.3731919

Interestingly, the blocked algorithm is now the worst native JVM implementation. The code generated by C2 got a lot faster, but peaks at 1.5 flops/cycle, which still doesn’t compete with LMS. Why? Taking a look at the assembly, it’s clear that the autovectoriser choked on the array offsets and produced scalar SSE2 code, just like the implementations in the paper. I wasn’t expecting this.

vmovss  xmm5,dword ptr [rdi+rcx*4+10h]
vfmadd231ss xmm5,xmm6,xmm2
vmovss  dword ptr [rdi+rcx*4+10h],xmm5

Is this the end of the story? No, with some hacks and the cost of array allocation and a copy or two, autovectorisation can be tricked into working again to generate faster code:

    public void fast(float[] a, float[] b, float[] c, int n) {
        float[] bBuffer = new float[n];
        float[] cBuffer = new float[n];
        int in = 0;
        for (int i = 0; i < n; ++i) {
            int kn = 0;
            for (int k = 0; k < n; ++k) {
                float aik = a[in + k];
                System.arraycopy(b, kn, bBuffer, 0, n);
                saxpy(n, aik, bBuffer, cBuffer);
                kn += n;
            }
            System.arraycopy(cBuffer, 0, c, in, n); 
            Arrays.fill(cBuffer, 0f);
            in += n;
        }
    }

    private void saxpy(int n, float aik, float[] b, float[] c) {
        for (int i = 0; i < n; ++i) {
            c[i] += aik * b[i];
        }
    }

Adding this hack into the NGen benchmark
(back in JDK 1.8.0_131) I get closer to the LMS generated code, and beat it beyond L3 cache residency (6MB). LMS is still faster when both matrices fit in L3 concurrently, but by percentage points rather than a multiple. The cost of the hacky array buffers gives the game up for small matrices.

====================================================
Benchmarking MMM.jMMM.fast (JVM implementation)
----------------------------------------------------
    Size (N) | Flops / Cycle
----------------------------------------------------
           8 | 0.2500390686
          32 | 0.7710872405
          64 | 1.1302489072
         128 | 2.5113453810
         192 | 2.9525859816
         256 | 3.1180920385
         320 | 3.1081563593
         384 | 3.1458423577
         448 | 3.0493148252
         512 | 3.0551158263
         576 | 3.1430376938
         640 | 3.2169923048
         704 | 3.1026513283
         768 | 2.4190053777
         832 | 3.3358586705
         896 | 3.0755689237
         960 | 2.9996690697
        1024 | 2.2935654309
====================================================

====================================================
Benchmarking MMM.nMMM.blocked (LMS generated)
----------------------------------------------------
    Size (N) | Flops / Cycle
----------------------------------------------------
           8 | 1.0001562744
          32 | 5.3330416826
          64 | 5.8180867784
         128 | 5.1717318641
         192 | 5.1639907462
         256 | 4.3418618628
         320 | 5.2536572701
         384 | 4.0801359215
         448 | 4.1337007093
         512 | 3.2678160754
         576 | 3.7973028890
         640 | 3.3557513664
         704 | 4.0103133240
         768 | 3.4188362575
         832 | 3.2189488327
         896 | 3.2316685219
         960 | 2.9985655539
        1024 | 1.7750946796
====================================================

With the benchmark below I calculate flops/cycle with improved JDK10 autovectorisation.

  @Benchmark
  public void fastBuffered(Blackhole bh) {
    fastBuffered(a, b, c, size);
    bh.consume(c);
  }

  public void fastBuffered(float[] a, float[] b, float[] c, int n) {
    float[] bBuffer = new float[n];
    float[] cBuffer = new float[n];
    int in = 0;
    for (int i = 0; i < n; ++i) {
      int kn = 0;
      for (int k = 0; k < n; ++k) {
        float aik = a[in + k];
        System.arraycopy(b, kn, bBuffer, 0, n);
        saxpy(n, aik, bBuffer, cBuffer);
        kn += n;
      }
      System.arraycopy(cBuffer, 0, c, in, n);
      Arrays.fill(cBuffer, 0f);
      in += n;
    }
  }

  private void saxpy(int n, float aik, float[] b, float[] c) {
    for (int i = 0; i < n; ++i) {
      c[i] = Math.fma(aik, b[i], c[i]);
    }
  }

Just as in the modified NGen benchmark, this starts paying off once the matrices have 64 rows and columns. Finally, and it took an upgrade and a hack, I breached 4 Flops per cycle:

BenchmarkModeThreadsSamplesScoreScore Error (99.9%)UnitParam: sizeFlops / Cycle
fastBufferedthrpt1101047184.03463532.95095ops/s80.412429404
fastBufferedthrpt11058373.563673239.615866ops/s321.471373026
fastBufferedthrpt11012099.41654497.33988ops/s642.439838038
fastBufferedthrpt1102136.50264105.038006ops/s1283.446592911
fastBufferedthrpt110673.470622102.577237ops/s1923.666730488
fastBufferedthrpt110305.54151925.959163ops/s2563.943181586
fastBufferedthrpt110158.4373726.708384ops/s3203.993596774
fastBufferedthrpt11088.2837187.58883ops/s3843.845306266
fastBufferedthrpt11058.5745074.248521ops/s4484.051345968
fastBufferedthrpt11037.1836354.360319ops/s5123.839002314
fastBufferedthrpt11029.9498840.63346ops/s5764.40270151
fastBufferedthrpt11020.7158334.175897ops/s6404.177331789
fastBufferedthrpt11010.8248370.902983ops/s7042.905333492
fastBufferedthrpt1108.2852541.438701ops/s7682.886995686
fastBufferedthrpt1106.170290.746537ops/s8322.733582608
fastBufferedthrpt1104.8288721.316901ops/s8962.671937962
fastBufferedthrpt1103.63431.293923ops/s9602.473381573
fastBufferedthrpt1102.4582960.171224ops/s10242.030442485

The code generated for the core of the loop looks better now:

vmovdqu ymm1,ymmword ptr [r13+r11*4+10h]
vfmadd231ps ymm1,ymm3,ymmword ptr [r14+r11*4+10h]
vmovdqu ymmword ptr [r13+r11*4+10h],ymm1                                              

Given this improvement, it would be exciting to see how LMS can profit from JDK9 or JDK10. L3 cache, which the LMS generated code seems to depend on for throughput, is typically shared between cores: a single thread rarely enjoys exclusive access. I would like to see benchmarks for the LMS generated code in the presence of concurrency.

责编内容来自:Richard Startin (源链) | 更多关于

阅读提示:酷辣虫无法对本内容的真实性提供任何保证,请自行验证并承担相关的风险与后果!
本站遵循[CC BY-NC-SA 4.0]。如您有版权、意见投诉等问题,请通过eMail联系我们处理。
酷辣虫 » 综合编程 » Multiplying Matrices, Fast and Slow

喜欢 (0)or分享给?

专业 x 专注 x 聚合 x 分享 CC BY-NC-SA 4.0

使用声明 | 英豪名录