/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysds.test.functions.builtin.part2;

import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.junit.Test;

import java.io.IOException;

public class BuiltinTSNETest extends AutomatedTestBase
{
	private final static String TEST_NAME = "tSNE";
	private final static String TEST_DIR = "functions/builtin/";
	private static final String TEST_CLASS_DIR = TEST_DIR + BuiltinTSNETest.class.getSimpleName() + "/";
	
	@Override
	public void setUp() {
		addTestConfiguration(TEST_NAME,new TestConfiguration(TEST_CLASS_DIR, TEST_NAME,new String[]{"B"})); 
	}

	@Test
	public void testTSNECP() throws IOException {
		runTSNETest(2, 30, 300.,
			0.9, 1000, 42, "FALSE", ExecType.CP);
	}

	@SuppressWarnings("unused")
	private void runTSNETest(Integer reduced_dims, Integer perplexity, Double lr,
		Double momentum, Integer max_iter, Integer seed, String is_verbose, ExecType instType)
		throws IOException
	{
		ExecMode platformOld = setExecMode(instType);

		try
		{
			loadTestConfiguration(getTestConfiguration(TEST_NAME));

			String HOME = SCRIPT_DIR + TEST_DIR;
			fullDMLScriptName = HOME + TEST_NAME + ".dml";
			programArgs = new String[]{
				"-nvargs", "X=" + input("X"), "Y=" + output("Y"),
				"reduced_dims=" + reduced_dims,
				"perplexity=" + perplexity,
				"lr=" + lr,
				"momentum=" + momentum,
				"max_iter=" + max_iter,
				"seed=" + seed,
				"is_verbose=" + is_verbose};

			// The Input values are calculated using the following R script:
			// TODO create via dml operations, avoid inlining data
			// library(Rtsne)
			// set.seed(42)
			// iris_unique <- unique(iris)
			// iris_matrix <- as.matrix(iris_unique[,1:4])
			// X <- normalize_input(iris_matrix) # the values used for the test
			
			// Input
			double[][] X = {{-0.23599574, 0.13972311, -0.74547391, -0.31565495},
					{-0.29946752, -0.01895634, -0.74547391, -0.31565495},
					{-0.36293930, 0.04451544, -0.77720980, -0.31565495},
					{-0.39467519, 0.01277955, -0.71373802, -0.31565495},
					{-0.26773163, 0.17145900, -0.74547391, -0.31565495},
					{-0.14078807, 0.26666667, -0.65026624, -0.25218317},
					{-0.39467519, 0.10798722, -0.74547391, -0.28391906},
					{-0.26773163, 0.10798722, -0.71373802, -0.31565495},
					{-0.45814696, -0.05069223, -0.74547391, -0.31565495},
					{-0.29946752, 0.01277955, -0.71373802, -0.34739084},
					{-0.14078807, 0.20319489, -0.71373802, -0.31565495},
					{-0.33120341, 0.10798722, -0.68200213, -0.31565495},
					{-0.33120341, -0.01895634, -0.74547391, -0.34739084},
					{-0.48988285, -0.01895634, -0.84068158, -0.34739084},
					{-0.01384452, 0.29840256, -0.80894569, -0.31565495},
					{-0.04558040, 0.42534611, -0.71373802, -0.25218317},
					{-0.14078807, 0.26666667, -0.77720980, -0.25218317},
					{-0.23599574, 0.13972311, -0.74547391, -0.28391906},
					{-0.04558040, 0.23493078, -0.65026624, -0.28391906},
					{-0.23599574, 0.23493078, -0.71373802, -0.28391906},
					{-0.14078807, 0.10798722, -0.65026624, -0.31565495},
					{-0.23599574, 0.20319489, -0.71373802, -0.25218317},
					{-0.39467519, 0.17145900, -0.87241747, -0.31565495},
					{-0.23599574, 0.07625133, -0.65026624, -0.22044728},
					{-0.33120341, 0.10798722, -0.58679446, -0.31565495},
					{-0.26773163, -0.01895634, -0.68200213, -0.31565495},
					{-0.26773163, 0.10798722, -0.68200213, -0.25218317},
					{-0.20425985, 0.13972311, -0.71373802, -0.31565495},
					{-0.20425985, 0.10798722, -0.74547391, -0.31565495},
					{-0.36293930, 0.04451544, -0.68200213, -0.31565495},
					{-0.33120341, 0.01277955, -0.68200213, -0.31565495},
					{-0.14078807, 0.10798722, -0.71373802, -0.25218317},
					{-0.20425985, 0.33013845, -0.71373802, -0.34739084},
					{-0.10905218, 0.36187433, -0.74547391, -0.31565495},
					{-0.29946752, 0.01277955, -0.71373802, -0.31565495},
					{-0.26773163, 0.04451544, -0.80894569, -0.31565495},
					{-0.10905218, 0.13972311, -0.77720980, -0.31565495},
					{-0.29946752, 0.17145900, -0.74547391, -0.34739084},
					{-0.45814696, -0.01895634, -0.77720980, -0.31565495},
					{-0.23599574, 0.10798722, -0.71373802, -0.31565495},
					{-0.26773163, 0.13972311, -0.77720980, -0.28391906},
					{-0.42641108, -0.24110756, -0.77720980, -0.28391906},
					{-0.45814696, 0.04451544, -0.77720980, -0.31565495},
					{-0.26773163, 0.13972311, -0.68200213, -0.18871140},
					{-0.23599574, 0.23493078, -0.58679446, -0.25218317},
					{-0.33120341, -0.01895634, -0.74547391, -0.28391906},
					{-0.23599574, 0.23493078, -0.68200213, -0.31565495},
					{-0.39467519, 0.04451544, -0.74547391, -0.31565495},
					{-0.17252396, 0.20319489, -0.71373802, -0.31565495},
					{-0.26773163, 0.07625133, -0.74547391, -0.31565495},
					{0.36698616, 0.04451544, 0.30181044, 0.06517572},
					{0.17657082, 0.04451544, 0.23833866, 0.09691161},
					{0.33525027, 0.01277955, 0.36528222, 0.09691161},
					{-0.10905218, -0.24110756, 0.07965921, 0.03343983},
					{0.20830671, -0.08242812, 0.27007455, 0.09691161},
					{-0.04558040, -0.08242812, 0.23833866, 0.03343983},
					{0.14483493, 0.07625133, 0.30181044, 0.12864750},
					{-0.29946752, -0.20937167, -0.14249201, -0.06176784},
					{0.24004260, -0.05069223, 0.27007455, 0.03343983},
					{-0.20425985, -0.11416400, 0.04792332, 0.06517572},
					{-0.26773163, -0.33631523, -0.07902023, -0.06176784},
					{0.01789137, -0.01895634, 0.14313099, 0.09691161},
					{0.04962726, -0.27284345, 0.07965921, -0.06176784},
					{0.08136315, -0.05069223, 0.30181044, 0.06517572},
					{-0.07731629, -0.05069223, -0.04728435, 0.03343983},
					{0.27177849, 0.01277955, 0.20660277, 0.06517572},
					{-0.07731629, -0.01895634, 0.23833866, 0.09691161},
					{-0.01384452, -0.11416400, 0.11139510, -0.06176784},
					{0.11309904, -0.27284345, 0.23833866, 0.09691161},
					{-0.07731629, -0.17763578, 0.04792332, -0.03003195},
					{0.01789137, 0.04451544, 0.33354633, 0.19211928},
					{0.08136315, -0.08242812, 0.07965921, 0.03343983},
					{0.14483493, -0.17763578, 0.36528222, 0.09691161},
					{0.08136315, -0.08242812, 0.30181044, 0.00170394},
					{0.17657082, -0.05069223, 0.17486688, 0.03343983},
					{0.24004260, -0.01895634, 0.20660277, 0.06517572},
					{0.30351438, -0.08242812, 0.33354633, 0.06517572},
					{0.27177849, -0.01895634, 0.39701810, 0.16038339},
					{0.04962726, -0.05069223, 0.23833866, 0.09691161},
					{-0.04558040, -0.14589989, -0.07902023, -0.06176784},
					{-0.10905218, -0.20937167, 0.01618743, -0.03003195},
					{-0.10905218, -0.20937167, -0.01554846, -0.06176784},
					{-0.01384452, -0.11416400, 0.04792332, 0.00170394},
					{0.04962726, -0.11416400, 0.42875399, 0.12864750},
					{-0.14078807, -0.01895634, 0.23833866, 0.09691161},
					{0.04962726, 0.10798722, 0.23833866, 0.12864750},
					{0.27177849, 0.01277955, 0.30181044, 0.09691161},
					{0.14483493, -0.24110756, 0.20660277, 0.03343983},
					{-0.07731629, -0.01895634, 0.11139510, 0.03343983},
					{-0.10905218, -0.17763578, 0.07965921, 0.03343983},
					{-0.10905218, -0.14589989, 0.20660277, 0.00170394},
					{0.08136315, -0.01895634, 0.27007455, 0.06517572},
					{-0.01384452, -0.14589989, 0.07965921, 0.00170394},
					{-0.26773163, -0.24110756, -0.14249201, -0.06176784},
					{-0.07731629, -0.11416400, 0.14313099, 0.03343983},
					{-0.04558040, -0.01895634, 0.14313099, 0.00170394},
					{-0.04558040, -0.05069223, 0.14313099, 0.03343983},
					{0.11309904, -0.05069223, 0.17486688, 0.03343983},
					{-0.23599574, -0.17763578, -0.23769968, -0.03003195},
					{-0.04558040, -0.08242812, 0.11139510, 0.03343983},
					{0.14483493, 0.07625133, 0.71437700, 0.41427050},
					{-0.01384452, -0.11416400, 0.42875399, 0.22385517},
					{0.39872204, -0.01895634, 0.68264111, 0.28732694},
					{0.14483493, -0.05069223, 0.58743344, 0.19211928},
					{0.20830671, -0.01895634, 0.65090522, 0.31906283},
					{0.55740149, -0.01895634, 0.90479233, 0.28732694},
					{-0.29946752, -0.17763578, 0.23833866, 0.16038339},
					{0.46219382, -0.05069223, 0.80958466, 0.19211928},
					{0.27177849, -0.17763578, 0.65090522, 0.19211928},
					{0.43045793, 0.17145900, 0.74611289, 0.41427050},
					{0.20830671, 0.04451544, 0.42875399, 0.25559105},
					{0.17657082, -0.11416400, 0.49222577, 0.22385517},
					{0.30351438, -0.01895634, 0.55569755, 0.28732694},
					{-0.04558040, -0.17763578, 0.39701810, 0.25559105},
					{-0.01384452, -0.08242812, 0.42875399, 0.38253461},
					{0.17657082, 0.04451544, 0.49222577, 0.35079872},
					{0.20830671, -0.01895634, 0.55569755, 0.19211928},
					{0.58913738, 0.23493078, 0.93652822, 0.31906283},
					{0.58913738, -0.14589989, 1.00000000, 0.35079872},
					{0.04962726, -0.27284345, 0.39701810, 0.09691161},
					{0.33525027, 0.04451544, 0.61916933, 0.35079872},
					{-0.07731629, -0.08242812, 0.36528222, 0.25559105},
					{0.58913738, -0.08242812, 0.93652822, 0.25559105},
					{0.14483493, -0.11416400, 0.36528222, 0.19211928},
					{0.27177849, 0.07625133, 0.61916933, 0.28732694},
					{0.43045793, 0.04451544, 0.71437700, 0.19211928},
					{0.11309904, -0.08242812, 0.33354633, 0.19211928},
					{0.08136315, -0.01895634, 0.36528222, 0.19211928},
					{0.17657082, -0.08242812, 0.58743344, 0.28732694},
					{0.43045793, -0.01895634, 0.65090522, 0.12864750},
					{0.49392971, -0.08242812, 0.74611289, 0.22385517},
					{0.65260916, 0.23493078, 0.84132055, 0.25559105},
					{0.17657082, -0.08242812, 0.58743344, 0.31906283},
					{0.14483493, -0.08242812, 0.42875399, 0.09691161},
					{0.08136315, -0.14589989, 0.58743344, 0.06517572},
					{0.58913738, -0.01895634, 0.74611289, 0.35079872},
					{0.14483493, 0.10798722, 0.58743344, 0.38253461},
					{0.17657082, 0.01277955, 0.55569755, 0.19211928},
					{0.04962726, -0.01895634, 0.33354633, 0.19211928},
					{0.33525027, 0.01277955, 0.52396166, 0.28732694},
					{0.27177849, 0.01277955, 0.58743344, 0.38253461},
					{0.33525027, 0.01277955, 0.42875399, 0.35079872},
					{0.30351438, 0.04451544, 0.68264111, 0.35079872},
					{0.27177849, 0.07625133, 0.61916933, 0.41427050},
					{0.27177849, -0.01895634, 0.46048988, 0.35079872},
					{0.14483493, -0.17763578, 0.39701810, 0.22385517},
					{0.20830671, -0.01895634, 0.46048988, 0.25559105},
					{0.11309904, 0.10798722, 0.52396166, 0.35079872},
					{0.01789137, -0.01895634, 0.42875399, 0.19211928}};

			// The reference output was created by using the builtin function with seed 42 and visually inspecting the
			// result with the following addition to the above R script:
			/*
			plot(Y, col = iris_unique$Species)
			 */

			// reference Output
			double[][] YReference = {{18.220536548250042, -12.846498524536738},
					{15.927903386925026, -14.212023388236792},
					{16.769777454402725, -14.867104469807458},
					{16.290613410318578, -14.971912325413014},
					{18.534108527624923, -13.081965971299278},
					{19.46702930119709, -11.107384827606543},
					{17.196995022994926, -14.952457676596161},
					{17.531360762128234, -13.133905834551287},
					{15.996750713161672, -15.670577143806288},
					{16.36534147032176, -13.94640444049381},
					{19.094349767837077, -11.557657039778153},
					{16.909635859249846, -13.432332957158627},
					{15.964241411757008, -14.583849627922195},
					{16.313709524761837, -16.090269929669734},
					{20.285962611966927, -10.881660944862407},
					{20.554758173661426, -11.06392329099603},
					{19.81906056722687, -11.547188487333667},
					{18.156895316378105, -12.705433004326382},
					{19.567551886989456, -10.747111880219315},
					{19.250142947939032, -12.360380239016678},
					{17.98651875041291, -11.373400820175243},
					{18.85701724837406, -12.344092918634603},
					{15.003071782449434, -15.134107990259263},
					{16.906273660742716, -12.081578738297303},
					{16.191062837166225, -12.743134302084322},
					{15.882718808144244, -13.620171658276757},
					{17.162351058283594, -12.52282386909434},
					{18.222170807106863, -12.25960056343841},
					{17.965249267850382, -12.497678277213915},
					{16.61308662972421, -14.338853140723119},
					{16.181343729377808, -14.009495314657007},
					{18.02370742093665, -11.5469978987216},
					{20.11583352615154, -11.971703104743623},
					{20.34766483388781, -11.368324149466424},
					{16.425063301904068, -13.920590257944497},
					{17.456451755948965, -14.167821170017678},
					{18.596836114631266, -11.38629820258609},
					{18.433917702277277, -13.54006232724752},
					{16.220662327912546, -15.705004632184767},
					{17.693153385837373, -12.795456099765842},
					{18.171476611546456, -13.323284884616521},
					{15.440053805476438, -16.14242285042382},
					{16.61041486346442, -15.64274785656758},
					{17.151462077196808, -11.907401620659614},
					{18.626998818277606, -10.74416620713765},
					{15.98116134477208, -14.570088706192026},
					{19.310896130733955, -12.328582468776172},
					{16.638185072116197, -15.041546247833768},
					{19.00600656884603, -11.840914144209874},
					{17.403065129854806, -13.533995008873486},
					{-7.038268225948856, 7.536962272871995},
					{-7.3914411713029935, 5.908404973751881},
					{-7.496137452637684, 7.80002229833549},
					{-6.460592719894593, 0.9062234315612744},
					{-7.616915840513189, 6.367666534180176},
					{-6.73555772449111, 3.123913125534304},
					{-8.080845938411253, 6.030180083529746},
					{-6.003581004054341, -0.5426273138978279},
					{-7.169789369755973, 6.530569753225136},
					{-5.74311471355638, 1.109042979038742},
					{-6.159692868338611, -0.47672936959379775},
					{-7.224663019809244, 3.2705583465396795},
					{-7.80311490665107, 0.8971687420050435},
					{-8.171493899007842, 4.945491553039871},
					{-6.055076176778027, 0.8068408811777839},
					{-6.841744219152401, 6.563990176275155},
					{-6.4688066027110995, 3.3600337950542185},
					{-7.290016820260793, 1.6539424879073648},
					{-9.427559060500517, 3.747497254956788},
					{-6.870144371536429, 0.8988782592514124},
					{-9.136986152714911, 5.784098425135586},
					{-7.68317760808878, 2.43574703565265},
					{-9.740448694574955, 5.387219781790028},
					{-8.073455284093917, 4.572592063175601},
					{-6.941003784668771, 5.2940569383973815},
					{-6.92987101320006, 6.2327815077976485},
					{-7.48373098301499, 7.341059972461793},
					{-8.40332894242311, 8.176817619574457},
					{-7.681170155620538, 4.334864702224226},
					{-6.956305041180643, 0.2075870963686768},
					{-6.655261285197925, 0.5228359022734664},
					{-6.666350115623071, 0.2717264575985321},
					{-7.0998853298193145, 1.4238408353689687},
					{-10.2308151080835, 6.427565359129248},
					{-6.01229671664338, 3.082044215406439},
					{-8.034938379871582, 5.334134289261261},
					{-7.359172540935799, 7.109704156745161},
					{-9.171082430966738, 3.5626753890742657},
					{-6.535162057220474, 2.312414150179807},
					{-6.3789219417924405, 1.2216485793591474},
					{-6.067256940518085, 2.1167290391035074},
					{-7.812215578205332, 4.875684399693769},
					{-7.1374342502995605, 1.4628113062889514},
					{-6.070547359306737, -0.5122472965230369},
					{-6.46343725557135, 1.979317360428436},
					{-6.847076431902512, 2.5392196334816846},
					{-6.808114424651112, 2.498147826359703},
					{-7.202776337109158, 4.407291480436176},
					{-6.008168593188563, -0.6619390315875436},
					{-6.831119010530879, 2.045771384669187},
					{-12.091830779150229, 11.265359769601078},
					{-10.854204065227533, 6.234604499378284},
					{-10.722649555263748, 12.399149192535997},
					{-11.024368675274824, 9.162144335131195},
					{-11.303048023519185, 10.732650608737089},
					{-11.09821732195217, 13.779405512286052},
					{-5.0247832693311425, 2.1742184328563603},
					{-10.619343405836839, 13.298677973987925},
					{-11.79686116542718, 9.494394976289728},
					{-11.586950451837295, 12.746499445013628},
					{-9.748309996567489, 9.564317509433048},
					{-10.481918971993704, 8.61651017140811},
					{-10.462599559239374, 10.911028525489696},
					{-11.130930155766606, 5.964285024103242},
					{-11.553937238933786, 6.4713896554450985},
					{-10.681883406818072, 10.318460208361941},
					{-10.683549192372697, 9.495275009653032},
					{-10.78603000302186, 14.25264787910952},
					{-11.488550071992856, 13.974915029487743},
					{-10.319674004978886, 4.889390024695409},
					{-10.860193039087495, 11.683572442191727},
					{-11.051151523346409, 5.890167964580233},
					{-11.229564353462594, 13.877567340046305},
					{-9.673077426769021, 6.894318103609319},
					{-10.96913834998139, 11.322206899793283},
					{-10.490035372117925, 12.825714370138753},
					{-9.462007223927465, 6.4675755622386655},
					{-9.51159790297749, 6.337629581532479},
					{-11.162432115640609, 9.827386758865119},
					{-10.151930853808723, 12.721404715075673},
					{-10.604601715753107, 13.20203200422886},
					{-10.661315153978798, 14.213091123997359},
					{-11.20216179567001, 10.027518789368527},
					{-9.545267493844456, 7.1692559024456735},
					{-11.159867362151168, 7.8902037332855635},
					{-11.088628333620173, 13.451554645466246},
					{-11.888514670355377, 10.819305597527286},
					{-10.69716594839169, 9.44576929331764},
					{-9.388828314326277, 6.02029293546396},
					{-10.20160261480168, 11.075164265002408},
					{-10.984952128353312, 11.125251232506113},
					{-9.734560763746329, 10.70872555677429},
					{-11.239599613172867, 11.820790907417845},
					{-11.469876956620716, 11.460531594176354},
					{-10.037039090761771, 10.470461014159099},
					{-10.182738080227073, 7.165787375813321},
					{-10.04300898871791, 9.494772412049684},
					{-11.912720679251729, 10.441038042708657},
					{-10.30731700479772, 6.343742643599125}};

			writeInputMatrixWithMTD("X", X, true);

			runTest(true, false, null, -1);
//			HashMap<MatrixValue.CellIndex, Double> dmlFileY = readDMLMatrixFromOutputDir("Y");

			// Verifying
			//TODO update hard-coded expected results (implementation dependent)
//			for (Entry<CellIndex, Double> entry : dmlFileY.entrySet()) {
//				MatrixValue.CellIndex key = entry.getKey();
//				Double value = entry.getValue();
//				System.out.println(value);
//				Assert.assertEquals("The DML data for cell (" + key.row + "," + key.column + ") '" + value + "' is " +
//					"not equal to the expected value '" + YReference[key.row-1][key.column-1] + "'",
//					YReference[key.row-1][key.column-1], value, 3); //TODO algorithm-level differences?
//			}
		}
		finally {
			rtplatform = platformOld;
		}
	}
}
