@@ -160,6 +160,8 @@ def accumulator_coder(self):
160160class CacheCoder (object ):
161161 """A coder iterface for encoding and decoding cache items."""
162162
163+ __metaclass__ = abc .ABCMeta
164+
163165 def __repr__ (self ):
164166 return '<{}>' .format (self .__class__ .__name__ )
165167
@@ -327,32 +329,92 @@ def output_tensor_infos(self):
327329 ] + self .combiner .output_tensor_infos ()
328330
329331
330- class Vocabulary (
331- collections .namedtuple (
332- 'Vocabulary' ,
333- [
334- 'top_k' ,
335- 'frequency_threshold' ,
336- 'vocab_filename' ,
337- 'store_frequency' ,
338- 'vocab_ordering_type' ,
339- 'use_adjusted_mutual_info' ,
340- 'min_diff_from_avg' ,
341- 'coverage_top_k' ,
342- 'coverage_frequency_threshold' ,
343- 'key_fn' ,
344- 'label'
345- ]),
346- AnalyzerDef ):
347- """OperationDef for computing a vocabulary of unique values.
332+ class VocabularyAccumulate (
333+ collections .namedtuple ('VocabularyAccumulate' ,
334+ ['vocab_ordering_type' , 'label' ]),
335+ nodes .OperationDef ):
336+ """An operation that accumulates unique words with their frequency or weight.
348337
349- This analyzer computes a vocabulary composed of the unique values present in
350- the input elements. It selects a subset of the unique elements based on the
351- provided parameters. It may also accept a label and weight as input
352- depending on the parameters.
338+ This operation is implemented by
339+ `tensorflow_transform.beam.analyzer_impls.VocabularyAccumulateImpl`.
340+ """
353341
354- This analyzer is implemented by
355- `tensorflow_transform.beam.analyzer_impls.VocabularyImpl`.
342+ def __new__ (cls , vocab_ordering_type , label = None ):
343+ if label is None :
344+ scope = tf .get_default_graph ().get_name_scope ()
345+ label = '{}[{}]' .format (cls .__name__ , scope )
346+ return super (VocabularyAccumulate , cls ).__new__ (
347+ cls , vocab_ordering_type = vocab_ordering_type , label = label )
348+
349+ @property
350+ def num_outputs (self ):
351+ return 1
352+
353+ @property
354+ def is_partitionable (self ):
355+ return True
356+
357+ @property
358+ def cache_coder (self ):
359+ return _VocabularyAccumulatorCoder ()
360+
361+
362+ class _VocabularyAccumulatorCoder (CacheCoder ):
363+ """Coder for vocabulary accumulators."""
364+
365+ def encode_cache (self , accumulator ):
366+ # Need to wrap in np.array and call tolist to make it JSON serializable.
367+ word , count = accumulator
368+ accumulator = (word .decode ('utf-8' ), count )
369+ return tf .compat .as_bytes (
370+ json .dumps (np .array (accumulator , dtype = object ).tolist ()))
371+
372+ def decode_cache (self , encoded_accumulator ):
373+ return np .array (json .loads (encoded_accumulator ), dtype = object )
374+
375+
376+ class VocabularyMerge (
377+ collections .namedtuple ('VocabularyMerge' , [
378+ 'vocab_ordering_type' , 'use_adjusted_mutual_info' , 'min_diff_from_avg' ,
379+ 'label'
380+ ]), nodes .OperationDef ):
381+ """An operation that merges the accumulators produced by VocabularyAccumulate.
382+
383+ This operation operates on the output of VocabularyAccumulate and is
384+ implemented by `tensorflow_transform.beam.analyzer_impls.VocabularyMergeImpl`.
385+
386+ See `tft.vocabulary` for a description of the parameters.
387+ """
388+
389+ def __new__ (cls ,
390+ vocab_ordering_type ,
391+ use_adjusted_mutual_info ,
392+ min_diff_from_avg ,
393+ label = None ):
394+ if label is None :
395+ scope = tf .get_default_graph ().get_name_scope ()
396+ label = '{}[{}]' .format (cls .__name__ , scope )
397+ return super (VocabularyMerge , cls ).__new__ (
398+ cls ,
399+ vocab_ordering_type = vocab_ordering_type ,
400+ use_adjusted_mutual_info = use_adjusted_mutual_info ,
401+ min_diff_from_avg = min_diff_from_avg ,
402+ label = label )
403+
404+ @property
405+ def num_outputs (self ):
406+ return 1
407+
408+
409+ class VocabularyOrderAndFilter (
410+ collections .namedtuple ('VocabularyOrderAndFilter' , [
411+ 'top_k' , 'frequency_threshold' , 'coverage_top_k' ,
412+ 'coverage_frequency_threshold' , 'key_fn' , 'label'
413+ ]), nodes .OperationDef ):
414+ """An operation that filters and orders a computed vocabulary.
415+
416+ This operation operates on the output of VocabularyMerge and is implemented by
417+ `tensorflow_transform.beam.analyzer_impls.VocabularyOrderAndFilterImpl`.
356418
357419 See `tft.vocabulary` for a description of the parameters.
358420 """
@@ -361,32 +423,49 @@ def __new__(
361423 cls ,
362424 top_k ,
363425 frequency_threshold ,
364- vocab_filename ,
365- store_frequency ,
366- vocab_ordering_type ,
367- use_adjusted_mutual_info ,
368- min_diff_from_avg ,
369426 coverage_top_k ,
370427 coverage_frequency_threshold ,
371428 key_fn ,
372429 label = None ):
373430 if label is None :
374431 scope = tf .get_default_graph ().get_name_scope ()
375432 label = '{}[{}]' .format (cls .__name__ , scope )
376- return super (Vocabulary , cls ).__new__ (
433+ return super (VocabularyOrderAndFilter , cls ).__new__ (
377434 cls ,
378435 top_k = top_k ,
379436 frequency_threshold = frequency_threshold ,
380- vocab_filename = vocab_filename ,
381- store_frequency = store_frequency ,
382- vocab_ordering_type = vocab_ordering_type ,
383- use_adjusted_mutual_info = use_adjusted_mutual_info ,
384- min_diff_from_avg = min_diff_from_avg ,
385437 coverage_top_k = coverage_top_k ,
386438 coverage_frequency_threshold = coverage_frequency_threshold ,
387439 key_fn = key_fn ,
388440 label = label )
389441
442+ @property
443+ def num_outputs (self ):
444+ return 1
445+
446+
447+ class VocabularyWrite (
448+ collections .namedtuple ('VocabularyWrite' ,
449+ ['vocab_filename' , 'store_frequency' , 'label' ]),
450+ AnalyzerDef ):
451+ """An analyzer that writes vocabulary files from an accumulator.
452+
453+ This operation operates on the output of VocabularyOrderAndFilter and is
454+ implemented by `tensorflow_transform.beam.analyzer_impls.VocabularyWriteImpl`.
455+
456+ See `tft.vocabulary` for a description of the parameters.
457+ """
458+
459+ def __new__ (cls , vocab_filename , store_frequency , label = None ):
460+ if label is None :
461+ scope = tf .get_default_graph ().get_name_scope ()
462+ label = '{}[{}]' .format (cls .__name__ , scope )
463+ return super (VocabularyWrite , cls ).__new__ (
464+ cls ,
465+ vocab_filename = vocab_filename ,
466+ store_frequency = store_frequency ,
467+ label = label )
468+
390469 @property
391470 def output_tensor_infos (self ):
392471 return [TensorInfo (tf .string , [], True )]
0 commit comments