Skip to content

Commit da4f1a0

Browse files
authored
Merge pull request from GHSA-4hq8-gmxx-h6w9
This change validates that the XML input we receive is safe to parse before passing it to the standard library's XML parsing functions or the etree DOM parsing functions. This validation mitigates critical vulnerabilities in `encoding/xml` - CVE-2020-29509, CVE-2020-29510, and CVE-2020-29511. TODO: is there going to be a go.mod version assigned to this on release?
1 parent a606939 commit da4f1a0

File tree

9 files changed

+203
-16
lines changed

9 files changed

+203
-16
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ require (
1010
github.com/dgrijalva/jwt-go v3.2.0+incompatible
1111
github.com/jonboulle/clockwork v0.2.1 // indirect
1212
github.com/kr/pretty v0.2.1
13+
github.com/mattermost/xml-roundtrip-validator v0.0.0-00010101000000-000000000000
1314
github.com/pkg/errors v0.8.1 // indirect
1415
github.com/russellhaering/goxmldsig v1.1.0
1516
github.com/stretchr/testify v1.6.1

identity_provider.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020
"time"
2121

2222
"github.com/beevik/etree"
23+
xrv "github.com/mattermost/xml-roundtrip-validator"
2324
dsig "github.com/russellhaering/goxmldsig"
2425

2526
"github.com/crewjam/saml/logger"
@@ -359,13 +360,18 @@ func NewIdpAuthnRequest(idp *IdentityProvider, r *http.Request) (*IdpAuthnReques
359360
default:
360361
return nil, fmt.Errorf("method not allowed")
361362
}
363+
362364
return req, nil
363365
}
364366

