package ircclient

import (
	"bufio"
	"crypto/ecdsa"
	"crypto/elliptic"
	"crypto/rand"
	"crypto/tls"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/base64"
	"encoding/pem"
	"fmt"
	"log"
	"math/big"
	rnd "math/rand"
	"net"
	"strings"
	"testing"
	"time"
)

func setupSocket() (listen net.Listener, addr string) {
	// establish network socket connection to set Comm_handle
	var err error
	var listener net.Listener
	var address string

	listener, err = net.Listen("tcp", "127.0.0.1:0")
	if err != nil {
		panic(err)
	}

	// I only need address for making the connection.
	// Get address of listening socket
	address = listener.Addr().String()
	return listener, address
}

func generateKeyPair() (keypair tls.Certificate) {
	// generate test certificate
	priv, _ := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
	durationBefore, _ := time.ParseDuration("-1h")
	notBefore := time.Now().Add(durationBefore)
	durationAfter, _ := time.ParseDuration("1h")
	notAfter := time.Now().Add(durationAfter)
	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 64)
	serialNumber, _ := rand.Int(rand.Reader, serialNumberLimit)

	template := x509.Certificate{
		SerialNumber: serialNumber,
		Subject: pkix.Name{
			Organization: []string{"Test Certificate"},
		},
		NotBefore:             notBefore,
		NotAfter:              notAfter,
		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
		BasicConstraintsValid: true,
		IsCA:                  true,
	}
	template.IPAddresses = append(template.IPAddresses, net.ParseIP("127.0.0.1"))
	template.IPAddresses = append(template.IPAddresses, net.ParseIP("::"))

	derBytes, _ := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)

	c := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
	b, _ := x509.MarshalECPrivateKey(priv)
	k := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: b})

	listenerKeyPair, _ := tls.X509KeyPair(c, k)
	return listenerKeyPair
}

func setupTLSSocket() (listen net.Listener, addr string) {
	// establish network socket connection to set Comm_handle
	var err error
	var listener net.Listener
	var address string
	var tlsconfig tls.Config
	var keypair tls.Certificate = generateKeyPair()

	tlsconfig.Certificates = make([]tls.Certificate, 0)
	tlsconfig.Certificates = append(tlsconfig.Certificates, keypair)

	listener, err = tls.Listen("tcp", "127.0.0.1:0", &tlsconfig)
	if err != nil {
		panic(err)
	}

	// I only need address for making the connection.
	// Get address of listening socket
	address = listener.Addr().String()
	return listener, address
}

func ircWrite(server net.Conn, output string, t *testing.T) {
	t.Logf(">> %s\n", output)
	server.Write([]byte(output + "\r\n"))
}

var abortAfter int = 150 // Milliseconds to abort part 3

