diff --git a/src/unit_tests b/src/unit_tests
deleted file mode 100755
index ff035a93ea140e83b2f57439eb874e1619dfceeb..0000000000000000000000000000000000000000
Binary files a/src/unit_tests and /dev/null differ
diff --git a/tests/cunit/README.md b/tests/cunit/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..12abd42cf0b666a1544bf3e9d2c2656e905f1d38
--- /dev/null
+++ b/tests/cunit/README.md
@@ -0,0 +1,3 @@
+Custom made C Unit Testing Library by Kapenga.
+
+Not to be confused with libcunit1.
\ No newline at end of file
diff --git a/tests/cunit/cunit.c b/tests/cunit/cunit.c
new file mode 100644
index 0000000000000000000000000000000000000000..07b520c38ccf5083c058d862dc62771c4abcccb2
--- /dev/null
+++ b/tests/cunit/cunit.c
@@ -0,0 +1,24 @@
+// cunit.c
+// The global constants for cunit and cunit_init(),
+// which should be called befoer using the unit test 
+// macros in cunit.h
+#include <stdio.h>
+#include "cunit.h"
+
+FILE    *cunit_log; 	// when cunit test macors print resu
+double  cunit_dmacheps;
+
+int cunit_init() {
+double dm;
+double dmacheps = 0.5;
+
+cunit_log = stderr;
+// a crude way for getting close to macheps 
+// should use a standard lib cal here 
+while( (1.0 + (dm = dmacheps/2.0) ) != 1.0  ) {
+   dmacheps = dm;
+}
+cunit_dmacheps = dmacheps;
+return 0; 
+}
+
diff --git a/tests/cunit/cunit.h b/tests/cunit/cunit.h
new file mode 100644
index 0000000000000000000000000000000000000000..92b438aa0357db7fd988e273e6b7d244b9c0a7ea
--- /dev/null
+++ b/tests/cunit/cunit.h
@@ -0,0 +1,113 @@
+#include <stdio.h>
+#include <assert.h>
+#include <math.h> 
+
+extern FILE *cunit_log; 
+extern double cunit_dmacheps;
+
+int cunit_init();
+
+#define cunit_open(log) { \
+    if((cunit_log=fopen( log, "w") == NULL ) { \
+      cunit_log = stderr; \
+    } \
+}
+#define cunit_close fclose(cunit_log)
+#define cunit_flush() fflush(cunit_log)
+#define cunit_date(str) { \
+    fprintf(cunit_log, "%s  LINE %d:  DATE:%s TIME:%s :%s\n", \
+              __FILE__ , __LINE__ , __DATE__ , __TIME__ , str ); \
+}
+#define cunit_print(str) { \
+    fprintf(cunit_log, "%s\n", str ); \
+}
+
+#define assert_true(str, a) { \
+  if( !( a ) ) { \
+    fprintf(cunit_log, "%s  LINE %d: %s, %d != TRUE\n", \
+              __FILE__ , __LINE__ , str ,  a ); \
+  } \
+}
+
+#define assert_false(str, a) { \
+  if( a ) { \
+    fprintf(cunit_log, "%s  LINE %d: %s, %d != FALSE\n", \
+              __FILE__ , __LINE__ , str ,  a ); \
+  } \
+}
+
+#define assert_eq(str,a,b) { \
+  if( a != b ) { \
+    fprintf(cunit_log, "%s  LINE %d: %s, %d !== %d\n", \
+              __FILE__ , __LINE__ , str ,  a , b ); \
+  } \
+}
+#define assert_neq(str,a,b) { \
+  if( a == b ) { \
+    fprintf(cunit_log, "%s  LINE %d: %s, %d !!= %d\n", \
+              __FILE__ , __LINE__ , str ,  a , b ); \
+  } \
+}
+#define assert_ge(str,a,b) { \
+  if( a < b ) { \
+    fprintf(cunit_log, "%s  LINE %d: %s, %d !>= %d\n", \
+              __FILE__ , __LINE__ , str ,  a , b ); \
+  } \
+}
+#define assert_gt(str,a,b) { \
+  if( a <= b ) { \
+    fprintf(cunit_log, "%s  LINE %d: %s, %d !> %d\n", \
+              __FILE__ , __LINE__ , str ,  a , b ); \
+  } \
+}
+#define assert_le(str,a,b) { \
+  if( a > b ) { \
+    fprintf(cunit_log, "%s  LINE %d: %s, %d !<= %d\n", \
+              __FILE__ , __LINE__ , str ,  a , b ); \
+  } \
+}
+#define assert_lt(str,a,b) { \
+  if( a >= b ) { \
+    fprintf(cunit_log, "%s  LINE %d: %s, %d !< %d\n", \
+              __FILE__ , __LINE__ , str ,  a , b ); \
+  } \
+}
+#define assert_feq(str,a,b) { \
+  if( a != b ) { \
+    fprintf(cunit_log, "%s  LINE %d: %s, %24.16f !== %24.16f\n", \
+              __FILE__ , __LINE__ , str , a , b ); \
+  } \
+}
+// eq subject to absolute error
+#define assert_feqaerr(str,a,b,aerr) { \
+  if( fabs(a - b) > aerr ) { \
+    fprintf(cunit_log, "%s  LINE %d: %s, %24.16f !== %24.16f err=%10.6e\n", \
+              __FILE__ , __LINE__ , str , a , b , err); \
+  } \
+}
+// eq subject to relative error
+// Perhaps it should check if a == b == 0.0 
+#define assert_feqrerr(str,a,b,rerr) { \
+  if( fabs(a - b)/(fabs(a) + fabs(b)) > rerr ) { \
+    fprintf(cunit_log, "%s  LINE %d: %s, %24.16f !>= %24.16f rerr=%10.6e\n", \
+              __FILE__ , __LINE__ , str , a , b , rerr ); \
+  } \
+}
+#define assert_fgt(str,a,b) { \
+  if( a <= b ) { \
+    fprintf(cunit_log, "%s  LINE %d: %s, %24.16f !> %24.16f\n", \
+              __FILE__ , __LINE__ , str ,  a , b ); \
+  } \
+}
+#define assert_fle(str,a,b) { \
+  if( a > b ) { \
+    fprintf(cunit_log, "%s  LINE %d: %s, %24.16f !<= %24.16f\n", \
+              __FILE__ , __LINE__ , str , a , b ); \
+  } \
+}
+#define assert_flt(str,a,b) { \
+  if( a >= b ) { \
+    fprintf(cunit_log, "%s  LINE %d: %s, %24.16f !< %24.16f\n", \
+              __FILE__ , __LINE__ , str , a , b ); \
+  } \
+}
diff --git a/tests/unit_tests/Makefile b/tests/unit_tests/Makefile
index 93e370f4b77e7c7a79883c5452bc7bbe28cf52b9..2a78f1460d3466fc84ff2a3001bdfe9f1d8d553b 100644
--- a/tests/unit_tests/Makefile
+++ b/tests/unit_tests/Makefile
@@ -4,7 +4,7 @@ LIBS = -lm
 CC = gcc
 CFLAGS = -Wall -pedantic -std=c99 -g
 SRC_DIR = ../../src
