From 2bc99398f9bfe33f67e56ce61784269585e66b70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Garc=C3=ADa=20Montoro?= Date: Thu, 12 Oct 2023 19:09:52 +0200 Subject: [PATCH] MM-53023: Add origin client to ObserveAPIEndpointDuration (#23631) * Add origin device to ObserveAPIEndpointDuration * Fix generation of einterfaces mocks * make einterfaces-mocks * Use request's query and headers to get origin * Add desktop to the origin device identification * Test originDevice function * Rename origin device to origin client --- server/channels/web/handlers.go | 45 ++++++++++++- server/channels/web/handlers_test.go | 68 ++++++++++++++++++++ server/einterfaces/metrics.go | 2 +- server/einterfaces/mocks/MetricsInterface.go | 6 +- 4 files changed, 116 insertions(+), 5 deletions(-) diff --git a/server/channels/web/handlers.go b/server/channels/web/handlers.go index 0b97e270b589..aa920840dde5 100644 --- a/server/channels/web/handlers.go +++ b/server/channels/web/handlers.go @@ -407,11 +407,54 @@ func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.URL.Path != model.APIURLSuffix+"/websocket" { elapsed := float64(time.Since(now)) / float64(time.Second) - c.App.Metrics().ObserveAPIEndpointDuration(h.HandlerName, r.Method, statusCode, elapsed) + originClient := string(originClient(r)) + c.App.Metrics().ObserveAPIEndpointDuration(h.HandlerName, r.Method, statusCode, originClient, elapsed) } } } +type OriginClient string + +const ( + OriginClientUnknown OriginClient = "unknown" + OriginClientWeb OriginClient = "web" + OriginClientMobile OriginClient = "mobile" + OriginClientDesktop OriginClient = "desktop" +) + +// originClient returns the device from which the provided request was issued. The algorithm roughly looks like: +// - If the URL contains the query mobilev2=true, then it's mobile +// - If the first field of the user agent starts with either "rnbeta" or "Mattermost", then it's mobile +// - If the last field of the user agent starts with "Mattermost", then it's desktop +// - Otherwise, it's web +func originClient(r *http.Request) OriginClient { + userAgent := r.Header.Get("User-Agent") + fields := strings.Fields(userAgent) + if len(fields) < 1 { + return OriginClientUnknown + } + + // Is mobile post v2? + queryParam := r.URL.Query().Get("mobilev2") + if queryParam == "true" { + return OriginClientMobile + } + + // Is mobile pre v2? + clientAgent := fields[0] + if strings.HasPrefix(clientAgent, "rnbeta") || strings.HasPrefix(clientAgent, "Mattermost") { + return OriginClientMobile + } + + // Is desktop? + if strings.HasPrefix(fields[len(fields)-1], "Mattermost") { + return OriginClientDesktop + } + + // Default to web + return OriginClientWeb +} + // checkCSRFToken performs a CSRF check on the provided request with the given CSRF token. Returns whether or not // a CSRF check occurred and whether or not it succeeded. func (h *Handler) checkCSRFToken(c *Context, r *http.Request, token string, tokenLocation app.TokenLocation, session *model.Session) (checked bool, passed bool) { diff --git a/server/channels/web/handlers_test.go b/server/channels/web/handlers_test.go index 8076d9a8dcae..059c612088ce 100644 --- a/server/channels/web/handlers_test.go +++ b/server/channels/web/handlers_test.go @@ -882,3 +882,71 @@ func TestCheckCSRFToken(t *testing.T) { assert.Nil(t, c.Err) }) } + +func TestOriginClient(t *testing.T) { + testCases := []struct { + name string + userAgent string + mobilev2 bool + expectedClient OriginClient + }{ + { + name: "No user agent - unknown client", + userAgent: "", + expectedClient: OriginClientUnknown, + }, + { + name: "Mozilla user agent", + userAgent: "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:109.0) Gecko/20100101 Firefox/118.0", + expectedClient: OriginClientWeb, + }, + { + name: "Chrome user agent", + userAgent: "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/118.0.0.0 Safari/537.36", + expectedClient: OriginClientWeb, + }, + { + name: "Mobile post v2", + userAgent: "someother-agent/3.2.4", + mobilev2: true, + expectedClient: OriginClientMobile, + }, + { + name: "Mobile Android", + userAgent: "rnbeta/2.0.0.441 someother-agent/3.2.4", + expectedClient: OriginClientMobile, + }, + { + name: "Mobile iOS", + userAgent: "Mattermost/2.0.0.441 someother-agent/3.2.4", + expectedClient: OriginClientMobile, + }, + { + name: "Desktop user agent", + userAgent: "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/110.0.5481.177 Electron/23.1.2 Safari/537.36 Mattermost/5.3.1", + expectedClient: OriginClientDesktop, + }, + } + + for _, tc := range testCases { + req, err := http.NewRequest(http.MethodGet, "example.com", nil) + require.NoError(t, err) + + // Set User-Agent header, if any + if tc.userAgent != "" { + req.Header.Set("User-Agent", tc.userAgent) + } + + // Set mobilev2 query if needed + if tc.mobilev2 { + q := req.URL.Query() + q.Add("mobilev2", "true") + req.URL.RawQuery = q.Encode() + } + + // Compute origin client + actualClient := originClient(req) + + require.Equal(t, tc.expectedClient, actualClient) + } +} diff --git a/server/einterfaces/metrics.go b/server/einterfaces/metrics.go index 579847547938..2cd4baf97037 100644 --- a/server/einterfaces/metrics.go +++ b/server/einterfaces/metrics.go @@ -58,7 +58,7 @@ type MetricsInterface interface { IncrementFilesSearchCounter() ObserveFilesSearchDuration(elapsed float64) ObserveStoreMethodDuration(method, success string, elapsed float64) - ObserveAPIEndpointDuration(endpoint, method, statusCode string, elapsed float64) + ObserveAPIEndpointDuration(endpoint, method, statusCode, originClient string, elapsed float64) IncrementPostIndexCounter() IncrementFileIndexCounter() IncrementUserIndexCounter() diff --git a/server/einterfaces/mocks/MetricsInterface.go b/server/einterfaces/mocks/MetricsInterface.go index 058ec05a1d4f..5998bb66ae5d 100644 --- a/server/einterfaces/mocks/MetricsInterface.go +++ b/server/einterfaces/mocks/MetricsInterface.go @@ -239,9 +239,9 @@ func (_m *MetricsInterface) IncrementWebsocketReconnectEvent(eventType string) { _m.Called(eventType) } -// ObserveAPIEndpointDuration provides a mock function with given fields: endpoint, method, statusCode, elapsed -func (_m *MetricsInterface) ObserveAPIEndpointDuration(endpoint string, method string, statusCode string, elapsed float64) { - _m.Called(endpoint, method, statusCode, elapsed) +// ObserveAPIEndpointDuration provides a mock function with given fields: endpoint, method, statusCode, originClient, elapsed +func (_m *MetricsInterface) ObserveAPIEndpointDuration(endpoint string, method string, statusCode string, originClient string, elapsed float64) { + _m.Called(endpoint, method, statusCode, originClient, elapsed) } // ObserveClusterRequestDuration provides a mock function with given fields: elapsed