// mock up an irc server
// part 1 : nick, user
// part 2 : identify to services (if not SASL/SASL failed)
// part 3 : quit after abortAfter ms
func ircServer(listener net.Listener, t *testing.T, config *IRCConfig) {
	var server net.Conn
	var err error

	server, err = listener.Accept()
	if err != nil {
		t.Error("Failed to accept connection.")
		return
	}

	listener.Close()
	var reader *bufio.Reader = bufio.NewReader(server)

	var output, line, expect string
	var ping int64 = rnd.Int63()

	output = fmt.Sprintf("PING :%X", ping)
	ircWrite(server, output, t)
	var parts []string

	var hasNick, hasUser, hasPing, hasPass bool
	var capSASL, successSASL bool
	var part1 bool

	// part 1 :  User, Nick, ServerPass and Ping reply

	for !part1 {

		line, err = reader.ReadString('\n')
		if err == nil {
			line = strings.Trim(line, "\r\n")
			// process the received line here
			parts = strings.Split(line, " ")
			t.Logf("<< %s", line)

			switch parts[0] {
			case "CAP":
				if config.UseTLS && config.UseSASL {
					if line == "CAP REQ :sasl" {
						// Acknowledge we support SASL
						ircWrite(server, ":irc.red-green.com CAP * ACK :sasl", t)
						capSASL = true
					}
					if line == "CAP END" {
						capSASL = true
						if successSASL {
							part1 = true
						}
					}
				}

			case "AUTHENTICATE":
				if capSASL {
					if line == "AUTHENTICATE PLAIN" {
						ircWrite(server, "AUTHENTICATE +", t)
					} else {
						// Process SASL auth message
						var auth64 string = parts[1]
						byteauth, _ := base64.StdEncoding.DecodeString(auth64)
						var auth string = string(byteauth)
						auth = strings.ReplaceAll(auth, "\x00", " ")
						t.Log(auth)
						expect = fmt.Sprintf(" %s %s", config.Nick, config.Password)
						if expect != auth {
							t.Errorf("Got %s, Expected %s", auth, expect)
							ircWrite(server, fmt.Sprintf(":irc.red-green.com 904 %s :SASL authentication failed",
								config.Nick), t)
						} else {
							// Success!
							ircWrite(server, fmt.Sprintf(":irc.red-green.com 900 %s %s!%s@127.0.0.1 %s :You are now logged in as %s.",
								config.Nick, config.Nick, config.Username, config.Nick, config.Nick), t)
							ircWrite(server, fmt.Sprintf(":irc.red-green.com 903 %s :SASL authentication successful",
								config.Nick), t)
							successSASL = true
						}
					}
				}
			case "PASS":
				expect = fmt.Sprintf("PASS %s", config.ServerPassword)
				if expect != line {
					t.Errorf("Got %s, Expected %s", line, expect)
				} else {
					hasPass = true
				}
			case "NICK":
				expect = fmt.Sprintf("NICK %s", config.MyNick)
				if expect != line {
					t.Errorf("Got %s, Expected %s", line, expect)
				} else {
					if config.MyNick == "bad" {
						// throw bad nick here
						ircWrite(server, fmt.Sprintf(":irc.red-green.com 433 :Nick already in use."), t)
					}
					hasNick = true
				}
			case "USER":
				// USER meow-bot 0 * :Meooow!  bugz is my owner.
				expect = fmt.Sprintf("USER %s 0 * :%s", config.Username, config.Realname)
				if expect != line {
					t.Errorf("Got %s, Expected %s", line, expect)
				} else {
					hasUser = true
				}
			case "PONG":
				expect = fmt.Sprintf("PONG %X", ping)
				if expect != line {
					t.Errorf("Got %s, Expected %s", line, expect)
				} else {
					hasPing = true
				}
			}

			if !part1 {
				if !capSASL && hasNick && hasUser && hasPing && ((config.ServerPassword == "") || hasPass) {
					part1 = true
				}
			}
		} else {
			t.Error("Read Error:", err)
			server.Close()
			return
		}
	}

	if !part1 {
		t.Error("Expected to pass part1 (user/nick/pong)")
	}

	// Display MOTD
	for _, line = range []string{":irc.red-green.com 001 %s :Welcome to the RedGreen IRC Network",
		":irc.red-green.com 002 %s :Your host is irc.red-green.com, running version UnrealIRCd-5.2.0.1",
		":irc.red-green.com 375 %s :- irc.red-green.com Message of the Day -",
		":irc.red-green.com 372 %s :- ",
		":irc.red-green.com 376 %s :End of /MOTD command.",
	} {
		output = fmt.Sprintf(line, config.Nick)
		ircWrite(server, output, t)
	}

	if config.UseSASL {
		if !successSASL {
			log.Println("Failed SASL Authentication.")
		}
	}

	// part 2: nickserv/register  (if not already registered with SASL)

	var part2 bool

	if successSASL {
		ircWrite(server, fmt.Sprintf(":NickServ MODE %s :+r", config.Nick), t)
		part2 = true
	} else {
		if config.Password != "" {
			for _, line = range []string{":NickServ!services@services.red-green.com NOTICE %s :This nickname is registered and protected.  If it is your",
				":NickServ!services@services.red-green.com NOTICE %s :nick, type \x02/msg NickServ IDENTIFY \x1fpassword\x1f\x02.  Otherwise,"} {
				output = fmt.Sprintf(line, config.Nick)
				ircWrite(server, output, t)
			}
		} else {
			// No password, so we can't register.  Skip this part.
			part2 = true
		}
	}

	for !part2 {

		line, err = reader.ReadString('\n')
		if err == nil {
			line = strings.Trim(line, "\r\n")
			// process the received line here
			parts = strings.Split(line, " ")
			t.Logf("<< %s", line)

			switch parts[0] {
			case "NS":
				expect = fmt.Sprintf("NS IDENTIFY %s", config.Password)
				if expect != line {
					t.Errorf("Got %s, Expected %s", line, expect)
				}
				// ok, mark the user as registered
				output = fmt.Sprintf(":NickServ!services@services.red-green.com NOTICE %s :Password accepted - you are now recognized.",
					config.Nick)
				ircWrite(server, output, t)
				output = fmt.Sprintf(":NickServ MODE %s :+r", config.Nick)
				ircWrite(server, output, t)
				part2 = true
			}
		} else {
			t.Error("Read Error:", err)
			server.Close()
			return
		}
	}

	if !part2 {
		t.Error("Expected to pass part2 (ns identify/+r)")
	}

	time.AfterFunc(time.Millisecond*time.Duration(abortAfter), func() { server.Close() })

	t.Log("Ok, Identified...")
	var Kicked bool

	for {
		line, err = reader.ReadString('\n')
		if err == nil {
			line = strings.Trim(line, "\r\n")
			// process the received line here
			parts = strings.Split(line, " ")
			t.Logf("<< %s", line)
			switch parts[0] {
			case "JOIN":
				for _, channel := range strings.Split(parts[1], ",") {
					output = fmt.Sprintf(":%s JOIN :%s", config.MyNick, channel)
					ircWrite(server, output, t)
					output = fmt.Sprintf(":irc.server 332 %s %s :Topic for (%s)", config.MyNick, channel, channel)
					ircWrite(server, output, t)
					output = fmt.Sprintf(":irc.server 333 %s %s user %d", config.MyNick, channel, time.Now().Unix())
					ircWrite(server, output, t)
					if strings.Contains(channel, "kick") {
						if !Kicked {
							Kicked = true
							output = fmt.Sprintf("user KICK %s %s :Get out", channel, config.MyNick)
							ircWrite(server, output, t)
						}
					}
				}
			}
			switch parts[0] {
			case "PRIVMSG", "NOTICE":
				if parts[1] == "echo" {
					parts[2] = parts[2][1:]
					// echo user, return whatever was sent back to them.
					output = fmt.Sprintf(":%s %s %s :%s", "echo", parts[0], config.MyNick, strings.Join(parts[2:], " "))
					ircWrite(server, output, t)
				}
				if strings.Contains(parts[1], "missing") {
					// Sending to missing user or channel.

					var number int
					if strings.Contains(parts[1], "#") {
						number = 404
					} else {
						number = 401
					}

					output = fmt.Sprintf(":irc.red-green.com %d %s %s :No such nick/channel", number, config.MyNick, parts[1])
					ircWrite(server, output, t)
				}
			}
		} else {
			t.Log("Read Error:", err)
			return
		}
	}
}