199 lines
4.1 KiB
Go
199 lines
4.1 KiB
Go
package realtime
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
var upgrader = websocket.Upgrader{
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
// In production, validate origin properly
|
|
return true
|
|
},
|
|
}
|
|
|
|
// Gateway handles WebRTC signaling and WebSocket connections
|
|
type Gateway struct {
|
|
connections map[string]*Connection
|
|
mu sync.RWMutex
|
|
}
|
|
|
|
// NewGateway creates a new WebRTC gateway
|
|
func NewGateway() *Gateway {
|
|
return &Gateway{
|
|
connections: make(map[string]*Connection),
|
|
}
|
|
}
|
|
|
|
// Connection represents a WebSocket connection for signaling
|
|
type Connection struct {
|
|
sessionID string
|
|
ws *websocket.Conn
|
|
send chan []byte
|
|
ctx context.Context
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
// HandleWebSocket handles WebSocket upgrade for signaling
|
|
func (g *Gateway) HandleWebSocket(w http.ResponseWriter, r *http.Request, sessionID string) error {
|
|
ws, err := upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to upgrade connection: %w", err)
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(r.Context())
|
|
conn := &Connection{
|
|
sessionID: sessionID,
|
|
ws: ws,
|
|
send: make(chan []byte, 256),
|
|
ctx: ctx,
|
|
cancel: cancel,
|
|
}
|
|
|
|
g.mu.Lock()
|
|
g.connections[sessionID] = conn
|
|
g.mu.Unlock()
|
|
|
|
// Start goroutines
|
|
go conn.writePump()
|
|
go conn.readPump(g)
|
|
|
|
return nil
|
|
}
|
|
|
|
// SendMessage sends a message to a specific session
|
|
func (g *Gateway) SendMessage(sessionID string, message interface{}) error {
|
|
g.mu.RLock()
|
|
conn, ok := g.connections[sessionID]
|
|
g.mu.RUnlock()
|
|
|
|
if !ok {
|
|
return fmt.Errorf("connection not found for session: %s", sessionID)
|
|
}
|
|
|
|
data, err := json.Marshal(message)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to marshal message: %w", err)
|
|
}
|
|
|
|
select {
|
|
case conn.send <- data:
|
|
return nil
|
|
case <-conn.ctx.Done():
|
|
return fmt.Errorf("connection closed")
|
|
}
|
|
}
|
|
|
|
// CloseConnection closes a connection
|
|
func (g *Gateway) CloseConnection(sessionID string) {
|
|
g.mu.Lock()
|
|
defer g.mu.Unlock()
|
|
|
|
if conn, ok := g.connections[sessionID]; ok {
|
|
conn.cancel()
|
|
conn.ws.Close()
|
|
delete(g.connections, sessionID)
|
|
}
|
|
}
|
|
|
|
// readPump reads messages from the WebSocket
|
|
func (c *Connection) readPump(gateway *Gateway) {
|
|
defer func() {
|
|
gateway.CloseConnection(c.sessionID)
|
|
c.ws.Close()
|
|
}()
|
|
|
|
c.ws.SetReadDeadline(time.Now().Add(60 * time.Second))
|
|
c.ws.SetPongHandler(func(string) error {
|
|
c.ws.SetReadDeadline(time.Now().Add(60 * time.Second))
|
|
return nil
|
|
})
|
|
|
|
for {
|
|
_, message, err := c.ws.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
|
|
log.Printf("WebSocket error: %v", err)
|
|
}
|
|
break
|
|
}
|
|
|
|
// Handle incoming message (ICE candidates, SDP offers/answers, etc.)
|
|
var msg map[string]interface{}
|
|
if err := json.Unmarshal(message, &msg); err != nil {
|
|
log.Printf("Failed to unmarshal message: %v", err)
|
|
continue
|
|
}
|
|
|
|
// Route message based on type
|
|
msgType, ok := msg["type"].(string)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
switch msgType {
|
|
case "ice-candidate":
|
|
// Handle ICE candidate
|
|
case "offer":
|
|
// Handle SDP offer
|
|
case "answer":
|
|
// Handle SDP answer
|
|
default:
|
|
log.Printf("Unknown message type: %s", msgType)
|
|
}
|
|
}
|
|
}
|
|
|
|
// writePump writes messages to the WebSocket
|
|
func (c *Connection) writePump() {
|
|
ticker := time.NewTicker(54 * time.Second)
|
|
defer func() {
|
|
ticker.Stop()
|
|
c.ws.Close()
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case message, ok := <-c.send:
|
|
c.ws.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
|
if !ok {
|
|
c.ws.WriteMessage(websocket.CloseMessage, []byte{})
|
|
return
|
|
}
|
|
|
|
w, err := c.ws.NextWriter(websocket.TextMessage)
|
|
if err != nil {
|
|
return
|
|
}
|
|
w.Write(message)
|
|
|
|
// Add queued messages
|
|
n := len(c.send)
|
|
for i := 0; i < n; i++ {
|
|
w.Write([]byte{'\n'})
|
|
w.Write(<-c.send)
|
|
}
|
|
|
|
if err := w.Close(); err != nil {
|
|
return
|
|
}
|
|
|
|
case <-ticker.C:
|
|
c.ws.SetWriteDeadline(time.Now().Add(10 * time.Second))
|
|
if err := c.ws.WriteMessage(websocket.PingMessage, nil); err != nil {
|
|
return
|
|
}
|
|
|
|
case <-c.ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|