Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions pkg/github/discussions.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import (
"github.com/github/github-mcp-server/pkg/scopes"
"github.com/github/github-mcp-server/pkg/translations"
"github.com/github/github-mcp-server/pkg/utils"
"github.com/go-viper/mapstructure/v2"
"github.com/google/go-github/v87/github"
"github.com/google/jsonschema-go/jsonschema"
"github.com/modelcontextprotocol/go-sdk/mcp"
Expand Down Expand Up @@ -313,15 +312,19 @@ func GetDiscussion(t translations.TranslationHelperFunc) inventory.ServerTool {
},
[]scopes.Scope{scopes.Repo},
func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
// Decode params
var params struct {
Owner string
Repo string
DiscussionNumber int32
owner, err := RequiredParam[string](args, "owner")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
if err := mapstructure.WeakDecode(args, &params); err != nil {
repo, err := RequiredParam[string](args, "repo")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
discussionNumber, err := RequiredInt(args, "discussionNumber")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}

client, err := deps.GetGQLClient(ctx)
if err != nil {
return utils.NewToolResultError(fmt.Sprintf("failed to get GitHub GQL client: %v", err)), nil, nil
Expand All @@ -345,9 +348,9 @@ func GetDiscussion(t translations.TranslationHelperFunc) inventory.ServerTool {
} `graphql:"repository(owner: $owner, name: $repo)"`
}
vars := map[string]any{
"owner": githubv4.String(params.Owner),
"repo": githubv4.String(params.Repo),
"discussionNumber": githubv4.Int(params.DiscussionNumber),
"owner": githubv4.String(owner),
"repo": githubv4.String(repo),
"discussionNumber": githubv4.Int(discussionNumber),
}
if err := client.Query(ctx, &q, vars); err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
Expand Down Expand Up @@ -384,7 +387,7 @@ func GetDiscussion(t translations.TranslationHelperFunc) inventory.ServerTool {
result := utils.NewToolResultText(string(out))
// Discussion content is user-authored (untrusted); confidentiality
// follows repo visibility.
result = attachRepoVisibilityIFCLabelLazy(ctx, deps, params.Owner, params.Repo, result, ifc.LabelRepoUserContent)
result = attachRepoVisibilityIFCLabelLazy(ctx, deps, owner, repo, result, ifc.LabelRepoUserContent)
return result, nil, nil
},
)
Expand Down Expand Up @@ -425,13 +428,16 @@ func GetDiscussionComments(t translations.TranslationHelperFunc) inventory.Serve
},
[]scopes.Scope{scopes.Repo},
func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) {
// Decode params
var params struct {
Owner string
Repo string
DiscussionNumber int32
owner, err := RequiredParam[string](args, "owner")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
repo, err := RequiredParam[string](args, "repo")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}
if err := mapstructure.WeakDecode(args, &params); err != nil {
discussionNumber, err := RequiredInt(args, "discussionNumber")
if err != nil {
return utils.NewToolResultError(err.Error()), nil, nil
}

Expand Down Expand Up @@ -467,9 +473,9 @@ func GetDiscussionComments(t translations.TranslationHelperFunc) inventory.Serve
}

vars := map[string]any{
"owner": githubv4.String(params.Owner),
"repo": githubv4.String(params.Repo),
"discussionNumber": githubv4.Int(params.DiscussionNumber),
"owner": githubv4.String(owner),
"repo": githubv4.String(repo),
"discussionNumber": githubv4.Int(discussionNumber),
"first": githubv4.Int(*paginationParams.First),
}
if paginationParams.After != nil {
Expand Down Expand Up @@ -592,7 +598,7 @@ func GetDiscussionComments(t translations.TranslationHelperFunc) inventory.Serve
result := utils.NewToolResultText(string(out))
// Discussion comments are user-authored (untrusted); confidentiality
// follows repo visibility.
result = attachRepoVisibilityIFCLabelLazy(ctx, deps, params.Owner, params.Repo, result, ifc.LabelRepoUserContent)
result = attachRepoVisibilityIFCLabelLazy(ctx, deps, owner, repo, result, ifc.LabelRepoUserContent)
return result, nil, nil
},
)
Expand Down
82 changes: 80 additions & 2 deletions pkg/github/discussions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -590,8 +590,47 @@ func Test_GetDiscussion(t *testing.T) {
}
}

