"""
Tests for dealing with SQL arrays in ADQL.

(and perhaps elsewhere).
"""

import math

from gavo.helpers import testhelpers

from gavo import base
from gavo import utils

import adqltest
import tresc


class _ArrayTestbed(tresc.RDDataResource):
	"""A table that contains arrays.
	"""
	rdName = "data/ufuncex.rd"
	dataId = "import_arr"

_arrayTestbed = _ArrayTestbed()


class AdditionTest(testhelpers.VerboseTest):
	resources = [
		("arrsampleable", _arrayTestbed),
		("querier", adqltest.adqlQuerier)]

	def testBasic(self):
		res = self.querier.queryADQL(
			"SELECT flux,"
			" flux+flux as fluxtwice,"
			" plc+plc as plctwice"
			" FROM test.arrsample")
		self.assertEqual(
			len(res.rows[0]["fluxtwice"]),
			len(res.rows[0]["flux"]))
		self.assertEqual(len(res.rows[0]["plctwice"]), 3)
		self.assertAlmostEqual(
			res.rows[0]["fluxtwice"][0]/2, res.rows[0]["flux"][0])
		self.assertAlmostEqual(
			res.rows[0]["plctwice"][2], -3.2491233)

	def testUnequalLength(self):
		res = self.querier.queryADQL(
			"SELECT flux+CAST(plc AS real[]) as s FROM test.arrsample")
		self.assertEqual(len(res.rows[0]["s"]), 6)
		self.assertTrue(math.isnan(res.rows[0]["s"][-1]))
		self.assertTrue(math.isnan(res.rows[0]["s"][-2]))
		self.assertTrue(math.isnan(res.rows[0]["s"][-3]))
		self.assertAlmostEqual(res.rows[0]["s"][-4], -0.86078703)
		self.assertEqual(res.tableDef.getColumnByName("s").type, "real[]")
	
	def testAddedMetadata(self):
		res = self.querier.queryADQL(
			"SELECT flux, flux+flux as ff, cast(flux as double precision[])+plc as fp"
			" FROM test.arrsample")
	
		col = res.tableDef.getColumnByName("flux")
		self.assertEqual(col.ucd, "phot.mag")
		self.assertEqual(col.unit, "Jy")
		self.assertEqual(col.type, "real[]")
		self.assertEqual(col.description, "Some flux values in an array")

		col = res.tableDef.getColumnByName("ff")
		self.assertEqual(col.ucd, "phot.mag")
		self.assertEqual(col.unit, "Jy")
		self.assertEqual(col.type, "real[]")
		self.assertEqual(col.description, "This field has traces of: Some flux values in an array; Some flux values in an array -- *TAINTED*: the value was operated on in a way that unit and ucd may be severely wrong")

		col = res.tableDef.getColumnByName("fp")
		self.assertEqual(col.ucd, "")
		self.assertEqual(col.unit, "")
		self.assertEqual(col.type, "double precision[]")
		self.assertEqual(col.description, "This field has traces of: Some flux values in an array; Possibly a cartesian position -- *TAINTED*: the value was operated on in a way that unit and ucd may be severely wrong")


class SubtractionTest(testhelpers.VerboseTest):
	resources = [
		("arrsampleable", _arrayTestbed),
		("querier", adqltest.adqlQuerier)]

	def testBasic(self):
		res = self.querier.queryADQL(
			"SELECT flux,"
			" flux-flux as allnull,"
			" flux-plc as indiff"
			" FROM test.arrsample")
		self.assertEqual(
			len(res.rows[0]["allnull"]),
			len(res.rows[0]["indiff"]))
		self.assertEqual(res.rows[0]["allnull"][5], 0)
		self.assertAlmostEqual(res.rows[0]["indiff"][2], 2.388336286310653)
		self.assertTrue(math.isnan(res.rows[0]["indiff"][5]))


class MultiplicationTest(testhelpers.VerboseTest):
	resources = [
		("arrsampleable", _arrayTestbed),
		("querier", adqltest.adqlQuerier)]

	def testBasic(self):
		res = self.querier.queryADQL(
			"SELECT flux[6]*flux[6] as fstest,"
			" flux*flux as fluxsquare,"
			" flux[3]*plc[3] as mixtest,"
			" flux*plc as mixed"
			" FROM test.arrsample")
		r = res.rows[0]

		self.assertEqual(
			len(r["fluxsquare"]),
			len(r["mixed"]))
		self.assertEqual(r["fluxsquare"][5], r["fstest"])
		self.assertAlmostEqual(r["mixed"][2], r["mixtest"])
		self.assertTrue(math.isnan(r["mixed"][5]))


