Aggregate Median/C code

From PostgreSQL wiki
Jump to navigationJump to search

by Chris Mayfield

I needed a quick and easy median aggregate for millions of groups of around 10 values each. Here's a C version that is application specific, however one could change the data type (currently float4), buffer length (14), etc. This code is a contrib module named statfunc, which maybe someday will implement other things (e.g., quantiles). There's also http://poststat.projects.postgresql.org/ and http://www.joeconway.com/plr/ for doing statistics in PostgreSQL.

After posting this code I discovered a related discussion on pgsql-hackers: http://markmail.org/message/wr6rmfd56rhpwnij

--
-- SQL interface to statfunc library
--

DROP AGGREGATE IF EXISTS med(real);
DROP FUNCTION IF EXISTS med_trans(state bytea, next real);
DROP FUNCTION IF EXISTS med_final(state bytea);

CREATE FUNCTION med_trans(state bytea, next real)
RETURNS bytea AS '$libdir/statfunc'
LANGUAGE C IMMUTABLE;

	COMMENT ON FUNCTION med_trans(state bytea, next real)
	IS 'accumulates non-null values into an array';

CREATE FUNCTION med_final(state bytea)
RETURNS real AS '$libdir/statfunc'
LANGUAGE C IMMUTABLE STRICT;

	COMMENT ON FUNCTION med_final(state bytea)
	IS 'sorts the array and returns the median';

CREATE AGGREGATE med(real) (
  SFUNC = med_trans,
  STYPE = bytea,
  FINALFUNC = med_final
);

	COMMENT ON AGGREGATE med(real)
	IS 'median of all input values, excluding nulls';
/*
 * statfunc.c - statistical functions for PostgreSQL
 */

#include "postgres.h"
#include "funcapi.h"

#define MAXLEN 14

/*
 * Internal state: simple array of floats.
 */
typedef struct {
	int32 vl_len_;        /* varlena header (do not touch directly!) */
	int32 nval;           /* number of values accumulated */
	float4 vals[MAXLEN];  /* array of non-null values */
} State;

/*
 * Ensures that the library is dynamically loaded into a compatible server.
 */
#ifdef PG_MODULE_MAGIC
PG_MODULE_MAGIC;
#endif

/*----------------------------------------------------------------------------*\
  C Implementation
\*----------------------------------------------------------------------------*/

/*
 * Accumulates non-null values into an array.
 */
static State *c_med_trans(State *state, float4 next) {
	if (state == NULL) {
		/* first value; allocate state */
		state = palloc0(sizeof(State));
		SET_VARSIZE(state, sizeof(State));
		state->nval = 1;
		state->vals[0] = next;
	} else {
		/* next value; check capacity */
		if (state->nval == MAXLEN) {
			ereport(ERROR, (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
							errmsg("need to increase MAXLEN in statfunc.c")));
		}
		state->vals[state->nval] = next;
		state->nval++;
	}
	return state;
}

/*
 * Comparison function for qsort.
 */
static int float4cmp(const void *p1, const void *p2) {
	float4 f1 = *(const float4 *) p1;
	float4 f2 = *(const float4 *) p2;
	if (f1 < f2)
		return -1;
	if (f1 > f2)
		return 1;
	return 0;
}

/*
 * Sorts the array and returns the median.
 */
static float4 c_med_final(State *state) {
	int32 mid;
	float4 ret;
	qsort(state->vals, state->nval, sizeof(float4), float4cmp);
	mid = state->nval / 2;
	if (state->nval % 2) {
		/* odd number of elements */
		ret = state->vals[mid];
	} else {
		/* even number of elements */
		ret = (state->vals[mid] + state->vals[mid - 1]) / 2;
	}
	return ret;
}

/*----------------------------------------------------------------------------*\
  Postgres V1 Function Prototypes
\*----------------------------------------------------------------------------*/

Datum med_trans(PG_FUNCTION_ARGS);
PG_FUNCTION_INFO_V1(med_trans);

Datum med_final(PG_FUNCTION_ARGS);
PG_FUNCTION_INFO_V1(med_final);

/*----------------------------------------------------------------------------*\
  SQL Interface (Wrappers)
\*----------------------------------------------------------------------------*/

Datum med_trans(PG_FUNCTION_ARGS) {
	State *state;
	float4 next;
	if (!fcinfo->context || !IsA(fcinfo->context, AggState)) {
		ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
						errmsg("med_trans() - must call from aggregate")));
	}
	if (PG_ARGISNULL(0)) {
		state = NULL; /* new group */
	} else {
		state = (State *) PG_GETARG_BYTEA_P(0);
	}
	if (PG_ARGISNULL(1)) {
		/* discard NULL input values */
	} else {
		next = PG_GETARG_FLOAT4(1);
		state = c_med_trans(state, next);
	}
	/* return the updated state */
	if (state == NULL) {
		PG_RETURN_NULL();
	} else {
		PG_RETURN_BYTEA_P(state);
	}
}

Datum med_final(PG_FUNCTION_ARGS) {
	State *state;
	float4 ret;
	if (!fcinfo->context || !IsA(fcinfo->context, AggState)) {
		ereport(ERROR, (errcode(ERRCODE_FEATURE_NOT_SUPPORTED),
						errmsg("med_final() - must call from aggregate")));
	}
	state = (State *) PG_GETARG_BYTEA_P(0);
	ret = c_med_final(state);
	PG_RETURN_FLOAT4(ret);
}
#
# Makefile for building PostgreSQL extension modules
#

MODULE_big = statfunc
OBJS = statfunc.o

DATA = 
DOCS = 

REGRESS = 

ifdef USE_PGXS
PG_CONFIG = pg_config
PGXS := $(shell $(PG_CONFIG) --pgxs)
include $(PGXS)
else
subdir = contrib/statfunc
top_builddir = ../..
include $(top_builddir)/src/Makefile.global
include $(top_srcdir)/contrib/contrib-global.mk
endif