func Test_GetDiscussionRequiredParams(t *testing.T) {
t.Parallel()

toolDef := GetDiscussion(translations.NullTranslationHelper)
handler := toolDef.Handler(BaseDeps{GQLClient: githubv4.NewClient(githubv4mock.NewMockedHTTPClient())})

tests := []struct {
name string
requestArgs map[string]any
expectedErrMsg string
}{
{
name: "missing owner",
requestArgs: map[string]any{"repo": "repo", "discussionNumber": float64(1)},
expectedErrMsg: "missing required parameter: owner",
},
{
name: "missing repo",
requestArgs: map[string]any{"owner": "owner", "discussionNumber": float64(1)},
expectedErrMsg: "missing required parameter: repo",
},
{
name: "missing discussionNumber",
requestArgs: map[string]any{"owner": "owner", "repo": "repo"},
expectedErrMsg: "missing required parameter: discussionNumber",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := createMCPRequest(tc.requestArgs)
res, err := handler(ContextWithDeps(context.Background(), BaseDeps{}), &req)
require.NoError(t, err)
require.True(t, res.IsError)
assert.Contains(t, getTextResult(t, res).Text, tc.expectedErrMsg)
})
}
}

func Test_GetDiscussionWithStringNumber(t *testing.T) {
// Test that WeakDecode handles string discussionNumber from MCP clients
// Test that RequiredInt handles string discussionNumber from MCP clients
toolDef := GetDiscussion(translations.NullTranslationHelper)

qGetDiscussion := "query($discussionNumber:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussion(number: $discussionNumber){number,title,body,createdAt,closed,isAnswered,answerChosenAt,url,category{name}}}}"
Expand Down Expand Up @@ -723,8 +762,47 @@ func Test_GetDiscussionComments(t *testing.T) {
assert.Equal(t, "This is the second comment", response.Comments[1].Body)
}

func Test_GetDiscussionCommentsRequiredParams(t *testing.T) {
t.Parallel()

toolDef := GetDiscussionComments(translations.NullTranslationHelper)
handler := toolDef.Handler(BaseDeps{GQLClient: githubv4.NewClient(githubv4mock.NewMockedHTTPClient())})

tests := []struct {
name string
requestArgs map[string]any
expectedErrMsg string
}{
{
name: "missing owner",
requestArgs: map[string]any{"repo": "repo", "discussionNumber": float64(1)},
expectedErrMsg: "missing required parameter: owner",
},
{
name: "missing repo",
requestArgs: map[string]any{"owner": "owner", "discussionNumber": float64(1)},
expectedErrMsg: "missing required parameter: repo",
},
{
name: "missing discussionNumber",
requestArgs: map[string]any{"owner": "owner", "repo": "repo"},
expectedErrMsg: "missing required parameter: discussionNumber",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := createMCPRequest(tc.requestArgs)
res, err := handler(ContextWithDeps(context.Background(), BaseDeps{}), &req)
require.NoError(t, err)
require.True(t, res.IsError)
assert.Contains(t, getTextResult(t, res).Text, tc.expectedErrMsg)
})
}
}

func Test_GetDiscussionCommentsWithStringNumber(t *testing.T) {
// Test that WeakDecode handles string discussionNumber from MCP clients
// Test that RequiredInt handles string discussionNumber from MCP clients
toolDef := GetDiscussionComments(translations.NullTranslationHelper)

qGetComments := "query($after:String$discussionNumber:Int!$first:Int!$owner:String!$repo:String!){repository(owner: $owner, name: $repo){discussion(number: $discussionNumber){comments(first: $first, after: $after){nodes{id,body,isAnswer},pageInfo{hasNextPage,hasPreviousPage,startCursor,endCursor},totalCount}}}}"
Expand Down