-WRAPPED = -Wl,--wrap=sqrt
+CUNIT_DIR = ../cunit
 
 .PHONY: default all clean test_unit_all
 
@@ -14,15 +14,18 @@ all: default
 test_unit_all: test_my_sqrt.out
 	./test_my_sqrt.out
 
-test_my_sqrt.out: my_sqrt.o test_my_sqrt.o mock_sqrt.o
-	${CC} -o $@ $^ $(LIBS) $(WRAPPED)
+test_my_sqrt.out: my_sqrt.o test_my_sqrt.o mock_sqrt.o cunit.o
+	${CC} -o $@ $^ $(LIBS) -Wl,--wrap=sqrt
 
 my_sqrt.o: $(SRC_DIR)/my_sqrt.c $(SRC_DIR)/my_sqrt.h
 	${CC} -c -o $@ $< ${CFLAGS}
 
-test_my_sqrt.o: test_my_sqrt.c
+test_my_sqrt.o: test_my_sqrt.c mock_sqrt.h
 
 mock_sqrt.o: mock_sqrt.c
 
+cunit.o: $(CUNIT_DIR)/cunit.c $(CUNIT_DIR)/cunit.h
+	${CC} -c -o $@ $< ${CFLAGS}
+
 clean:
 	-rm -f *.o *.out
diff --git a/tests/unit_tests/mock_sqrt.c b/tests/unit_tests/mock_sqrt.c
index 376b8b4d30d125bdab3ea3f9036c6134d9b9d8a8..2c04e79a1c4cc839566c0c438b311798341ade4b 100644
--- a/tests/unit_tests/mock_sqrt.c
+++ b/tests/unit_tests/mock_sqrt.c
@@ -1,4 +1,6 @@
-static short x = 0;
+static double test_x = 0;
+static double test_sqrt_x = 0;
+static double test_x_input = 0;
 
 /**
  * Mocks sqrt function.
@@ -6,5 +8,16 @@ static short x = 0;
  * Requires '-Wl,--wrap=sqrt' flag when linking.
  */
 double __wrap_sqrt(double x) {
-    return 1.0f;
+    test_x_input = x;
+    return test_sqrt_x;
 }