class DivisionTest(testhelpers.VerboseTest):
	resources = [
		("arrsampleable", _arrayTestbed),
		("querier", adqltest.adqlQuerier)]

	def testBasic(self):
		res = self.querier.queryADQL(
			"SELECT flux[6]/flux[6] as fstest,"
			" flux/flux as allones,"
			" flux[3]/plc[3] as mixtest,"
			" flux/plc as mixed"
			" FROM test.arrsample")
		r = res.rows[0]

		self.assertEqual(
			len(r["allones"]),
			len(r["mixed"]))
		self.assertEqual(r["allones"][5], r["fstest"])
		self.assertAlmostEqual(r["mixed"][2], r["mixtest"])
		self.assertTrue(math.isnan(r["mixed"][5]))

		self.assertEqual(res.tableDef.getColumnByName("allones").type,
			"real[]")
		self.assertEqual(res.tableDef.getColumnByName("mixed").type,
			"double precision[]")


class ScalarProdTest(testhelpers.VerboseTest):
	resources = [
		("arrsampleable", _arrayTestbed),
		("querier", adqltest.adqlQuerier)]

	def testBasic(self):
		res = self.querier.queryADQL(
			"SELECT"
			" arr_dot(flux, flux) as fluxprod,"
			" arr_dot(flux, plc) as mixedprod"
			" FROM test.arrsample")
		self.assertAlmostEqual(res.rows[0]["fluxprod"], 1.832107663154602)
		self.assertTrue(math.isnan(res.rows[0]["mixedprod"]))

		col = res.tableDef.getColumnByName("fluxprod")
		self.assertEqual(col.type, "double precision")
		self.assertEqual(col.ucd, "")
		self.assertEqual(col.unit, "")

		col = res.tableDef.getColumnByName("mixedprod")
		self.assertEqual(col.type, "double precision")
		self.assertEqual(col.ucd, "")
		self.assertEqual(col.unit, "")


class ScalarMulTest(testhelpers.VerboseTest):
	resources = [
		("arrsampleable", _arrayTestbed),
		("querier", adqltest.adqlQuerier)]

	def testWithLiteral(self):
		res = self.querier.queryADQL(
			"SELECT"
			" 2*flux-flux*2 as realzeros,"
			" 5*plc as dub"
			" FROM test.arrsample")

		col = res.tableDef.getColumnByName("realzeros")
		self.assertEqual(col.type, "real[]")
		self.assertEqual(col.ucd, "phot.mag")
		self.assertEqual(col.unit, "Jy")

		self.assertEqual(res.rows[0]["realzeros"], [0., 0., 0., 0., 0., 0.])

		col = res.tableDef.getColumnByName("dub")
		self.assertEqual(col.type, "double precision[]")
		self.assertEqual(col.ucd, "pos.cartesian")
		self.assertEqual(col.unit, "m")
		self.assertEqual(col.description, "Possibly a cartesian position -- *TAINTED*: the value was operated on in a way that unit and ucd may be severely wrong")

		self.assertAlmostEqual(res.rows[0]["dub"][2], -8.122808264515303)

	def testWithExpression(self):
		res = self.querier.queryADQL(
			"SELECT"
			" plc[1]*flux as fromleft,"
			" flux*plc[1] as fromright"
			" FROM test.arrsample")

		col = res.tableDef.getColumnByName("fromleft")
		self.assertEqual(col.type, "double precision[]")
		self.assertEqual(col.ucd, "")
		self.assertEqual(col.unit, "m*Jy")

		self.assertEqual([a-b for a, b in zip(
			res.rows[0]["fromleft"],
			res.rows[0]["fromright"])],
			[0., 0., 0., 0., 0., 0.])


class ScalarDivTest(testhelpers.VerboseTest):
	resources = [
		("arrsampleable", _arrayTestbed),
		("querier", adqltest.adqlQuerier)]

	def testWithLiteral(self):
		res = self.querier.queryADQL(
			"SELECT flux/2 as half FROM test.arrsample")

		col = res.tableDef.getColumnByName("half")
		self.assertEqual(col.type, "real[]")
		self.assertEqual(col.ucd, "phot.mag")
		self.assertEqual(col.unit, "Jy")
		self.assertEqual(res.rows[0]["half"][0], 0.06718212)

	def testWithExpression(self):
		res = self.querier.queryADQL(
			"SELECT flux/plc[1] as k FROM test.arrsample")

		col = res.tableDef.getColumnByName("k")
		self.assertEqual(col.type, "double precision[]")
		self.assertEqual(col.ucd, "")
		self.assertEqual(col.unit, "Jy/(m)")
		self.assertEqual(res.rows[0]["k"][-1], 0.7412795)