365367
// Validate checks that the authentication request is valid and assigns
366368
// the AuthnRequest and Metadata properties. Returns a non-nil error if the
367369
// request is not valid.
368370
func (req *IdpAuthnRequest) Validate() error {
371+
if err := xrv.Validate(bytes.NewReader(req.RequestBuffer)); err != nil {
372+
return err
373+
}
374+
369375
if err := xml.Unmarshal(req.RequestBuffer, &req.Request); err != nil {
370376
return err
371377
}

identity_provider_test.go

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
package saml
22

33
import (
4+
"bytes"
5+
"compress/flate"
46
"crypto"
57
"crypto/rsa"
68
"crypto/x509"
79
"encoding/base64"
810
"encoding/pem"
911
"encoding/xml"
1012
"fmt"
13+
"io/ioutil"
1114
"math/rand"
1215
"net/http"
1316
"net/http/httptest"
@@ -232,15 +235,41 @@ func TestIDPHTTPCanHandleMetadataRequest(t *testing.T) {
232235
func TestIDPHTTPCanHandleSSORequest(t *testing.T) {
233236
test := NewIdentifyProviderTest()
234237
w := httptest.NewRecorder()
235-
r, _ := http.NewRequest("GET", "https://idp.example.com/saml/sso?RelayState=ThisIsTheRelayState&SAMLRequest=lJJBayoxFIX%2FypC9JhnU5wszAz7lgWCLaNtFd5fMbQ1MkmnunVb%2FfUfbUqEgdhs%2BTr5zkmLW8S5s8KVD4mzvm0Cl6FIwEciRCeCRDFuznd2sTD5Upk2Ro42NyGZEmNjFMI%2BBOo9pi%2BnVWbzfrEqxY27JSEntEPfg2waHNnpJ4JtcgiWRLfoLXYBjwDfu6p%2B8JIoiWy5K4eqBUipXIzVRUwXKKtRK53qkJ3qqQVuNPUjU4TIQQ%2BBS5EqPBzofKH2ntBn%2FMervo8jWnyX%2BuVC78FwKkT1gopNKX1JUxSklXTMIfM0gsv8xeeDL%2BPGk7%2FF0Qg0GdnwQ1cW5PDLUwFDID6uquO1Dlot1bJw9%2FPLRmia%2BzRMCYyk4dSiq6205QSDXOxfy3KAq5Pkvqt4DAAD%2F%2Fw%3D%3D", nil)
238+
239+
const validRequest = `lJJBayoxFIX%2FypC9JhnU5wszAz7lgWCLaNtFd5fMbQ1MkmnunVb%2FfUfbUqEgdhs%2BTr5zkmLW8S5s8KVD4mzvm0Cl6FIwEciRCeCRDFuznd2sTD5Upk2Ro42NyGZEmNjFMI%2BBOo9pi%2BnVWbzfrEqxY27JSEntEPfg2waHNnpJ4JtcgiWRLfoLXYBjwDfu6p%2B8JIoiWy5K4eqBUipXIzVRUwXKKtRK53qkJ3qqQVuNPUjU4TIQQ%2BBS5EqPBzofKH2ntBn%2FMervo8jWnyX%2BuVC78FwKkT1gopNKX1JUxSklXTMIfM0gsv8xeeDL%2BPGk7%2FF0Qg0GdnwQ1cW5PDLUwFDID6uquO1Dlot1bJw9%2FPLRmia%2BzRMCYyk4dSiq6205QSDXOxfy3KAq5Pkvqt4DAAD%2F%2Fw%3D%3D`
240+
241+
r, _ := http.NewRequest("GET", "https://idp.example.com/saml/sso?RelayState=ThisIsTheRelayState&"+
242+
"SAMLRequest="+validRequest, nil)
236243
test.IDP.Handler().ServeHTTP(w, r)
237244
assert.Equal(t, http.StatusOK, w.Code)
238245

239246
// rejects requests that are invalid
240247
w = httptest.NewRecorder()
241-
r, _ = http.NewRequest("GET", "https://idp.example.com/saml/sso?RelayState=ThisIsTheRelayState&SAMLRequest=PEF1dGhuUmVxdWVzdA%3D%3D", nil)
248+
r, _ = http.NewRequest("GET", "https://idp.example.com/saml/sso?RelayState=ThisIsTheRelayState&"+
249+
"SAMLRequest=PEF1dGhuUmVxdWVzdA%3D%3D", nil)
242250
test.IDP.Handler().ServeHTTP(w, r)
243251
assert.Equal(t, http.StatusBadRequest, w.Code)
252+
253+
// rejects requests that contain malformed XML
254+
{
255+
a, _ := url.QueryUnescape(validRequest)
256+
b, _ := base64.StdEncoding.DecodeString(a)
257+
c, _ := ioutil.ReadAll(flate.NewReader(bytes.NewReader(b)))
258+
d := bytes.Replace(c, []byte("<AuthnRequest"), []byte("<AuthnRequest ::foo=\"bar\""), 1)
259+
f := bytes.Buffer{}
260+
e, _ := flate.NewWriter(&f, flate.DefaultCompression)
261+
e.Write(d)
262+
e.Close()
263+
g := base64.StdEncoding.EncodeToString(f.Bytes())
264+
invalidRequest := url.QueryEscape(g)
265+
266+
w = httptest.NewRecorder()
267+
r, _ = http.NewRequest("GET", "https://idp.example.com/saml/sso?RelayState=ThisIsTheRelayState&"+
268+
"SAMLRequest="+invalidRequest, nil)
269+
test.IDP.Handler().ServeHTTP(w, r)
270+
assert.Equal(t, http.StatusBadRequest, w.Code)
271+
}
272+
244273
}
245274

246275
func TestIDPCanHandleRequestWithNewSession(t *testing.T) {

samlidp/util.go

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
package samlidp
22

33
import (
4+
"bytes"
5+
"encoding/xml"
46
"errors"
7+
"io"
58
"io/ioutil"
69

7-
"encoding/xml"
8-
9-
"io"
10+
xrv "github.com/mattermost/xml-roundtrip-validator"
1011

1112
"github.com/crewjam/saml"
1213
)
@@ -20,19 +21,20 @@ func randomBytes(n int) []byte {
2021
}
2122

2223
func getSPMetadata(r io.Reader) (spMetadata *saml.EntityDescriptor, err error) {
23-
var bytes []byte
24-
25-
if bytes, err = ioutil.ReadAll(r); err != nil {
24+
var data []byte
25+
if data, err = ioutil.ReadAll(r); err != nil {
2626
return nil, err
2727
}
2828

2929
spMetadata = &saml.EntityDescriptor{}
30+
if err := xrv.Validate(bytes.NewBuffer(data)); err != nil {
31+
return nil, err
32+
}
3033

31-
if err := xml.Unmarshal(bytes, &spMetadata); err != nil {
34+
if err := xml.Unmarshal(data, &spMetadata); err != nil {
3235
if err.Error() == "expected element type <EntityDescriptor> but have <EntitiesDescriptor>" {
3336
entities := &saml.EntitiesDescriptor{}
34-
35-
if err := xml.Unmarshal(bytes, &entities); err != nil {
37+
if err := xml.Unmarshal(data, &entities); err != nil {
3638
return nil, err
3739
}
3840

samlidp/util_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
package samlidp
2+
3+
import (
4+
"strings"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
func TestGetSPMetadata(t *testing.T) {
11+
good := "" +
12+
"<EntityDescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" validUntil=\"2013-03-10T00:32:19.104Z\" cacheDuration=\"PT1H\" entityID=\"http://localhost:5000/e087a985171710fb9fb30f30f41384f9/saml2/metadata/\">\n" +
13+
"</EntityDescriptor>"
14+
_, err := getSPMetadata(strings.NewReader(good))
15+
assert.NoError(t, err)
16+
17+
bad := "" +
18+
"<EntityDescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" ::attr=\"foo\" validUntil=\"2013-03-10T00:32:19.104Z\" cacheDuration=\"PT1H\" entityID=\"http://localhost:5000/e087a985171710fb9fb30f30f41384f9/saml2/metadata/\">\n" +
19+
"</EntityDescriptor>"
20+
_, err = getSPMetadata(strings.NewReader(bad))
21+
assert.EqualError(t, err, "validator: in token starting at 1:1: roundtrip error: expected {{ EntityDescriptor} [{{ xmlns} urn:oasis:names:tc:SAML:2.0:metadata} {{ :attr} foo} {{ validUntil} 2013-03-10T00:32:19.104Z} {{ cacheDuration} PT1H} {{ entityID} http://localhost:5000/e087a985171710fb9fb30f30f41384f9/saml2/metadata/}]}, observed {{ EntityDescriptor} [{{ xmlns} urn:oasis:names:tc:SAML:2.0:metadata} {{ attr} foo} {{ validUntil} 2013-03-10T00:32:19.104Z} {{ cacheDuration} PT1H} {{ entityID} http://localhost:5000/e087a985171710fb9fb30f30f41384f9/saml2/metadata/}]}")
22+
}

samlsp/fetch_metadata.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package samlsp
22

33
import (
4+
"bytes"
45
"context"
56
"encoding/xml"
67
"errors"
@@ -9,6 +10,7 @@ import (
910
"net/url"
1011

1112
"github.com/crewjam/httperr"
13+
xrv "github.com/mattermost/xml-roundtrip-validator"
1214

1315
"github.com/crewjam/saml"
1416
)
@@ -20,6 +22,11 @@ import (
2022
// <EntityDescriptor>.
2123
func ParseMetadata(data []byte) (*saml.EntityDescriptor, error) {
2224
entity := &saml.EntityDescriptor{}
25+
26+
if err := xrv.Validate(bytes.NewBuffer(data)); err != nil {
27+
return nil, err
28+
}
29+
2330
err := xml.Unmarshal(data, entity)
2431

2532
// this comparison is ugly, but it is how the error is generated in encoding/xml

samlsp/fetch_metadata_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"net/http"
77
"net/http/httptest"
88
"net/url"
9+
"strings"
910
"testing"
1011

1112
"github.com/stretchr/testify/assert"
@@ -25,3 +26,19 @@ func TestFetchMetadata(t *testing.T) {
2526
assert.NoError(t, err)
2627
assert.Equal(t, "https://idp.testshib.org/idp/shibboleth", md.EntityID)
2728
}
29+
30+
func TestFetchMetadataRejectsInvalid(t *testing.T) {
31+
test := NewMiddlewareTest()
32+
test.IDPMetadata = strings.Replace(test.IDPMetadata, "<EntityDescriptor ", "<EntityDescriptor ::foo=\"bar\"", -1)
33+
34+
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
35+
assert.Equal(t, "/metadata", r.URL.String())
36+
fmt.Fprint(w, test.IDPMetadata)
37+
}))
38+
39+
fmt.Println(testServer.URL + "/metadata")
40+
u, _ := url.Parse(testServer.URL + "/metadata")
41+
md, err := FetchMetadata(context.Background(), testServer.Client(), *u)
42+
assert.EqualError(t, err, "validator: in token starting at 2:1: roundtrip error: expected {{ EntityDescriptor} [{{ :foo} bar} {{ xmlns} urn:oasis:names:tc:SAML:2.0:metadata} {{xmlns ds} http://www.w3.org/2000/09/xmldsig#} {{xmlns mdalg} urn:oasis:names:tc:SAML:metadata:algsupport} {{xmlns mdui} urn:oasis:names:tc:SAML:metadata:ui} {{xmlns shibmd} urn:mace:shibboleth:metadata:1.0} {{xmlns xsi} http://www.w3.org/2001/XMLSchema-instance} {{ Name} urn:mace:shibboleth:testshib:two} {{ entityID} https://idp.testshib.org/idp/shibboleth}]}, observed {{ EntityDescriptor} [{{ foo} bar} {{ xmlns} urn:oasis:names:tc:SAML:2.0:metadata} {{xmlns ds} http://www.w3.org/2000/09/xmldsig#} {{xmlns mdalg} urn:oasis:names:tc:SAML:metadata:algsupport} {{xmlns mdui} urn:oasis:names:tc:SAML:metadata:ui} {{xmlns shibmd} urn:mace:shibboleth:metadata:1.0} {{xmlns xsi} http://www.w3.org/2001/XMLSchema-instance} {{ Name} urn:mace:shibboleth:testshib:two} {{ entityID} https://idp.testshib.org/idp/shibboleth} {{ entityID} https://idp.testshib.org/idp/shibboleth}]}")
43+
assert.Nil(t, md)
44+
}

service_provider.go

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,14 @@ import (
1111
"errors"
1212
"fmt"
1313
"html/template"
14+
"io/ioutil"
1415
"net/http"
1516
"net/url"
1617
"regexp"
1718
"time"
1819

20+
xrv "github.com/mattermost/xml-roundtrip-validator"
21+
1922
"github.com/beevik/etree"
2023
dsig "github.com/russellhaering/goxmldsig"
2124
"github.com/russellhaering/goxmldsig/etreeutils"
@@ -553,9 +556,15 @@ func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleR
553556
Response: string(decodedResponseXML),
554557
}
555558

559+
// ensure that the response XML is well formed before we parse it
560+
if err := xrv.Validate(bytes.NewReader(decodedResponseXML)); err != nil {
561+
retErr.PrivateErr = fmt.Errorf("invalid xml: %s", err)
562+
return nil, retErr
563+
}
564+
556565
// do some validation first before we decrypt
557566
resp := Response{}
558-
if err := xml.Unmarshal([]byte(decodedResponseXML), &resp); err != nil {
567+
if err := xml.Unmarshal(decodedResponseXML, &resp); err != nil {
559568
retErr.PrivateErr = fmt.Errorf("cannot unmarshal response: %s", err)
560569
return nil, retErr
561570
}
@@ -659,6 +668,12 @@ func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleR
659668
}
660669
retErr.Response = string(plaintextAssertion)
661670

671+
// TODO(ross): add test case for this
672+
if err := xrv.Validate(bytes.NewReader(plaintextAssertion)); err != nil {
673+
retErr.PrivateErr = fmt.Errorf("plaintext response contains invalid XML: %s", err)
674+
return nil, retErr
675+
}
676+
662677
doc = etree.NewDocument()
663678
if err := doc.ReadFromBytes(plaintextAssertion); err != nil {
664679
retErr.PrivateErr = fmt.Errorf("cannot parse plaintext response %v", err)
@@ -673,6 +688,8 @@ func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleR
673688
}
674689

675690
assertion = &Assertion{}
691+
// Note: plaintextAssertion is known to be safe to parse because
692+
// plaintextAssertion is unmodified from when xrv.Validate() was called above.
676693
if err := xml.Unmarshal(plaintextAssertion, assertion); err != nil {
677694
retErr.PrivateErr = err
678695
return nil, retErr
@@ -1001,8 +1018,12 @@ func (sp *ServiceProvider) ValidateLogoutResponseForm(postFormData string) error
10011018
return fmt.Errorf("unable to parse base64: %s", err)
10021019
}
10031020

1004-
var resp LogoutResponse
1021+
// TODO(ross): add test case for this (SLO does not have tests right now)
1022+
if err := xrv.Validate(bytes.NewReader(rawResponseBuf)); err != nil {
1023+
return fmt.Errorf("response contains invalid XML: %s", err)
1024+
}
10051025

1026+
var resp LogoutResponse
10061027
if err := xml.Unmarshal(rawResponseBuf, &resp); err != nil {
10071028
return fmt.Errorf("cannot unmarshal response: %s", err)
10081029
}
@@ -1034,9 +1055,16 @@ func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData str
10341055
return fmt.Errorf("unable to parse base64: %s", err)
10351056
}
10361057

1037-
gr := flate.NewReader(bytes.NewBuffer(rawResponseBuf))
1058+
gr, err := ioutil.ReadAll(flate.NewReader(bytes.NewBuffer(rawResponseBuf)))
1059+
if err != nil {
1060+
return err
1061+
}
1062+
1063+
if err := xrv.Validate(bytes.NewReader(gr)); err != nil {
1064+
return err
1065+
}
10381066

1039-
decoder := xml.NewDecoder(gr)
1067+
decoder := xml.NewDecoder(bytes.NewReader(gr))
10401068

10411069
var resp LogoutResponse
10421070

@@ -1050,7 +1078,7 @@ func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData str
10501078
}
10511079

10521080
doc := etree.NewDocument()
1053-
if _, err := doc.ReadFrom(gr); err != nil {
1081+
if _, err := doc.ReadFrom(bytes.NewReader(gr)); err != nil {
10541082
return err
10551083
}
10561084

0 commit comments

Comments
 (0)