Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 66 additions & 9 deletions pkg/parser/import_field_extractor.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ type importAccumulator struct {
caches []string
features []map[string]any
models []map[string][]string // model alias maps from each imported file (appended in import order)
modelPolicies []map[string][]string // model policy sets from each imported file (appended in import order)
modelCosts []map[string]any // model pricing overlays from each imported file (appended in import order)
runInstallScripts bool // true if any imported workflow sets runtimes.node.run-install-scripts: true
agentFile string
Expand Down Expand Up @@ -89,6 +90,12 @@ type importAccumulator struct {
warnings []string
}

const (
modelPolicyAllowedKey = "allowed"
modelPolicyDisallowedKey = "disallowed"
modelPolicyBlockedKey = "blocked"
)

// newImportAccumulator creates and initializes a new importAccumulator.
// Maps (botsSet, etc.) are explicitly initialized to prevent nil map panics
// during deduplication. Slices are left as nil, which is valid for append operations.
Expand Down Expand Up @@ -621,6 +628,10 @@ func (acc *importAccumulator) appendModelsField(fm map[string]any) {
if jsonErr := json.Unmarshal([]byte(modelsContent), &rawModels); jsonErr != nil {
return
}
if modelPolicy := normalizeModelPolicies(rawModels); len(modelPolicy) > 0 {
acc.modelPolicies = append(acc.modelPolicies, modelPolicy)
parserLog.Printf("Extracted model policy from import: allowed=%d, disallowed=%d, blocked=%d", len(modelPolicy["allowed"]), len(modelPolicy["disallowed"]), len(modelPolicy["blocked"]))
}
if _, hasProviders := rawModels["providers"]; hasProviders {
acc.modelCosts = append(acc.modelCosts, rawModels)
if providers, ok := rawModels["providers"].(map[string]any); ok {
Expand All @@ -631,31 +642,76 @@ func (acc *importAccumulator) appendModelsField(fm map[string]any) {
return
}

modelsMap := normalizeModelAliases(rawModels)
aliasModels := make(map[string]any, len(rawModels))
for key, value := range rawModels {
if isModelPolicyKey(key) {
continue
}
aliasModels[key] = value
}
if len(aliasModels) == 0 {
return
}
modelsMap := normalizeModelAliases(aliasModels)
if len(modelsMap) > 0 {
acc.models = append(acc.models, modelsMap)
parserLog.Printf("Extracted model aliases from import: %d entries", len(modelsMap))
}
}

func normalizeModelPolicies(rawModels map[string]any) map[string][]string {
parse := func(key string) []string {
return parseStringSliceField(rawModels[key], false)
}
allowed := parse(modelPolicyAllowedKey)
disallowed := parse(modelPolicyDisallowedKey)
blocked := parse(modelPolicyBlockedKey)
if len(allowed) == 0 && len(disallowed) == 0 && len(blocked) == 0 {
return nil
}
return map[string][]string{
modelPolicyAllowedKey: allowed,
modelPolicyDisallowedKey: disallowed,
modelPolicyBlockedKey: blocked,
}
}

func normalizeModelAliases(rawModels map[string]any) map[string][]string {
modelsMap := make(map[string][]string, len(rawModels))
for k, v := range rawModels {
patterns, ok := v.([]any)
if !ok {
strs := parseStringSliceField(v, true)
if len(strs) == 0 {
continue
}
strs := make([]string, 0, len(patterns))
for _, p := range patterns {
if s, ok := p.(string); ok {
strs = append(strs, s)
}
}
modelsMap[k] = strs
}
return modelsMap
}

func parseStringSliceField(value any, keepEmpty bool) []string {
values, ok := value.([]any)
if !ok {
return nil
}
result := make([]string, 0, len(values))
for _, v := range values {
if s, ok := v.(string); ok {
if s == "" && !keepEmpty {
continue
}
result = append(result, s)
}
}
if len(result) == 0 {
return nil
}
return result
}

func isModelPolicyKey(key string) bool {
return key == modelPolicyAllowedKey || key == modelPolicyDisallowedKey || key == modelPolicyBlockedKey
}

func (acc *importAccumulator) extractRunInstallScripts(fm map[string]any, fullPath string) {
if acc.runInstallScripts {
return
Expand Down Expand Up @@ -737,6 +793,7 @@ func (acc *importAccumulator) toImportsResult(topologicalOrder []string) *Import
MergedEnvSources: acc.envSources,
MergedFeatures: acc.features,
MergedModels: acc.models,
MergedModelPolicies: acc.modelPolicies,
MergedModelCosts: acc.modelCosts,
MergedObservability: mergeObservabilityConfigs(acc.observabilityConfigs),
ImportedFiles: topologicalOrder,
Expand Down
43 changes: 43 additions & 0 deletions pkg/parser/import_field_extractor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -699,3 +699,46 @@ func TestExtractConfigFields_FirstWinsAndAccumulates(t *testing.T) {
assert.Contains(t, acc.secretMaskingBuilder.String(), "enabled")
assert.Contains(t, acc.secretMaskingBuilder.String(), "log-mask")
}

func TestAppendModelsField_ExtractsModelPolicySets(t *testing.T) {
acc := newImportAccumulator()
fm := map[string]any{
"models": map[string]any{
"allowed": []any{"gpt-5", "claude-sonnet"},
"disallowed": []any{"gpt-5-pro"},
"blocked": []any{"claude-opus"},
},
}

acc.appendModelsField(fm)

require.Len(t, acc.modelPolicies, 1, "expected one model policy set")
assert.Equal(t, []string{"gpt-5", "claude-sonnet"}, acc.modelPolicies[0]["allowed"])
assert.Equal(t, []string{"gpt-5-pro"}, acc.modelPolicies[0]["disallowed"])
assert.Equal(t, []string{"claude-opus"}, acc.modelPolicies[0]["blocked"])
assert.Empty(t, acc.models, "policy fields should not be interpreted as model aliases")
}

func TestAppendModelsField_ExtractsModelCostsAndPolicyTogether(t *testing.T) {
acc := newImportAccumulator()
fm := map[string]any{
"models": map[string]any{
"allowed": []any{"gpt-5-mini"},
"providers": map[string]any{
"openai": map[string]any{
"models": map[string]any{
"gpt-5-mini": map[string]any{
"cost": map[string]any{"input": "1e-6"},
},
},
},
},
},
}

acc.appendModelsField(fm)

require.Len(t, acc.modelCosts, 1, "expected one model cost overlay")
require.Len(t, acc.modelPolicies, 1, "expected one model policy set")
assert.Equal(t, []string{"gpt-5-mini"}, acc.modelPolicies[0]["allowed"])
}
1 change: 1 addition & 0 deletions pkg/parser/import_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ type ImportsResult struct {
MergedEnvSources map[string]string // env var name → source import path (for conflict detection and lock file header listing)
MergedFeatures []map[string]any // Merged features configuration from all imports (parsed YAML structures)
MergedModels []map[string][]string // Merged model alias definitions from all imports (first import to define a key wins among imports)
MergedModelPolicies []map[string][]string // Merged model policy sets from all imports (models.allowed/disallowed/blocked)
MergedModelCosts []map[string]any // Merged model pricing overlays (models.json provider structure) from all imports
MergedObservability string // Merged observability config (JSON) from all imports as an endpoint array (deduped by URL)
MergedEngineMCPToolTimeout string // First engine.mcp.tool-timeout found across all imports (Go duration string, e.g. "10m")
Expand Down
24 changes: 22 additions & 2 deletions pkg/parser/schemas/main_workflow_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -2693,10 +2693,30 @@
]
},
"models": {
"description": "Custom model pricing data in the same structure as models.json. Merged with the built-in models.json at runtime; frontmatter entries override matching models and fill gaps for unknown models. Useful for custom or private models, or to adjust pricing for AI Credits cost accounting.",
"description": "Model policy and optional pricing configuration. The policy fields (allowed/disallowed/blocked) are merged as unions across imports. The providers field is optional and supplies pricing data merged by provider/model key.",
"type": "object",
"required": ["providers"],
"properties": {
"allowed": {
"type": "array",
"description": "Allowlist of model names/patterns. Mapped to AWF apiProxy.allowedModels.",
"items": {
"type": "string"
}
},
"disallowed": {
"type": "array",
"description": "Denylist of model names/patterns. Mapped to AWF apiProxy.disallowedModels.",
"items": {
"type": "string"
}
},
"blocked": {
"type": "array",
"description": "Alias denylist of model names/patterns. Unioned with disallowed and mapped to AWF apiProxy.disallowedModels.",
"items": {
"type": "string"
}
},
"providers": {
"type": "object",
"description": "Provider-keyed map of model pricing data.",
Expand Down
33 changes: 33 additions & 0 deletions pkg/workflow/awf_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ import (
"github.com/github/gh-aw/pkg/jsonutil"
"github.com/github/gh-aw/pkg/logger"
"github.com/github/gh-aw/pkg/setutil"
"github.com/github/gh-aw/pkg/workflow/compilerenv"
)

//go:embed schemas/awf-config.schema.json
Expand Down Expand Up @@ -242,6 +243,11 @@ type AWFAPIProxyConfig struct {
// AWF resolves aliases recursively; loops are not permitted.
// Per the AWF config schema, this lives under apiProxy.models.
Models map[string][]string `json:"models,omitempty"`

// AllowedModels is the explicit allowlist policy for model names/patterns.
AllowedModels []string `json:"allowedModels,omitempty"`
// DisallowedModels is the explicit denylist policy for model names/patterns.
DisallowedModels []string `json:"disallowedModels,omitempty"`
}

// AWFModelFallbackConfig is the "apiProxy.modelFallback" section of the AWF config file.
Expand Down Expand Up @@ -492,6 +498,15 @@ func BuildAWFConfigJSON(config AWFCommandConfig) (string, error) {
apiProxy.Models = config.WorkflowData.ModelMappings
awfConfigLog.Printf("Models section: %d alias entries", len(config.WorkflowData.ModelMappings))
}
allowedModels, disallowedModels := resolveModelPolicyForAWFConfig(config.WorkflowData)
if len(allowedModels) > 0 {
apiProxy.AllowedModels = allowedModels
awfConfigLog.Printf("Models policy: %d allowed model pattern(s)", len(allowedModels))
}
if len(disallowedModels) > 0 {
apiProxy.DisallowedModels = disallowedModels
awfConfigLog.Printf("Models policy: %d disallowed model pattern(s)", len(disallowedModels))
}

awfConfig.APIProxy = apiProxy

Expand Down Expand Up @@ -550,6 +565,24 @@ func splitDomainList(domains string) []string {
return result
}

func resolveModelPolicyForAWFConfig(workflowData *WorkflowData) ([]string, []string) {
envAllowed, hasAllowedOverride := compilerenv.ResolvePolicyModelsAllowed()
envBlocked, hasBlockedOverride := compilerenv.ResolvePolicyModelsBlocked()
var allowed []string
var blocked []string
if hasAllowedOverride {
allowed = envAllowed
} else if workflowData != nil {
allowed = workflowData.ModelPolicyAllowed
}
if hasBlockedOverride {
blocked = envBlocked
} else if workflowData != nil {
blocked = workflowData.ModelPolicyBlocked
}
return allowed, blocked
}

func extractModelMultipliers(workflowData *WorkflowData) map[string]float64 {
if workflowData == nil || workflowData.EngineConfig == nil || workflowData.EngineConfig.TokenWeights == nil {
return nil
Expand Down
45 changes: 45 additions & 0 deletions pkg/workflow/awf_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1619,3 +1619,48 @@ func TestBuildAWFTopologyAttachList(t *testing.T) {
assert.Equal(t, []string{"awmg-mcpg", "awmg-cli-proxy"}, targets)
})
}

func TestBuildAWFConfigJSON_EmitsModelPolicyFromWorkflowData(t *testing.T) {
config := AWFCommandConfig{
EngineName: "copilot",
AllowedDomains: "github.com",
WorkflowData: &WorkflowData{
EngineConfig: &EngineConfig{ID: "copilot"},
NetworkPermissions: &NetworkPermissions{
Firewall: &FirewallConfig{Enabled: true},
},
ModelPolicyAllowed: []string{"gpt-5", "claude-sonnet"},
ModelPolicyBlocked: []string{"gpt-5-pro", "claude-opus"},
},
}

jsonStr, err := BuildAWFConfigJSON(config)
require.NoError(t, err)
assert.Contains(t, jsonStr, `"allowedModels":["gpt-5","claude-sonnet"]`)
assert.Contains(t, jsonStr, `"disallowedModels":["gpt-5-pro","claude-opus"]`)
}

func TestBuildAWFConfigJSON_ModelPolicyEnvOverridePrecedence(t *testing.T) {
t.Setenv(compilerenv.PolicyModelsAllowed, "gemini-pro,gpt-5-mini")
t.Setenv(compilerenv.PolicyModelsBlocked, "claude-opus, gpt-5-pro")

config := AWFCommandConfig{
EngineName: "copilot",
AllowedDomains: "github.com",
WorkflowData: &WorkflowData{
EngineConfig: &EngineConfig{ID: "copilot"},
NetworkPermissions: &NetworkPermissions{
Firewall: &FirewallConfig{Enabled: true},
},
ModelPolicyAllowed: []string{"frontmatter-allowed"},
ModelPolicyBlocked: []string{"frontmatter-blocked"},
},
}

jsonStr, err := BuildAWFConfigJSON(config)
require.NoError(t, err)
assert.Contains(t, jsonStr, `"allowedModels":["gemini-pro","gpt-5-mini"]`)
assert.Contains(t, jsonStr, `"disallowedModels":["claude-opus","gpt-5-pro"]`)
assert.NotContains(t, jsonStr, "frontmatter-allowed")
assert.NotContains(t, jsonStr, "frontmatter-blocked")
}
2 changes: 2 additions & 0 deletions pkg/workflow/compiler_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,8 @@ type WorkflowData struct {
KnownActionCredentialEnvVars map[string]struct{} // env vars for clean_known_action_credentials.sh; keyed by GH_AW_CLEAN_* names; nil when no known credential-leaking actions are detected
ModelMappings map[string][]string // merged model alias map (builtins + imported workflow aliases + main frontmatter overrides, in priority order); NOT yet emitted to AWF config JSON — pending AWF firewall support (config.models)
ModelCosts map[string]any // model pricing data from frontmatter `models` field (providers structure); merged with built-in models.json at runtime by generate_aw_info.cjs
ModelPolicyAllowed []string // merged models.allowed policy list (union across imports + main frontmatter)
ModelPolicyBlocked []string // merged denylist from models.disallowed/models.blocked (union across imports + main frontmatter)
ActionPinMappings map[string]string // action-pin redirect table from aw.json action_pins: maps "owner/repo@version" → "owner/repo@version"
}

Expand Down
47 changes: 47 additions & 0 deletions pkg/workflow/compilerenv/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ const (
// PolicyStrict enables runtime enforcement that workflows must be compiled in strict mode
// when GH_AW_POLICY_STRICT is set to the string value "true".
PolicyStrict = "GH_AW_POLICY_STRICT"
// PolicyModelsAllowed centrally overrides models.allowed frontmatter policy.
PolicyModelsAllowed = "GHAW_POLICY_MODELS_ALLOWED"
// PolicyModelsBlocked centrally overrides models.disallowed/models.blocked frontmatter policy.
PolicyModelsBlocked = "GHAW_POLICY_MODELS_BLOCKED"
)

// ResolveDefaultMaxTurns returns fallback when the env var is unset/invalid,
Expand Down Expand Up @@ -161,3 +165,46 @@ func BuildModelOverrideExpression(primaryVar, enterpriseDefaultVar, builtinFallb
func BuildModelOverrideExpressionEmptyFallback(primaryVar, enterpriseDefaultVar string) string {
return fmt.Sprintf("${{ vars.%s || vars.%s || '' }}", primaryVar, enterpriseDefaultVar)
}

// ResolvePolicyModelsAllowed returns configured allowed model policy entries.
// When the env var is unset/empty, ok=false and callers should use frontmatter policy.
func ResolvePolicyModelsAllowed() ([]string, bool) {
return resolveModelListEnv(PolicyModelsAllowed)
}

// ResolvePolicyModelsBlocked returns configured blocked model policy entries.
// When the env var is unset/empty, ok=false and callers should use frontmatter policy.
func ResolvePolicyModelsBlocked() ([]string, bool) {
return resolveModelListEnv(PolicyModelsBlocked)
}

func resolveModelListEnv(name string) ([]string, bool) {
raw := strings.TrimSpace(os.Getenv(name))
if raw == "" {
return nil, false
}
parts := strings.FieldsFunc(raw, func(r rune) bool {
return r == ',' || r == '\n' || r == '\r'
})
if len(parts) == 0 {
return nil, false
}
result := make([]string, 0, len(parts))
seen := map[string]struct{}{}
for _, part := range parts {
model := strings.TrimSpace(part)
if model == "" {
continue
}
if _, exists := seen[model]; exists {
continue
}
seen[model] = struct{}{}
result = append(result, model)
}
if len(result) == 0 {
return nil, false
}
managerLog.Printf("Applying model policy override %s with %d model(s)", name, len(result))
return result, true
}
Loading
Loading