import sys
import unittest
from xml import sax
from xml.sax.handler import ContentHandler

from gavo import base
from gavo import api
from gavo.protocols import adqlglue

totalRows = 3


class NamedNode(object):
	"""is a helper class for NodeBuilder to change node names from
	handling functions.
	"""
	def __init__(self, name, node):
		self.name, self.node = name, node


class SimpleNodeBuilder(ContentHandler):
	"""a node builder is a content handler working more along the
	lines of conventional parse tree builders.

	This means that for every element we want handled, there is a
	method _make_<elementname>(name, attrs, children) that receives
	the name of the element (so you can reuse implementations for
	elements behaving analogously), the attributes of the element
	and a list of children.  The children come in a list of tuples
	(name, content), where name is the element name and content
	is whatever the _make_x method returned for that element.
	Text nodes have a name of None.

	If a method _make_default is defined, it will be called if no handler
	for a node is defined.  Otherwise, elements with no handlers are errors.

	In general, text children are deleted when they are whitespace
	only, and they are joined to form a single one.  However, you
	can define a set keepWhitespaceNames containing the names of
	elements for which this is not wanted.  Don't add to the class-level
	set, make a new one.

	On errors during node construction, the class will call a
	handleError method with a sys.exc_info tuple.  The default implementation
	exits with an error message.

	If the parser calls the setDocumentLocator method, its result is available
	as the locator attribute.
	"""
	keepWhitespaceNames = set()

	def __init__(self):
		ContentHandler.__init__(self)
		self.elementStack = []
		self.childStack = [[]]
		self.locator = None

	def handleError(self, exc_info):
		msg = ("Error while parsing XML at"
			" %d:%d (%s)"%(self.locator.getLineNumber(), 
				self.locator.getColumnNumber(), exc_info[1]))
		sys.exit(msg)

	def setDocumentLocator(self, locator):
		self.locator = locator

	def startElement(self, name, attrs):
		self.elementStack.append((name, attrs))
		self.childStack.append([])

	def _enterNewNode(self, name, attrs, newNode):
			if isinstance(newNode, NamedNode):
				newChild = (newNode.name, newNode.node)
			else:
				newChild = (name, newNode)
			self.childStack[-1].append(newChild)

	def endElement(self, name):
		_, attrs = self.elementStack.pop()
		try:
			children = self.childStack.pop()
			if name not in self.keepWhitespaceNames:
				children = self._cleanTextNodes(children)
			if hasattr(self, "_make_"+name):
				handler = getattr(self, "_make_"+name)
			elif hasattr(self, "_make_default"):
				handler = getattr(self, "_make_default")
			else:
				raise Error("No handler for %s"%name)
			newNode = handler(name, attrs, children)
			if newNode is not None:
				self._enterNewNode(name, attrs, newNode)
		except:
			self.handleError(sys.exc_info())

	def characters(self, content):
		self.childStack[-1].append((None, content))

	def getResult(self):
		return self.childStack[0][0][1]

	def _cleanTextNodes(self, children):
		"""joins adjacent text nodes and prunes whitespace-only nodes.
		"""
		cleanedChildren, text = [], []
		for type, node in children:
			if type is None:
				text.append(node)
			else:
				if text:
					chars = "".join(text)
					if chars.strip():
						cleanedChildren.append((None, chars))
					text = []
				cleanedChildren.append((type, node))
		chars = "".join(text)
		if chars.strip():
			cleanedChildren.append((None, chars))
		return cleanedChildren

	def getContentWS(self, children):
		"""returns the entire text content of the node in a string without doing
		whitespace normalization.
		"""
		return "".join([n[1] for n in children if n[0] is None])

	def getContent(self, children):
		"""returns the entire text content of the node in a string.

		This probably won't do what you want in mixed-content models.
		"""
		return self.getContentWS(children).strip()


class ADQLTest(unittest.TestCase):
	def _getResultRows(self, query):
		return adqlglue.query(query).rows


def makeCardTest(query, expectedLength, locator, hint):
	qInfo = " query: '%s', %s"%(query, locator)
	if hint:
		qInfo += " Hint: %s"%hint
	class CardTest(ADQLTest):
		def testCard(self):
			foundLen = len(self._getResultRows(query))
			self.assertEqual(foundLen, expectedLength, "%d rows in result"
				" instead of %d%s"%(foundLen, expectedLength, qInfo))
	return CardTest


def makeResTest(query, expectedResult, locator, hint):
	qInfo = " Query: '%s', %s"%(query, locator)
	if hint:
		qInfo += " Hint: %s"%hint
	class ResTest(ADQLTest):
		def testRes(self):
			rows = self._getResultRows(query)
			self.assertEqual(len(rows), 1, "ResTests must return exactly one row,"
				" not %d.%s"%(len(rows), qInfo))
			self.assertEqual(len(rows[0]), 1, "Result row has length"
				" %d, but resTests must only return exactly one row with exactly"
				" one field.%s"%(len(rows[0]), qInfo))
			resVal = rows[0].popitem()[1]
			self.assertEqual(resVal, expectedResult, 
				"Result is '%s' instead of '%s'%s"%(
					resVal, expectedResult, qInfo))
	return ResTest


class ADQLSuite(unittest.TestSuite):
	def __init__(self, name, *args, **kwargs):
		self.__name = name
		unittest.TestSuite.__init__(self, *args, **kwargs)

	def shortDescription(self):
		return self.__name


class TestSuiteMaker(SimpleNodeBuilder):
	def handleError(self, excTuple):
		raise

	def getPos(self):
		return "Line %d"%self.locator.getLineNumber()

	def _make_cardTest(self, name, attrs, children):
		query = self.getContent(children)
		expected = int(attrs["res"])
		if expected<0:
			expected += totalRows
		return makeCardTest(query, expected, self.getPos(), attrs.get("hint"))

	def _make_intResTest(self, name, attrs, children):
		query = self.getContent(children)
		expected = int(attrs["res"])
		return makeResTest(query, expected, self.getPos(), attrs.get("hint"))

	def _make_charResTest(self, name, attrs, children):
		query = self.getContent(children)
		return makeResTest(query, attrs["res"], self.getPos(), 
			attrs.get("hint"))

	def _make_testGroup(self, name, attrs, children):
		loader = unittest.TestLoader()
		return ADQLSuite(attrs["name"], [
			loader.loadTestsFromTestCase(test) for _, test in children])

	def _make_tests(self, name, attrs, children):
		newSuite = unittest.TestSuite()
		for name, test in children:
			newSuite.addTest(test)
		return newSuite

	def _make_default(self, name, attrs, children):
		print "Ignoring", name


def getSuite(srcName):
	parser = TestSuiteMaker()
	sax.parse(open(srcName), parser)
	return parser.getResult()


def setTotalRows():
	global totalRows
	totalRows = len(adqlglue.query("select * from val_1"))


if __name__=="__main__":
	base.setDBProfile("trustedquery")
	setTotalRows()
	suite = getSuite("level0.queries")
	unittest.TextTestRunner().run(suite)
