|
1 | 1 | //! Parsing and matching APIs for GitHub Actions expressions
|
2 | 2 | //! contexts (e.g. `github.event.name`).
|
| 3 | +
|
3 | 4 | use super::Expr;
|
4 | 5 |
|
5 | 6 | /// Represents a context in a GitHub Actions expression.
|
@@ -47,6 +48,54 @@ impl<'src> Context<'src> {
|
47 | 48 | _ => None,
|
48 | 49 | }
|
49 | 50 | }
|
| 51 | + |
| 52 | + /// Returns the "pattern equivalent" of this context. |
| 53 | + /// |
| 54 | + /// This is a string that can be used to efficiently match the context, |
| 55 | + /// such as is done in `zizmor`'s template-injection audit via a |
| 56 | + /// finite state transducer. |
| 57 | + /// |
| 58 | + /// Returns None if the context doesn't have a sensible pattern |
| 59 | + /// equivalent, e.g. if it starts with a call. |
| 60 | + pub fn as_pattern(&self) -> Option<String> { |
| 61 | + fn push_part(part: &Expr<'_>, pattern: &mut String) { |
| 62 | + match part { |
| 63 | + Expr::Identifier(ident) => pattern.push_str(ident.0), |
| 64 | + Expr::Star => pattern.push('*'), |
| 65 | + Expr::Index(idx) => match idx.as_ref() { |
| 66 | + // foo['bar'] -> foo.bar |
| 67 | + Expr::String(idx) => pattern.push_str(idx), |
| 68 | + // any kind of numeric or computed index, e.g.: |
| 69 | + // foo[0], foo[1 + 2], foo[bar] |
| 70 | + _ => pattern.push('*'), |
| 71 | + }, |
| 72 | + _ => unreachable!("unexpected part in context pattern"), |
| 73 | + } |
| 74 | + } |
| 75 | + |
| 76 | + // TODO: Optimization ideas: |
| 77 | + // 1. Add a happy path for contexts that contain only |
| 78 | + // identifiers? Problem: case normalization. |
| 79 | + // 2. Use `regex-automata` to return a case insensitive |
| 80 | + // automation here? |
| 81 | + let mut pattern = String::with_capacity(self.raw.len()); |
| 82 | + |
| 83 | + let mut parts = self.parts.iter().peekable(); |
| 84 | + |
| 85 | + let head = parts.next()?; |
| 86 | + if matches!(head, Expr::Call { .. }) { |
| 87 | + return None; |
| 88 | + } |
| 89 | + |
| 90 | + push_part(head, &mut pattern); |
| 91 | + for part in parts { |
| 92 | + pattern.push('.'); |
| 93 | + push_part(part, &mut pattern); |
| 94 | + } |
| 95 | + |
| 96 | + pattern.make_ascii_lowercase(); |
| 97 | + Some(pattern) |
| 98 | + } |
50 | 99 | }
|
51 | 100 |
|
52 | 101 | impl PartialEq for Context<'_> {
|
@@ -120,33 +169,28 @@ impl<'src> ContextPattern<'src> {
|
120 | 169 | }
|
121 | 170 | }
|
122 | 171 |
|
| 172 | + fn compare_part(pattern: &str, part: &Expr<'src>) -> bool { |
| 173 | + if pattern == "*" { |
| 174 | + true |
| 175 | + } else { |
| 176 | + match part { |
| 177 | + Expr::Identifier(part) => pattern.eq_ignore_ascii_case(part.0), |
| 178 | + Expr::Index(part) => match part.as_ref() { |
| 179 | + Expr::String(part) => pattern.eq_ignore_ascii_case(part), |
| 180 | + _ => false, |
| 181 | + }, |
| 182 | + _ => false, |
| 183 | + } |
| 184 | + } |
| 185 | + } |
| 186 | + |
123 | 187 | fn compare(&self, ctx: &Context<'src>) -> Option<Comparison> {
|
124 | 188 | let mut pattern_parts = self.0.split('.').peekable();
|
125 | 189 | let mut ctx_parts = ctx.parts.iter().peekable();
|
126 | 190 |
|
127 | 191 | while let (Some(pattern), Some(part)) = (pattern_parts.peek(), ctx_parts.peek()) {
|
128 |
| - // TODO: Refactor this; it's way too hard to read. |
129 |
| - match (*pattern, part) { |
130 |
| - // Calls can't be compared to patterns. |
131 |
| - (_, Expr::Call { .. }) => return None, |
132 |
| - // "*" matches any part. |
133 |
| - ("*", _) => {} |
134 |
| - (_, Expr::Star) => return None, |
135 |
| - (pattern, Expr::Identifier(part)) if !pattern.eq_ignore_ascii_case(part.0) => { |
136 |
| - return None; |
137 |
| - } |
138 |
| - (pattern, Expr::Index(idx)) => { |
139 |
| - // Anything other than a string index is invalid |
140 |
| - // for part-wise comparison. |
141 |
| - let Expr::String(part) = idx.as_ref() else { |
142 |
| - return None; |
143 |
| - }; |
144 |
| - |
145 |
| - if !pattern.eq_ignore_ascii_case(part) { |
146 |
| - return None; |
147 |
| - } |
148 |
| - } |
149 |
| - _ => {} |
| 192 | + if !Self::compare_part(pattern, part) { |
| 193 | + return None; |
150 | 194 | }
|
151 | 195 |
|
152 | 196 | pattern_parts.next();
|
@@ -253,6 +297,45 @@ mod tests {
|
253 | 297 | }
|
254 | 298 | }
|
255 | 299 |
|
| 300 | + #[test] |
| 301 | + fn test_context_as_pattern() { |
| 302 | + for (case, expected) in &[ |
| 303 | + // Basic cases. |
| 304 | + ("foo", Some("foo")), |
| 305 | + ("foo.bar", Some("foo.bar")), |
| 306 | + ("foo.bar.baz", Some("foo.bar.baz")), |
| 307 | + ("foo.bar.baz_baz", Some("foo.bar.baz_baz")), |
| 308 | + ("foo.bar.baz-baz", Some("foo.bar.baz-baz")), |
| 309 | + ("foo.*", Some("foo.*")), |
| 310 | + ("foo.bar.*", Some("foo.bar.*")), |
| 311 | + ("foo.*.baz", Some("foo.*.baz")), |
| 312 | + ("foo.*.*", Some("foo.*.*")), |
| 313 | + // Case sensitivity. |
| 314 | + ("FOO", Some("foo")), |
| 315 | + ("FOO.BAR", Some("foo.bar")), |
| 316 | + ("FOO.BAR.BAZ", Some("foo.bar.baz")), |
| 317 | + ("FOO.BAR.BAZ_BAZ", Some("foo.bar.baz_baz")), |
| 318 | + ("FOO.BAR.BAZ-BAZ", Some("foo.bar.baz-baz")), |
| 319 | + ("FOO.*", Some("foo.*")), |
| 320 | + ("FOO.BAR.*", Some("foo.bar.*")), |
| 321 | + ("FOO.*.BAZ", Some("foo.*.baz")), |
| 322 | + ("FOO.*.*", Some("foo.*.*")), |
| 323 | + // Indexes. |
| 324 | + ("foo.bar.baz[0]", Some("foo.bar.baz.*")), |
| 325 | + ("foo.bar.baz['abc']", Some("foo.bar.baz.abc")), |
| 326 | + ("foo.bar.baz[0].qux", Some("foo.bar.baz.*.qux")), |
| 327 | + ("foo.bar.baz[0].qux[1]", Some("foo.bar.baz.*.qux.*")), |
| 328 | + ("foo[1][2][3]", Some("foo.*.*.*")), |
| 329 | + ("foo.bar[abc]", Some("foo.bar.*")), |
| 330 | + ("foo.bar[abc()]", Some("foo.bar.*")), |
| 331 | + // Invalid cases |
| 332 | + ("foo().bar", None), |
| 333 | + ] { |
| 334 | + let ctx = Context::try_from(*case).unwrap(); |
| 335 | + assert_eq!(ctx.as_pattern().as_deref(), *expected); |
| 336 | + } |
| 337 | + } |
| 338 | + |
256 | 339 | #[test]
|
257 | 340 | fn test_contextpattern_new() {
|
258 | 341 | for (case, expected) in &[
|
|
0 commit comments