1 module boilerplate.conditions;
2 
3 version(unittest)
4 {
5     import core.exception : AssertError;
6     import unit_threaded.should;
7 }
8 
9 /++
10 `GenerateInvariants` is a mixin string that automatically generates an `invariant{}` block
11 for each field with a condition.
12 +/
13 public enum string GenerateInvariants = `
14     import boilerplate.conditions : GenerateInvariantsTemplate;
15     mixin GenerateInvariantsTemplate;
16     mixin(typeof(this).generateInvariantsImpl());
17 `;
18 
19 /++
20 When a field is marked with `@NonEmpty`, `!field.empty` is asserted.
21 +/
22 public struct NonEmpty
23 {
24 }
25 
26 ///
27 @("throws when a NonEmpty field is initialized empty")
28 unittest
29 {
30     class Class
31     {
32         @NonEmpty
33         int[] array_;
34 
35         this(int[] array)
36         {
37             this.array_ = array;
38         }
39 
40         mixin(GenerateInvariants);
41     }
42 
43     (new Class(null)).shouldThrow!AssertError;
44 }
45 
46 ///
47 @("throws when a NonEmpty field is assigned empty")
48 unittest
49 {
50     class Class
51     {
52         @NonEmpty
53         private int[] array_;
54 
55         this(int[] array)
56         {
57             this.array_ = array;
58         }
59 
60         public void array(int[] arrayValue)
61         {
62             this.array_ = arrayValue;
63         }
64 
65         mixin(GenerateInvariants);
66     }
67 
68     (new Class([2])).array(null).shouldThrow!AssertError;
69 }
70 
71 /++
72 When a field is marked with `@NonNull`, `field !is null` is asserted.
73 +/
74 public struct NonNull
75 {
76 }
77 
78 ///
79 @("throws when a NonNull field is initialized null")
80 unittest
81 {
82     class Class
83     {
84         @NonNull
85         Object obj_;
86 
87         this(Object obj)
88         {
89             this.obj_ = obj;
90         }
91 
92         mixin(GenerateInvariants);
93     }
94 
95     (new Class(null)).shouldThrow!AssertError;
96 }
97 
98 /++
99 When a field is marked with `@AllNonNull`, `field.all!"a !is null"` is asserted.
100 +/
101 public struct AllNonNull
102 {
103 }
104 
105 ///
106 @("throws when an AllNonNull field is initialized with an array containing null")
107 unittest
108 {
109     class Class
110     {
111         @AllNonNull
112         Object[] objs;
113 
114         this(Object[] objs)
115         {
116             this.objs = objs;
117         }
118 
119         mixin(GenerateInvariants);
120     }
121 
122     (new Class(null)).objs.shouldEqual(null);
123     (new Class([null])).shouldThrow!AssertError;
124     (new Class([new Object, null])).shouldThrow!AssertError;
125 }
126 
127 /// `@AllNonNull` may be used with associative arrays.
128 @("supports AllNonNull on associative arrays")
129 unittest
130 {
131     class Class
132     {
133         @AllNonNull
134         Object[int] objs;
135 
136         this(Object[int] objs)
137         {
138             this.objs = objs;
139         }
140 
141         mixin(GenerateInvariants);
142     }
143 
144     (new Class(null)).objs.shouldEqual(null);
145     (new Class([0: null])).shouldThrow!AssertError;
146     (new Class([0: new Object, 1: null])).shouldThrow!AssertError;
147 }
148 
149 /// When used with associative arrays, `@AllNonNull` may check keys, values or both.
150 @("supports AllNonNull on associative array keys")
151 unittest
152 {
153     class Class
154     {
155         @AllNonNull
156         int[Object] objs;
157 
158         this(int[Object] objs)
159         {
160             this.objs = objs;
161         }
162 
163         mixin(GenerateInvariants);
164     }
165 
166     (new Class(null)).objs.shouldEqual(null);
167     (new Class([null: 0])).shouldThrow!AssertError;
168     (new Class([new Object: 0, null: 1])).shouldThrow!AssertError;
169 }
170 
171 /++
172 When a field is marked with `@NonInit`, `field !is T.init` is asserted.
173 +/
174 public struct NonInit
175 {
176 }
177 
178 ///
179 @("throws when a NonInit field is initialized with T.init")
180 unittest
181 {
182     import core.time : Duration;
183 
184     class Class
185     {
186         @NonInit
187         float f_;
188 
189         this(float f) { this.f_ = f; }
190 
191         mixin(GenerateInvariants);
192     }
193 
194     (new Class(float.init)).shouldThrow!AssertError;
195 }
196 
197 /++
198 When <b>any</b> condition check is applied to a nullable field, the test applies to the value,
199 if any, contained in the field. The "null" state of the field is ignored.
200 +/
201 @("doesn't throw when a Nullable field is null")
202 unittest
203 {
204     import std.typecons : Nullable, nullable;
205 
206     class Class
207     {
208         @NonInit
209         Nullable!float f_;
210 
211         this(Nullable!float f)
212         {
213             this.f_ = f;
214         }
215 
216         mixin(GenerateInvariants);
217     }
218 
219     (new Class(5f.nullable)).f_.isNull.shouldBeFalse;
220     (new Class(Nullable!float())).f_.isNull.shouldBeTrue;
221     (new Class(float.init.nullable)).shouldThrow!AssertError;
222 }
223 
224 /++
225 Conditions can be applied to static attributes, generating static invariants.
226 +/
227 @("does not allow invariants on static fields")
228 unittest
229 {
230     import std.typecons : Nullable, nullable;
231 
232     static assert(!__traits(compiles, ()
233     {
234         class Class
235         {
236             @NonNull
237             private static Object obj;
238 
239             mixin(GenerateInvariants);
240         }
241     }), "invariant on static field compiled when it shouldn't");
242 }
243 
244 @("works with classes inheriting from templates")
245 unittest
246 {
247     interface I(T)
248     {
249     }
250 
251     interface K(T)
252     {
253     }
254 
255     class C : I!ubyte, K!ubyte
256     {
257     }
258 
259     class S
260     {
261         C c;
262 
263         mixin(GenerateInvariants);
264     }
265 }
266 
267 mixin template GenerateInvariantsTemplate()
268 {
269     private static string generateInvariantsImpl()
270     {
271         if (!__ctfe)
272         {
273             return null;
274         }
275 
276         import boilerplate.conditions : IsConditionAttribute, generateChecksForAttributes;
277         import boilerplate.util : GenNormalMemberTuple, isStatic;
278         import std.format : format;
279         import std.meta : StdMetaFilter = Filter;
280 
281         string result = null;
282 
283         result ~= `invariant {` ~
284             `import std.format : format;` ~
285             `import std.array : empty;`;
286 
287         // TODO blocked by https://issues.dlang.org/show_bug.cgi?id=18504
288         // note: synchronized without lock contention is basically free
289         // IMPORTANT! Do not enable this until you have a solution for reliably detecting which attributes actually
290         // require synchronization! overzealous synchronize has the potential to lead to needless deadlocks.
291         // (consider implementing @GuardedBy)
292         enum synchronize = false;
293 
294         result ~= synchronize ? `synchronized (this) {` : ``;
295 
296         mixin GenNormalMemberTuple;
297 
298         foreach (member; NormalMemberTuple)
299         {
300             mixin(`alias symbol = this.` ~ member ~ `;`);
301 
302             alias ConditionAttributes = StdMetaFilter!(IsConditionAttribute, __traits(getAttributes, symbol));
303 
304             static if (mixin(isStatic(member)) && ConditionAttributes.length > 0)
305             {
306                 result ~= format!(`static assert(false, `
307                     ~ `"Cannot add constraint on static field %s: no support for static invariants");`
308                 )(member);
309             }
310 
311             static if (__traits(compiles, typeof(symbol).init))
312             {
313                 result ~= generateChecksForAttributes!(typeof(symbol), ConditionAttributes)(`this.` ~ member);
314             }
315         }
316 
317         result ~= synchronize ? ` }` : ``;
318 
319         result ~= ` }`;
320 
321         return result;
322     }
323 }
324 
325 public string generateChecksForAttributes(T, Attributes...)(string member_expression, string info = "")
326 {
327     import boilerplate.conditions : NonEmpty, NonNull;
328     import boilerplate.util : udaIndex;
329     import std.array : empty;
330     import std.string : format;
331     import std.traits : ConstOf, isAssociativeArray;
332     import std.typecons : Nullable;
333 
334     enum isNullable = is(T: Nullable!Args, Args...);
335 
336     static if (isNullable)
337     {
338         enum access = `%s.get`;
339     }
340     else
341     {
342         enum access = `%s`;
343     }
344 
345     alias MemberType = typeof(mixin(format!access(`T.init`)));
346 
347     string expression = format!access(member_expression);
348 
349     enum canFormat = __traits(compiles, format(`%s`, ConstOf!MemberType.init));
350 
351     string checks;
352 
353     static if (udaIndex!(NonEmpty, Attributes) != -1)
354     {
355         static if (!__traits(compiles, MemberType.init.empty()))
356         {
357             return format!`static assert(false, "Cannot call std.array.empty() on '%s'");`(expression);
358         }
359 
360         static if (canFormat)
361         {
362             checks ~= format!(`assert(!%s.empty, `
363                 ~ `format("@NonEmpty: assert(!%s.empty) failed%s: %s = %%s", %s));`)
364                 (expression, expression, info, expression, expression);
365         }
366         else
367         {
368             checks ~= format!`assert(!%s.empty(), "@NonEmpty: assert(!%s.empty) failed%s");`
369                 (expression, expression, info);
370         }
371     }
372 
373     static if (udaIndex!(NonNull, Attributes) != -1)
374     {
375         static if (__traits(compiles, MemberType.init.isNull))
376         {
377             checks ~= format!`assert(!%s.isNull, "@NonNull: assert(!%s.isNull) failed%s");`
378                 (expression, expression, info);
379         }
380         else static if (__traits(compiles, MemberType.init !is null))
381         {
382             // Nothing good can come of printing something that is null.
383             checks ~= format!`assert(%s !is null, "@NonNull: assert(%s !is null) failed%s");`
384                 (expression, expression, info);
385         }
386         else
387         {
388             return format!`static assert(false, "Cannot compare '%s' to null");`(expression);
389         }
390     }
391 
392     static if (udaIndex!(NonInit, Attributes) != -1)
393     {
394         auto reference = `typeof(` ~ expression ~ `).init`;
395 
396         if (!__traits(compiles, MemberType.init !is MemberType.init))
397         {
398             return format!`static assert(false, "Cannot compare '%s' to %s.init");`(expression, MemberType.stringof);
399         }
400 
401         static if (canFormat)
402         {
403             checks ~=
404                 format!(`assert(%s !is %s, `
405                     ~ `format("@NonInit: assert(%s !is %s.init) failed%s: %s = %%s", %s));`)
406                     (expression, reference, expression, MemberType.stringof, info, expression, expression);
407         }
408         else
409         {
410             checks ~=
411                 format!`assert(%s !is %s, "@NonInit: assert(%s !is %s.init) failed%s");`
412                     (expression, reference, expression, MemberType.stringof, info);
413         }
414     }
415 
416     static if (udaIndex!(AllNonNull, Attributes) != -1)
417     {
418         import std.algorithm: all;
419 
420         checks ~= `import std.algorithm: all;`;
421 
422         static if (__traits(compiles, MemberType.init.all!"a !is null"))
423         {
424             static if (canFormat)
425             {
426                 checks ~=
427                     format!(`assert(%s.all!"a !is null", format(`
428                         ~ `"@AllNonNull: assert(%s.all!\"a !is null\") failed%s: %s = %%s", %s));`)
429                         (expression, expression, info, expression, expression);
430             }
431             else
432             {
433                 checks ~= format!(`assert(%s.all!"a !is null", `
434                     ~ `"@AllNonNull: assert(%s.all!\"a !is null\") failed%s");`)
435                     (expression, expression, info);
436             }
437         }
438         else static if (__traits(compiles, MemberType.init.all!"!a.isNull"))
439         {
440             static if (canFormat)
441             {
442                 checks ~=
443                     format!(`assert(%s.all!"!a.isNull", format(`
444                         ~ `"@AllNonNull: assert(%s.all!\"!a.isNull\") failed%s: %s = %%s", %s));`)
445                         (expression, expression, info, expression, expression);
446             }
447             else
448             {
449                 checks ~= format!(`assert(%s.all!"!a.isNull", `
450                     ~ `"@AllNonNull: assert(%s.all!\"!a.isNull\") failed%s");`)
451                     (expression, expression, info);
452             }
453         }
454         else static if (__traits(compiles, isAssociativeArray!MemberType))
455         {
456             enum checkValues = __traits(compiles, MemberType.init.byValue.all!`a !is null`);
457             enum checkKeys = __traits(compiles, MemberType.init.byKey.all!"a !is null");
458 
459             static if (!checkKeys && !checkValues)
460             {
461                 return format!(`static assert(false, "Neither key nor value of associative array `
462                     ~ `'%s' can be checked against null.");`)(expression);
463             }
464 
465             static if (checkValues)
466             {
467                 checks ~=
468                     format!(`assert(%s.byValue.all!"a !is null", `
469                         ~ `"@AllNonNull: assert(%s.byValue.all!\"a !is null\") failed%s");`)
470                         (expression, expression, info);
471             }
472 
473             static if (checkKeys)
474             {
475                 checks ~=
476                     format!(`assert(%s.byKey.all!"a !is null", `
477                         ~ `"@AllNonNull: assert(%s.byKey.all!\"a !is null\") failed%s");`)
478                         (expression, expression, info);
479             }
480         }
481         else
482         {
483             return format!`static assert(false, "Cannot compare all '%s' to null");`(expression);
484         }
485     }
486 
487     if (checks.empty)
488     {
489         return null;
490     }
491 
492     static if (isNullable)
493     {
494         return `if (!` ~ member_expression ~ `.isNull) {` ~ checks ~ `}`;
495     }
496     else
497     {
498         return checks;
499     }
500 }
501 
502 public enum IsConditionAttribute(alias A) = __traits(isSame, A, NonEmpty) || __traits(isSame, A, NonNull)
503     || __traits(isSame, A, NonInit) || __traits(isSame, A, AllNonNull);