package websocket import ( "net/http/httptest" "testing" ) func TestWebsocketOriginAllowedDefaultsToSameHost(t *testing.T) { t.Setenv("WEBSOCKET_ALLOWED_ORIGINS", "") req := httptest.NewRequest("GET", "http://example.com/ws", nil) req.Host = "example.com:8080" req.Header.Set("Origin", "https://example.com") if !websocketOriginAllowed(req) { t.Fatal("expected same-host websocket origin to be allowed by default") } } func TestWebsocketOriginAllowedRejectsCrossOriginByDefault(t *testing.T) { t.Setenv("WEBSOCKET_ALLOWED_ORIGINS", "") req := httptest.NewRequest("GET", "http://example.com/ws", nil) req.Host = "example.com:8080" req.Header.Set("Origin", "https://attacker.example") if websocketOriginAllowed(req) { t.Fatal("expected cross-origin websocket request to be rejected by default") } } func TestWebsocketOriginAllowedHonorsExplicitAllowlist(t *testing.T) { t.Setenv("WEBSOCKET_ALLOWED_ORIGINS", "https://app.example, https://ops.example") req := httptest.NewRequest("GET", "http://example.com/ws", nil) req.Host = "example.com:8080" req.Header.Set("Origin", "https://ops.example") if !websocketOriginAllowed(req) { t.Fatal("expected allowlisted websocket origin to be accepted") } }