+
+void mock_sqrt_setup(double x, double sqrt_x) {
+    test_x = x;
+    test_sqrt_x = sqrt_x;
+}
+
+void mock_sqrt_get(double *x, double *sqrt_x) {
+    *x = test_x_input;
+    *sqrt_x = test_sqrt_x;
+}
\ No newline at end of file
diff --git a/tests/unit_tests/mock_sqrt.h b/tests/unit_tests/mock_sqrt.h
new file mode 100644
index 0000000000000000000000000000000000000000..0f64be5d9a6038c25ec76d14f37a4a6704032488
--- /dev/null
+++ b/tests/unit_tests/mock_sqrt.h
@@ -0,0 +1,3 @@
+// mock_sqrt.h
+void mock_sqrt_setup(double x, double sqrt_x);
+void mock_sqrt_get(double *x, double *sqrt_x);
\ No newline at end of file
diff --git a/tests/unit_tests/test_my_sqrt.c b/tests/unit_tests/test_my_sqrt.c
index aaad03ae01deb2fe84bbe92090d906c10d6db603..54ed2a613b9b9c15cd30c6bbc7a5d6a4789ac78e 100644
--- a/tests/unit_tests/test_my_sqrt.c
+++ b/tests/unit_tests/test_my_sqrt.c
@@ -1,23 +1,77 @@
 #include <stdio.h>
 #include "../../src/my_sqrt.h"
+#include "../cunit/cunit.h"
+#include "mock_sqrt.h"
+
+void test_my_sqrt_4_returns_2() {
+    int my_sqrt_error = 0;
+    double mock_x_in = 4.0f;
+    double mock_x_out = 2.0f;
+    double passed_x_in = 0;
+    double passed_x_out = 0;
+    double x_out = 0;
+
+    mock_sqrt_setup(mock_x_in, mock_x_out);
+
+    my_sqrt_error = my_sqrt(mock_x_in, &x_out);
+
+    mock_sqrt_get(&passed_x_in, &passed_x_out);
+
+    assert_eq("my_sqrt returned error", 0, my_sqrt_error);
+    assert_feq("mock_sqrt did not get right argument", passed_x_in, mock_x_in);
+    assert_feq("my_sqrt modified mock_sqrt return value", passed_x_out, x_out);
+}
+
+void test_my_sqrt_nan_has_error() {
+    int my_sqrt_error = 0;
+    double x_in = NAN;
+    double x_out = 0;
+
+    mock_sqrt_setup(0, 0);
+
+    my_sqrt_error = my_sqrt(x_in, &x_out);
+
+    assert_neq("my_sqrt didn't return error", 0, my_sqrt_error);
+}
+
+void test_my_sqrt_inf_has_error() {
+    int my_sqrt_error = 0;
+    double x_in = INFINITY;
+    double x_out = 0;
+
+    mock_sqrt_setup(0, 0);
+
+    my_sqrt_error = my_sqrt(x_in, &x_out);
+
+    assert_neq("my_sqrt didn't return error", 0, my_sqrt_error);
+}
+
+void test_my_sqrt_negative_returns_nan_and_sqrt_not_called() {
+    int my_sqrt_error = 0;
+    double x_in = -1.0f;
+    double x_out = 0;
+
+    mock_sqrt_setup(0, 0); // zero will be returned if mock_sqrt called
+
+    my_sqrt_error = my_sqrt(x_in, &x_out);
+
+    assert_eq("my_sqrt returned error", 0, my_sqrt_error);
+    assert_true("Expected NAN", isnan(x_out));
+}
 
 /**
- * TODO: This is a work in progress.
- *
- * Currently this shows that my_sqrt's sqrt() function has been mocked.
- * It will always return 1, instead of computing the correct value of 2.
+ * Test my_sqrt by mocking the math library sqrt() function.
  */
 int main(int argc, char *argv[]) {
     int error = 0;
-    int my_sqrt_error = 0;
-    double x = 4.0f;
-    double out = 0;
 
-    printf("Hello!\n");
+    // init cunit
+    cunit_init();
 
-    my_sqrt_error = my_sqrt(x, &out);
-    printf("my_sqrt_error: %d\n", my_sqrt_error);
-    printf("Computed value: %f\n", out);
+    test_my_sqrt_4_returns_2();
+    test_my_sqrt_nan_has_error();
+    test_my_sqrt_inf_has_error();
+    test_my_sqrt_negative_returns_nan_and_sqrt_not_called();
 
     return error;
 }