tinc_build/codegen/cel/functions/
all.rs

1use proc_macro2::TokenStream;
2use quote::{ToTokens, quote};
3use syn::parse_quote;
4use tinc_cel::CelValue;
5
6use super::Function;
7use crate::codegen::cel::compiler::{CompileError, CompiledExpr, CompilerCtx, ConstantCompiledExpr, RuntimeCompiledExpr};
8use crate::codegen::cel::types::CelType;
9use crate::types::{ProtoModifiedValueType, ProtoType, ProtoValueType};
10
11#[derive(Debug, Clone, Default)]
12pub(crate) struct All;
13
14fn native_impl(iter: TokenStream, item_ident: syn::Ident, compare: impl ToTokens) -> syn::Expr {
15    parse_quote!({
16        let mut iter = (#iter).into_iter();
17        loop {
18            let Some(#item_ident) = iter.next() else {
19                break true;
20            };
21
22            if !(#compare) {
23                break false;
24            }
25        }
26    })
27}
28
29// this.all(<ident>, <expr>)
30impl Function for All {
31    fn name(&self) -> &'static str {
32        "all"
33    }
34
35    fn syntax(&self) -> &'static str {
36        "<this>.all(<ident>, <expr>)"
37    }
38
39    fn compile(&self, ctx: CompilerCtx) -> Result<CompiledExpr, CompileError> {
40        let Some(this) = &ctx.this else {
41            return Err(CompileError::syntax("missing this", self));
42        };
43
44        if ctx.args.len() != 2 {
45            return Err(CompileError::syntax("invalid number of args, expected 2", self));
46        }
47
48        let cel_parser::Expression::Ident(variable) = &ctx.args[0] else {
49            return Err(CompileError::syntax("first argument must be an ident", self));
50        };
51
52        match this {
53            CompiledExpr::Runtime(RuntimeCompiledExpr { expr, ty }) => {
54                let mut child_ctx = ctx.child();
55
56                match ty {
57                    CelType::CelValue => {
58                        child_ctx.add_variable(variable, CompiledExpr::runtime(CelType::CelValue, parse_quote!(item)));
59                    }
60                    CelType::Proto(ProtoType::Modified(
61                        ProtoModifiedValueType::Repeated(ty) | ProtoModifiedValueType::Map(ty, _),
62                    )) => {
63                        child_ctx.add_variable(
64                            variable,
65                            CompiledExpr::runtime(CelType::Proto(ProtoType::Value(ty.clone())), parse_quote!(item)),
66                        );
67                    }
68                    v => {
69                        return Err(CompileError::TypeConversion {
70                            ty: Box::new(v.clone()),
71                            message: "type cannot be iterated over".to_string(),
72                        });
73                    }
74                };
75
76                let arg = child_ctx.resolve(&ctx.args[1])?.into_bool(&child_ctx);
77
78                Ok(CompiledExpr::runtime(
79                    CelType::Proto(ProtoType::Value(ProtoValueType::Bool)),
80                    match &ty {
81                        CelType::CelValue => parse_quote! {
82                            ::tinc::__private::cel::CelValue::cel_all(#expr, |item| {
83                                ::core::result::Result::Ok(
84                                    #arg
85                                )
86                            })?
87                        },
88                        CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Map(_, _))) => {
89                            native_impl(quote!((#expr).keys()), parse_quote!(item), arg)
90                        }
91                        CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Repeated(_))) => {
92                            native_impl(quote!((#expr).iter()), parse_quote!(item), arg)
93                        }
94                        _ => unreachable!(),
95                    },
96                ))
97            }
98            CompiledExpr::Constant(ConstantCompiledExpr {
99                value: value @ (CelValue::List(_) | CelValue::Map(_)),
100            }) => {
101                let compile_val = |value: CelValue<'static>| {
102                    let mut child_ctx = ctx.child();
103
104                    child_ctx.add_variable(variable, CompiledExpr::constant(value));
105
106                    child_ctx.resolve(&ctx.args[1]).map(|v| v.into_bool(&child_ctx))
107                };
108
109                let collected: Result<Vec<_>, _> = match value {
110                    CelValue::List(item) => item.iter().cloned().map(compile_val).collect(),
111                    CelValue::Map(item) => item.iter().map(|(key, _)| key).cloned().map(compile_val).collect(),
112                    _ => unreachable!(),
113                };
114
115                let collected = collected?;
116                if collected.iter().any(|c| matches!(c, CompiledExpr::Runtime(_))) {
117                    Ok(CompiledExpr::runtime(
118                        CelType::Proto(ProtoType::Value(ProtoValueType::Bool)),
119                        native_impl(quote!([#(#collected),*]), parse_quote!(item), quote!(item)),
120                    ))
121                } else {
122                    Ok(CompiledExpr::constant(CelValue::Bool(collected.into_iter().all(
123                        |c| match c {
124                            CompiledExpr::Constant(ConstantCompiledExpr { value }) => value.to_bool(),
125                            _ => unreachable!("all values must be constant"),
126                        },
127                    ))))
128                }
129            }
130            CompiledExpr::Constant(ConstantCompiledExpr { value }) => Err(CompileError::TypeConversion {
131                ty: Box::new(CelType::CelValue),
132                message: format!("{value:?} cannot be iterated over"),
133            }),
134        }
135    }
136}
137
138#[cfg(test)]
139#[cfg(feature = "prost")]
140#[cfg_attr(coverage_nightly, coverage(off))]
141mod tests {
142    use quote::quote;
143    use syn::parse_quote;
144    use tinc_cel::{CelValue, CelValueConv};
145
146    use crate::codegen::cel::compiler::{CompiledExpr, Compiler, CompilerCtx};
147    use crate::codegen::cel::functions::{All, Function};
148    use crate::codegen::cel::types::CelType;
149    use crate::types::{ProtoModifiedValueType, ProtoType, ProtoTypeRegistry, ProtoValueType};
150
151    #[test]
152    fn test_all_syntax() {
153        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, crate::extern_paths::ExternPaths::new(crate::Mode::Prost));
154        let compiler = Compiler::new(&registry);
155        insta::assert_debug_snapshot!(All.compile(CompilerCtx::new(compiler.child(), None, &[])), @r#"
156        Err(
157            InvalidSyntax {
158                message: "missing this",
159                syntax: "<this>.all(<ident>, <expr>)",
160            },
161        )
162        "#);
163
164        insta::assert_debug_snapshot!(All.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::List(Default::default()))), &[])), @r#"
165        Err(
166            InvalidSyntax {
167                message: "invalid number of args, expected 2",
168                syntax: "<this>.all(<ident>, <expr>)",
169            },
170        )
171        "#);
172
173        insta::assert_debug_snapshot!(All.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::String("hi".into()))), &[
174            cel_parser::parse("x").unwrap(),
175            cel_parser::parse("dyn(x >= 1)").unwrap(),
176        ])), @r#"
177        Err(
178            TypeConversion {
179                ty: CelValue,
180                message: "String(Borrowed(\"hi\")) cannot be iterated over",
181            },
182        )
183        "#);
184
185        insta::assert_debug_snapshot!(All.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::runtime(CelType::Proto(ProtoType::Value(ProtoValueType::Bool)), parse_quote!(input))), &[
186            cel_parser::parse("x").unwrap(),
187            cel_parser::parse("dyn(x >= 1)").unwrap(),
188        ])), @r#"
189        Err(
190            TypeConversion {
191                ty: Proto(
192                    Value(
193                        Bool,
194                    ),
195                ),
196                message: "type cannot be iterated over",
197            },
198        )
199        "#);
200
201        insta::assert_debug_snapshot!(All.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::List(Default::default()))), &[
202            cel_parser::parse("1 + 1").unwrap(), // not an ident
203            cel_parser::parse("x + 2").unwrap(),
204        ])), @r#"
205        Err(
206            InvalidSyntax {
207                message: "first argument must be an ident",
208                syntax: "<this>.all(<ident>, <expr>)",
209            },
210        )
211        "#);
212
213        insta::assert_debug_snapshot!(All.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::List([
214            CelValueConv::conv(4),
215            CelValueConv::conv(3),
216            CelValueConv::conv(10),
217        ].into_iter().collect()))), &[
218            cel_parser::parse("x").unwrap(),
219            cel_parser::parse("x > 2").unwrap(),
220        ])), @r"
221        Ok(
222            Constant(
223                ConstantCompiledExpr {
224                    value: Bool(
225                        true,
226                    ),
227                },
228            ),
229        )
230        ");
231
232        insta::assert_debug_snapshot!(All.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::List([
233            CelValueConv::conv(2),
234        ].into_iter().collect()))), &[
235            cel_parser::parse("x").unwrap(),
236            cel_parser::parse("x > 2").unwrap(),
237        ])), @r"
238        Ok(
239            Constant(
240                ConstantCompiledExpr {
241                    value: Bool(
242                        false,
243                    ),
244                },
245            ),
246        )
247        ");
248
249        insta::assert_debug_snapshot!(All.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::Map([
250            (CelValueConv::conv(2), CelValue::Null),
251        ].into_iter().collect()))), &[
252            cel_parser::parse("x").unwrap(),
253            cel_parser::parse("x > 2").unwrap(),
254        ])), @r"
255        Ok(
256            Constant(
257                ConstantCompiledExpr {
258                    value: Bool(
259                        false,
260                    ),
261                },
262            ),
263        )
264        ");
265
266        insta::assert_debug_snapshot!(All.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValueConv::conv(1))), &[
267            cel_parser::parse("x").unwrap(),
268            cel_parser::parse("x > 2").unwrap(),
269        ])), @r#"
270        Err(
271            TypeConversion {
272                ty: CelValue,
273                message: "Number(I64(1)) cannot be iterated over",
274            },
275        )
276        "#);
277    }
278
279    #[test]
280    #[cfg(not(valgrind))]
281    fn test_all_cel_value() {
282        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, crate::extern_paths::ExternPaths::new(crate::Mode::Prost));
283        let compiler = Compiler::new(&registry);
284
285        let map = CompiledExpr::runtime(CelType::CelValue, parse_quote!(input));
286
287        let result = All
288            .compile(CompilerCtx::new(
289                compiler.child(),
290                Some(map),
291                &[
292                    cel_parser::parse("x").unwrap(), // not an ident
293                    cel_parser::parse("x > 2").unwrap(),
294                ],
295            ))
296            .unwrap();
297
298        let result = postcompile::compile_str!(
299            postcompile::config! {
300                test: true,
301                dependencies: vec![
302                    postcompile::Dependency::version("tinc", "*"),
303                ],
304            },
305            quote! {
306                #[allow(dead_code)]
307                fn all<'a>(
308                    input: ::tinc::__private::cel::CelValue<'a>,
309                ) -> Result<bool, ::tinc::__private::cel::CelError<'a>> {
310                    Ok(
311                        #result
312                    )
313                }
314
315                #[test]
316                fn test_all() {
317                    assert_eq!(all(::tinc::__private::cel::CelValueConv::conv(&[0, 1, 2] as &[i32])).unwrap(), false);
318                    assert_eq!(all(::tinc::__private::cel::CelValueConv::conv(&[3, 4, 5] as &[i32])).unwrap(), true);
319                    assert_eq!(all(::tinc::__private::cel::CelValueConv::conv(&[] as &[i32])).unwrap(), true);
320                }
321            },
322        );
323
324        insta::assert_snapshot!(result);
325    }
326
327    #[test]
328    #[cfg(not(valgrind))]
329    fn test_all_proto_map() {
330        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, crate::extern_paths::ExternPaths::new(crate::Mode::Prost));
331        let compiler = Compiler::new(&registry);
332
333        let map = CompiledExpr::runtime(
334            CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Map(
335                ProtoValueType::Int32,
336                ProtoValueType::Float,
337            ))),
338            parse_quote!(input),
339        );
340
341        let result = All
342            .compile(CompilerCtx::new(
343                compiler.child(),
344                Some(map),
345                &[
346                    cel_parser::parse("x").unwrap(), // not an ident
347                    cel_parser::parse("x > 2").unwrap(),
348                ],
349            ))
350            .unwrap();
351
352        let result = postcompile::compile_str!(
353            postcompile::config! {
354                test: true,
355                dependencies: vec![
356                    postcompile::Dependency::version("tinc", "*"),
357                ],
358            },
359            quote! {
360                #[allow(dead_code)]
361                fn all(
362                    input: &std::collections::BTreeMap<i32, f32>,
363                ) -> Result<bool, ::tinc::__private::cel::CelError<'static>> {
364                    Ok(
365                        #result
366                    )
367                }
368
369                #[test]
370                fn test_all() {
371                    assert_eq!(all(&{
372                        let mut map = std::collections::BTreeMap::new();
373                        map.insert(3, 2.0);
374                        map.insert(4, 2.0);
375                        map.insert(5, 2.0);
376                        map
377                    }).unwrap(), true);
378                    assert_eq!(all(&{
379                        let mut map = std::collections::BTreeMap::new();
380                        map.insert(3, 2.0);
381                        map.insert(1, 2.0);
382                        map.insert(5, 2.0);
383                        map
384                    }).unwrap(), false);
385                    assert_eq!(all(&std::collections::BTreeMap::new()).unwrap(), true)
386                }
387            },
388        );
389
390        insta::assert_snapshot!(result);
391    }
392
393    #[test]
394    #[cfg(not(valgrind))]
395    fn test_all_proto_repeated() {
396        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, crate::extern_paths::ExternPaths::new(crate::Mode::Prost));
397        let compiler = Compiler::new(&registry);
398
399        let repeated = CompiledExpr::runtime(
400            CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Repeated(ProtoValueType::Int32))),
401            parse_quote!(input),
402        );
403
404        let result = All
405            .compile(CompilerCtx::new(
406                compiler.child(),
407                Some(repeated),
408                &[
409                    cel_parser::parse("x").unwrap(), // not an ident
410                    cel_parser::parse("x > 2").unwrap(),
411                ],
412            ))
413            .unwrap();
414
415        let result = postcompile::compile_str!(
416            postcompile::config! {
417                test: true,
418                dependencies: vec![
419                    postcompile::Dependency::version("tinc", "*"),
420                ],
421            },
422            quote! {
423                #[allow(dead_code)]
424                fn all(
425                    input: &Vec<i32>,
426                ) -> Result<bool, ::tinc::__private::cel::CelError<'static>> {
427                    Ok(
428                        #result
429                    )
430                }
431
432                #[test]
433                fn test_all() {
434                    assert_eq!(all(&vec![1, 2, 3]).unwrap(), false);
435                    assert_eq!(all(&vec![3, 4, 60]).unwrap(), true);
436                    assert_eq!(all(&vec![]).unwrap(), true);
437                }
438            },
439        );
440
441        insta::assert_snapshot!(result);
442    }
443
444    #[test]
445    #[cfg(not(valgrind))]
446    fn test_all_const_needs_runtime() {
447        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, crate::extern_paths::ExternPaths::new(crate::Mode::Prost));
448        let compiler = Compiler::new(&registry);
449
450        let list = CompiledExpr::constant(CelValue::List([CelValue::Number(0.into())].into_iter().collect()));
451
452        let result = All
453            .compile(CompilerCtx::new(
454                compiler.child(),
455                Some(list),
456                &[
457                    cel_parser::parse("x").unwrap(), // not an ident
458                    cel_parser::parse("dyn(x > 2)").unwrap(),
459                ],
460            ))
461            .unwrap();
462
463        let result = postcompile::compile_str!(
464            postcompile::config! {
465                test: true,
466                dependencies: vec![
467                    postcompile::Dependency::version("tinc", "*"),
468                ],
469            },
470            quote! {
471                #[allow(dead_code)]
472                fn all() -> Result<bool, ::tinc::__private::cel::CelError<'static>> {
473                    Ok(
474                        #result
475                    )
476                }
477
478                #[test]
479                fn test_all() {
480                    assert_eq!(all().unwrap(), false);
481                }
482            },
483        );
484
485        insta::assert_snapshot!(result);
486    }
487
488    #[test]
489    #[cfg(not(valgrind))]
490    fn test_all_runtime() {
491        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, crate::extern_paths::ExternPaths::new(crate::Mode::Prost));
492        let compiler = Compiler::new(&registry);
493
494        let list = CompiledExpr::runtime(
495            CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Repeated(ProtoValueType::Int32))),
496            parse_quote!(input),
497        );
498
499        let result = All
500            .compile(CompilerCtx::new(
501                compiler.child(),
502                Some(list),
503                &[
504                    cel_parser::parse("x").unwrap(), // not an ident
505                    cel_parser::parse("x > 2").unwrap(),
506                ],
507            ))
508            .unwrap();
509
510        insta::assert_snapshot!(postcompile::compile_str!(
511            postcompile::config! {
512                test: true,
513                dependencies: vec![
514                    postcompile::Dependency::version("tinc", "*"),
515                ],
516            },
517            quote! {
518                #[allow(dead_code)]
519                fn runtime_slice(
520                    input: &[i32],
521                ) -> Result<bool, ::tinc::__private::cel::CelError<'static>> {
522                    Ok(
523                        #result
524                    )
525                }
526
527                #[allow(dead_code)]
528                fn runtime_vec(
529                    input: &Vec<i32>,
530                ) -> Result<bool, ::tinc::__private::cel::CelError<'static>> {
531                    Ok(
532                        #result
533                    )
534                }
535
536                #[test]
537                fn test_empty_lists() {
538                    assert!(runtime_slice(&[]).unwrap());
539                    assert!(runtime_vec(&vec![]).unwrap());
540                    assert!(runtime_slice(&[3, 4, 5]).unwrap());
541                    assert!(runtime_vec(&vec![3, 4, 5]).unwrap());
542                    assert!(!runtime_slice(&[3, 4, 5, 2]).unwrap());
543                    assert!(!runtime_vec(&vec![3, 4, 5, 2]).unwrap());
544                }
545            },
546        ));
547    }
548}