瀏覽代碼

Working signal tests.

We now properly shut down the go routines.
Steve Thielemann 2 年之前
父節點
當前提交
231d64d65b
共有 3 個文件被更改,包括 131 次插入14 次删除
  1. 79 0
      client_test.go
  2. 44 14
      irc-client.go
  3. 8 0
      ircd_test.go

+ 79 - 0
client_test.go

@@ -4,6 +4,7 @@ import (
 	"net"
 	"strconv"
 	"strings"
+	"syscall"
 	"testing"
 )
 
@@ -235,6 +236,7 @@ func TestConnectAutojoin(t *testing.T) {
 				config.Notice("echo", "Testing")
 				config.Msg("#test", "Message 1")
 				config.Msg("#test", "Message 2")
+				config.PriorityWrite("PRIVMSG #test :Message")
 			}
 		}
 		if Msg.Cmd == "CTCP" {
@@ -305,7 +307,13 @@ func TestConnectKick(t *testing.T) {
 
 	config.Hostname = parts[0]
 	config.Port, _ = strconv.Atoi(parts[1])
+
+	// Save and Restore abortAfter
+	var abortPrev = abortAfter
+	defer func() { abortAfter = abortPrev }()
+	abortAfter = 500
 	go ircServer(listen, t, &config)
+
 	var FromIRC chan IRCMsg
 
 	FromIRC = make(chan IRCMsg)
@@ -329,6 +337,9 @@ func TestConnectKick(t *testing.T) {
 		}
 		if Msg.Cmd == "JOIN" {
 			joins++
+			if joins == 4 {
+				config.PriorityWrite("NICK something")
+			}
 		}
 	}
 
@@ -342,4 +353,72 @@ func TestConnectKick(t *testing.T) {
 	if !identify {
 		t.Error("Missing Identified")
 	}
+	if config.MyNick != "something" {
+		t.Errorf("Expected nick to be something, not %s", config.MyNick)
+	}
+}
+
+func TestConnectSignal(t *testing.T) {
+	var exit bool
+
+	var onexit func() = func() { exit = true }
+
+	var config IRCConfig = IRCConfig{Nick: "test",
+		Username:    "test",
+		Realname:    "testing",
+		Password:    "12345",
+		UseTLS:      true,
+		UseSASL:     true,
+		Insecure:    true,
+		AutoJoin:    []string{"#chat"},
+		Flood_Num:   2,
+		Flood_Delay: 10,
+		OnExit:      onexit,
+	}
+	var listen net.Listener
+	var address string
+
+	listen, address = setupTLSSocket()
+	var parts []string = strings.Split(address, ":")
+
+	config.Hostname = parts[0]
+	config.Port, _ = strconv.Atoi(parts[1])
+	go ircServer(listen, t, &config)
+	var FromIRC chan IRCMsg
+
+	FromIRC = make(chan IRCMsg)
+	config.ReadChannel = FromIRC
+
+	config.Connect()
+	defer config.Close()
+
+	var Msg IRCMsg
+	var motd, identify bool
+
+	for Msg = range FromIRC {
+		if Msg.Cmd == "EndMOTD" {
+			t.Log("Got EndMOTD")
+			motd = true
+			// config.PriorityWrite("NICK something")
+		}
+		if Msg.Cmd == "Identified" {
+			t.Log("Identified")
+			identify = true
+		}
+		if Msg.Cmd == "JOIN" {
+			// Ok, we've joined.  Test the signal
+			e := syscall.Kill(syscall.Getpid(), syscall.SIGTERM)
+			t.Log("Kill:", e)
+		}
+	}
+
+	if !motd {
+		t.Error("Missing EndMOTD")
+	}
+	if !identify {
+		t.Error("Missing Identified")
+	}
+	if !exit {
+		t.Error("AtExit wasn't called.")
+	}
 }

+ 44 - 14
irc-client.go

@@ -14,6 +14,7 @@ import (
 	"os/signal"
 	"strconv"
 	"strings"
+	"sync"
 	"syscall"
 	"time"
 )
@@ -54,6 +55,7 @@ type IRCConfig struct {
 	AutoJoin       []string `json: AutoJoin`       // Channels to auto-join
 	RejoinDelay    int      `json: RejoinDelay`    // ms to rejoin
 	MyNick         string
+	OnExit         func()
 	Socket         net.Conn
 	Reader         *bufio.Reader
 	ReadChannel    chan IRCMsg
@@ -64,11 +66,15 @@ type IRCConfig struct {
 	Flood_Num      int // Number of lines sent before considered a flood
 	Flood_Time     int // Number of Seconds to track previous messages
 	Flood_Delay    int // Delay between sending when flood protection on (Milliseconds)
-	OnExit         func()
+	wg             sync.WaitGroup
 }
 
 // Writer       *bufio.Writer
 
+func (Config *IRCConfig) IsAuto(ch string) bool {
+	return StrInArray(Config.AutoJoin, ch)
+}
+
 func (Config *IRCConfig) Connect() bool {
 	var err error
 
@@ -122,11 +128,13 @@ func (Config *IRCConfig) Connect() bool {
 		Config.Reader = bufio.NewReader(Config.Socket)
 	}
 
-	Config.WriteChannel = make(chan IRCWrite)
+	// WriteChannel may contain a message when we're trying to PriorityWrite from sigChannel.
+	Config.WriteChannel = make(chan IRCWrite, 3)
 	Config.DelChannel = make(chan string)
 
 	// We are connected.
 	go Config.WriterRoutine()
+	Config.wg.Add(1)
 
 	// Registration
 	if Config.UseTLS && Config.UseSASL {
@@ -144,6 +152,7 @@ func (Config *IRCConfig) Connect() bool {
 	// Config.Writer.Flush()
 
 	go Config.ReaderRoutine()
+	Config.wg.Add(1)
 	return true
 }
 
@@ -167,6 +176,7 @@ func (Config *IRCConfig) WriterRoutine() {
 	var throttle ThrottleBuffer
 	var Flood FloodTrack
 
+	defer Config.wg.Done()
 	throttle.init()
 
 	Flood.Init(Config.Flood_Num, Config.Flood_Time)
@@ -178,6 +188,8 @@ func (Config *IRCConfig) WriterRoutine() {
 	// Change this into a select with timeout.
 	// Timeout, if there's something to be buffered.
 
+	var gotSignal bool
+
 	for {
 		if throttle.Life_sucks {
 			select {
@@ -186,6 +198,7 @@ func (Config *IRCConfig) WriterRoutine() {
 					err = Config.write(output.Output)
 					if err != nil {
 						log.Println("Writer:", err)
+						return
 					}
 					continue
 				}
@@ -193,10 +206,19 @@ func (Config *IRCConfig) WriterRoutine() {
 				throttle.push(output.To, output.Output)
 
 			case <-sigChannel:
-				if Config.OnExit != nil {
-					Config.OnExit()
+				if !gotSignal {
+					gotSignal = true
+					log.Println("SIGNAL")
+					if Config.OnExit != nil {
+						Config.OnExit()
+					}
+					Config.PriorityWrite("QUIT :Received SIGINT")
 				}
-				os.Exit(2)
+				// return
+				continue
+				// Config.Close()
+				// return
+				//os.Exit(2)
 
 			case remove := <-Config.DelChannel:
 				log.Printf("Remove: [%s]\n", remove)
@@ -219,10 +241,19 @@ func (Config *IRCConfig) WriterRoutine() {
 			// Life is good.
 			select {
 			case <-sigChannel:
-				if Config.OnExit != nil {
-					Config.OnExit()
+				if !gotSignal {
+					gotSignal = true
+					log.Println("SIGNAL")
+					if Config.OnExit != nil {
+						Config.OnExit()
+					}
+					Config.PriorityWrite("QUIT :Received SIGINT")
 				}
-				os.Exit(2)
+				// return
+				continue
+				// Config.Close()
+				// return
+				// os.Exit(2)
 
 			case remove := <-Config.DelChannel:
 				log.Printf("Remove: [%s]\n", remove)
@@ -233,6 +264,7 @@ func (Config *IRCConfig) WriterRoutine() {
 					err = Config.write(output.Output)
 					if err != nil {
 						log.Println("Writer:", err)
+						return
 					}
 					continue
 				}
@@ -255,12 +287,9 @@ func (Config *IRCConfig) WriterRoutine() {
 }
 
 func (Config *IRCConfig) Close() {
-	/*
-		if Config.UseTLS {
-			Config.TLSSocket.Close()
-		} else {
-	*/
 	Config.Socket.Close()
+	Config.PriorityWrite("")
+	Config.wg.Wait()
 }
 
 func IRCParse(line string) []string {
@@ -317,6 +346,7 @@ func (Config *IRCConfig) Action(to string, message string) {
 }
 
 func (Config *IRCConfig) ReaderRoutine() {
+	defer Config.wg.Done()
 	for {
 		var line string
 		var err error
@@ -475,7 +505,7 @@ func (Config *IRCConfig) ReaderRoutine() {
 			// 2022/04/13 20:02:52 << :[email protected] KICK #bugz meow-bot :bugz
 			// Msg: ircclient.IRCMsg{MsgParts:[]string{":[email protected]", "KICK", "#bugz", "meow-bot", "bugz"}, From:"bugz", To:"#bugz", Cmd:"KICK", Msg:"meow-bot"}
 			if strings.Contains(msg.Msg, Config.MyNick) {
-				if StrInArray(Config.AutoJoin, msg.To) {
+				if Config.IsAuto(msg.To) {
 					// Yes, we were kicked from AutoJoin channel
 					time.AfterFunc(time.Duration(Config.RejoinDelay)*time.Millisecond, func() { Config.WriteTo(msg.To, "JOIN "+msg.To) })
 				}

+ 8 - 0
ircd_test.go

@@ -365,7 +365,15 @@ func ircServer(listener net.Listener, t *testing.T, config *IRCConfig) {
 					output = fmt.Sprintf(":irc.red-green.com %d %s %s :No such nick/channel", number, config.MyNick, parts[1])
 					ircWrite(server, output, t)
 				}
+			case "NICK":
+				output = fmt.Sprintf(":%s NICK :%s", config.MyNick, parts[1])
+				ircWrite(server, output, t)
+			case "QUIT":
+				output = fmt.Sprintf("ERROR: Closing link:%s", config.MyNick)
+				ircWrite(server, output, t)
+				server.Close()
 			}
+
 		} else {
 			t.Log("Read Error:", err)
 			return