tinc_build/codegen/cel/functions/
contains.rs

1use quote::quote;
2use syn::parse_quote;
3use tinc_cel::CelValue;
4
5use super::Function;
6use crate::codegen::cel::compiler::{CompileError, CompiledExpr, CompilerCtx, ConstantCompiledExpr, RuntimeCompiledExpr};
7use crate::codegen::cel::types::CelType;
8use crate::types::{ProtoModifiedValueType, ProtoType, ProtoValueType};
9
10#[derive(Debug, Clone, Default)]
11pub(crate) struct Contains;
12
13// this.contains(arg)
14// arg in this
15impl Function for Contains {
16    fn name(&self) -> &'static str {
17        "contains"
18    }
19
20    fn syntax(&self) -> &'static str {
21        "<this>.contains(<arg>)"
22    }
23
24    fn compile(&self, mut ctx: CompilerCtx) -> Result<CompiledExpr, CompileError> {
25        let Some(this) = ctx.this.take() else {
26            return Err(CompileError::syntax("missing this", self));
27        };
28
29        if ctx.args.len() != 1 {
30            return Err(CompileError::syntax("takes exactly one argument", self));
31        }
32
33        let arg = ctx.resolve(&ctx.args[0])?.into_cel()?;
34
35        if let CompiledExpr::Runtime(RuntimeCompiledExpr {
36            expr,
37            ty:
38                ty @ CelType::Proto(ProtoType::Modified(
39                    ProtoModifiedValueType::Repeated(item) | ProtoModifiedValueType::Map(item, _),
40                )),
41        }) = &this
42        {
43            if !matches!(item, ProtoValueType::Message { .. } | ProtoValueType::Enum(_)) {
44                let op = match &ty {
45                    CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Repeated(_))) => {
46                        quote! { array_contains }
47                    }
48                    CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Map(_, _))) => {
49                        quote! { map_contains }
50                    }
51                    _ => unreachable!(),
52                };
53
54                return Ok(CompiledExpr::runtime(
55                    CelType::Proto(ProtoType::Value(ProtoValueType::Bool)),
56                    parse_quote! {
57                        ::tinc::__private::cel::#op(
58                            #expr,
59                            #arg,
60                        )
61                    },
62                ));
63            }
64        }
65
66        let this = this.clone().into_cel()?;
67
68        match (this, arg) {
69            (
70                CompiledExpr::Constant(ConstantCompiledExpr { value: this }),
71                CompiledExpr::Constant(ConstantCompiledExpr { value: arg }),
72            ) => Ok(CompiledExpr::constant(CelValue::cel_contains(this, arg)?)),
73            (this, arg) => Ok(CompiledExpr::runtime(
74                CelType::Proto(ProtoType::Value(ProtoValueType::Bool)),
75                parse_quote! {
76                    ::tinc::__private::cel::CelValue::cel_contains(
77                        #this,
78                        #arg,
79                    )?
80                },
81            )),
82        }
83    }
84}
85
86#[cfg(test)]
87#[cfg(feature = "prost")]
88#[cfg_attr(coverage_nightly, coverage(off))]
89mod tests {
90    use quote::quote;
91    use syn::parse_quote;
92    use tinc_cel::CelValue;
93
94    use crate::codegen::cel::compiler::{CompiledExpr, Compiler, CompilerCtx};
95    use crate::codegen::cel::functions::{Contains, Function};
96    use crate::codegen::cel::types::CelType;
97    use crate::types::{ProtoModifiedValueType, ProtoType, ProtoTypeRegistry, ProtoValueType};
98
99    #[test]
100    fn test_contains_syntax() {
101        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, crate::extern_paths::ExternPaths::new(crate::Mode::Prost));
102        let compiler = Compiler::new(&registry);
103        insta::assert_debug_snapshot!(Contains.compile(CompilerCtx::new(compiler.child(), None, &[])), @r#"
104        Err(
105            InvalidSyntax {
106                message: "missing this",
107                syntax: "<this>.contains(<arg>)",
108            },
109        )
110        "#);
111
112        insta::assert_debug_snapshot!(Contains.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::String("hi".into()))), &[])), @r#"
113        Err(
114            InvalidSyntax {
115                message: "takes exactly one argument",
116                syntax: "<this>.contains(<arg>)",
117            },
118        )
119        "#);
120
121        insta::assert_debug_snapshot!(Contains.compile(CompilerCtx::new(compiler.child(), Some(CompiledExpr::constant(CelValue::List(Default::default()))), &[
122            cel_parser::parse("1 + 1").unwrap(),
123        ])), @r"
124        Ok(
125            Constant(
126                ConstantCompiledExpr {
127                    value: Bool(
128                        false,
129                    ),
130                },
131            ),
132        )
133        ");
134    }
135
136    #[test]
137    #[cfg(not(valgrind))]
138    fn test_contains_runtime_string() {
139        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, crate::extern_paths::ExternPaths::new(crate::Mode::Prost));
140        let compiler = Compiler::new(&registry);
141
142        let string_value =
143            CompiledExpr::runtime(CelType::Proto(ProtoType::Value(ProtoValueType::String)), parse_quote!(input));
144
145        let output = Contains
146            .compile(CompilerCtx::new(
147                compiler.child(),
148                Some(string_value),
149                &[cel_parser::parse("(1 + 1).string()").unwrap()],
150            ))
151            .unwrap();
152
153        insta::assert_snapshot!(postcompile::compile_str!(
154            postcompile::config! {
155                test: true,
156                dependencies: vec![
157                    postcompile::Dependency::version("tinc", "*"),
158                ],
159            },
160            quote! {
161                fn contains(input: &String) -> Result<bool, ::tinc::__private::cel::CelError<'_>> {
162                    Ok(#output)
163                }
164
165                #[test]
166                fn test_contains() {
167                    assert_eq!(contains(&"in2dastring".into()).unwrap(), true);
168                    assert_eq!(contains(&"in3dastring".into()).unwrap(), false);
169                }
170            },
171        ));
172    }
173
174    #[test]
175    #[cfg(not(valgrind))]
176    fn test_contains_runtime_map() {
177        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, crate::extern_paths::ExternPaths::new(crate::Mode::Prost));
178        let compiler = Compiler::new(&registry);
179
180        let string_value = CompiledExpr::runtime(
181            CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Map(
182                ProtoValueType::String,
183                ProtoValueType::Bool,
184            ))),
185            parse_quote!(input),
186        );
187
188        let output = Contains
189            .compile(CompilerCtx::new(
190                compiler.child(),
191                Some(string_value),
192                &[cel_parser::parse("'value'").unwrap()],
193            ))
194            .unwrap();
195
196        insta::assert_snapshot!(postcompile::compile_str!(
197            postcompile::config! {
198                test: true,
199                dependencies: vec![
200                    postcompile::Dependency::version("tinc", "*"),
201                ],
202            },
203            quote! {
204                fn contains(input: &std::collections::HashMap<String, bool>) -> Result<bool, ::tinc::__private::cel::CelError<'_>> {
205                    Ok(#output)
206                }
207
208                #[test]
209                fn test_contains() {
210                    assert_eq!(contains(&{
211                        let mut map = std::collections::HashMap::new();
212                        map.insert("value".to_string(), true);
213                        map
214                    }).unwrap(), true);
215                    assert_eq!(contains(&{
216                        let mut map = std::collections::HashMap::new();
217                        map.insert("not_value".to_string(), true);
218                        map
219                    }).unwrap(), false);
220                    assert_eq!(contains(&{
221                        let mut map = std::collections::HashMap::new();
222                        map.insert("xd".to_string(), true);
223                        map.insert("value".to_string(), true);
224                        map
225                    }).unwrap(), true);
226                }
227            },
228        ));
229    }
230
231    #[test]
232    #[cfg(not(valgrind))]
233    fn test_contains_runtime_repeated() {
234        let registry = ProtoTypeRegistry::new(crate::Mode::Prost, crate::extern_paths::ExternPaths::new(crate::Mode::Prost));
235        let compiler = Compiler::new(&registry);
236
237        let string_value = CompiledExpr::runtime(
238            CelType::Proto(ProtoType::Modified(ProtoModifiedValueType::Repeated(ProtoValueType::String))),
239            parse_quote!(input),
240        );
241
242        let output = Contains
243            .compile(CompilerCtx::new(
244                compiler.child(),
245                Some(string_value),
246                &[cel_parser::parse("'value'").unwrap()],
247            ))
248            .unwrap();
249
250        insta::assert_snapshot!(postcompile::compile_str!(
251            postcompile::config! {
252                test: true,
253                dependencies: vec![
254                    postcompile::Dependency::version("tinc", "*"),
255                ],
256            },
257            quote! {
258                fn contains(input: &Vec<String>) -> Result<bool, ::tinc::__private::cel::CelError<'_>> {
259                    Ok(#output)
260                }
261
262                #[test]
263                fn test_contains() {
264                    assert_eq!(contains(&vec!["value".into()]).unwrap(), true);
265                    assert_eq!(contains(&vec!["not_value".into()]).unwrap(), false);
266                    assert_eq!(contains(&vec!["xd".into(), "value".into()]).unwrap(), true);
267                }
268            },
269        ));
270    }
271}