Coding: Four Greatest

I went ahead and let myself get a little distracted today after reading a tweet from Daniel Collin over at DICE. He posted a bit of code that uses SSE intrinsics to find the four highest valued floats in an array of floats (along with their indices). The original code looked like this:

void find_four(const float * a, size_t sz, float * fres, int * ires)
{
	__declspec(align(16)) float sinit = -FLT_MAX;
	__declspec(align(16)) int iinit[4] = {-1, -1, -1, -1};

	// Initialize all the scores to -FLT_MAX
	__m128 s = _mm_load_ps1(&sinit);

	// We just do shuffles and blends of the indices, so we store the ints as floats.
	__m128 index = _mm_load_ps((float*)iinit);

	int i = 0;
	for(const float* pa = a, *paend = a + sz; pa != paend; ++pa, ++i)
	{
		// Load the index into all 4 elements of im
		__m128 im = _mm_load_ps1((float*)&i);

		// Load a value from the array into all 4 elements in v
		__m128 v = _mm_load_ps1(pa);

		// Compare with the currently best scores
		__m128 cmp = _mm_cmpge_ps(v, s);

		// Convert to a mask which is one of 0000, 1000, 1100, 1110 or 1111
		// Switch on the mask and shuffle/blend as appropriate.
		// The same operation is done on both s and index to keep them in sync.
		switch(_mm_movemask_ps(cmp))
		{
		case 0x0:
			// dcba -> dcba
		break;
		case 0x8:
			// dcba -> Vcba
			s = _mm_blend_ps(s, v, 8);
			index = _mm_blend_ps(index, im, 8);
		break;
		case 0xc:
			// dcba -> cVba
			s = _mm_shuffle_ps(s, s, _MM_SHUFFLE(2, 2, 1, 0));
			s = _mm_blend_ps(s, v, 4);
			index = _mm_shuffle_ps(index, index, _MM_SHUFFLE(2, 2, 1, 0));
			index = _mm_blend_ps(index, im, 4);
		break;
		case 0xe:
			// dcba -> cbVa
			s = _mm_shuffle_ps(s, s, _MM_SHUFFLE(2, 1, 1, 0));
			s = _mm_blend_ps(s, v, 2);
			index = _mm_shuffle_ps(index, index, _MM_SHUFFLE(2, 1, 1, 0));
			index = _mm_blend_ps(index, im, 2);
		break;
		case 0xf:
			// dcba -> cbaV
			s = _mm_shuffle_ps(s, s, _MM_SHUFFLE(2, 1, 0, 0));
			s = _mm_blend_ps(s, v, 1);
			index = _mm_shuffle_ps(index, index, _MM_SHUFFLE(2, 1, 0, 0));
			index = _mm_blend_ps(index, im, 1);
		break;
		default:
			assert(0);
		break;
		}
	}

	_mm_store_ps(fres, s);
	_mm_store_ps((float*)ires, index);
}

You can write up a more straightforward plain scalar version of this code:

void find_four_scalar(const float * a, size_t sz, float * fres, int * ires)
{
	fres[0] = fres[1] = fres[2] = fres[3] = -FLT_MAX;
	ires[0] = ires[1] = ires[2] = ires[3] = -1;
	int i = 0;
	for(const float* pa = a, *paend = a + sz; pa != paend; pa++, i++)
	{
		float v = *pa;
		if (v >= fres[0])
		{
			fres[3] = fres[2];
			fres[2] = fres[1];
			fres[1] = fres[0];
			fres[0] = v;
			ires[3] = ires[2];
			ires[2] = ires[1];
			ires[1] = ires[0];
			ires[0] = i;
		}
		else if (v >= fres[1])
		{
			fres[3] = fres[2];
			fres[2] = fres[1];
			fres[1] = v;
			ires[3] = ires[2];
			ires[2] = ires[1];
			ires[1] = i;
		}
		else if (v >= fres[2])
		{
			fres[3] = fres[2];
			fres[2] = v;
			ires[3] = ires[2];
			ires[2] = i;
		}
		else if (v >= fres[3])
		{
			fres[3] = v;
			ires[3] = i;
		}
	}
}

Given an array of 2^25 random floats, and testing on my i5-3317U I get the following times:

find_four: 54 ms
find_four_scalar: 79 ms
We can provoke the worst case behaviour for the SSE implementation by making the array of values be monotonically increasing - giving us:
find_four: 148 ms
find_four_scalar: 68 ms
And best case behaviour by making sure the four highest values are at the very start:
find_four: 54 ms
find_four_scalar: 79 ms

So the question is - can we do better? As it turns out we can make a simple adjustment that improves performance quite a bit for the random and best case and has only a small impact on the worse case. The SSE version still works on one float at a time, but we can adjust it to potentially reject groups of floats at a time (in this case I will try 8). Specifically, I load up 8 floats and using some shuffles and _mm_max_ps I get the maximum value of those 8 floats; if the maximum is less than our current 4 best then we can just skip to the next 8. Simple. The code:

