hckrnws
There is something off with the explanation.
At first, there is 16 fetches per row x column, 1024 in total. Then, it is observed that an input row needs to be fetched only once per output row, reducing the amount to 8 fetches per row, plus 8 per row x column, 8 * 8 + 8 * 64 = 576 in total. This requires the same amount of 16 numbers to be kept in registers.
But then it is claimed that by doing one quadrant at a time, all that is needed is 64 fetches per quadrant or 256 fetches in total. But that assumes we can keep 4 rows and 4 columns, 8 numbers per row or column = 64 numbers in registers! If we can only keep 16 numbers like above, each row of the quadrant is going to take 40 fetches, and we get 160 fetches per quadrant or 640 fetches in total, a pessimization from 576 fetches!
That’s a valid point - I’m assuming infinite register capacity at that point in the post.
The next section discusses what you’re talking about eg, how to deal with finite register/shared capacity by splitting the k dimension. I’ll mention the shared/register memory limitation sooner to clarify confusion.
The overall problem with your blog post is that it is beating around the bush rather than getting to the point. Overall, it feels like the blog post is explaining tiling in reverse order of what is needed to understand it.
"How effective is tiling?" and "Why tiling tiling is so fast" should be at the end, while the key section "Why there's a limit to tiling" which should be front and center is in the middle, followed by a subversion of the entire concept in "How to sidestep tiling limits"
It's also incredibly jarring to read this:
"Wondering how we were able to reduce memory usage "for free"? Indeed, the reduction wasn't free. In fact, we paid for this reduction a different way — by incurring more writes."
This is again, completely backwards. Let's assume you don't have a cache at all, you'll have to write out everything to DRAM every single time. The opposite is also true. Imagine you had an infinite number of registers. Every addition operation will accumulate into a register, which is a write operation. Hence, the number of write operations doesn't change.
Really the main points should be in this order: 1. matrix multiplication works best with square or almost square matrices. 2. registers and SRAM (including caches) is limited, forcing you to process matrices of finite size (aka tiles) 3. memory hierarchy means that the biggest matrix you can store at a given hierarchy gets bigger. 4. you can split matrix multiplication using inner and outer products 5. outer products take few inputs and have many outputs/accumulators, inner products take many inputs and have few outputs/accumulators. 6. You want to calculate the biggest outer product you can get away with, since this significantly reduces the memory needed to store inputs and maximizes number of cycles doing calculations, once you hit the limit, you want to reuse the accumulator, so you calculate inner products of outer products.
I see, thanks for the feedback - the current blog post’s flow certainly isn’t optimal. I’ll try reordering to eliminate jarring bits and see how it flows.
When thinking about block matrix multiplication, it's always a fun time to revisit Strassen's algorithm, which is less than O(n^3).
Normal block multiplication works like:
[ A11 A12 ] [ B11 B12 ] = [ A11*B11 + A12*B21 A11*B12 + A12*B22 ] = [ C11 C12 ]
[ A21 A22 ] [ B21 B22 ] [ A21*B11 + A22*B21 A21*B12 + A22*B22 ] = [ C21 C22 ]
Which takes 8 matrix multiplications on the sub blocks. But by cleverly defining only 7 different matrix multiplications on top of block additions and subtractions, like: M3 = A11 * (B12 - B22)
You can make the C blocks out of just additions and subtractions of the 7 different matrix multiplications.https://en.wikipedia.org/wiki/Strassen_algorithm
As far as I know this is not useful in the major GPU libraries for saving bandwidth, but I have never bothered to spend the time to figure out why. It must have something to do with the ratio of bandwidth to FLOPs, which is way past my knowledge of GPUs.
The tricky parts with Strassen are that it requires some fairly large changes to your looping strategy, and that it decreases accuracy, It also only helps once you are compute rather than bandwidth bound, and GPUs have lots of compute.
> only helps once you are compute rather than bandwidth bound
Asymptotically, I don't think Strassen performs Theta(n^3) memory operations in sub-n^3 time.
See also http://ulaff.net/
Crafted by Rajat
Source Code