class DerivedAggFunctionTest(testhelpers.VerboseTest):
	resources = [
		("arrsampleable", _arrayTestbed),
		("querier", adqltest.adqlQuerier)]

	def testMessage(self):
		self.assertRaisesWithMsg(base.ValidationError,
			utils.EqualingRE("Field query: Could not parse your query: Expected ['\"]\\)[\"'], found ','  \\(at char 18\\), \\(line:1, col:19\\)"),
			self.querier.queryADQL,
			("SELECT arr_min(plc, plc) FROM test.arrsample",))

	def testMin(self):
		res = self.querier.queryADQL(
			"SELECT"
			" arr_min(plc) as minplc,"
			" arr_min(flux) as minflux"
			" FROM test.arrsample")

		self.assertAlmostEqual(res.rows[0]["minplc"], -1.6245616529030604)
		self.assertAlmostEqual(res.rows[1]["minplc"], -0.7674541696434232)
		self.assertAlmostEqual(res.rows[0]["minflux"], 0.13436425)

		col = res.tableDef.getColumnByName("minplc")
		self.assertEqual(col.type, "double precision")
		self.assertEqual(col.ucd, "stat.min;pos.cartesian")
		self.assertEqual(col.unit, "m")

		col = res.tableDef.getColumnByName("minflux")
		self.assertEqual(col.type, "real")
		self.assertEqual(col.ucd, "stat.min;phot.mag")
		self.assertEqual(col.unit, "Jy")

	def testOnNonArray(self):
		self.assertRaisesWithMsg(
			base.ValidationError,
			"Field query: Scalar array function called on non-array?",
			self.querier.queryADQL,
			("SELECT arr_min(plc[0]) as minplc FROM test.arrsample",))

	def testMax(self):
		res = self.querier.queryADQL(
			"SELECT"
			" arr_max(plc) as maxplc,"
			" arr_max(flux) as maxflux"
			" FROM test.arrsample")

		self.assertAlmostEqual(res.rows[1]["maxplc"], 0.6789216057608836)
		self.assertAlmostEqual(res.rows[0]["maxflux"], 0.84743375)

		col = res.tableDef.getColumnByName("maxplc")
		self.assertEqual(col.ucd, "stat.max;pos.cartesian")
		col = res.tableDef.getColumnByName("maxflux")
		self.assertEqual(col.ucd, "stat.max;phot.mag")

	def testSum(self):
		res = self.querier.queryADQL(
			"SELECT"
			" ARR_sum(plc) as splc,"
			" arr_SUM(flux) as sflux"
			" FROM test.arrsample")

		self.assertAlmostEqual(res.rows[1]["splc"], 0.3352440988313101)
		self.assertAlmostEqual(res.rows[0]["sflux"], 2.9455676)

		col = res.tableDef.getColumnByName("splc")
		self.assertEqual(col.ucd, "pos.cartesian;arith.sum")
		col = res.tableDef.getColumnByName("sflux")
		self.assertEqual(col.ucd, "phot.mag;arith.sum")

	def testAvg(self):
		res = self.querier.queryADQL(
			"SELECT"
			" arr_avg(plc) as mplc,"
			" arr_avg(flux) as mflux"
			" FROM test.arrsample")

		self.assertAlmostEqual(res.rows[1]["mplc"], 0.11174803294377005)
		self.assertAlmostEqual(res.rows[0]["mflux"], 0.49092796)

		col = res.tableDef.getColumnByName("mplc")
		self.assertEqual(col.ucd, "pos.cartesian;stat.mean")
		col = res.tableDef.getColumnByName("mflux")
		self.assertEqual(col.ucd, "phot.mag;stat.mean")

	def testStddev(self):
		res = self.querier.queryADQL(
			"SELECT"
			" power(arr_stddev(plc), 2) as varce, plc,"
			" arr_stddev(flux) as fluxerr, flux"
			" FROM test.arrsample")

		self.assertAlmostEqual(res.rows[0]["fluxerr"], 0.27786547)
		self.assertAlmostEqual(res.rows[1]["varce"], 0.596022120)

		col = res.tableDef.getColumnByName("varce")
		self.assertEqual(col.ucd, "")
		self.assertEqual(col.description, "Possibly a cartesian position -- *TAINTED*: the value was operated on in a way that unit and ucd may be severely wrong")
		col = res.tableDef.getColumnByName("fluxerr")
		self.assertEqual(col.ucd, "stat.stdev;phot.mag")


