Selaa lähdekoodia

Added tests.

TODO: Merge connections into one. net.Conn.
Steve Thielemann 3 vuotta sitten
vanhempi
commit
202fe36bc0
5 muutettua tiedostoa jossa 585 lisäystä ja 24 poistoa
  1. 380 0
      client_test.go
  2. 29 0
      color_test.go
  3. 50 19
      irc-client.go
  4. 21 5
      throttle.go
  5. 105 0
      throttle_test.go

+ 380 - 0
client_test.go

@@ -0,0 +1,380 @@
+package ircclient
+
+import (
+	"bufio"
+	"crypto/ecdsa"
+	"crypto/elliptic"
+	"crypto/rand"
+	"crypto/tls"
+	"crypto/x509"
+	"crypto/x509/pkix"
+	"encoding/base64"
+	"encoding/pem"
+	"fmt"
+	"math/big"
+	rnd "math/rand"
+	"net"
+	"strconv"
+	"strings"
+	"testing"
+	"time"
+)
+
+func setupSocket() (listen net.Listener, addr string) {
+	// establish network socket connection to set Comm_handle
+	var err error
+	var listener net.Listener
+	var address string
+
+	listener, err = net.Listen("tcp", "127.0.0.1:0")
+	if err != nil {
+		panic(err)
+	}
+
+	// I only need address for making the connection.
+	// Get address of listening socket
+	address = listener.Addr().String()
+	return listener, address
+}
+
+func generateKeyPair() (keypair tls.Certificate) {
+	// generate test certificate
+	priv, _ := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
+	durationBefore, _ := time.ParseDuration("-1h")
+	notBefore := time.Now().Add(durationBefore)
+	durationAfter, _ := time.ParseDuration("1h")
+	notAfter := time.Now().Add(durationAfter)
+	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 64)
+	serialNumber, _ := rand.Int(rand.Reader, serialNumberLimit)
+
+	template := x509.Certificate{
+		SerialNumber: serialNumber,
+		Subject: pkix.Name{
+			Organization: []string{"Test Certificate"},
+		},
+		NotBefore:             notBefore,
+		NotAfter:              notAfter,
+		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
+		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
+		BasicConstraintsValid: true,
+		IsCA:                  true,
+	}
+	template.IPAddresses = append(template.IPAddresses, net.ParseIP("127.0.0.1"))
+	template.IPAddresses = append(template.IPAddresses, net.ParseIP("::"))
+
+	derBytes, _ := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
+
+	c := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
+	b, _ := x509.MarshalECPrivateKey(priv)
+	k := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: b})
+
+	listenerKeyPair, _ := tls.X509KeyPair(c, k)
+	return listenerKeyPair
+}
+
+func setupTLSSocket() (listen net.Listener, addr string) {
+	// establish network socket connection to set Comm_handle
+	var err error
+	var listener net.Listener
+	var address string
+	var tlsconfig tls.Config
+	var keypair tls.Certificate = generateKeyPair()
+
+	tlsconfig.Certificates = make([]tls.Certificate, 0)
+	tlsconfig.Certificates = append(tlsconfig.Certificates, keypair)
+
+	listener, err = tls.Listen("tcp", "127.0.0.1:0", &tlsconfig)
+	if err != nil {
+		panic(err)
+	}
+
+	// I only need address for making the connection.
+	// Get address of listening socket
+	address = listener.Addr().String()
+	return listener, address
+}
+
+func ircWrite(server net.Conn, output string, t *testing.T) {
+	t.Logf(">> %s\n", output)
+	server.Write([]byte(output + "\r\n"))
+}
+
+func ircServer(listener net.Listener, t *testing.T, config *IRCConfig) {
+	var server net.Conn
+	var err error
+
+	server, err = listener.Accept()
+	if err != nil {
+		t.Error("Failed to accept connection.")
+		return
+	}
+
+	listener.Close()
+	var reader *bufio.Reader = bufio.NewReader(server)
+
+	var output, line, expect string
+	var ping int64 = rnd.Int63()
+
+	output = fmt.Sprintf("PING :%X", ping)
+	ircWrite(server, output, t)
+	var parts []string
+
+	var hasNick, hasUser, hasPing, hasPass bool
+	var capSASL bool
+	var part1 bool
+
+	for !part1 {
+
+		line, err = reader.ReadString('\n')
+		if err == nil {
+			line = strings.Trim(line, "\r\n")
+			// process the received line here
+			parts = strings.Split(line, " ")
+			t.Logf("<< %s", line)
+
+			switch parts[0] {
+			case "CAP":
+				if config.UseTLS && config.UseSASL {
+					if line == "CAP REQ :sasl" {
+						// Acknowledge we support SASL
+						ircWrite(server, ":irc.red-green.com CAP * ACK :sasl", t)
+						capSASL = true
+					}
+					if line == "CAP END" {
+						capSASL = true
+					}
+				}
+
+			case "AUTHENTICATE":
+				if capSASL {
+					if line == "AUTHENTICATE PLAIN" {
+						ircWrite(server, "AUTHENTICATE +", t)
+					} else {
+						// Process SASL auth message
+						var auth64 string = parts[1]
+						byteauth, _ := base64.StdEncoding.DecodeString(auth64)
+						var auth string = string(byteauth)
+						auth = strings.ReplaceAll(auth, "\x00", " ")
+						t.Log(auth)
+					}
+				}
+			case "PASS":
+				expect = fmt.Sprintf("PASS %s", config.ServerPassword)
+				if expect != line {
+					t.Errorf("Got %s, Expected %s", line, expect)
+				} else {
+					hasPass = true
+				}
+			case "NICK":
+				expect = fmt.Sprintf("NICK %s", config.Nick)
+				if expect != line {
+					t.Errorf("Got %s, Expected %s", line, expect)
+				} else {
+					hasNick = true
+				}
+			case "USER":
+				// USER meow-bot 0 * :Meooow!  bugz is my owner.
+				expect = fmt.Sprintf("USER %s 0 * :%s", config.Username, config.Realname)
+				if expect != line {
+					t.Errorf("Got %s, Expected %s", line, expect)
+				} else {
+					hasUser = true
+				}
+			case "PONG":
+				expect = fmt.Sprintf("PONG %X", ping)
+				if expect != line {
+					t.Errorf("Got %s, Expected %s", line, expect)
+				} else {
+					hasPing = true
+				}
+			}
+
+			if !part1 {
+				if !capSASL && hasNick && hasUser && hasPing && ((config.ServerPassword == "") || hasPass) {
+					part1 = true
+
+					// part 2:
+					var line string
+					for _, line = range []string{":irc.red-green.com 001 %s :Welcome to the RedGreen IRC Network",
+						":irc.red-green.com 002 %s :Your host is irc.red-green.com, running version UnrealIRCd-5.2.0.1",
+						":irc.red-green.com 375 %s :- irc.red-green.com Message of the Day -",
+						":irc.red-green.com 372 %s :- ",
+						":irc.red-green.com 376 %s :End of /MOTD command.",
+					} {
+						output = fmt.Sprintf(line, config.Nick)
+						ircWrite(server, output, t)
+					}
+				}
+			}
+		} else {
+			t.Error("Read Error:", err)
+			server.Close()
+			return
+		}
+	}
+
+	if !part1 {
+		t.Error("Expected to pass part1 (user/nick/pong)")
+	}
+
+	// part 2: nickserv/register
+	var part2 bool
+
+	for _, line = range []string{":[email protected] NOTICE %s :This nickname is registered and protected.  If it is your",
+		":[email protected] NOTICE %s :nick, type \x02/msg NickServ IDENTIFY \x1fpassword\x1f\x02.  Otherwise,"} {
+		output = fmt.Sprintf(line, config.Nick)
+		ircWrite(server, output, t)
+	}
+
+	for !part2 {
+
+		line, err = reader.ReadString('\n')
+		if err == nil {
+			line = strings.Trim(line, "\r\n")
+			// process the received line here
+			parts = strings.Split(line, " ")
+			t.Logf("<< %s", line)
+
+			switch parts[0] {
+			case "NS":
+				expect = fmt.Sprintf("NS IDENTIFY %s", config.Password)
+				if expect != line {
+					t.Errorf("Got %s, Expected %s", line, expect)
+				}
+				// ok, mark the user as registered
+				output = fmt.Sprintf(":[email protected] NOTICE %s :Password accepted - you are now recognized.",
+					config.Nick)
+				ircWrite(server, output, t)
+				output = fmt.Sprintf(":NickServ MODE %s :+r", config.Nick)
+				ircWrite(server, output, t)
+				part2 = true
+			}
+		} else {
+			t.Error("Read Error:", err)
+			server.Close()
+			return
+		}
+	}
+
+	if !part2 {
+		t.Error("Expected to pass part2 (ns identify/+r)")
+	}
+
+	time.AfterFunc(time.Millisecond*time.Duration(50), func() { server.Close() })
+
+	t.Log("Ok, Identified...")
+
+	for {
+		line, err = reader.ReadString('\n')
+		if err == nil {
+			line = strings.Trim(line, "\r\n")
+			// process the received line here
+			parts = strings.Split(line, " ")
+			t.Logf("<< %s", line)
+
+		} else {
+			t.Log("Read Error:", err)
+			return
+		}
+	}
+}
+
+func TestConnect(t *testing.T) {
+	var config IRCConfig = IRCConfig{Nick: "test",
+		Username:       "test",
+		Realname:       "testing",
+		Password:       "12345",
+		ServerPassword: "allow"}
+	var listen net.Listener
+	var address string
+
+	listen, address = setupSocket()
+	var parts []string = strings.Split(address, ":")
+
+	config.Hostname = parts[0]
+	config.Port, _ = strconv.Atoi(parts[1])
+	go ircServer(listen, t, &config)
+	var FromIRC chan IRCMsg
+
+	FromIRC = make(chan IRCMsg)
+	config.ReadChannel = FromIRC
+
+	config.Connect()
+	defer config.Close()
+
+	var Msg IRCMsg
+	var motd, identify bool
+	for Msg = range FromIRC {
+		if Msg.Cmd == "EndMOTD" {
+			t.Log("Got EndMOTD")
+			motd = true
+		}
+		if Msg.Cmd == "Identified" {
+			t.Log("Identified")
+			identify = true
+		}
+	}
+
+	if !motd {
+		t.Error("Missing EndMOTD")
+	}
+	if !identify {
+		t.Error("Missing Identified")
+	}
+
+	if config.MyNick != config.Nick {
+		t.Errorf("Got %s, Expected %s", config.MyNick, config.Nick)
+	}
+
+}
+func TestConnectTLS(t *testing.T) {
+	var config IRCConfig = IRCConfig{Nick: "test",
+		Username:       "test",
+		Realname:       "testing",
+		Password:       "12345",
+		UseTLS:         true,
+		UseSASL:        true,
+		Insecure:       true,
+		ServerPassword: "allow"}
+	var listen net.Listener
+	var address string
+
+	listen, address = setupTLSSocket()
+	var parts []string = strings.Split(address, ":")
+
+	config.Hostname = parts[0]
+	config.Port, _ = strconv.Atoi(parts[1])
+	go ircServer(listen, t, &config)
+	var FromIRC chan IRCMsg
+
+	FromIRC = make(chan IRCMsg)
+	config.ReadChannel = FromIRC
+
+	config.Connect()
+	defer config.Close()
+
+	var Msg IRCMsg
+	var motd, identify bool
+	for Msg = range FromIRC {
+		if Msg.Cmd == "EndMOTD" {
+			t.Log("Got EndMOTD")
+			motd = true
+		}
+		if Msg.Cmd == "Identified" {
+			t.Log("Identified")
+			identify = true
+		}
+	}
+
+	if !motd {
+		t.Error("Missing EndMOTD")
+	}
+	if !identify {
+		t.Error("Missing Identified")
+	}
+
+	if config.MyNick != config.Nick {
+		t.Errorf("Got %s, Expected %s", config.MyNick, config.Nick)
+	}
+
+}

