Skip to content

Commit be6444c

Browse files
committed
feat: implement task cancellation feature and update task handling
1 parent 98ba7c5 commit be6444c

File tree

9 files changed

+103
-24
lines changed

9 files changed

+103
-24
lines changed

bot/handle_add_task.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ func AddToQueue(ctx *ext.Context, update *ext.Update) error {
163163
task.StoragePath = path.Join(dir.Path, file.FileName)
164164
}
165165

166-
queue.AddTask(task)
166+
queue.AddTask(&task)
167167

168168
entityBuilder := entity.Builder{}
169169
var entities []tg.MessageEntityClass

bot/handle_cancel_task.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package bot
2+
3+
import (
4+
"strings"
5+
6+
"github.com/celestix/gotgproto/dispatcher"
7+
"github.com/celestix/gotgproto/ext"
8+
"github.com/gotd/td/tg"
9+
"github.com/krau/SaveAny-Bot/queue"
10+
)
11+
12+
func cancelTask(ctx *ext.Context, update *ext.Update) error {
13+
key := strings.Split(string(update.CallbackQuery.Data), " ")[1]
14+
ok := queue.CancelTask(key)
15+
if ok {
16+
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
17+
QueryID: update.CallbackQuery.QueryID,
18+
Message: "任务已取消",
19+
})
20+
return dispatcher.EndGroups
21+
}
22+
ctx.AnswerCallback(&tg.MessagesSetBotCallbackAnswerRequest{
23+
QueryID: update.CallbackQuery.QueryID,
24+
Message: "任务取消失败",
25+
})
26+
return dispatcher.EndGroups
27+
}

bot/handlers.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,6 @@ func RegisterHandlers(dispatcher dispatcher.Dispatcher) {
2222
dispatcher.AddHandler(handlers.NewMessage(linkRegexFilter, handleLinkMessage))
2323
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("add"), AddToQueue))
2424
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("set_default"), setDefaultStorage))
25+
dispatcher.AddHandler(handlers.NewCallbackQuery(filters.CallbackQuery.Prefix("cancel"), cancelTask))
2526
dispatcher.AddHandler(handlers.NewMessage(filters.Message.Media, handleFileMessage))
2627
}