class ArrayAggregationTest(testhelpers.VerboseTest):
	resources = [
		("arrsampleable", _arrayTestbed),
		("querier", adqltest.adqlQuerier)]

	def testBasicSum(self):
		res = self.querier.queryADQL(
			"SELECT"
			" sum(flux) allflux,"
			" sum(plc) as allplc"
			" FROM test.arrsample")

		self.assertEqual(len(res.rows[0]["allflux"]), 6)
		self.assertEqual(len(res.rows[0]["allplc"]), 3)

		self.assertAlmostEqual(res.rows[0]["allflux"][0], 1.0903986)
		self.assertAlmostEqual(res.rows[0]["allplc"][-1], -1.2007849901)

		self.assertEqual(res.tableDef.getColumnByName("allflux").ucd, "phot.mag")

	def testBasicMin(self):
		res = self.querier.queryADQL(
			"SELECT"
			" min(flux) minflux,"
			" min(plc) as minplc"
			" FROM test.arrsample")

		self.assertEqual(len(res.rows[0]["minflux"]), 6)
		self.assertEqual(len(res.rows[0]["minplc"]), 3)

		self.assertAlmostEqual(res.rows[0]["minflux"][-1], 0.44949105)
		self.assertAlmostEqual(res.rows[0]["minplc"][0], 0.606371890)

		self.assertEqual(
			res.tableDef.getColumnByName("minflux").ucd,
			"stat.min;phot.mag")
	
	def testMixedMin(self):
		res = self.querier.queryADQL(
			"SELECT min(arr) AS mixedmin FROM ("
			" SELECT flux AS arr FROM test.arrsample"
			" UNION"
			" SELECT plc AS arr FROM test.arrsample) as q")

		self.assertEqual(len(res.rows[0]["mixedmin"]), 6)

		self.assertAlmostEqual(res.rows[0]["mixedmin"][-1], 0.44949105)
		self.assertAlmostEqual(res.rows[0]["mixedmin"][0], 0.1343642473)

		# lazyness alert: the UCD should probably be "", but I won't dig
		# now why it's Not
		self.assertEqual(
			res.tableDef.getColumnByName("mixedmin").ucd,
			None)
		self.assertEqual(
			res.tableDef.getColumnByName("mixedmin").unit,
			"")

	def testBasicMax(self):
		res = self.querier.queryADQL(
			"SELECT"
			" max(flux) maxflux,"
			" max(plc) as maxplc"
			" FROM test.arrsample")

		self.assertEqual(len(res.rows[0]["maxflux"]), 6)
		self.assertEqual(len(res.rows[0]["maxplc"]), 3)

		self.assertAlmostEqual(res.rows[0]["maxflux"][-1], 0.73596996)
		self.assertAlmostEqual(res.rows[0]["maxplc"][0], 0.678921605)

		self.assertEqual(
			res.tableDef.getColumnByName("maxflux").ucd,
			"stat.max;phot.mag")

	def testBasicAvg(self):
		res = self.querier.queryADQL(
			"SELECT"
			" avg(flux) avgflux,"
			" avg(plc) as avgplc"
			" FROM test.arrsample")

		self.assertEqual(len(res.rows[0]["avgflux"]), 6)
		self.assertEqual(len(res.rows[0]["avgplc"]), 3)

		self.assertAlmostEqual(res.rows[0]["avgflux"][-1],
			list(self.querier.connection.query(
				"select avg(flux[6]) from test.arrsample"))[0][0])
		self.assertAlmostEqual(res.rows[0]["avgplc"][0],
			list(self.querier.connection.query(
				"select avg(plc[1]) from test.arrsample"))[0][0])

		self.assertEqual(
			res.tableDef.getColumnByName("avgflux").ucd,
			"phot.mag;stat.mean")

	def testBasicStddev(self):
		res = self.querier.queryADQL(
			"SELECT"
			" stddev(flux[4]) as fluxtest,"
			" stddev(flux) fluxerr,"
			" stddev(plc[1]) as plctest,"
			" stddev(plc) as plcerr"
			" FROM test.arrsample")

		r = res.rows[0]
		self.assertAlmostEqual(r["fluxerr"][3], r["fluxtest"])
		self.assertAlmostEqual(r["plcerr"][0], r["plctest"])
		
		self.assertEqual(
			res.tableDef.getColumnByName("fluxtest").ucd,
			"stat.stdev;phot.mag")


class ArrMapTest(testhelpers.VerboseTest):
	resources = [
		("arrsampleable", _arrayTestbed),
		("querier", adqltest.adqlQuerier)]

	def testBasic(self):
		res = self.querier.queryADQL(
			"SELECT arr_map(sin(x)+1, flux) as x, sin(flux[3])+1 as xt,"
			"  arr_map(round(x*100), plc) as p, round(plc[1]*100) as pt"
			" FROM test.arrsample")

		self.assertAlmostEqual(res.rows[0]["x"][2], res.rows[0]["xt"])
		self.assertAlmostEqual(res.rows[0]["p"][0], res.rows[0]["pt"])

		self.assertEqual(
			res.tableDef.getColumnByName("x").ucd,
			"")
		self.assertEqual(
			res.tableDef.getColumnByName("x").type,
			"double precision[]")


if __name__=="__main__":
	testhelpers.main(AdditionTest)
