#
# Copyright (c) 2010 Canonical
#
# Written by Gustavo Niemeyer <gustavo@niemeyer.net>
#
# This file is part of the Xpresser GUI automation library.
#
# Xpresser is free software; you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License version 3,
# as published by the Free Software Foundation.
#
# Xpresser is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
import opencv
import gtk

from xpresser.image import Image
from xpresser.opencvfinder import OpenCVFinder

from xpresser.lib.testing import TestCase
from xpresser.tests.images import get_image_path


class OpenCVFinderTest(TestCase):

    def setUp(self):
        self.screen_image = Image(filename=get_image_path("screen.png"))
        self.red_square = Image(filename=get_image_path("red-square.png"))
        self.red_circle = Image(filename=get_image_path("red-circle.png"))
        self.red_ellipse = Image(filename=get_image_path("red-ellipse.png"))
        self.green_square = Image(filename=get_image_path("green-square.png"))
        self.yellow_square = Image(filename=get_image_path("yellow-square.png"))
        self.yellow_circle = Image(filename=get_image_path("yellow-circle.png"))
        self.red_circle_with_blue_circle = \
            Image(filename=get_image_path("red-circle-with-blue-circle.png"))
        self.finder = OpenCVFinder()

    def test_loads_widths_and_heights(self):
        self.finder.find(self.screen_image, self.red_ellipse)
        self.assertEquals(self.screen_image.width, 300)
        self.assertEquals(self.screen_image.height, 300)
        self.assertEquals(self.red_ellipse.width, 40)
        self.assertEquals(self.red_ellipse.height, 50)

    def test_find_perfect_match(self):
        match = self.finder.find(self.screen_image, self.green_square)
        self.assertEquals(match.image, self.green_square)
        self.assertEquals(match.x, 200)
        self.assertEquals(match.y, 0)
        self.assertEquals(match.similarity, 1.0)

    def test_find_perfect_match_with_low_threshold(self):
        self.green_square.similarity = 0.7
        match = self.finder.find(self.screen_image, self.green_square)
        self.assertEquals(match.image, self.green_square)
        self.assertEquals(match.x, 200)
        self.assertEquals(match.y, 0)

    def test_find_all_with_perfect_match(self):
        matches = self.finder.find_all(self.screen_image, self.green_square)
        self.assertEquals(len(matches), 1)
        self.assertEquals(matches[0].image, self.green_square)
        self.assertEquals(matches[0].x, 200)
        self.assertEquals(matches[0].y, 0)
        self.assertEquals(matches[0].similarity, 1.0)

    def test_find_all_with_low_threshold_containing_perfect_match(self):
        self.red_circle.similarity = 0.8
        matches = self.finder.find_all(self.screen_image, self.red_circle)
        self.assertTrue(len(matches) > 1)
        self.assertTrue(min(m.similarity for m in matches) >= 0.8)

    def test_no_matches(self):
        match = self.finder.find(self.screen_image,
                                 self.red_circle_with_blue_circle)
        self.assertEquals(match, None)

    def test_fuzzy_match(self):
        self.red_circle_with_blue_circle.similarity = 0.9
        match = self.finder.find(self.screen_image,
                                 self.red_circle_with_blue_circle)
        self.assertEquals(match.image, self.red_circle_with_blue_circle)
        self.assertEquals(match.x, 100)
        self.assertEquals(match.y, 200)

    def test_self_match(self):
        """
        This test will explore a bug in the Python OpenCV bindings.  It
        will handle the dimensions of the result matrix in a different
        way when there's a single result.
        """
        match = self.finder.find(self.red_circle, self.red_circle)
        self.assertEquals(match.x, 0)
        self.assertEquals(match.y, 0)

    def test_opencv_image_cache(self):
        match = self.finder.find(self.red_circle, self.yellow_circle)
        opencv_image = self.red_circle.cache.get("opencv_image")
        self.assertEquals(match, None)
        self.assertNotEquals(opencv_image, None)
        self.assertEquals(type(opencv_image), opencv.CvMat)

        # Let's ensure the cache is *actually* in use.
        self.red_circle.cache["opencv_image"] = \
            self.yellow_circle.cache["opencv_image"]

        match = self.finder.find(self.red_circle, self.yellow_circle)
        self.assertNotEquals(match, None)

    def test_filtering_of_similar_matches(self):
        """
        This example would actually have hundreds of matches if there was
        no filtering per proximity and match quality.  The filtering
        algorithm is not entirely trivial, and likely has other cases
        which need to be covered by individual unit tests too.
        """
        self.red_circle.similarity = 0.8
        matches = self.finder.find_all(self.screen_image, self.red_circle)
        matches.sort(key=lambda match: -match.similarity)
        self.assertEquals(len(matches), 2)
        self.assertEquals(matches[0].x, 100)
        self.assertEquals(matches[0].y, 200)
        self.assertEquals(matches[1].x, 198)
        self.assertEquals(matches[1].y, 100)

    def test_find_with_array_image(self):
        # Reset the image, including the cache (shouldn't be needed, but
        # just to be 100% sure).
        self.green_square.cache.clear()
        filename = self.green_square.filename
        self.green_square.filename = None

        # Use gtk to transform the image into a numpy array, and set it
        # back into the image.
        pixbuf = gtk.image_new_from_file(filename).get_pixbuf()
        self.green_square.array = pixbuf.get_pixels_array()

        # Try to match normally.
        match = self.finder.find(self.screen_image, self.green_square)
        self.assertEquals(match.image, self.green_square)
        self.assertEquals(match.x, 200)
        self.assertEquals(match.y, 0)
        self.assertEquals(match.similarity, 1.0)