bot/utils.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ func HandleSilentAddTask(ctx *ext.Context, update *ext.Update, user *dao.User, t
264264
})
265265
return dispatcher.EndGroups
266266
}
267-
queue.AddTask(*task)
267+
queue.AddTask(task)
268268
ctx.EditMessage(update.EffectiveChat().GetID(), &tg.MessagesEditMessageRequest{
269269
Message: fmt.Sprintf("已添加到队列: %s\n当前排队任务数: %d", task.FileName(), queue.Len()),
270270
ID: task.ReplyMessageID,

core/core.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ func worker(queue *queue.TaskQueue, semaphore chan struct{}) {
2222
switch task.Status {
2323
case types.Pending:
2424
logger.L.Infof("Processing task: %s", task.String())
25-
if err := processPendingTask(&task); err != nil {
26-
logger.L.Errorf("Failed to do task: %s", err)
25+
if err := processPendingTask(task); err != nil {
2726
task.Error = err
2827
if errors.Is(err, context.Canceled) {
2928
logger.L.Debugf("Task canceled: %s", task.String())
3029
task.Status = types.Canceled
3130
} else {
31+
logger.L.Errorf("Failed to do task: %s", err)
3232
task.Status = types.Failed
3333
}
3434
} else {

core/download.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package core
22

33
import (
4+
"context"
45
"fmt"
56
"path/filepath"
67
"time"
@@ -48,11 +49,16 @@ func processPendingTask(task *types.Task) error {
4849
return fmt.Errorf("context is not *ext.Context: %T", task.Ctx)
4950
}
5051

52+
cancelCtx, cancel := context.WithCancel(ctx)
53+
task.Cancel = cancel
54+
task.Ctx = cancelCtx
55+
5156
text, entities := buildProgressMessageEntity(task, 0, task.StartTime, 0)
5257
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
53-
Message: text,
54-
Entities: entities,
55-
ID: task.ReplyMessageID,
58+
Message: text,
59+
Entities: entities,
60+
ID: task.ReplyMessageID,
61+
ReplyMarkup: getCancelTaskMarkup(task),
5662
})
5763
progressCallback := buildProgressCallback(ctx, task, getProgressUpdateCount(task.File.FileSize))
5864

@@ -63,7 +69,7 @@ func processPendingTask(task *types.Task) error {
6369
defer dest.Close()
6470
task.StartTime = time.Now()
6571
downloadBuider := Downloader.Download(bot.Client.API(), task.File.Location).WithThreads(getTaskThreads(task.File.FileSize))
66-
_, err = downloadBuider.Parallel(ctx, dest)
72+
_, err = downloadBuider.Parallel(cancelCtx, dest)
6773
if err != nil {
6874
return fmt.Errorf("下载文件失败: %w", err)
6975
}

core/utils.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,13 +139,20 @@ func buildProgressCallback(ctx *ext.Context, task *types.Task, updateCount int)
139139
}
140140
text, entities := buildProgressMessageEntity(task, bytesRead, task.StartTime, progress)
141141
ctx.EditMessage(task.ReplyChatID, &tg.MessagesEditMessageRequest{
142-
Message: text,
143-
Entities: entities,
144-
ID: task.ReplyMessageID,
142+
Message: text,
143+
Entities: entities,
144+
ID: task.ReplyMessageID,
145+
ReplyMarkup: getCancelTaskMarkup(task),
145146
})
146147
}
147148
}
148149

150+
func getCancelTaskMarkup(task *types.Task) *tg.ReplyInlineMarkup {
151+
return &tg.ReplyInlineMarkup{
152+
Rows: []tg.KeyboardButtonRow{{Buttons: []tg.KeyboardButtonClass{&tg.KeyboardButtonCallback{Text: "取消任务", Data: fmt.Appendf(nil, "cancel %s", task.Key())}}}},
153+
}
154+
}
155+
149156
func fixTaskFileExt(task *types.Task, localFilePath string) {
150157
if path.Ext(task.FileName()) == "" {
151158
mimeType, err := mimetype.DetectFile(localFilePath)

queue/queue.go

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,58 @@ import (
88
)
99

1010
type TaskQueue struct {
11-
list *list.List
12-
cond *sync.Cond
13-
mutex *sync.Mutex
11+
list *list.List
12+
cond *sync.Cond
13+
mutex *sync.Mutex
14+
activeMap map[string]*types.Task
1415
}
1516

16-
func (q *TaskQueue) AddTask(task types.Task) {
17+
func (q *TaskQueue) AddTask(task *types.Task) {
1718
q.mutex.Lock()
1819
defer q.mutex.Unlock()
19-
q.list.PushBack(task)
20-
q.cond.Signal()
20+
if task.Status == types.Pending {
21+
q.list.PushBack(task)
22+
q.cond.Signal()
23+
} else {
24+
delete(q.activeMap, task.Key())
25+
}
2126
}
2227

23-
func (q *TaskQueue) GetTask() types.Task {
28+
func (q *TaskQueue) GetTask() *types.Task {
2429
q.mutex.Lock()
2530
defer q.mutex.Unlock()
2631
for q.list.Len() == 0 {
2732
q.cond.Wait()
2833
}
2934
e := q.list.Front()
30-
task := e.Value.(types.Task)
35+
task := e.Value.(*types.Task)
3136
q.list.Remove(e)
37+
q.activeMap[task.Key()] = task
3238
return task
3339
}
3440

41+
func (q *TaskQueue) CancelTask(key string) bool {
42+
q.mutex.Lock()
43+
defer q.mutex.Unlock()
44+
if task, ok := q.activeMap[key]; ok {
45+
if task.Cancel != nil {
46+
task.Cancel()
47+
return true
48+
}
49+
}
50+
for e := q.list.Front(); e != nil; e = e.Next() {
51+
task := e.Value.(*types.Task)
52+
if task.Key() == key {
53+
if task.Cancel != nil {
54+
task.Cancel()
55+
}
56+
q.list.Remove(e)
57+
return true
58+
}
59+
}
60+
return false
61+
}
62+
3563
func (q *TaskQueue) Len() int {
3664
q.mutex.Lock()
3765
defer q.mutex.Unlock()
@@ -47,20 +75,25 @@ func init() {
4775
func NewQueue() *TaskQueue {
4876
m := &sync.Mutex{}
4977
return &TaskQueue{
50-
list: list.New(),
51-
cond: sync.NewCond(m),
52-
mutex: m,
78+
list: list.New(),
79+
cond: sync.NewCond(m),
80+
mutex: m,
81+
activeMap: make(map[string]*types.Task),
5382
}
5483
}
5584

56-
func AddTask(task types.Task) {
85+
func AddTask(task *types.Task) {
5786
Queue.AddTask(task)
5887
}
5988

60-
func GetTask() types.Task {
89+
func GetTask() *types.Task {
6190
return Queue.GetTask()
6291
}
6392

6493
func Len() int {
6594
return Queue.Len()
6695
}
96+
97+
func CancelTask(key string) bool {
98+
return Queue.CancelTask(key)
99+
}

types/types.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ var StorageTypeDisplay = map[StorageType]string{
3636

3737
type Task struct {
3838
Ctx context.Context
39+
Cancel context.CancelFunc
3940
Error error
4041
Status TaskStatus
4142
File *File
@@ -52,6 +53,10 @@ type Task struct {
5253
UserID int64
5354
}
5455

56+
func (t Task) Key() string {
57+
return fmt.Sprintf("%d:%d", t.FileChatID, t.FileMessageID)
58+
}
59+
5560
func (t Task) String() string {
5661
return fmt.Sprintf("[%d:%d]:%s", t.FileChatID, t.FileMessageID, t.File.FileName)
5762
}

0 commit comments

Comments
 (0)