diff --git a/src/satextractor/scheduler/scheduler.py b/src/satextractor/scheduler/scheduler.py index 2a6027d..7d8f545 100644 --- a/src/satextractor/scheduler/scheduler.py +++ b/src/satextractor/scheduler/scheduler.py @@ -115,106 +115,111 @@ def create_tasks_by_splits( items = [] tasks: List[ExtractionTask] = [] - with open(item_collection, "rb") as json_file: - for item in ijson.items(json_file, "features.item"): - if counter != collection_chunks: - items.append(item) - counter += 1 - else: - counter = 0 - - stac_items = pystac.ItemCollection( - items=[pystac.Item.from_dict(item) for item in items], + if isinstance(item_collection, str): + json_file = open(item_collection, "rb") + all_items = ijson.items(json_file, "features.item") + else: + all_items = item_collection["features"] + + for item in all_items: + if counter < collection_chunks: + items.append(item) + counter += 1 + else: + counter = 0 + + stac_items = pystac.ItemCollection( + items=[pystac.Item.from_dict(it) for it in items], + ) + gdf = gpd.GeoDataFrame.from_features( + {"type": "FeatureCollection", "features": items}, + ) + items = [] + + gdf.datetime = pd.to_datetime(gdf.datetime).dt.tz_localize(None) + + tiles_gdf = cluster_tiles_in_utm(tiles, split_m) + + logger.info( + "Creating extraction tasks for each constellations, date, and band ...", + ) + + task_tracker = 0 + + for constellation in constellations: + + # Get all the date ranges for the given interval + dates = get_dates_in_range( + gdf.loc[gdf.constellation == constellation, "datetime"] + .min() + .to_pydatetime(), + gdf.loc[gdf.constellation == constellation, "datetime"] + .max() + .to_pydatetime(), + interval, ) - gdf = gpd.GeoDataFrame.from_features( - {"type": "FeatureCollection", "features": items}, - ) - items = [] - - gdf.datetime = pd.to_datetime(gdf.datetime).dt.tz_localize(None) - tiles_gdf = cluster_tiles_in_utm(tiles, split_m) + if bands is not None: + run_bands = [ + b["band"].name + for kk, b in BAND_INFO[constellation].items() + if b["band"].name in bands + ] + else: + run_bands = [ + b["band"].name for kk, b in BAND_INFO[constellation].items() + ] logger.info( - "Creating extraction tasks for each constellations, date, and band ...", + f"Getting cluster item indexes for {constellation} in parallel...", ) - - task_tracker = 0 - - for constellation in constellations: - - # Get all the date ranges for the given interval - dates = get_dates_in_range( - gdf.loc[gdf.constellation == constellation, "datetime"] - .min() - .to_pydatetime(), - gdf.loc[gdf.constellation == constellation, "datetime"] - .max() - .to_pydatetime(), - interval, - ) - - if bands is not None: - run_bands = [ - b["band"].name - for kk, b in BAND_INFO[constellation].items() - if b["band"].name in bands - ] - else: - run_bands = [ - b["band"].name for kk, b in BAND_INFO[constellation].items() - ] - - logger.info( - f"Getting cluster item indexes for {constellation} in parallel...", - ) - with tqdm_joblib( - tqdm(desc="Extraction Tasks creation.", total=len(dates)), - ): - cluster_items = Parallel(n_jobs=n_jobs, verbose=verbose)( - delayed(get_cluster_items_indexes)( - gdf[ - (gdf.datetime >= start) - & (gdf.datetime <= end) - & (gdf.constellation == constellation) - ], - tiles_gdf, - ) - for start, end in dates + with tqdm_joblib( + tqdm(desc="Extraction Tasks creation.", total=len(dates)), + ): + cluster_items = Parallel(n_jobs=n_jobs, verbose=verbose)( + delayed(get_cluster_items_indexes)( + gdf[ + (gdf.datetime >= start) + & (gdf.datetime <= end) + & (gdf.constellation == constellation) + ], + tiles_gdf, ) + for start, end in dates + ) - for i, date_cluster_item in enumerate(cluster_items): - for k, v in date_cluster_item.items(): - if v: - c_tiles = tiles_gdf[tiles_gdf["cluster_id"] == k] - c_items_geom = gdf.iloc[v].unary_union - t_indexes = c_tiles[ - c_tiles.geometry.apply(c_items_geom.contains) - ].index - if not t_indexes.empty: - c_items = pystac.ItemCollection( - [ - stac_items.items[item_index] - for item_index in v - ], + for i, date_cluster_item in enumerate(cluster_items): + for k, v in date_cluster_item.items(): + if v: + c_tiles = tiles_gdf[tiles_gdf["cluster_id"] == k] + c_items_geom = gdf.iloc[v].unary_union + t_indexes = c_tiles[ + c_tiles.geometry.apply(c_items_geom.contains) + ].index + if not t_indexes.empty: + c_items = pystac.ItemCollection( + [stac_items.items[item_index] for item_index in v], + ) + region_tiles = [tiles[t_index] for t_index in t_indexes] + sensing_time = dates[i][0] + + for b in run_bands: + tasks.append( + ExtractionTask( + task_id=str(task_tracker), + tiles=region_tiles, + item_collection=c_items, + band=b, + constellation=constellation, + sensing_time=sensing_time, + ), ) - region_tiles = [ - tiles[t_index] for t_index in t_indexes - ] - sensing_time = dates[i][0] - - for b in run_bands: - tasks.append( - ExtractionTask( - task_id=str(task_tracker), - tiles=region_tiles, - item_collection=c_items, - band=b, - constellation=constellation, - sensing_time=sensing_time, - ), - ) - task_tracker += 1 + task_tracker += 1 + + try: + json_file.close() + except Exception: + pass logger.info(f"There are a total of {len(tasks)} tasks")