inline void cmp_one_to_four(int i, __m128 v, __m128 & s, __m128 & index)
{
	// Load the index into all 4 elements of im
	__m128 im = _mm_load_ps1((float*)&i);
 
	// Compare with the currently best scores
	__m128 cmp = _mm_cmpge_ps(v, s);
 
	// Convert to a mask which is one of 0000, 1000, 1100, 1110 or 1111
	// Switch on the mask and shuffle/blend as appropriate.
	// The same operation is done on both s and index to keep them in sync.
	switch(_mm_movemask_ps(cmp))
	{
	case 0x0:
		// dcba -> dcba
	break;
	case 0x8:
		// dcba -> Vcba
		s = _mm_blend_ps(s, v, 8);
		index = _mm_blend_ps(index, im, 8);
	break;
	case 0xc:
		// dcba -> cVba
		s = _mm_shuffle_ps(s, s, _MM_SHUFFLE(2, 2, 1, 0));
		s = _mm_blend_ps(s, v, 4);
		index = _mm_shuffle_ps(index, index, _MM_SHUFFLE(2, 2, 1, 0));
		index = _mm_blend_ps(index, im, 4);
	break;
	case 0xe:
		// dcba -> cbVa
		s = _mm_shuffle_ps(s, s, _MM_SHUFFLE(2, 1, 1, 0));
		s = _mm_blend_ps(s, v, 2);
		index = _mm_shuffle_ps(index, index, _MM_SHUFFLE(2, 1, 1, 0));
		index = _mm_blend_ps(index, im, 2);
	break;
	case 0xf:
		// dcba -> cbaV
		s = _mm_shuffle_ps(s, s, _MM_SHUFFLE(2, 1, 0, 0));
		s = _mm_blend_ps(s, v, 1);
		index = _mm_shuffle_ps(index, index, _MM_SHUFFLE(2, 1, 0, 0));
		index = _mm_blend_ps(index, im, 1);
	break;
	default:
		assert(0);
	break;
	}
}
		
void find_four_mod(const float * a, size_t sz, float * fres, int * ires)
{
	__declspec(align(16)) float sinit = -FLT_MAX;
	__declspec(align(16)) int iinit[4] = {-1, -1, -1, -1};
 
	__m128 s = _mm_load_ps1(&sinit);
	__m128 index = _mm_load_ps((float*)iinit);
 
	int i = 0;
	for(const float* pa = a, *paend = a + sz; pa != paend; pa += 8, i += 8)
	{
		__m128 m = _mm_max_ps(_mm_load_ps(pa), _mm_load_ps(pa + 4));
		m = _mm_max_ps(_mm_max_ps(_mm_shuffle_ps(m,m,_MM_SHUFFLE(0,0,0,0)), _mm_shuffle_ps(m,m,_MM_SHUFFLE(1,1,1,1))),
				_mm_max_ps(_mm_shuffle_ps(m,m,_MM_SHUFFLE(2,2,2,2)), _mm_shuffle_ps(m,m,_MM_SHUFFLE(3,3,3,3))));
		
		if (_mm_movemask_ps(_mm_cmpge_ps(m, s)) == 0)
			continue; 

		__m128 a = _mm_load1_ps(pa);
		__m128 b = _mm_load1_ps(pa + 1);
		__m128 c = _mm_load1_ps(pa + 2);
		__m128 d = _mm_load1_ps(pa + 3);
		m = _mm_max_ps(_mm_max_ps(a,b), _mm_max_ps(c,d));
		if (_mm_movemask_ps(_mm_cmpge_ps(m, s)) != 0)
		{
			cmp_one_to_four(i, a, s, index);
			cmp_one_to_four(i + 1, b, s, index);
			cmp_one_to_four(i + 2, c, s, index);
			cmp_one_to_four(i + 3, d, s, index);	
		}

		a = _mm_load1_ps(pa + 4);
		b = _mm_load1_ps(pa + 5);
		c = _mm_load1_ps(pa + 6);
		d = _mm_load1_ps(pa + 7);
		m = _mm_max_ps(_mm_max_ps(a,b), _mm_max_ps(c,d));
		if (_mm_movemask_ps(_mm_cmpge_ps(m, s)) != 0)
		{
			cmp_one_to_four(i + 4, a, s, index);
			cmp_one_to_four(i + 5, b, s, index);
			cmp_one_to_four(i + 6, c, s, index);
			cmp_one_to_four(i + 7, d, s, index);	
		}
	}

	_mm_store_ps(fres, s);
	_mm_store_ps((float*)ires, index);
}

How does this clock in? Running the same test with this code yields:

Random-Case: 12 ms
Worst-Case: 159 ms
Best-Case: 12 ms
So over a 4x improvement for the best/random cases, and only slightly slower in the worst case scenario.

Anyway, that was a fun little distraction from the networking code I was otherwise working on...