Skip to content

Commit

Permalink
Add WithCtx option.
Browse files Browse the repository at this point in the history
  • Loading branch information
onrik committed Jun 16, 2023
1 parent e0e79a3 commit 32b2e05
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 29 deletions.
30 changes: 14 additions & 16 deletions bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ type Bot struct {
token string
updates chan Update
offset uint64
ctx context.Context
cancelFunc context.CancelFunc
}

Expand All @@ -41,21 +40,19 @@ func NewBot(token string, opts ...Option) (*Bot, error) {
logger: newLogger("[micha] "),
apiServer: defaultAPIServer,
httpClient: http.DefaultClient,
ctx: context.Background(),
}

for _, opt := range opts {
opt(&options)
}

ctx, cancelFunc := context.WithCancel(context.Background())

bot := Bot{
Options: options,
token: token,
updates: make(chan Update),
ctx: ctx,
cancelFunc: cancelFunc,
Options: options,
token: token,
updates: make(chan Update),
}
bot.ctx, bot.cancelFunc = context.WithCancel(options.ctx)

if bot.apiServer == "" {
bot.apiServer = defaultAPIServer
Expand Down Expand Up @@ -101,7 +98,7 @@ func (bot *Bot) decodeResponse(data []byte, target interface{}) error {

// Send GET request to Telegram API
func (bot *Bot) get(method string, params url.Values, target interface{}) error {
request, err := newGetRequest(bot.buildURL(method), params)
request, err := newGetRequest(bot.ctx, bot.buildURL(method), params)
if err != nil {
return err
}
Expand All @@ -121,7 +118,7 @@ func (bot *Bot) get(method string, params url.Values, target interface{}) error

// Send POST request to Telegram API
func (bot *Bot) post(method string, data, target interface{}) error {
request, err := newPostRequest(bot.buildURL(method), data)
request, err := newPostRequest(bot.ctx, bot.buildURL(method), data)
if err != nil {
return err
}
Expand Down Expand Up @@ -184,16 +181,17 @@ func (bot *Bot) Start(allowedUpdates ...string) {
bot.logger.Printf("Get updates error (%s)\n", err.Error())
}

for _, update := range updates {
bot.updates <- update
bot.offset = update.UpdateID
}

select {
case <-bot.ctx.Done():
close(bot.updates)
return
default:
}

for _, update := range updates {
bot.updates <- update
bot.offset = update.UpdateID
}
}
}

Expand Down Expand Up @@ -250,7 +248,7 @@ func (bot *Bot) DeleteWebhook() error {
// but will not be able to log in back to the cloud Bot API server for 10 minutes.
func (bot *Bot) Logout() error {
url := defaultAPIServer + fmt.Sprintf("/bot%s/logOut", bot.token)
request, err := newGetRequest(url, nil)
request, err := newGetRequest(bot.ctx, url, nil)
if err != nil {
return err
}
Expand Down
8 changes: 4 additions & 4 deletions bot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"bytes"
"context"
"fmt"
"io/ioutil"
"io"
"log"
"net/http"
"net/url"
Expand Down Expand Up @@ -64,7 +64,7 @@ func (s *BotTestSuite) registerResultWithRequestCheck(method, result, exceptedRe

httpmock.RegisterResponder("POST", url, func(request *http.Request) (*http.Response, error) {
defer request.Body.Close()
body, err := ioutil.ReadAll(request.Body)
body, err := io.ReadAll(request.Body)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -100,12 +100,12 @@ func (s *BotTestSuite) registeMultipartrRequestCheck(method string, exceptedValu
}

defer file.Close()
data, err := ioutil.ReadAll(file)
data, err := io.ReadAll(file)
if err != nil {
return nil, err
}

exceptedData, err := ioutil.ReadAll(exceptedFile.Source)
exceptedData, err := io.ReadAll(exceptedFile.Source)
if err != nil {
return nil, err
}
Expand Down
8 changes: 7 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
module github.com/onrik/micha

go 1.16
go 1.19

require (
github.com/jarcoal/httpmock v1.0.8
github.com/stretchr/testify v1.7.0
)

require (
github.com/davecgh/go-spew v1.1.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c // indirect
)
22 changes: 15 additions & 7 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@ package micha

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"mime/multipart"
"net/http"
"net/url"
Expand All @@ -16,6 +16,14 @@ type HttpClient interface {
Do(*http.Request) (*http.Response, error)
}

type HTTPError struct {
StatusCode int
}

func (e HTTPError) Error() string {
return fmt.Sprintf("HTTP status: %d", e.StatusCode)
}

type fileField struct {
Source io.Reader
Fieldname string
Expand All @@ -25,28 +33,28 @@ type fileField struct {
func handleResponse(response *http.Response) ([]byte, error) {
defer response.Body.Close()
if response.StatusCode > http.StatusBadRequest {
return nil, fmt.Errorf("HTTP status: %d", response.StatusCode)
return nil, HTTPError{response.StatusCode}
}

return ioutil.ReadAll(response.Body)
return io.ReadAll(response.Body)
}

func newGetRequest(url string, params url.Values) (*http.Request, error) {
func newGetRequest(ctx context.Context, url string, params url.Values) (*http.Request, error) {
if params != nil {
url += fmt.Sprintf("?%s", params.Encode())
}
return http.NewRequest(http.MethodGet, url, nil)
return http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
}

func newPostRequest(url string, data interface{}) (*http.Request, error) {
func newPostRequest(ctx context.Context, url string, data interface{}) (*http.Request, error) {
body := new(bytes.Buffer)
if data != nil {
if err := json.NewEncoder(body).Encode(data); err != nil {
return nil, fmt.Errorf("Encode data error (%s)", err.Error())
}
}

request, err := http.NewRequest(http.MethodPost, url, body)
request, err := http.NewRequestWithContext(ctx, http.MethodPost, url, body)
if err != nil {
return nil, err
}
Expand Down
13 changes: 12 additions & 1 deletion options.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
package micha

import "strings"
import (
"context"
"strings"
)

type Options struct {
limit int
timeout int
logger Logger
apiServer string
httpClient HttpClient
ctx context.Context
}

type Option func(*Options)
Expand Down Expand Up @@ -49,6 +53,13 @@ func WithAPIServer(url string) Option {
}
}

// WithAPIServer - set custom context
func WithCtx(ctx context.Context) Option {
return func(o *Options) {
o.ctx = ctx
}
}

// SendMessageOptions optional params SendMessage method
type SendMessageOptions struct {
ParseMode ParseMode `json:"parse_mode,omitempty"`
Expand Down

0 comments on commit 32b2e05

Please sign in to comment.