package websocket import ( "encoding/json" "log" "net/http" "sync" "time" "github.com/gorilla/websocket" ) var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true // Allow all origins in development }, } // Server represents the WebSocket server type Server struct { clients map[*Client]bool broadcast chan []byte register chan *Client unregister chan *Client mu sync.RWMutex } // Client represents a WebSocket client type Client struct { conn *websocket.Conn send chan []byte server *Server subscriptions map[string]bool } // NewServer creates a new WebSocket server func NewServer() *Server { return &Server{ clients: make(map[*Client]bool), broadcast: make(chan []byte), register: make(chan *Client), unregister: make(chan *Client), } } // Start starts the WebSocket server func (s *Server) Start() { for { select { case client := <-s.register: s.mu.Lock() s.clients[client] = true s.mu.Unlock() log.Printf("Client connected. Total clients: %d", len(s.clients)) case client := <-s.unregister: s.mu.Lock() if _, ok := s.clients[client]; ok { delete(s.clients, client) close(client.send) } s.mu.Unlock() log.Printf("Client disconnected. Total clients: %d", len(s.clients)) case message := <-s.broadcast: s.mu.RLock() for client := range s.clients { select { case client.send <- message: default: close(client.send) delete(s.clients, client) } } s.mu.RUnlock() } } } // HandleWebSocket handles WebSocket connections func (s *Server) HandleWebSocket(w http.ResponseWriter, r *http.Request) { conn, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Printf("WebSocket upgrade failed: %v", err) return } client := &Client{ conn: conn, send: make(chan []byte, 256), server: s, subscriptions: make(map[string]bool), } s.register <- client go client.writePump() go client.readPump() } // Broadcast sends a message to all connected clients func (s *Server) Broadcast(message []byte) { s.broadcast <- message } // readPump reads messages from the WebSocket connection func (c *Client) readPump() { defer func() { c.server.unregister <- c c.conn.Close() }() c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(60 * time.Second)) return nil }) for { _, message, err := c.conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { log.Printf("WebSocket error: %v", err) } break } // Handle message var msg map[string]interface{} if err := json.Unmarshal(message, &msg); err != nil { continue } c.handleMessage(msg) } } // writePump writes messages to the WebSocket connection func (c *Client) writePump() { ticker := time.NewTicker(30 * time.Second) defer func() { ticker.Stop() c.conn.Close() }() for { select { case message, ok := <-c.send: c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if !ok { c.conn.WriteMessage(websocket.CloseMessage, []byte{}) return } w, err := c.conn.NextWriter(websocket.TextMessage) if err != nil { return } w.Write(message) n := len(c.send) for i := 0; i < n; i++ { w.Write([]byte{'\n'}) w.Write(<-c.send) } if err := w.Close(); err != nil { return } case <-ticker.C: c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { return } } } } // handleMessage handles incoming WebSocket messages func (c *Client) handleMessage(msg map[string]interface{}) { msgType, ok := msg["type"].(string) if !ok { return } switch msgType { case "subscribe": channel, _ := msg["channel"].(string) c.subscriptions[channel] = true c.sendMessage(map[string]interface{}{ "type": "subscribed", "channel": channel, }) case "unsubscribe": channel, _ := msg["channel"].(string) delete(c.subscriptions, channel) c.sendMessage(map[string]interface{}{ "type": "unsubscribed", "channel": channel, }) case "ping": c.sendMessage(map[string]interface{}{ "type": "pong", "timestamp": time.Now().Unix(), }) } } // sendMessage sends a message to the client func (c *Client) sendMessage(msg map[string]interface{}) { data, err := json.Marshal(msg) if err != nil { return } select { case c.send <- data: default: close(c.send) } }