+ 29 - 0
color_test.go

@@ -0,0 +1,29 @@
+package ircclient
+
+import (
+	"testing"
+)
+
+type colortext struct {
+	Text  string
+	Color int
+}
+
+func TestColor(t *testing.T) {
+	var sent colortext
+	var expect string
+	var got string
+	var plain string
+	for sent, expect = range map[colortext]string{{Text: "One", Color: 1}: "\x0301One\x0f",
+		{"Two", 2}:   "\x0302Two\x0f",
+		{"Three", 3}: "\x0303Three\x0f"} {
+		got = Color(sent.Color, sent.Text)
+		if got != expect {
+			t.Errorf("Got %s, expected %s", got, expect)
+		}
+		plain = Stripper(got)
+		if plain != sent.Text {
+			t.Errorf("Got %s, expected %s", plain, sent.Text)
+		}
+	}
+}

+ 50 - 19
irc-client.go

@@ -18,6 +18,15 @@ import (
 	"time"
 )
 
+func StrInArray(strings []string, str string) bool {
+	for _, s := range strings {
+		if s == str {
+			return true
+		}
+	}
+	return false
+}
+
 type IRCMsg struct {
 	MsgParts []string
 	From     string
@@ -32,17 +41,19 @@ type IRCWrite struct {
 }
 
 type IRCConfig struct {
-	Port           int
-	Hostname       string
-	UseTLS         bool // Use TLS Secure connection
-	UseSASL        bool // Authenticate via SASL
-	Insecure       bool // Allow self-signed certificates
-	Nick           string
+	Port           int      `json: Port`
+	Hostname       string   `json: Hostname`
+	UseTLS         bool     `json: UseTLS`   // Use TLS Secure connection
+	UseSASL        bool     `json: UseSASL`  // Authenticate via SASL
+	Insecure       bool     `json: Insecure` // Allow self-signed certificates
+	Nick           string   `json: Nick`
+	Username       string   `json: Username`
+	Realname       string   `json: Realname`
+	Password       string   `json: Password`       // Password for nickserv
+	ServerPassword string   `json: ServerPassword` // Password for server
+	AutoJoin       []string `json: AutoJoin`       // Channels to auto-join
+	RejoinDelay    int      `json: RejoinDelay`    // ms to rejoin
 	MyNick         string
-	Username       string
-	Realname       string
-	Password       string // Password for nickserv
-	ServerPassword string // Password for server
 	Socket         net.Conn
 	TLSSocket      *tls.Conn
 	Reader         *bufio.Reader
@@ -150,7 +161,7 @@ func (Config *IRCConfig) WriterRoutine() {
 
 	throttle.init()
 
-	Flood.Track = make([]time.Time, Config.Flood_Num)
+	Flood.Init(Config.Flood_Num, Config.Flood_Time)
 
 	// signal handler
 	var sigChannel chan os.Signal = make(chan os.Signal, 1)
@@ -213,10 +224,7 @@ func (Config *IRCConfig) WriterRoutine() {
 					}
 					continue
 				}
-				if Flood.Pos > 0 {
-					Flood.Expire(Config.Flood_Time)
-				}
-				if Flood.Pos == Config.Flood_Num {
+				if Flood.Full() {
 					throttle.push(output.To, output.Output)
 				} else {
 					// Flood limits not reached
@@ -304,7 +312,7 @@ func (Config *IRCConfig) ReaderRoutine() {
 		line, err = Config.Reader.ReadString('\n')
 		if err == nil {
 			line = strings.Trim(line, "\r\n")
-			log.Println("<< ", line)
+			log.Println("<<", line)
 
 			results = IRCParse(line)
 
@@ -414,9 +422,9 @@ func (Config *IRCConfig) ReaderRoutine() {
 			}
 
 			/*
-				2022/04/06 19:12:11 <<  :[email protected] NOTICE meow :This nickname is registered and protected.  If it is your
-				2022/04/06 19:12:11 <<  :[email protected] NOTICE meow :nick, type /msg NickServ IDENTIFY password.  Otherwise,
-				2022/04/06 19:12:11 <<  :[email protected] NOTICE meow :please choose a different nick.
+				2022/04/06 19:12:11 << :[email protected] NOTICE meow :This nickname is registered and protected.  If it is your
+				2022/04/06 19:12:11 << :[email protected] NOTICE meow :nick, type /msg NickServ IDENTIFY password.  Otherwise,
+				2022/04/06 19:12:11 << :[email protected] NOTICE meow :please choose a different nick.
 			*/
 			if (msg.From == "NickServ") && (msg.Cmd == "NOTICE") {
 				if strings.Contains(msg.Msg, "IDENTIFY") {
@@ -437,6 +445,21 @@ func (Config *IRCConfig) ReaderRoutine() {
 			if (msg.Msg[0] == '+') && (strings.Contains(msg.Msg, "r")) {
 				Config.ReadChannel <- IRCMsg{Cmd: "Identified"}
 			}
+			if len(Config.AutoJoin) > 0 {
+				Config.PriorityWrite("JOIN " + strings.Join(Config.AutoJoin, ","))
+			}
+		}
+
+		if msg.Cmd == "KICK" {
+			// Were we kicked, is channel in AutoJoin?
+			// 2022/04/13 20:02:52 << :[email protected] KICK #bugz meow-bot :bugz
+			// Msg: ircclient.IRCMsg{MsgParts:[]string{":[email protected]", "KICK", "#bugz", "meow-bot", "bugz"}, From:"bugz", To:"#bugz", Cmd:"KICK", Msg:"meow-bot"}
+			if strings.Contains(msg.Msg, Config.MyNick) {
+				if StrInArray(Config.AutoJoin, msg.To) {
+					// Yes, we were kicked from AutoJoin channel
+					time.AfterFunc(time.Duration(Config.RejoinDelay)*time.Millisecond, func() { Config.WriteTo(msg.To, "JOIN "+msg.To) })
+				}
+			}
 		}
 
 		if Config.ReadChannel != nil {
@@ -449,5 +472,13 @@ func (Config *IRCConfig) ReaderRoutine() {
 			Config.ReadChannel <- reg
 		}
 
+		if msg.Cmd == "NICK" {
+			// :meow NICK :meow-bot
+
+			if msg.From == Config.MyNick {
+				Config.MyNick = msg.To
+				log.Println("Nick is now:", Config.MyNick)
+			}
+		}
 	}
 }

+ 21 - 5
throttle.go

@@ -6,17 +6,32 @@ import (
 )
 
 type FloodTrack struct {
-	Pos   int
-	Track []time.Time
+	Pos     int
+	Size    int
+	Timeout int
+	Track   []time.Time
 }
 
-func (F *FloodTrack) Expire(timeout int) {
+func (F *FloodTrack) Init(size int, timeout int) {
+	F.Size = size
+	F.Timeout = timeout
+	F.Track = make([]time.Time, size)
+}
+
+func (F *FloodTrack) Full() bool {
+	if F.Pos > 0 {
+		F.Expire()
+	}
+	return F.Pos == F.Size
+}
+
+func (F *FloodTrack) Expire() {
 	var idx int
 
 ReCheck:
 	for idx = 0; idx < F.Pos; idx++ {
 		// log.Println(idx, time.Since(F.Track[idx]).Seconds())
-		if time.Since(F.Track[idx]).Seconds() > float64(timeout) {
+		if time.Since(F.Track[idx]).Seconds() > float64(F.Timeout) {
 			// Remove this from the list
 			F.Pos--
 			for pos := idx; pos < F.Pos; pos++ {
@@ -121,7 +136,8 @@ func (T *ThrottleBuffer) delete(To string) {
 			for x := 0; x < len(T.targets); x++ {
 				if T.targets[x] == To {
 					T.targets = append(T.targets[:x], T.targets[x+1:]...)
-					if x <= T.last {
+					// Don't decrement if already at first item
+					if (x <= T.last) && (T.last != 0) {
 						T.last--
 					}
 					break

+ 105 - 0
throttle_test.go

@@ -0,0 +1,105 @@
+package ircclient
+
+import (
+	"bytes"
+	"fmt"
+	"log"
+	"os"
+	"testing"
+)
+
+func TestFloodTrack(t *testing.T) {
+	var Flood FloodTrack
+	Flood.Init(3, 2)
+
+	if Flood.Full() {
+		t.Error("Expected Track to be empty")
+	}
+
+	for x := 1; x < 3; x++ {
+		Flood.Save()
+		if Flood.Pos != x {
+			t.Error(fmt.Sprintf("Expected Track Pos to be %d", x))
+		}
+		if Flood.Full() {
+			t.Error("Expected Track to be empty")
+		}
+	}
+
+	Flood.Save()
+	if Flood.Pos != 3 {
+		t.Error("Expected Track Pos to be 3")
+	}
+	if !Flood.Full() {
+		t.Error("Expected Track to be full")
+	}
+}
+
+func TestThrottleBuffer(t *testing.T) {
+	// eat log output
+	var logbuff bytes.Buffer
+	log.SetOutput(&logbuff)
+	defer func() {
+		log.SetOutput(os.Stderr)
+	}()
+
+	var buff ThrottleBuffer
+	buff.init()
+	buff.push("#chat", "msg1")
+	if !buff.Life_sucks {
+		t.Error("Flood control should be enabled here.")
+	}
+	buff.push("#chat", "msg2")
+	buff.push("user", "msg3")
+
+	// verify output order
+
+	var str string
+	var expect string
+	for _, expect = range []string{"msg1", "msg3", "msg2"} {
+		str = buff.pop()
+		if str != expect {
+			t.Error(fmt.Sprintf("Expected %s, got %s", expect, str))
+		}
+	}
+
+	if buff.Life_sucks {
+		t.Error("Flood control should not be enabled here.")
+	}
+
+	// verify deleting 1st item works
+
+	buff.push("#chat", "msg1")
+	if !buff.Life_sucks {
+		t.Error("Flood control should be enabled here.")
+	}
+	buff.push("#chat", "msg2")
+	buff.push("user", "msg3")
+	buff.delete("#chat")
+
+	str = buff.pop()
+	expect = "msg3"
+	if str != expect {
+		t.Error(fmt.Sprintf("Expected %s, got %s", expect, str))
+	}
+
+	if buff.Life_sucks {
+		t.Error("Flood control should not be enabled here.")
+	}
+
+	// verify deleting 2nd item works
+
+	buff.push("#chat", "txt1")
+	buff.push("user", "txt2")
+	buff.delete("user")
+	str = buff.pop()
+	expect = "txt1"
+	if str != expect {
+		t.Error(fmt.Sprintf("Expected %s, got %s", expect, str))
+	}
+
+	if buff.Life_sucks {
+		t.Error("Flood control should not be enabled here.")
+	}
+
+}